Add rewriting for class merging enqueuer extension.

Some attempts to get the bootstrap current equality test to work.

Bug: 167385752
Bug: 163311975
Change-Id: Iddc27260e8d06c5e412bd5aadc485ec59c412676
diff --git a/src/main/java/com/android/tools/r8/R8.java b/src/main/java/com/android/tools/r8/R8.java
index ab78b10..34b73a8 100644
--- a/src/main/java/com/android/tools/r8/R8.java
+++ b/src/main/java/com/android/tools/r8/R8.java
@@ -85,7 +85,6 @@
 import com.android.tools.r8.shaking.AnnotationRemover;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.shaking.ClassInitFieldSynthesizer;
-import com.android.tools.r8.shaking.ClassMergingEnqueuerExtension;
 import com.android.tools.r8.shaking.DefaultTreePrunerConfiguration;
 import com.android.tools.r8.shaking.DiscardedChecker;
 import com.android.tools.r8.shaking.Enqueuer;
@@ -99,6 +98,7 @@
 import com.android.tools.r8.shaking.ProguardConfigurationUtils;
 import com.android.tools.r8.shaking.RootSetBuilder;
 import com.android.tools.r8.shaking.RootSetBuilder.RootSet;
+import com.android.tools.r8.shaking.RuntimeTypeCheckInfo;
 import com.android.tools.r8.shaking.StaticClassMerger;
 import com.android.tools.r8.shaking.TreePruner;
 import com.android.tools.r8.shaking.TreePrunerConfiguration;
@@ -324,8 +324,8 @@
       timing.begin("Strip unused code");
       Set<DexType> classesToRetainInnerClassAttributeFor = null;
       Set<DexType> missingClasses = null;
-      ClassMergingEnqueuerExtension classMergingEnqueuerExtension =
-          new ClassMergingEnqueuerExtension(appView.dexItemFactory());
+      RuntimeTypeCheckInfo.Builder classMergingEnqueuerExtensionBuilder =
+          new RuntimeTypeCheckInfo.Builder(appView.dexItemFactory());
       try {
         // TODO(b/154849103): Find a better way to determine missing classes.
         missingClasses = new SubtypingInfo(appView).getMissingClasses();
@@ -379,7 +379,8 @@
                 executorService,
                 appView,
                 subtypingInfo,
-                classMergingEnqueuerExtension);
+                classMergingEnqueuerExtensionBuilder);
+
         assert appView.rootSet().verifyKeptFieldsAreAccessedAndLive(appViewWithLiveness.appInfo());
         assert appView.rootSet().verifyKeptMethodsAreTargetedAndLive(appViewWithLiveness.appInfo());
         assert appView.rootSet().verifyKeptTypesAreLive(appViewWithLiveness.appInfo());
@@ -506,6 +507,7 @@
       boolean isKotlinLibraryCompilationWithInlinePassThrough =
           options.enableCfByteCodePassThrough && appView.hasCfByteCodePassThroughMethods();
 
+      RuntimeTypeCheckInfo runtimeTypeCheckInfo = classMergingEnqueuerExtensionBuilder.build();
       if (!isKotlinLibraryCompilationWithInlinePassThrough
           && options.getProguardConfiguration().isOptimizing()) {
         if (options.enableStaticClassMerging) {
@@ -529,6 +531,7 @@
           if (lens != null) {
             appView.setVerticallyMergedClasses(verticalClassMerger.getMergedClasses());
             appView.rewriteWithLens(lens);
+            runtimeTypeCheckInfo = runtimeTypeCheckInfo.rewriteWithLens(lens);
           }
           timing.end();
         }
@@ -564,7 +567,7 @@
           timing.begin("HorizontalClassMerger");
           HorizontalClassMerger merger =
               new HorizontalClassMerger(
-                  appViewWithLiveness, mainDexTracingResult, classMergingEnqueuerExtension);
+                  appViewWithLiveness, mainDexTracingResult, runtimeTypeCheckInfo);
           DirectMappedDexApplication.Builder appBuilder =
               appView.appInfo().app().asDirect().builder();
           HorizontalClassMergerGraphLens lens = merger.run(appBuilder);
