Add synthetic $r8$HorizontalClassMergingArgument as main dex root if merged class is root

Change-Id: I23ef7f2e0e488ecf9092f77787bfb6b148bf8fbc
diff --git a/src/main/java/com/android/tools/r8/R8.java b/src/main/java/com/android/tools/r8/R8.java
index f43bfc9..67bd8bd 100644
--- a/src/main/java/com/android/tools/r8/R8.java
+++ b/src/main/java/com/android/tools/r8/R8.java
@@ -42,7 +42,7 @@
 import com.android.tools.r8.graph.classmerging.StaticallyMergedClasses;
 import com.android.tools.r8.graph.classmerging.VerticallyMergedClasses;
 import com.android.tools.r8.horizontalclassmerging.HorizontalClassMerger;
-import com.android.tools.r8.horizontalclassmerging.HorizontalClassMergerGraphLens;
+import com.android.tools.r8.horizontalclassmerging.HorizontalClassMergerResult;
 import com.android.tools.r8.horizontalclassmerging.HorizontallyMergedClasses;
 import com.android.tools.r8.inspector.internal.InspectorImpl;
 import com.android.tools.r8.ir.conversion.IRConverter;
@@ -560,20 +560,26 @@
           HorizontalClassMerger merger = new HorizontalClassMerger(appViewWithLiveness);
           DirectMappedDexApplication.Builder appBuilder =
               appView.appInfo().app().asDirect().builder();
-          HorizontalClassMergerGraphLens lens =
+          HorizontalClassMergerResult horizontalClassMergerResult =
               merger.run(appBuilder, mainDexTracingResult, runtimeTypeCheckInfo);
