// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
// for details. All rights reserved. Use of this source code is governed by a
// BSD-style license that can be found in the LICENSE file.

package com.android.tools.r8.graph;

import com.android.tools.r8.errors.Unreachable;
import com.google.common.collect.Sets;
import java.util.IdentityHashMap;
import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Predicate;

/**
 * Holds whole program information about the usage of a given field.
 *
 * <p>The information is generated by the {@link com.android.tools.r8.shaking.Enqueuer}.
 */
public class FieldAccessInfoImpl implements FieldAccessInfo {

  public static final FieldAccessInfoImpl MISSING_FIELD_ACCESS_INFO = new FieldAccessInfoImpl(null);

  public static int FLAG_IS_READ_FROM_METHOD_HANDLE = 1 << 0;
  public static int FLAG_IS_WRITTEN_FROM_METHOD_HANDLE = 1 << 1;
  public static int FLAG_HAS_REFLECTIVE_ACCESS = 1 << 2;

  // A direct reference to the definition of the field.
  private DexField field;

  // If this field is accessed from a method handle or has a reflective access.
  private int flags;

  // Maps every direct and indirect reference in a read-context to the set of methods in which that
  // reference appears.
  private Map<DexField, Set<DexEncodedMethod>> readsWithContexts;

  // Maps every direct and indirect reference in a write-context to the set of methods in which that
  // reference appears.
  private Map<DexField, Set<DexEncodedMethod>> writesWithContexts;

  public FieldAccessInfoImpl(DexField field) {
    this.field = field;
  }

  void flattenAccessContexts() {
    flattenAccessContexts(readsWithContexts);
    flattenAccessContexts(writesWithContexts);
  }

  private void flattenAccessContexts(Map<DexField, Set<DexEncodedMethod>> accessesWithContexts) {
    if (accessesWithContexts != null) {
      Set<DexEncodedMethod> flattenedAccessContexts =
          accessesWithContexts.computeIfAbsent(field, ignore -> Sets.newIdentityHashSet());
      accessesWithContexts.forEach(
          (access, contexts) -> {
            if (access != field) {
              flattenedAccessContexts.addAll(contexts);
            }
          });
      accessesWithContexts.clear();
      if (!flattenedAccessContexts.isEmpty()) {
        accessesWithContexts.put(field, flattenedAccessContexts);
      }
      assert accessesWithContexts.size() <= 1;
    }
  }

  @Override
  public FieldAccessInfoImpl asMutable() {
    return this;
  }

  @Override
  public DexField getField() {
    return field;
  }

  @Override
  public int getNumberOfReadContexts() {
    return getNumberOfAccessContexts(readsWithContexts);
  }

  @Override
  public int getNumberOfWriteContexts() {
    return getNumberOfAccessContexts(writesWithContexts);
  }

  private int getNumberOfAccessContexts(Map<DexField, Set<DexEncodedMethod>> accessesWithContexts) {
    if (accessesWithContexts == null) {
      return 0;
    }
    if (accessesWithContexts.size() == 1) {
      return accessesWithContexts.values().iterator().next().size();
    }
    throw new Unreachable("Should only be querying the number of access contexts after flattening");
  }

  @Override
  public DexEncodedMethod getUniqueReadContext() {
    if (readsWithContexts != null && readsWithContexts.size() == 1) {
      Set<DexEncodedMethod> contexts = readsWithContexts.values().iterator().next();
      if (contexts.size() == 1) {
        return contexts.iterator().next();
      }
    }
    return null;
  }

  @Override
  public void forEachIndirectAccess(Consumer<DexField> consumer) {
    // There can be indirect reads and writes of the same field reference, so we need to keep track
    // of the previously-seen indirect accesses to avoid reporting duplicates.
    Set<DexField> visited = Sets.newIdentityHashSet();
    forEachAccessInMap(
        readsWithContexts, access -> access != field && visited.add(access), consumer);
    forEachAccessInMap(
        writesWithContexts, access -> access != field && visited.add(access), consumer);
  }

