Remove single target lookup cache due to inconsistency

We are updating instantiated classes during IR in the
AppInfoWithLivenessModifier that pulls out instantiated classes. As
such, since the lookup is dependent on instantiated classes, we cannot
trust the cache in the current form.

Bug: 151070908
Change-Id: I6de5daf969206992e75fd3325c32e0065994a7b1
diff --git a/src/main/java/com/android/tools/r8/graph/DexMethod.java b/src/main/java/com/android/tools/r8/graph/DexMethod.java
index b15ee67..80605e5 100644
--- a/src/main/java/com/android/tools/r8/graph/DexMethod.java
+++ b/src/main/java/com/android/tools/r8/graph/DexMethod.java
@@ -6,17 +6,12 @@
 import com.android.tools.r8.dex.IndexedItemCollection;
 import com.android.tools.r8.errors.CompilationError;
 import com.android.tools.r8.naming.NamingLens;
-import com.google.common.collect.Maps;
-import java.util.Map;
 
 public class DexMethod extends DexMember<DexEncodedMethod, DexMethod> {
 
   public final DexProto proto;
   public final DexString name;
 
-  // Caches used during processing.
-  private Map<DexType, DexEncodedMethod> singleTargetCache;
-
   DexMethod(DexType holder, DexProto proto, DexString name, boolean skipNameValidationForTesting) {
     super(holder);
     this.proto = proto;
@@ -174,21 +169,4 @@
     return name == dexItemFactory.deserializeLambdaMethodName
         && proto == dexItemFactory.deserializeLambdaMethodProto;
   }
-
-  synchronized public void setSingleVirtualMethodCache(
-      DexType receiverType, DexEncodedMethod method) {
-    if (singleTargetCache == null) {
-      singleTargetCache = Maps.newIdentityHashMap();
-    }
-    singleTargetCache.put(receiverType, method);
-  }
-
-  synchronized public boolean isSingleVirtualMethodCached(DexType receiverType) {
-    return singleTargetCache != null && singleTargetCache.containsKey(receiverType);
-  }
-
-  synchronized public DexEncodedMethod getSingleVirtualMethodCache(DexType receiverType) {
-    assert isSingleVirtualMethodCached(receiverType);
-    return singleTargetCache.get(receiverType);
-  }
 }
diff --git a/src/main/java/com/android/tools/r8/optimize/ClassAndMemberPublicizer.java b/src/main/java/com/android/tools/r8/optimize/ClassAndMemberPublicizer.java
index b319c6d..4816759 100644
--- a/src/main/java/com/android/tools/r8/optimize/ClassAndMemberPublicizer.java
+++ b/src/main/java/com/android/tools/r8/optimize/ClassAndMemberPublicizer.java
@@ -142,9 +142,7 @@
       lenseBuilder.add(encodedMethod.method);
       accessFlags.promoteToFinal();
       accessFlags.promoteToPublic();
-      // Although the current method became public, it surely has the single virtual target.
-      encodedMethod.method.setSingleVirtualMethodCache(
-          encodedMethod.method.holder, encodedMethod);
+      // The method just became public and is therefore not a library override.
       encodedMethod.setLibraryMethodOverride(OptionalBool.FALSE);
       return true;
     }
diff --git a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
index eef80a6..ed4c3e3 100644
--- a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
+++ b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
@@ -186,6 +186,9 @@
 
   final Set<DexType> instantiatedLambdas;
 
+  /* A cache to improve the lookup performance of lookupSingleVirtualTarget */
+  private final SingleTargetLookupCache singleTargetLookupCache = new SingleTargetLookupCache();
+
   // TODO(zerny): Clean up the constructors so we have just one.
   AppInfoWithLiveness(
       DirectMappedDexApplication application,
@@ -732,6 +735,10 @@
     return objectAllocationInfoCollection;
   }
 
+  void removeFromSingleTargetLookupCache(DexClass clazz) {
+    singleTargetLookupCache.removeInstantiatedType(clazz.type, this);
+  }
+
   private boolean assertNoItemRemoved(Collection<DexReference> items, Collection<DexType> types) {
     Set<DexType> typeSet = ImmutableSet.copyOf(types);
     for (DexReference item : items) {
@@ -1113,31 +1120,6 @@
     }
   }
 