-          if (lens != null) {
+          if (horizontalClassMergerResult != null) {
             // Must rewrite AppInfoWithLiveness before pruning the merged classes, to ensure that
             // allocations sites, fields accesses, etc. are correctly transferred to the target
             // classes.
-            appView.rewriteWithLensAndApplication(lens, appBuilder.build());
-            merger.recordSyntheticFieldAccesses();
+            appView.rewriteWithLensAndApplication(
+                horizontalClassMergerResult.getGraphLens(), appBuilder.build());
+            horizontalClassMergerResult
+                .getFieldAccessInfoCollectionModifier()
+                .modify(appViewWithLiveness);
+
             appView.pruneItems(
                 PrunedItems.builder()
                     .setPrunedApp(appView.appInfo().app())
                     .addRemovedClasses(appView.horizontallyMergedClasses().getSources())
                     .addNoLongerSyntheticItems(appView.horizontallyMergedClasses().getTargets())
                     .build());
+
+            mainDexTracingResult = horizontalClassMergerResult.getMainDexTracingResult();
           }
           timing.end();
         } else {
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
index b1f1c31..09b975b 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
@@ -52,19 +52,14 @@
 public class HorizontalClassMerger {
 
   private final AppView<AppInfoWithLiveness> appView;
-  private FieldAccessInfoCollectionModifier fieldAccessChanges;
 
   public HorizontalClassMerger(AppView<AppInfoWithLiveness> appView) {
     this.appView = appView;
     assert appView.options().enableInlining;
   }
 
-  public void recordSyntheticFieldAccesses() {
-    fieldAccessChanges.modify(appView);
-  }
-
   // TODO(b/165577835): replace Collection<DexProgramClass> with MergeGroup
-  public HorizontalClassMergerGraphLens run(
+  public HorizontalClassMergerResult run(
       DirectMappedDexApplication.Builder appBuilder,
       MainDexTracingResult mainDexTracingResult,
       RuntimeTypeCheckInfo runtimeTypeCheckInfo) {
@@ -87,6 +82,8 @@
         new HorizontalClassMergerGraphLens.Builder();
     FieldAccessInfoCollectionModifier.Builder fieldAccessChangesBuilder =
         new FieldAccessInfoCollectionModifier.Builder();
+    MainDexTracingResult.Builder mainDexTracingResultBuilder =
+        mainDexTracingResult.extensionBuilder(appView.appInfo());
 
     // Set up a class merger for each group.
     List<ClassMerger> classMergers =
@@ -98,7 +95,9 @@
 
     // Merge the classes.
     SyntheticArgumentClass syntheticArgumentClass =
-        new SyntheticArgumentClass.Builder().build(appView, appBuilder, allMergeClasses);
+        new SyntheticArgumentClass.Builder(
+                appBuilder, appView, mainDexTracingResult, mainDexTracingResultBuilder)
+            .build(allMergeClasses);
     applyClassMergers(classMergers, syntheticArgumentClass);
 
     // Generate the graph lens.
@@ -106,13 +105,13 @@
     appView.setHorizontallyMergedClasses(mergedClasses);
     HorizontalClassMergerGraphLens lens =
         createLens(mergedClasses, lensBuilder, fieldAccessChangesBuilder, syntheticArgumentClass);
-    fieldAccessChanges = fieldAccessChangesBuilder.build();
 
     // Prune keep info.
     KeepInfoCollection keepInfo = appView.appInfo().getKeepInfo();
     keepInfo.mutate(mutator -> mutator.removeKeepInfoForPrunedItems(mergedClasses.getSources()));
 
-    return lens;
+    return new HorizontalClassMergerResult(
+        fieldAccessChangesBuilder.build(), lens, mainDexTracingResultBuilder.build());
   }
 
   private List<Policy> getPolicies(
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMergerResult.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMergerResult.java
new file mode 100644
index 0000000..75aafb4
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMergerResult.java
@@ -0,0 +1,36 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.horizontalclassmerging;
+
+import com.android.tools.r8.shaking.FieldAccessInfoCollectionModifier;
+import com.android.tools.r8.shaking.MainDexTracingResult;
+
+public class HorizontalClassMergerResult {
+
+  private final FieldAccessInfoCollectionModifier fieldAccessInfoCollectionModifier;
+  private final HorizontalClassMergerGraphLens graphLens;
+  private final MainDexTracingResult mainDexTracingResult;
+
+  HorizontalClassMergerResult(
+      FieldAccessInfoCollectionModifier fieldAccessInfoCollectionModifier,
+      HorizontalClassMergerGraphLens graphLens,
+      MainDexTracingResult mainDexTracingResult) {
+    this.fieldAccessInfoCollectionModifier = fieldAccessInfoCollectionModifier;
+    this.graphLens = graphLens;
+    this.mainDexTracingResult = mainDexTracingResult;
+  }
+
+  public FieldAccessInfoCollectionModifier getFieldAccessInfoCollectionModifier() {
+    return fieldAccessInfoCollectionModifier;
+  }
+
+  public HorizontalClassMergerGraphLens getGraphLens() {
+    return graphLens;
+  }
+
+  public MainDexTracingResult getMainDexTracingResult() {
+    return mainDexTracingResult;
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/SyntheticArgumentClass.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/SyntheticArgumentClass.java
index c3aa3c0..ec64e8a 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/SyntheticArgumentClass.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/SyntheticArgumentClass.java
@@ -17,6 +17,8 @@
 import com.android.tools.r8.graph.GenericSignature.ClassSignature;
 import com.android.tools.r8.origin.SynthesizedOrigin;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.shaking.MainDexTracingResult;
+import com.google.common.collect.Iterables;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
@@ -53,13 +55,27 @@
 
   public static class Builder {
 
-    private DexType synthesizeClass(
-        AppView<AppInfoWithLiveness> appView,
+    private final DirectMappedDexApplication.Builder appBuilder;
+    private final AppView<AppInfoWithLiveness> appView;
+    private final MainDexTracingResult mainDexTracingResult;
+    private final MainDexTracingResult.Builder mainDexTracingResultBuilder;
+
+    Builder(
         DirectMappedDexApplication.Builder appBuilder,
+        AppView<AppInfoWithLiveness> appView,
+        MainDexTracingResult mainDexTracingResult,
+        MainDexTracingResult.Builder mainDexTracingResultBuilder) {
+      this.appBuilder = appBuilder;
+      this.appView = appView;
+      this.mainDexTracingResult = mainDexTracingResult;
+      this.mainDexTracingResultBuilder = mainDexTracingResultBuilder;
+    }
+
+    private DexType synthesizeClass(
         DexProgramClass context,
         boolean requiresMainDex,
+        boolean addToMainDexTracingResult,
         int index) {
-
       DexType syntheticClassType =
           appView
               .dexItemFactory()
@@ -92,26 +108,31 @@
 
       appBuilder.addSynthesizedClass(clazz);
       appView.appInfo().addSynthesizedClass(clazz, requiresMainDex);
+      if (addToMainDexTracingResult) {
+        mainDexTracingResultBuilder.addRoot(clazz);
+      }
 
       return clazz.type;
     }
 
-    public SyntheticArgumentClass build(
-        AppView<AppInfoWithLiveness> appView,
-        DirectMappedDexApplication.Builder appBuilder,
-        Iterable<DexProgramClass> mergeClasses) {
-
+    public SyntheticArgumentClass build(Iterable<DexProgramClass> mergeClasses) {
       // Find a fresh name in an existing package.
       DexProgramClass context = mergeClasses.iterator().next();
 
+      // Add to the main dex list if one of the merged classes is in the main dex.
       boolean requiresMainDex = appView.appInfo().getMainDexClasses().containsAnyOf(mergeClasses);
 
+      // Also add as a root to the main dex tracing result if any of the merged classes is a root.
+      // This is needed to satisfy an assertion in the inliner that verifies that we do not inline
+      // methods with references to non-roots into classes that are roots.
+      boolean addToMainDexTracingResult = Iterables.any(mergeClasses, mainDexTracingResult::isRoot);
+
       List<DexType> syntheticArgumentTypes = new ArrayList<>();
       for (int i = 0;
           i < appView.options().horizontalClassMergerOptions().getSyntheticArgumentCount();
           i++) {
         syntheticArgumentTypes.add(
-            synthesizeClass(appView, appBuilder, context, requiresMainDex, i));
+            synthesizeClass(context, requiresMainDex, addToMainDexTracingResult, i));
       }
 
       return new SyntheticArgumentClass(syntheticArgumentTypes);
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/DefaultInliningOracle.java b/src/main/java/com/android/tools/r8/ir/optimize/DefaultInliningOracle.java
index 603e21f..4534c9b 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/DefaultInliningOracle.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/DefaultInliningOracle.java
@@ -204,7 +204,9 @@
       return false;
     }
     assert reason != Reason.FORCE
-        || !inlineeRefersToClassesNotInMainDex(method.getHolderType(), singleTarget);
+            || !inlineeRefersToClassesNotInMainDex(method.getHolderType(), singleTarget)
+        : MainDexDirectReferenceTracer.getFirstReferenceOutsideFromCode(
+            appView.appInfo(), singleTarget, inliner.mainDexClasses.getRoots());
     return true;
   }
 
diff --git a/src/main/java/com/android/tools/r8/shaking/MainDexDirectReferenceTracer.java b/src/main/java/com/android/tools/r8/shaking/MainDexDirectReferenceTracer.java
index 91f9bbc..6f7cade 100644
--- a/src/main/java/com/android/tools/r8/shaking/MainDexDirectReferenceTracer.java
+++ b/src/main/java/com/android/tools/r8/shaking/MainDexDirectReferenceTracer.java
@@ -22,7 +22,7 @@
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.graph.UseRegistry;
-import com.android.tools.r8.utils.BooleanBox;
+import com.android.tools.r8.utils.Box;
 import java.util.Set;
 import java.util.function.Consumer;
 
@@ -68,9 +68,12 @@
 
   public static boolean hasReferencesOutsideFromCode(
       AppInfoWithClassHierarchy appInfo, ProgramMethod method, Set<DexType> classes) {
+    return getFirstReferenceOutsideFromCode(appInfo, method, classes) != null;
+  }
 
-    BooleanBox result = new BooleanBox();
-
+  public static DexProgramClass getFirstReferenceOutsideFromCode(
+      AppInfoWithClassHierarchy appInfo, ProgramMethod method, Set<DexType> classes) {
+    Box<DexProgramClass> result = new Box<>();
     new MainDexDirectReferenceTracer(
             appInfo,
             type -> {
@@ -78,12 +81,11 @@
               if (baseType.isClassType() && !classes.contains(baseType)) {
                 DexClass cls = appInfo.definitionFor(baseType);
                 if (cls != null && cls.isProgramClass()) {
-                  result.set(true);
+                  result.set(cls.asProgramClass());
                 }
               }
             })
         .runOnCode(method);
-
     return result.get();
   }
 
diff --git a/src/main/java/com/android/tools/r8/shaking/MainDexTracingResult.java b/src/main/java/com/android/tools/r8/shaking/MainDexTracingResult.java
index 6383473..3af323c 100644
--- a/src/main/java/com/android/tools/r8/shaking/MainDexTracingResult.java
+++ b/src/main/java/com/android/tools/r8/shaking/MainDexTracingResult.java
@@ -6,8 +6,10 @@
 
 import com.android.tools.r8.graph.AppInfo;
 import com.android.tools.r8.graph.DexClass;
+import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.ProgramDefinition;
+import com.android.tools.r8.utils.SetUtils;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Sets;
 import java.util.Collection;
@@ -23,11 +25,29 @@
 
   public static class Builder {
     public final AppInfo appInfo;
-    public final Set<DexType> roots = Sets.newIdentityHashSet();
-    public final Set<DexType> dependencies = Sets.newIdentityHashSet();
+    public final Set<DexType> roots;
+    public final Set<DexType> dependencies;
 
     private Builder(AppInfo appInfo) {
+      this(appInfo, Sets.newIdentityHashSet(), Sets.newIdentityHashSet());
+    }
+
+    private Builder(AppInfo appInfo, MainDexTracingResult mainDexTracingResult) {
+      this(
+          appInfo,
+          SetUtils.newIdentityHashSet(mainDexTracingResult.getRoots()),
+          SetUtils.newIdentityHashSet(mainDexTracingResult.getDependencies()));
+    }
+
+    private Builder(AppInfo appInfo, Set<DexType> roots, Set<DexType> dependencies) {
       this.appInfo = appInfo;
+      this.roots = roots;
+      this.dependencies = dependencies;
+    }
+
+    public Builder addRoot(DexProgramClass clazz) {
+      roots.add(clazz.getType());
+      return this;
     }
 
     public Builder addRoot(DexType type) {
@@ -141,4 +161,8 @@
   public static Builder builder(AppInfo appInfo) {
     return new Builder(appInfo);
   }
+
+  public Builder extensionBuilder(AppInfo appInfo) {
+    return new Builder(appInfo, this);
+  }
 }