@@ -572,12 +575,13 @@
             DirectMappedDexApplication app = appBuilder.build();
             appView.removePrunedClasses(app, appView.horizontallyMergedClasses().getSources());
             appView.rewriteWithLens(lens);
+
+            // Only required for class merging, clear instance to save memory.
+            runtimeTypeCheckInfo = null;
           }
           timing.end();
         }
 
-        // Only required for class merging, clear instance to save memory.
-        classMergingEnqueuerExtension = null;
       }
 
       // None of the optimizations above should lead to the creation of type lattice elements.
@@ -998,7 +1002,7 @@
       ExecutorService executorService,
       AppView<AppInfoWithClassHierarchy> appView,
       SubtypingInfo subtypingInfo,
-      ClassMergingEnqueuerExtension classMergingEnqueuerExtension)
+      RuntimeTypeCheckInfo.Builder classMergingEnqueuerExtensionBuilder)
       throws ExecutionException {
     Enqueuer enqueuer = EnqueuerFactory.createForInitialTreeShaking(appView, subtypingInfo);
     enqueuer.setAnnotationRemoverBuilder(annotationRemoverBuilder);
@@ -1012,7 +1016,7 @@
     }
 
     if (options.isClassMergingExtensionRequired()) {
-      classMergingEnqueuerExtension.attach(enqueuer);
+      classMergingEnqueuerExtensionBuilder.attach(enqueuer);
     }
 
     AppView<AppInfoWithLiveness> appViewWithLiveness =
diff --git a/src/main/java/com/android/tools/r8/graph/RewrittenPrototypeDescription.java b/src/main/java/com/android/tools/r8/graph/RewrittenPrototypeDescription.java
index 31ea06f..8a44fd0 100644
--- a/src/main/java/com/android/tools/r8/graph/RewrittenPrototypeDescription.java
+++ b/src/main/java/com/android/tools/r8/graph/RewrittenPrototypeDescription.java
@@ -465,7 +465,8 @@
     if (parameters.isEmpty()) {
       return this;
     }