  private static void forEachAccessInMap(
      Map<DexField, Set<DexEncodedMethod>> accessesWithContexts,
      Predicate<DexField> predicate,
      Consumer<DexField> consumer) {
    if (accessesWithContexts != null) {
      accessesWithContexts.forEach(
          (access, contexts) -> {
            if (predicate.test(access)) {
              consumer.accept(access);
            }
          });
    }
  }

  @Override
  public void forEachIndirectAccessWithContexts(
      BiConsumer<DexField, Set<DexEncodedMethod>> consumer) {
    Map<DexField, Set<DexEncodedMethod>> indirectAccessesWithContexts = new IdentityHashMap<>();
    extendAccessesWithContexts(
        indirectAccessesWithContexts, access -> access != field, readsWithContexts);
    extendAccessesWithContexts(
        indirectAccessesWithContexts, access -> access != field, writesWithContexts);
    indirectAccessesWithContexts.forEach(consumer);
  }

  private void extendAccessesWithContexts(
      Map<DexField, Set<DexEncodedMethod>> accessesWithContexts,
      Predicate<DexField> predicate,
      Map<DexField, Set<DexEncodedMethod>> extension) {
    if (extension != null) {
      extension.forEach(
          (access, contexts) -> {
            if (predicate.test(access)) {
              accessesWithContexts
                  .computeIfAbsent(access, ignore -> Sets.newIdentityHashSet())
                  .addAll(contexts);
            }
          });
    }
  }

  @Override
  public void forEachReadContext(Consumer<DexEncodedMethod> consumer) {
    forEachAccessContext(readsWithContexts, consumer);
  }

  @Override
  public void forEachWriteContext(Consumer<DexEncodedMethod> consumer) {
    forEachAccessContext(writesWithContexts, consumer);
  }

  private void forEachAccessContext(
      Map<DexField, Set<DexEncodedMethod>> accessesWithContexts,
      Consumer<DexEncodedMethod> consumer) {
    // There can be indirect reads and writes of the same field reference, so we need to keep track
    // of the previously-seen indirect accesses to avoid reporting duplicates.
    Set<DexEncodedMethod> visited = Sets.newIdentityHashSet();
    if (accessesWithContexts != null) {
      for (Set<DexEncodedMethod> encodedAccessContexts : accessesWithContexts.values()) {
        for (DexEncodedMethod encodedAccessContext : encodedAccessContexts) {
          if (visited.add(encodedAccessContext)) {
            consumer.accept(encodedAccessContext);
          }
        }
      }
    }
  }

  @Override
  public boolean hasReflectiveAccess() {
    return (flags & FLAG_HAS_REFLECTIVE_ACCESS) != 0;
  }

  public void setHasReflectiveAccess() {
    flags |= FLAG_HAS_REFLECTIVE_ACCESS;
  }

  /** Returns true if this field is read by the program. */
  @Override
  public boolean isRead() {
    return readsWithContexts != null && !readsWithContexts.isEmpty();
  }

  @Override
  public boolean isReadFromMethodHandle() {
    return (flags & FLAG_IS_READ_FROM_METHOD_HANDLE) != 0;
  }

  public void setReadFromMethodHandle() {
    flags |= FLAG_IS_READ_FROM_METHOD_HANDLE;
  }

  @Override
  public boolean isReadOnlyIn(DexEncodedMethod method) {
    assert isRead();
    assert method != null;
    DexEncodedMethod uniqueReadContext = getUniqueReadContext();
    return uniqueReadContext != null && uniqueReadContext == method;
  }

  /** Returns true if this field is written by the program. */
  @Override
  public boolean isWritten() {
    return writesWithContexts != null && !writesWithContexts.isEmpty();
  }

