Add policy to prevent increasing main DEX size

Bug: 165741800
Bug: 165000217
Change-Id: Ie8bbbaf93c10aa98f22e8a9ebe188178a9fbd01b
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
index 8b17c78..581915e 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/HorizontalClassMerger.java
@@ -15,6 +15,7 @@
 import com.android.tools.r8.horizontalclassmerging.policies.NoStaticClassInitializer;
 import com.android.tools.r8.horizontalclassmerging.policies.NotEntryPoint;
 import com.android.tools.r8.horizontalclassmerging.policies.NotMatchedByNoHorizontalClassMerging;
+import com.android.tools.r8.horizontalclassmerging.policies.PreventMergeIntoMainDex;
 import com.android.tools.r8.horizontalclassmerging.policies.RespectPackageBoundaries;
 import com.android.tools.r8.horizontalclassmerging.policies.SameParentClass;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
@@ -34,7 +35,7 @@
 
   public HorizontalClassMerger(
       AppView<AppInfoWithLiveness> appView,
-      MainDexTracingResult mainDexClasses,
+      MainDexTracingResult mainDexTracingResult,
       ClassMergingEnqueuerExtension classMergingEnqueuerExtension) {
     this.appView = appView;
 
@@ -50,6 +51,7 @@
             new NoKeepRules(appView),
             new NoRuntimeTypeChecks(classMergingEnqueuerExtension),
             new NotEntryPoint(appView.dexItemFactory()),
+            new PreventMergeIntoMainDex(appView, mainDexTracingResult),
             new SameParentClass(),
             new RespectPackageBoundaries(appView)
             // TODO: add policies
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/MultiClassPolicy.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/MultiClassPolicy.java
index 5c67423..829c1f3 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/MultiClassPolicy.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/MultiClassPolicy.java
@@ -10,12 +10,17 @@
 
 public abstract class MultiClassPolicy extends Policy {
 
+  // TODO(b/165577835): Move to a virtual method on MergeGroup.
+  protected boolean isTrivial(Collection<DexProgramClass> group) {
+    return group.size() < 2;
+  }
+
   /**
    * Remove all groups containing no or only a single class, as there is no point in merging these.
    */
   protected void removeTrivialGroups(Collection<Collection<DexProgramClass>> groups) {
     assert !(groups instanceof ArrayList);
-    groups.removeIf(group -> group.size() < 2);
+    groups.removeIf(this::isTrivial);
   }
 
   /**
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/PreventMergeIntoMainDex.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/PreventMergeIntoMainDex.java
new file mode 100644
index 0000000..da65126
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/PreventMergeIntoMainDex.java
@@ -0,0 +1,55 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.horizontalclassmerging.policies;
+
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.horizontalclassmerging.MultiClassPolicy;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.shaking.MainDexClasses;
+import com.android.tools.r8.shaking.MainDexTracingResult;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+
+public class PreventMergeIntoMainDex extends MultiClassPolicy {
+  private final MainDexClasses mainDexClasses;
+  private final MainDexTracingResult mainDexTracingResult;
+
+  public PreventMergeIntoMainDex(
+      AppView<AppInfoWithLiveness> appView, MainDexTracingResult mainDexTracingResult) {
+    this.mainDexClasses = appView.appInfo().getMainDexClasses();
+    this.mainDexTracingResult = mainDexTracingResult;
+  }
+
+  public boolean isMainDexClass(DexProgramClass clazz) {
+    return mainDexClasses.contains(clazz) || mainDexTracingResult.contains(clazz);
+  }
+
+  @Override
+  public Collection<Collection<DexProgramClass>> apply(Collection<DexProgramClass> group) {
+    List<DexProgramClass> mainDexMembers = new LinkedList<>();
+    Iterator<DexProgramClass> iterator = group.iterator();
+    while (iterator.hasNext()) {
+      DexProgramClass clazz = iterator.next();
+      if (isMainDexClass(clazz)) {
+        iterator.remove();
+        mainDexMembers.add(clazz);
+      }
+    }
+
+    Collection<Collection<DexProgramClass>> newGroups = new LinkedList<>();
+    if (!isTrivial(mainDexMembers)) {
+      // TODO(b/165577835) remove this cast when we introduce MergeGroup.
+      newGroups.add((Collection) mainDexMembers);
+    }
+    if (!isTrivial(group)) {
+      newGroups.add(group);
+    }
+
+    return newGroups;
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/shaking/MainDexTracingResult.java b/src/main/java/com/android/tools/r8/shaking/MainDexTracingResult.java
index d333211..76b6d8a 100644
--- a/src/main/java/com/android/tools/r8/shaking/MainDexTracingResult.java
+++ b/src/main/java/com/android/tools/r8/shaking/MainDexTracingResult.java
@@ -6,6 +6,7 @@
 
 import com.android.tools.r8.graph.AppInfo;
 import com.android.tools.r8.graph.DexClass;
+import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexType;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Sets;
@@ -92,6 +93,14 @@
     return classes;
   }
 
+  public boolean contains(DexProgramClass clazz) {
+    return contains(clazz.type);
+  }
+
+  public boolean contains(DexType type) {
+    return getClasses().contains(type);
+  }
+
   private void collectTypesMatching(
       Set<DexType> types, Predicate<DexType> predicate, Consumer<DexType> consumer) {
     types.forEach(
diff --git a/src/test/java/com/android/tools/r8/classmerging/horizontal/PreventMergeMainDexListTest.java b/src/test/java/com/android/tools/r8/classmerging/horizontal/PreventMergeMainDexListTest.java
new file mode 100644
index 0000000..4a50200
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/classmerging/horizontal/PreventMergeMainDexListTest.java
@@ -0,0 +1,104 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.classmerging.horizontal;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.core.IsNot.not;
+
+import com.android.tools.r8.NeverClassInline;
+import com.android.tools.r8.OutputMode;
+import com.android.tools.r8.R8TestCompileResult;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.classmerging.horizontal.EmptyClassTest.A;
+import com.android.tools.r8.classmerging.horizontal.EmptyClassTest.B;
+import com.android.tools.r8.classmerging.horizontal.EmptyClassTest.Main;
+import com.android.tools.r8.utils.BooleanUtils;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import java.nio.file.Path;
+import java.util.List;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
+public class PreventMergeMainDexListTest extends HorizontalClassMergingTestBase {
+  public PreventMergeMainDexListTest(
+      TestParameters parameters, boolean enableHorizontalClassMerging) {
+    super(parameters, enableHorizontalClassMerging);
+  }
+
+  @Parameterized.Parameters(name = "{0}, horizontalClassMerging:{1}")
+  public static List<Object[]> data() {
+    return buildParameters(
+        getTestParameters()
+            .withDexRuntimes()
+            .withApiLevelsEndingAtExcluding(apiLevelWithNativeMultiDexSupport())
+            .build(),
+        BooleanUtils.values());
+  }
+
+  @Test
+  public void testR8() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepClassAndMembersRules(Main.class)
+        .addMainDexListClasses(A.class, Main.class)
+        .addOptionsModification(
+            options -> {
+              options.enableHorizontalClassMerging = enableHorizontalClassMerging;
+              options.minimalMainDex = true;
+            })
+        .enableNeverClassInliningAnnotations()
+        .setMinApi(parameters.getApiLevel())
+        .compile()
+        .apply(this::checkCompileResult)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("main dex");
+  }
+
+  private void checkCompileResult(R8TestCompileResult compileResult) throws Exception {
+    Path out = temp.newFolder().toPath();
+    compileResult.app.writeToDirectory(out, OutputMode.DexIndexed);
+    Path classes = out.resolve("classes.dex");
+    Path classes2 = out.resolve("classes2.dex");
+    inspectMainDex(new CodeInspector(classes, compileResult.getProguardMap()));
+    inspectSecondaryDex(new CodeInspector(classes2, compileResult.getProguardMap()));
+  }
+
+  private void inspectMainDex(CodeInspector inspector) {
+    assertThat(inspector.clazz(A.class), isPresent());
+    assertThat(inspector.clazz(B.class), not(isPresent()));
+  }
+
+  private void inspectSecondaryDex(CodeInspector inspector) {
+    assertThat(inspector.clazz(A.class), not(isPresent()));
+    assertThat(inspector.clazz(B.class), isPresent());
+  }
+
+  public static class Main {
+    public static void main(String[] args) {
+      A a = new A();
+    }
+
+    public static void otherDex() {
+      B b = new B();
+    }
+  }
+
+  @NeverClassInline
+  public static class A {
+    public A() {
+      System.out.println("main dex");
+    }
+  }
+
+  @NeverClassInline
+  public static class B {
+    public B() {
+      System.out.println("not main dex");
+    }
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/classmerging/horizontal/PreventMergeMainDexTracingTest.java b/src/test/java/com/android/tools/r8/classmerging/horizontal/PreventMergeMainDexTracingTest.java
new file mode 100644
index 0000000..9e81bc6
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/classmerging/horizontal/PreventMergeMainDexTracingTest.java
@@ -0,0 +1,108 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.classmerging.horizontal;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.core.IsNot.not;
+
+import com.android.tools.r8.NeverClassInline;
+import com.android.tools.r8.NoHorizontalClassMerging;
+import com.android.tools.r8.OutputMode;
+import com.android.tools.r8.R8TestCompileResult;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.classmerging.horizontal.PreventMergeMainDexListTest.Main;
+import com.android.tools.r8.utils.BooleanUtils;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import java.nio.file.Path;
+import java.util.List;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
+public class PreventMergeMainDexTracingTest extends HorizontalClassMergingTestBase {
+  public PreventMergeMainDexTracingTest(
+      TestParameters parameters, boolean enableHorizontalClassMerging) {
+    super(parameters, enableHorizontalClassMerging);
+  }
+
+  @Parameterized.Parameters(name = "{0}, horizontalClassMerging:{1}")
+  public static List<Object[]> data() {
+    return buildParameters(
+        getTestParameters()
+            .withDexRuntimes()
+            .withApiLevelsEndingAtExcluding(apiLevelWithNativeMultiDexSupport())
+            .build(),
+        BooleanUtils.values());
+  }
+
+  @Test
+  public void testR8() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepMainRule(Main.class)
+        .addKeepClassAndMembersRules(Other.class)
+        .addMainDexClassRules(Main.class)
+        .addOptionsModification(
+            options -> {
+              options.enableHorizontalClassMerging = enableHorizontalClassMerging;
+              options.minimalMainDex = true;
+            })
+        .enableNeverClassInliningAnnotations()
+        .enableNoHorizontalClassMergingAnnotations()
+        .setMinApi(parameters.getApiLevel())
+        .compile()
+        .apply(this::checkCompileResult)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("main dex");
+  }
+
+  private void checkCompileResult(R8TestCompileResult compileResult) throws Exception {
+    Path out = temp.newFolder().toPath();
+    compileResult.app.writeToDirectory(out, OutputMode.DexIndexed);
+    Path classes = out.resolve("classes.dex");
+    Path classes2 = out.resolve("classes2.dex");
+    inspectMainDex(new CodeInspector(classes, compileResult.getProguardMap()));
+    inspectSecondaryDex(new CodeInspector(classes2, compileResult.getProguardMap()));
+  }
+
+  private void inspectMainDex(CodeInspector inspector) {
+    assertThat(inspector.clazz(A.class), isPresent());
+    assertThat(inspector.clazz(B.class), not(isPresent()));
+  }
+
+  private void inspectSecondaryDex(CodeInspector inspector) {
+    assertThat(inspector.clazz(A.class), not(isPresent()));
+    assertThat(inspector.clazz(B.class), isPresent());
+  }
+
+  public static class Main {
+    public static void main(String[] args) {
+      A a = new A();
+    }
+  }
+
+  @NoHorizontalClassMerging
+  public static class Other {
+    public static void otherDex() {
+      B b = new B();
+    }
+  }
+
+  @NeverClassInline
+  public static class A {
+    public A() {
+      System.out.println("main dex");
+    }
+  }
+
+  @NeverClassInline
+  public static class B {
+    public B() {
+      System.out.println("not main dex");
+    }
+  }
+}