-  private DexEncodedMethod validateSingleVirtualTarget(
-      DexEncodedMethod singleTarget, DexEncodedMethod resolutionResult) {
-    assert resolutionResult.isVirtualMethod();
-
-    if (singleTarget == null || singleTarget == DexEncodedMethod.SENTINEL) {
-      return null;
-    }
-
-    // Art978_virtual_interfaceTest correctly expects an IncompatibleClassChangeError exception
-    // at runtime.
-    if (isInvalidSingleVirtualTarget(singleTarget, resolutionResult)) {
-      return null;
-    }
-
-    return singleTarget;
-  }
-
-  private boolean isInvalidSingleVirtualTarget(
-      DexEncodedMethod singleTarget, DexEncodedMethod resolutionResult) {
-    assert resolutionResult.isVirtualMethod();
-    // Art978_virtual_interfaceTest correctly expects an IncompatibleClassChangeError exception
-    // at runtime.
-    return !singleTarget.accessFlags.isAtLeastAsVisibleAs(resolutionResult.accessFlags);
-  }
-
   /** For mapping invoke virtual instruction to single target method. */
   public DexEncodedMethod lookupSingleVirtualTarget(
       DexMethod method, DexType invocationContext, boolean isInterface) {
@@ -1175,25 +1157,87 @@
       // (it is either primitive or array).
       return null;
     }
+    DexClass initialResolutionHolder = definitionFor(method.holder);
+    if (initialResolutionHolder == null || initialResolutionHolder.isInterface() != isInterface) {
+      return null;
+    }
     DexClass refinedReceiverClass = definitionFor(refinedReceiverType);
     if (refinedReceiverClass == null) {
       // The refined receiver is not defined in the program and we cannot determine the target.
       return null;
     }
+    if (receiverLowerBoundType == null
+        && singleTargetLookupCache.hasCachedItem(refinedReceiverType, method)) {
+      DexEncodedMethod cachedItem =
+          singleTargetLookupCache.getCachedItem(refinedReceiverType, method);
+      return cachedItem;
+    }
     SingleResolutionResult resolution =
-        resolveMethod(method.holder, method, isInterface).asSingleResolution();
+        resolveMethod(initialResolutionHolder, method).asSingleResolution();
     if (resolution == null
         || !resolution.isAccessibleForVirtualDispatchFrom(invocationClass, this)) {
       return null;
     }
     // If the method is modeled, return the resolution.
+    DexEncodedMethod resolvedMethod = resolution.getResolvedMethod();
     if (modeledPredicate.isModeled(resolution.getResolvedHolder().type)) {
       if (resolution.getResolvedHolder().isFinal()
-          || (resolution.getResolvedMethod().isFinal()
-              && resolution.getResolvedMethod().accessFlags.isPublic())) {
-        return resolution.getResolvedMethod();
+          || (resolvedMethod.isFinal() && resolvedMethod.accessFlags.isPublic())) {
+        singleTargetLookupCache.addToCache(refinedReceiverType, method, resolvedMethod);
+        return resolvedMethod;
       }
     }