  @Override
  public boolean isWrittenFromMethodHandle() {
    return (flags & FLAG_IS_WRITTEN_FROM_METHOD_HANDLE) != 0;
  }

  public void setWrittenFromMethodHandle() {
    flags |= FLAG_IS_WRITTEN_FROM_METHOD_HANDLE;
  }

  /**
   * Returns true if this field is written by a method for which {@param predicate} returns true.
   */
  @Override
  public boolean isWrittenInMethodSatisfying(Predicate<DexEncodedMethod> predicate) {
    if (writesWithContexts != null) {
      for (Set<DexEncodedMethod> encodedWriteContexts : writesWithContexts.values()) {
        for (DexEncodedMethod encodedWriteContext : encodedWriteContexts) {
          if (predicate.test(encodedWriteContext)) {
            return true;
          }
        }
      }
    }
    return false;
  }

  /**
   * Returns true if this field is only written by methods for which {@param predicate} returns
   * true.
   */
  @Override
  public boolean isWrittenOnlyInMethodSatisfying(Predicate<DexEncodedMethod> predicate) {
    if (writesWithContexts != null) {
      for (Set<DexEncodedMethod> encodedWriteContexts : writesWithContexts.values()) {
        for (DexEncodedMethod encodedWriteContext : encodedWriteContexts) {
          if (!predicate.test(encodedWriteContext)) {
            return false;
          }
        }
      }
    }
    return true;
  }

  /**
   * Returns true if this field is written by a method in the program other than {@param method}.
   */
  @Override
  public boolean isWrittenOutside(DexEncodedMethod method) {
    if (writesWithContexts != null) {
      for (Set<DexEncodedMethod> encodedWriteContexts : writesWithContexts.values()) {
        for (DexEncodedMethod encodedWriteContext : encodedWriteContexts) {
          if (encodedWriteContext != method) {
            return true;
          }
        }
      }
    }
    return false;
  }

  public boolean recordRead(DexField access, DexEncodedMethod context) {
    if (readsWithContexts == null) {
      readsWithContexts = new IdentityHashMap<>();
    }
    return readsWithContexts
        .computeIfAbsent(access, ignore -> Sets.newIdentityHashSet())
        .add(context);
  }

  public boolean recordWrite(DexField access, DexEncodedMethod context) {
    if (writesWithContexts == null) {
      writesWithContexts = new IdentityHashMap<>();
    }
    return writesWithContexts
        .computeIfAbsent(access, ignore -> Sets.newIdentityHashSet())
        .add(context);
  }

  public void clearReads() {
    readsWithContexts = null;
  }

  public void clearWrites() {
    writesWithContexts = null;
  }

  public FieldAccessInfoImpl rewrittenWithLens(DexDefinitionSupplier definitions, GraphLense lens) {
    FieldAccessInfoImpl rewritten = new FieldAccessInfoImpl(lens.lookupField(field));
    rewritten.flags = flags;
    if (readsWithContexts != null) {
      rewritten.readsWithContexts = new IdentityHashMap<>();
      readsWithContexts.forEach(
          (access, contexts) -> {
            Set<DexEncodedMethod> newContexts =
                rewritten.readsWithContexts.computeIfAbsent(
                    lens.lookupField(access), ignore -> Sets.newIdentityHashSet());
            for (DexEncodedMethod context : contexts) {
              newContexts.add(lens.mapDexEncodedMethod(context, definitions));
            }
          });
    }
    if (writesWithContexts != null) {
      rewritten.writesWithContexts = new IdentityHashMap<>();
      writesWithContexts.forEach(
          (access, contexts) -> {
            Set<DexEncodedMethod> newContexts =
                rewritten.writesWithContexts.computeIfAbsent(
                    lens.lookupField(access), ignore -> Sets.newIdentityHashSet());
            for (DexEncodedMethod context : contexts) {
              newContexts.add(lens.mapDexEncodedMethod(context, definitions));
            }
          });
    }
    return rewritten;
  }
}