-    List<ExtraParameter> newExtraParameters = new ArrayList<>();
+    List<ExtraParameter> newExtraParameters =
+        new ArrayList<>(extraParameters.size() + parameters.size());
     newExtraParameters.addAll(extraParameters);
     newExtraParameters.addAll(parameters);
     return new RewrittenPrototypeDescription(
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
index 4536c09..763827a 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
@@ -30,9 +30,9 @@
 import com.android.tools.r8.horizontalclassmerging.policies.SameNestHost;
 import com.android.tools.r8.horizontalclassmerging.policies.SameParentClass;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
-import com.android.tools.r8.shaking.ClassMergingEnqueuerExtension;
 import com.android.tools.r8.shaking.FieldAccessInfoCollectionModifier;
 import com.android.tools.r8.shaking.MainDexTracingResult;
+import com.android.tools.r8.shaking.RuntimeTypeCheckInfo;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Iterables;
 import java.util.ArrayList;
@@ -48,7 +48,7 @@
   public HorizontalClassMerger(
       AppView<AppInfoWithLiveness> appView,
       MainDexTracingResult mainDexTracingResult,
-      ClassMergingEnqueuerExtension classMergingEnqueuerExtension) {
+      RuntimeTypeCheckInfo runtimeTypeCheckInfo) {
     this.appView = appView;
     assert appView.options().enableInlining;
 
@@ -67,7 +67,7 @@
             new NoStaticClassInitializer(),
             new NoKeepRules(appView),
             new NotVerticallyMergedIntoSubtype(appView),
-            new NoRuntimeTypeChecks(classMergingEnqueuerExtension),
+            new NoRuntimeTypeChecks(runtimeTypeCheckInfo),
             new NotEntryPoint(appView.dexItemFactory()),
             new PreventMergeIntoMainDex(appView, mainDexTracingResult),
             new SameParentClass(),
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoRuntimeTypeChecks.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoRuntimeTypeChecks.java
index 23f10b2..010fee8 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoRuntimeTypeChecks.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/NoRuntimeTypeChecks.java
@@ -6,18 +6,18 @@
 
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.horizontalclassmerging.SingleClassPolicy;
-import com.android.tools.r8.shaking.ClassMergingEnqueuerExtension;
+import com.android.tools.r8.shaking.RuntimeTypeCheckInfo;
 
 public class NoRuntimeTypeChecks extends SingleClassPolicy {
-  private final ClassMergingEnqueuerExtension classMergingEnqueuerExtension;
+  private final RuntimeTypeCheckInfo runtimeTypeCheckInfo;
 
-  public NoRuntimeTypeChecks(ClassMergingEnqueuerExtension classMergingEnqueuerExtension) {
-    this.classMergingEnqueuerExtension = classMergingEnqueuerExtension;
+  public NoRuntimeTypeChecks(RuntimeTypeCheckInfo runtimeTypeCheckInfo) {
+    this.runtimeTypeCheckInfo = runtimeTypeCheckInfo;
   }
 
   @Override
   public boolean canMerge(DexProgramClass clazz) {
     // We currently assume we only merge classes that implement the same set of interfaces.
-    return !classMergingEnqueuerExtension.isRuntimeCheckType(clazz);
+    return !runtimeTypeCheckInfo.isRuntimeCheckType(clazz);
   }
 }
diff --git a/src/main/java/com/android/tools/r8/shaking/ClassMergingEnqueuerExtension.java b/src/main/java/com/android/tools/r8/shaking/ClassMergingEnqueuerExtension.java
deleted file mode 100644
index 6a2151b..0000000
--- a/src/main/java/com/android/tools/r8/shaking/ClassMergingEnqueuerExtension.java
+++ /dev/null
@@ -1,68 +0,0 @@
-// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
-// for details. All rights reserved. Use of this source code is governed by a
-// BSD-style license that can be found in the LICENSE file.
-
-package com.android.tools.r8.shaking;
-
-import com.android.tools.r8.graph.DexItemFactory;
-import com.android.tools.r8.graph.DexProgramClass;
-import com.android.tools.r8.graph.DexType;
-import com.android.tools.r8.graph.ProgramMethod;
-import com.android.tools.r8.graph.analysis.EnqueuerCheckCastAnalysis;
-import com.android.tools.r8.graph.analysis.EnqueuerExceptionGuardAnalysis;
-import com.android.tools.r8.graph.analysis.EnqueuerInstanceOfAnalysis;
-import com.google.common.collect.Sets;
-import java.util.Set;
-
-public class ClassMergingEnqueuerExtension
-    implements EnqueuerInstanceOfAnalysis,
-        EnqueuerCheckCastAnalysis,
-        EnqueuerExceptionGuardAnalysis {
-
-  private final Set<DexType> instanceOfTypes = Sets.newIdentityHashSet();
-  private final Set<DexType> checkCastTypes = Sets.newIdentityHashSet();
-  private final Set<DexType> exceptionGuardTypes = Sets.newIdentityHashSet();
-  private final DexItemFactory factory;
-
-  public ClassMergingEnqueuerExtension(DexItemFactory factory) {
-    this.factory = factory;
-  }
-
-  @Override
-  public void traceCheckCast(DexType type, ProgramMethod context) {
-    checkCastTypes.add(type.toBaseType(factory));
-  }
-
-  @Override
-  public void traceInstanceOf(DexType type, ProgramMethod context) {
-    instanceOfTypes.add(type.toBaseType(factory));
-  }
-
-  @Override
-  public void traceExceptionGuard(DexType guard, ProgramMethod context) {
-    exceptionGuardTypes.add(guard);
-  }
-
-  public boolean isCheckCastType(DexProgramClass clazz) {
-    return checkCastTypes.contains(clazz.type);
-  }
-
-  public boolean isInstanceOfType(DexProgramClass clazz) {
-    return instanceOfTypes.contains(clazz.type);
-  }
-
-  public boolean isExceptionGuardType(DexProgramClass clazz) {
-    return exceptionGuardTypes.contains(clazz.type);
-  }
-
-  public boolean isRuntimeCheckType(DexProgramClass clazz) {
-    return isInstanceOfType(clazz) || isCheckCastType(clazz) || isExceptionGuardType(clazz);
-  }
-
-  public void attach(Enqueuer enqueuer) {
-    enqueuer
-        .registerInstanceOfAnalysis(this)
-        .registerCheckCastAnalysis(this)
-        .registerExceptionGuardAnalysis(this);
-  }
-}
diff --git a/src/main/java/com/android/tools/r8/shaking/RuntimeTypeCheckInfo.java b/src/main/java/com/android/tools/r8/shaking/RuntimeTypeCheckInfo.java
new file mode 100644
index 0000000..ddfea43
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/shaking/RuntimeTypeCheckInfo.java
@@ -0,0 +1,95 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.shaking;
+
+import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.GraphLens;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.graph.analysis.EnqueuerCheckCastAnalysis;
+import com.android.tools.r8.graph.analysis.EnqueuerExceptionGuardAnalysis;
+import com.android.tools.r8.graph.analysis.EnqueuerInstanceOfAnalysis;
+import com.android.tools.r8.utils.SetUtils;
+import com.google.common.collect.Sets;
+import java.util.Set;
+
+public class RuntimeTypeCheckInfo {
+
+  private final Set<DexType> instanceOfTypes;
+  private final Set<DexType> checkCastTypes;
+  private final Set<DexType> exceptionGuardTypes;
+
+  public RuntimeTypeCheckInfo(
+      Set<DexType> instanceOfTypes, Set<DexType> checkCastTypes, Set<DexType> exceptionGuardTypes) {
+    this.instanceOfTypes = instanceOfTypes;
+    this.checkCastTypes = checkCastTypes;
+    this.exceptionGuardTypes = exceptionGuardTypes;
+  }
+
+  public static class Builder
+      implements EnqueuerInstanceOfAnalysis,
+          EnqueuerCheckCastAnalysis,
+          EnqueuerExceptionGuardAnalysis {
+    private final DexItemFactory factory;
+
+    private final Set<DexType> instanceOfTypes = Sets.newIdentityHashSet();
+    private final Set<DexType> checkCastTypes = Sets.newIdentityHashSet();
+    private final Set<DexType> exceptionGuardTypes = Sets.newIdentityHashSet();
+
+    public Builder(DexItemFactory factory) {
+      this.factory = factory;
+    }
+
+    public RuntimeTypeCheckInfo build() {
+      return new RuntimeTypeCheckInfo(instanceOfTypes, checkCastTypes, exceptionGuardTypes);
+    }
+
+    @Override
+    public void traceCheckCast(DexType type, ProgramMethod context) {
+      checkCastTypes.add(type.toBaseType(factory));
+    }
+
+    @Override
+    public void traceInstanceOf(DexType type, ProgramMethod context) {
+      instanceOfTypes.add(type.toBaseType(factory));
+    }
+
+    @Override
+    public void traceExceptionGuard(DexType guard, ProgramMethod context) {
+      exceptionGuardTypes.add(guard);
+    }
+
+    public void attach(Enqueuer enqueuer) {
+      enqueuer
+          .registerInstanceOfAnalysis(this)
+          .registerCheckCastAnalysis(this)
+          .registerExceptionGuardAnalysis(this);
+    }
+  }
+
+  public boolean isCheckCastType(DexProgramClass clazz) {
+    return checkCastTypes.contains(clazz.type);
+  }
+
+  public boolean isInstanceOfType(DexProgramClass clazz) {
+    return instanceOfTypes.contains(clazz.type);
+  }
+
+  public boolean isExceptionGuardType(DexProgramClass clazz) {
+    return exceptionGuardTypes.contains(clazz.type);
+  }
+
+  public boolean isRuntimeCheckType(DexProgramClass clazz) {
+    return isInstanceOfType(clazz) || isCheckCastType(clazz) || isExceptionGuardType(clazz);
+  }
+
+  public RuntimeTypeCheckInfo rewriteWithLens(GraphLens.NonIdentityGraphLens graphLens) {
+    return new RuntimeTypeCheckInfo(
+        SetUtils.mapIdentityHashSet(instanceOfTypes, graphLens::lookupType),
+        SetUtils.mapIdentityHashSet(checkCastTypes, graphLens::lookupType),
+        SetUtils.mapIdentityHashSet(exceptionGuardTypes, graphLens::lookupType));
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/classmerging/horizontal/VerticallyMergedClassDistinguishedByCheckCastTest.java b/src/test/java/com/android/tools/r8/classmerging/horizontal/VerticallyMergedClassDistinguishedByCheckCastTest.java
new file mode 100644
index 0000000..9ecbf5a
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/classmerging/horizontal/VerticallyMergedClassDistinguishedByCheckCastTest.java
@@ -0,0 +1,71 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.classmerging.horizontal;
+
+import com.android.tools.r8.NeverClassInline;
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestParameters;
+import org.junit.Test;
+
+public class VerticallyMergedClassDistinguishedByCheckCastTest
+    extends HorizontalClassMergingTestBase {
+
+  public VerticallyMergedClassDistinguishedByCheckCastTest(
+      TestParameters parameters, boolean enableHorizontalClassMerging) {
+    super(parameters, enableHorizontalClassMerging);
+  }
+
+  @Test
+  public void testR8() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepMainRule(Main.class)
+        .addOptionsModification(
+            options -> options.enableHorizontalClassMerging = enableHorizontalClassMerging)
+        .enableNeverClassInliningAnnotations()
+        .setMinApi(parameters.getApiLevel())
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("fail", "bar")
+        .inspect(codeInspector -> {});
+  }
+
+  @NeverClassInline
+  public static class Parent {
+    @NeverInline
+    public void bar() {
+      System.out.println("bar");
+    }
+  }
+
+  @NeverClassInline
+  public static class A {
+    @NeverInline
+    public void foo() {
+      System.out.println("foo");
+    }
+  }
+
+  @NeverClassInline
+  public static class B extends Parent {}
+
+  public static class Main {
+    @NeverInline
+    public static void checkObject(Object o) {
+      try {
+        Parent b = (Parent) o;
+        b.bar();
+      } catch (ClassCastException ex) {
+        System.out.println("fail");
+      }
+    }
+
+    public static void main(String[] args) {
+      A a = new A();
+      B b = new B();
+      checkObject(a);
+      checkObject(b);
+    }
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/classmerging/horizontal/VerticallyMergedClassDistinguishedByInstanceOfTest.java b/src/test/java/com/android/tools/r8/classmerging/horizontal/VerticallyMergedClassDistinguishedByInstanceOfTest.java
new file mode 100644
index 0000000..bba42bb
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/classmerging/horizontal/VerticallyMergedClassDistinguishedByInstanceOfTest.java
@@ -0,0 +1,66 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.classmerging.horizontal;
+
+import com.android.tools.r8.NeverClassInline;
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestParameters;
+import org.junit.Test;
+
+public class VerticallyMergedClassDistinguishedByInstanceOfTest
+    extends HorizontalClassMergingTestBase {
+
+  public VerticallyMergedClassDistinguishedByInstanceOfTest(
+      TestParameters parameters, boolean enableHorizontalClassMerging) {
+    super(parameters, enableHorizontalClassMerging);
+  }
+
+  @Test
+  public void testR8() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepMainRule(Main.class)
+        .addOptionsModification(
+            options -> options.enableHorizontalClassMerging = enableHorizontalClassMerging)
+        .enableNeverClassInliningAnnotations()
+        .setMinApi(parameters.getApiLevel())
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("false", "true")
+        .inspect(codeInspector -> {});
+  }
+
+  @NeverClassInline
+  public static class Parent {
+    @NeverInline
+    public void bar() {
+      System.out.println("bar");
+    }
+  }
+
+  @NeverClassInline
+  public static class A {
+    @NeverInline
+    public void foo() {
+      System.out.println("foo");
+    }
+  }
+
+  @NeverClassInline
+  public static class B extends Parent {}
+
+  public static class Main {
+    @NeverInline
+    public static void checkObject(Object o) {
+      System.out.println(o instanceof Parent);
+    }
+
+    public static void main(String[] args) {
+      A a = new A();
+      B b = new B();
+      checkObject(a);
+      checkObject(b);
+    }
+  }
+}