+    DexEncodedMethod exactTarget =
+        getMethodTargetFromExactRuntimeInformation(
+            refinedReceiverType, receiverLowerBoundType, resolution, refinedReceiverClass);
+    if (exactTarget != null) {
+      // We are not caching single targets here because the cache does not include the
+      // lower bound dimension.
+      return exactTarget == DexEncodedMethod.SENTINEL ? null : exactTarget;
+    }
+    if (refinedReceiverClass.isNotProgramClass()) {
+      // The refined receiver is not defined in the program and we cannot determine the target.
+      singleTargetLookupCache.addToCache(refinedReceiverType, method, null);
+      return null;
+    }
+    DexClass resolvedHolder = resolution.getResolvedHolder();
+    // TODO(b/148769279): Disable lookup single target on lambda's for now.
+    if (resolvedHolder.isInterface()
+        && resolvedHolder.isProgramClass()
+        && hasAnyInstantiatedLambdas(resolvedHolder.asProgramClass())) {
+      singleTargetLookupCache.addToCache(refinedReceiverType, method, null);
+      return null;
+    }
+    DexEncodedMethod singleMethodTarget = null;
+    DexProgramClass refinedLowerBound = null;
+    if (receiverLowerBoundType != null) {
+      DexClass refinedLowerBoundClass = definitionFor(receiverLowerBoundType.getClassType());
+      if (refinedLowerBoundClass != null) {
+        refinedLowerBound = refinedLowerBoundClass.asProgramClass();
+      }
+    }
+    LookupResultSuccess lookupResult =
+        resolution
+            .lookupVirtualDispatchTargets(
+                invocationClass, this, refinedReceiverClass.asProgramClass(), refinedLowerBound)
+            .asLookupResultSuccess();
+    if (lookupResult != null && !lookupResult.isIncomplete()) {
+      LookupTarget singleTarget = lookupResult.getSingleLookupTarget();
+      if (singleTarget != null && singleTarget.isMethodTarget()) {
+        singleMethodTarget = singleTarget.asMethodTarget().getMethod();
+      }
+    }
+    if (receiverLowerBoundType == null) {
+      singleTargetLookupCache.addToCache(refinedReceiverType, method, singleMethodTarget);
+    }
+    return singleMethodTarget;
+  }
+
+  private DexEncodedMethod getMethodTargetFromExactRuntimeInformation(
+      DexType refinedReceiverType,
+      ClassTypeLatticeElement receiverLowerBoundType,
+      SingleResolutionResult resolution,
+      DexClass refinedReceiverClass) {
     // If the lower-bound on the receiver type is the same as the upper-bound, then we have exact
     // runtime type information. In this case, the invoke will dispatch to the resolution result
     // from the runtime type of the receiver.
@@ -1204,7 +1248,7 @@
             resolution.lookupVirtualDispatchTarget(refinedReceiverClass.asProgramClass(), this);
         if (clazzAndMethod == null || isPinned(clazzAndMethod.getMethod().method)) {
           // TODO(b/150640456): We should maybe only consider program methods.
-          return null;
+          return DexEncodedMethod.SENTINEL;
         }
         return clazzAndMethod.getMethod();
       } else {
@@ -1212,56 +1256,16 @@
         // If we resolved to a method on the refined receiver in the library, then we report the
         // method as a single target as well. This is a bit iffy since the library could change
         // implementation, but we use this for library modelling.
-        DexEncodedMethod targetOnReceiver = refinedReceiverClass.lookupVirtualMethod(method);
-        if (targetOnReceiver != null
-            && isOverriding(resolution.getResolvedMethod(), targetOnReceiver)) {
+        DexEncodedMethod resolvedMethod = resolution.getResolvedMethod();
+        DexEncodedMethod targetOnReceiver =
+            refinedReceiverClass.lookupVirtualMethod(resolvedMethod.method);
+        if (targetOnReceiver != null && isOverriding(resolvedMethod, targetOnReceiver)) {
           return targetOnReceiver;
         }
-        return null;
+        return DexEncodedMethod.SENTINEL;
       }
     }
-    if (refinedReceiverClass.isNotProgramClass()) {
-      // The refined receiver is not defined in the program and we cannot determine the target.
-      return null;
-    }
-    DexClass resolvedHolder = resolution.getResolvedHolder();
-    // TODO(b/148769279): Disable lookup single target on lambda's for now.
-    if (resolvedHolder.isInterface()
-        && resolvedHolder.isProgramClass()
-        && hasAnyInstantiatedLambdas(resolvedHolder.asProgramClass())) {
-      return null;
-    }
-
-    if (method.isSingleVirtualMethodCached(refinedReceiverType)) {
-      return method.getSingleVirtualMethodCache(refinedReceiverType);
-    }
-
-    DexProgramClass refinedLowerBound = null;
-    if (receiverLowerBoundType != null) {
-      assert receiverLowerBoundType.isClassType();
-      DexClass refinedLowerBoundClass = definitionFor(receiverLowerBoundType.getClassType());
-      if (refinedLowerBoundClass != null) {
-        refinedLowerBound = refinedLowerBoundClass.asProgramClass();
-      }
-    }
-
-    LookupResultSuccess lookupResult =
-        resolution
-            .lookupVirtualDispatchTargets(
-                invocationClass, this, refinedReceiverClass.asProgramClass(), refinedLowerBound)
-            .asLookupResultSuccess();
-
-    if (lookupResult == null || lookupResult.isIncomplete()) {
-      return null;
-    }
-
-    LookupTarget singleTarget = lookupResult.getSingleLookupTarget();
-    DexEncodedMethod singleMethodTarget = null;
-    if (singleTarget != null && singleTarget.isMethodTarget()) {
-      singleMethodTarget = singleTarget.asMethodTarget().getMethod();
-    }
-    method.setSingleVirtualMethodCache(refinedReceiverType, singleMethodTarget);
-    return singleMethodTarget;
+    return null;
   }
 
   public AppInfoWithLiveness withSwitchMaps(Map<DexField, Int2ReferenceMap<DexField>> switchMaps) {
@@ -1276,12 +1280,6 @@
     return new AppInfoWithLiveness(this, switchMaps, enumValueInfoMaps);
   }
 
