// Copyright (c) 2018, the R8 project authors. Please see the AUTHORS file
// for details. All rights reserved. Use of this source code is governed by a
// BSD-style license that can be found in the LICENSE file.
package com.android.tools.r8;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

import com.android.tools.r8.origin.Origin;
import com.android.tools.r8.origin.PathOrigin;
import com.android.tools.r8.position.Position;
import com.android.tools.r8.position.TextPosition;
import com.android.tools.r8.position.TextRange;
import com.android.tools.r8.utils.ListUtils;
import com.google.common.collect.ImmutableList;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.function.Consumer;

// Helper to check that a particular error occurred.
public class DiagnosticsChecker implements DiagnosticsHandler {

  public List<Diagnostic> errors = new ArrayList<>();
  public List<Diagnostic> warnings = new ArrayList<>();
  public List<Diagnostic> infos = new ArrayList<>();

  @Override
  public void error(Diagnostic error) {
    errors.add(error);
  }

  @Override
  public void warning(Diagnostic warning) {
    warnings.add(warning);
  }

  @Override
  public void info(Diagnostic info) {
    infos.add(info);
  }

  public interface FailingRunner {

    void run(DiagnosticsHandler handler) throws CompilationFailedException;
  }

  private static void checkContains(String snippet, List<Diagnostic> diagnostics) {
    List<String> messages = ListUtils.map(diagnostics, Diagnostic::getDiagnosticMessage);
    System.out.println("Expecting match for '" + snippet + "'");
    System.out.println("Diagnostics messages:\n" + messages);
    assertTrue(
        "Expected to find snippet '"
            + snippet
            + "' in error messages:\n"
            + String.join("\n", messages),
        diagnostics.stream().anyMatch(d -> d.getDiagnosticMessage().contains(snippet)));
  }

  private static void checkNotContains(String snippet, List<Diagnostic> diagnostics) {
    List<String> messages = ListUtils.map(diagnostics, Diagnostic::getDiagnosticMessage);
    System.out.println("Expecting no match for '" + snippet + "'");
    System.out.println("Diagnostics messages:\n" + messages);
    assertTrue(
        "Expected to *not* find snippet '"
            + snippet
            + "' in error messages:\n"
            + String.join("\n", messages),
        diagnostics.stream().noneMatch(d -> d.getDiagnosticMessage().contains(snippet)));
  }

  public static void checkContains(Collection<String> snippets, List<Diagnostic> diagnostics) {
    snippets.forEach(snippet -> checkContains(snippet, diagnostics));
  }

  public static void checkNotContains(Collection<String> snippets, List<Diagnostic> diagnostics) {
    snippets.forEach(snippet -> checkNotContains(snippet, diagnostics));
  }

  public void checkErrorsContains(String snippet) {
    checkContains(snippet, errors);
  }

  public void checkWarningsContains(String snippet) {
    checkContains(snippet, warnings);
  }

  public void checkInfosContains(String snippet) {
    checkContains(snippet, infos);
  }

  public static void checkErrorsContains(String snippet, FailingRunner runner)
      throws CompilationFailedException {
    checkErrorsContains(ImmutableList.of(snippet), runner);
  }

  public static void checkErrorsContains(Collection<String> snippets, FailingRunner runner)
      throws CompilationFailedException {
    DiagnosticsChecker handler = new DiagnosticsChecker();
    try {
      runner.run(handler);
      fail("Failure expected");
    } catch (CompilationFailedException e) {
      checkContains(snippets, handler.errors);
      throw e;
    }
  }

  public static void checkErrorDiagnostics(
      Consumer<DiagnosticsChecker> checker, FailingRunner runner)
      throws CompilationFailedException {
    DiagnosticsChecker handler = new DiagnosticsChecker();
    try {
      runner.run(handler);
      fail("Failure expected");
    } catch (CompilationFailedException e) {
      checker.accept(handler);
      throw e;
    }
  }

  public static void checkDiagnostics(Consumer<DiagnosticsChecker> checker, FailingRunner runner)
      throws CompilationFailedException {
    DiagnosticsChecker handler = new DiagnosticsChecker();
    runner.run(handler);
    checker.accept(handler);
  }

  public static void checkWarningsContains(String snippet, FailingRunner runner)
      throws CompilationFailedException {
    DiagnosticsChecker handler = new DiagnosticsChecker();
    runner.run(handler);
    checkContains(snippet, handler.warnings);
  }

  public static void checkInfosContains(String snippet, FailingRunner runner)
      throws CompilationFailedException {
    DiagnosticsChecker handler = new DiagnosticsChecker();
    runner.run(handler);
    checkContains(snippet, handler.infos);
  }

  public static Diagnostic checkDiagnostic(Diagnostic diagnostic, Consumer<Origin> originChecker,
      int lineStart, int columnStart, String... messageParts) {
    if (originChecker != null) {
      originChecker.accept(diagnostic.getOrigin());
    }
    TextPosition position;
    if (diagnostic.getPosition() instanceof TextRange) {
      position = ((TextRange) diagnostic.getPosition()).getStart();
    } else {
      position = ((TextPosition) diagnostic.getPosition());
    }
    if (lineStart > 0) {
      assertEquals(lineStart, position.getLine());
    }
    if (columnStart > 0) {
      assertEquals(columnStart, position.getColumn());
    }
    for (String part : messageParts) {
      assertTrue(diagnostic.getDiagnosticMessage() + " doesn't contain \"" + part + "\"",
          diagnostic.getDiagnosticMessage().contains(part));
    }
    return diagnostic;
  }

  public static Diagnostic checkDiagnostic(Diagnostic diagnostic, Consumer<Origin> originChecker,
      String... messageParts) {
    if (originChecker != null) {
      originChecker.accept(diagnostic.getOrigin());
    }
    assertEquals(diagnostic.getPosition(), Position.UNKNOWN);
    for (String part : messageParts) {
      assertTrue(diagnostic.getDiagnosticMessage() + " doesn't contain \"" + part + "\"",
          diagnostic.getDiagnosticMessage().contains(part));
    }
    return diagnostic;
  }

  static class PathOriginChecker implements Consumer<Origin> {
    private final Path path;
    PathOriginChecker(Path path) {
      this.path = path;
    }

    public void accept(Origin origin) {
      if (path != null) {
        assertEquals(path, ((PathOrigin) origin).getPath());
      } else {
        assertSame(Origin.unknown(), origin);
      }
    }
  }

  public static Diagnostic checkDiagnostic(Diagnostic diagnostic, Path path,
      int lineStart, int columnStart, String... messageParts) {
    return checkDiagnostic(diagnostic, new PathOriginChecker(path), lineStart, columnStart,
        messageParts);
  }

  public static Diagnostic checkDiagnostics(List<Diagnostic> diagnostics, int index, Path path,
      int lineStart, int columnStart, String... messageParts) {
    return checkDiagnostic(diagnostics.get(index), path, lineStart, columnStart, messageParts);
  }

  public static Diagnostic checkDiagnostics(List<Diagnostic> diagnostics, Path path,
      int lineStart, int columnStart, String... messageParts) {
    assertEquals(1, diagnostics.size());
    return checkDiagnostics(diagnostics, 0, path, lineStart, columnStart, messageParts);
  }
}