-  public void forEachLiveProgramClass(Consumer<DexProgramClass> fn) {
-    for (DexType type : liveTypes) {
-      fn.accept(definitionFor(type).asProgramClass());
-    }
-  }
-
   /**
    * Visit all class definitions of classpath classes that are referenced in the compilation unit.
    *
diff --git a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLivenessModifier.java b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLivenessModifier.java
index 75c27fd..99f2d84 100644
--- a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLivenessModifier.java
+++ b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLivenessModifier.java
@@ -36,8 +36,11 @@
     // Instantiated classes.
     ObjectAllocationInfoCollectionImpl objectAllocationInfoCollection =
         appInfo.getMutableObjectAllocationInfoCollection();
-    noLongerInstantiatedClasses.forEach(objectAllocationInfoCollection::markNoLongerInstantiated);
-
+    noLongerInstantiatedClasses.forEach(
+        clazz -> {
+          objectAllocationInfoCollection.markNoLongerInstantiated(clazz);
+          appInfo.removeFromSingleTargetLookupCache(clazz);
+        });
     // Written fields.
     FieldAccessInfoCollectionImpl fieldAccessInfoCollection =
         appInfo.getMutableFieldAccessInfoCollection();
diff --git a/src/main/java/com/android/tools/r8/shaking/SingleTargetLookupCache.java b/src/main/java/com/android/tools/r8/shaking/SingleTargetLookupCache.java
new file mode 100644
index 0000000..1691758
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/shaking/SingleTargetLookupCache.java
@@ -0,0 +1,58 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.shaking;
+
+import com.android.tools.r8.graph.DexClass;
+import com.android.tools.r8.graph.DexEncodedMethod;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexType;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+public class SingleTargetLookupCache {
+
+  private Map<DexType, Map<DexMethod, DexEncodedMethod>> cache = new ConcurrentHashMap<>();
+
+  public void addToCache(DexType refinedReceiverType, DexMethod method, DexEncodedMethod target) {
+    assert target != DexEncodedMethod.SENTINEL;
+    Map<DexMethod, DexEncodedMethod> methodCache =
+        cache.computeIfAbsent(refinedReceiverType, ignored -> new ConcurrentHashMap<>());
+    target = target == null ? DexEncodedMethod.SENTINEL : target;
+    assert methodCache.getOrDefault(method, target) == target;
+    methodCache.putIfAbsent(method, target);
+  }
+
+  public void removeInstantiatedType(DexType instantiatedType, AppInfoWithLiveness appInfo) {
+    // Remove all types in the hierarchy related to this type.
+    cache.remove(instantiatedType);
+    DexClass clazz = appInfo.definitionFor(instantiatedType);
+    if (clazz == null) {
+      return;
+    }
+    appInfo.forEachSuperType(clazz, (type, ignore) -> cache.remove(type));
+    appInfo.subtypes(instantiatedType).forEach(cache::remove);
+  }
+
+  public DexEncodedMethod getCachedItem(DexType receiverType, DexMethod method) {
+    Map<DexMethod, DexEncodedMethod> cachedMethods = cache.get(receiverType);
+    if (cachedMethods == null) {
+      return null;
+    }
+    DexEncodedMethod target = cachedMethods.get(method);
+    return target == DexEncodedMethod.SENTINEL ? null : target;
+  }
+
+  public boolean hasCachedItem(DexType receiverType, DexMethod method) {
+    Map<DexMethod, DexEncodedMethod> cachedMethods = cache.get(receiverType);
+    if (cachedMethods == null) {
+      return false;
+    }
+    return cachedMethods.containsKey(method);
+  }
+
+  public void clear() {
+    cache = new ConcurrentHashMap<>();
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/resolution/virtualtargets/InstantiatedLowerBoundTest.java b/src/test/java/com/android/tools/r8/resolution/singletarget/InstantiatedLowerBoundTest.java
similarity index 98%
rename from src/test/java/com/android/tools/r8/resolution/virtualtargets/InstantiatedLowerBoundTest.java
rename to src/test/java/com/android/tools/r8/resolution/singletarget/InstantiatedLowerBoundTest.java
index 6001370..deec1e3 100644
--- a/src/test/java/com/android/tools/r8/resolution/virtualtargets/InstantiatedLowerBoundTest.java
+++ b/src/test/java/com/android/tools/r8/resolution/singletarget/InstantiatedLowerBoundTest.java
@@ -2,7 +2,7 @@
 // for details. All rights reserved. Use of this source code is governed by a
 // BSD-style license that can be found in the LICENSE file.
 
-package com.android.tools.r8.resolution.virtualtargets;
+package com.android.tools.r8.resolution.singletarget;
 
 import static junit.framework.TestCase.assertEquals;
 import static junit.framework.TestCase.assertNotNull;
diff --git a/src/test/java/com/android/tools/r8/resolution/singletarget/SuccessAndInvalidLookupTest.java b/src/test/java/com/android/tools/r8/resolution/singletarget/SuccessAndInvalidLookupTest.java
new file mode 100644
index 0000000..870c9cc
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/resolution/singletarget/SuccessAndInvalidLookupTest.java
@@ -0,0 +1,95 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.resolution.singletarget;
+
+import static junit.framework.TestCase.assertEquals;
+import static junit.framework.TestCase.assertNotNull;
+import static junit.framework.TestCase.assertNull;
+
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexEncodedMethod;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import java.util.ArrayList;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class SuccessAndInvalidLookupTest extends TestBase {
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withNoneRuntime().build();
+  }
+
+  public SuccessAndInvalidLookupTest(TestParameters parameters) {
+    // Empty to satisfy construction of none-runtime.
+  }
+
+  @Test
+  public void testSingleTargetWithInvalidInvokeInterfaceInvoke() throws Exception {
+    AppView<AppInfoWithLiveness> appView =
+        computeAppViewWithLiveness(
+            buildClasses(I.class, A.class, Main.class).build(),
+            factory -> new ArrayList<>(buildKeepRuleForClassAndMethods(Main.class, factory)));
+    AppInfoWithLiveness appInfo = appView.appInfo();
+    DexType typeMain = buildType(Main.class, appInfo.dexItemFactory());
+    DexType typeA = buildType(A.class, appInfo.dexItemFactory());
+    DexMethod fooA = buildNullaryVoidMethod(A.class, "foo", appInfo.dexItemFactory());
+    DexEncodedMethod singleTarget =
+        appInfo.lookupSingleVirtualTarget(fooA, typeMain, false, t -> false, typeA, null);
+    assertNotNull(singleTarget);
+    assertEquals(fooA, singleTarget.method);
+    DexEncodedMethod invalidSingleTarget =
+        appInfo.lookupSingleVirtualTarget(fooA, typeMain, true, t -> false, typeA, null);
+    assertNull(invalidSingleTarget);
+  }
+
+  @Test
+  public void testSingleTargetWithInvalidInvokeVirtualInvoke() throws Exception {
+    AppView<AppInfoWithLiveness> appView =
+        computeAppViewWithLiveness(
+            buildClasses(I.class, A.class, Main.class).build(),
+            factory -> new ArrayList<>(buildKeepRuleForClassAndMethods(Main.class, factory)));
+    AppInfoWithLiveness appInfo = appView.appInfo();
+    DexType typeMain = buildType(Main.class, appInfo.dexItemFactory());
+    DexType typeA = buildType(I.class, appInfo.dexItemFactory());
+    DexMethod fooI = buildNullaryVoidMethod(I.class, "foo", appInfo.dexItemFactory());
+    DexMethod fooA = buildNullaryVoidMethod(A.class, "foo", appInfo.dexItemFactory());
+    DexEncodedMethod singleTarget =
+        appInfo.lookupSingleVirtualTarget(fooI, typeMain, true, t -> false, typeA, null);
+    assertNotNull(singleTarget);
+    assertEquals(fooA, singleTarget.method);
+    DexEncodedMethod invalidSingleTarget =
+        appInfo.lookupSingleVirtualTarget(fooI, typeMain, false, t -> false, typeA, null);
+    assertNull(invalidSingleTarget);
+  }
+
+  public interface I {
+
+    void foo();
+  }
+
+  public static class A implements I {
+
+    @Override
+    public void foo() {
+      System.out.println("A.foo");
+    }
+  }
+
+  public static class Main {
+
+    public static void main(String[] args) {
+      new A();
+    }
+  }
+}