Perform depth first traversal during horizontal merging tree fixer

Bug: 169527976
Bug: 163311975
Change-Id: Ic69dc77d3b049558244c009a078cebbb11cfb906
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/SubtypingForrestForClasses.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/SubtypingForrestForClasses.java
new file mode 100644
index 0000000..402ddb1
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/SubtypingForrestForClasses.java
@@ -0,0 +1,77 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.horizontalclassmerging;
+
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.IdentityHashMap;
+import java.util.Map;
+import java.util.function.BiFunction;
+
+/**
+ * Calculates the subtyping forrest for all classes. Unlike {@link
+ * com.android.tools.r8.graph.SubtypingInfo}, interfaces are not included in this subtyping
+ * information and only the immediate parents are stored (i.e. the transitive parents are not
+ * calculated). In the following example graph, the roots are A, E and G, and each edge indicates an
+ * entry in {@link SubtypingForrestForClasses#subtypeMap} going from the parent to an entry in the
+ * collection of children. <code>
+ *     A      E     G
+ *    / \     |
+ *   B  C     F
+ *   |
+ *   D
+ * </code>
+ */
+public class SubtypingForrestForClasses {
+  private final AppView<AppInfoWithLiveness> appView;
+
+  private final Collection<DexProgramClass> roots = new ArrayList<>();
+  private final Map<DexProgramClass, Collection<DexProgramClass>> subtypeMap =
+      new IdentityHashMap<>();
+
+  public SubtypingForrestForClasses(AppView<AppInfoWithLiveness> appView) {
+    this.appView = appView;
+
+    calculateSubtyping(appView.appInfo().classes());
+  }
+
+  private DexProgramClass superClass(DexProgramClass clazz) {
+    return appView.programDefinitionFor(clazz.superType, clazz);
+  }
+
+  private void calculateSubtyping(Iterable<DexProgramClass> classes) {
+    classes.forEach(this::calculateSubtyping);
+  }
+
+  private void calculateSubtyping(DexProgramClass clazz) {
+    if (clazz.isInterface()) {
+      return;
+    }
+    DexProgramClass superClass = superClass(clazz);
+    if (superClass == null) {
+      roots.add(clazz);
+    } else {
+      subtypeMap.computeIfAbsent(superClass, ignore -> new ArrayList<>()).add(clazz);
+    }
+  }
+
+  public Collection<DexProgramClass> getProgramRoots() {
+    return roots;
+  }
+
+  private Collection<DexProgramClass> getSubtypesFor(DexProgramClass clazz) {
+    return subtypeMap.getOrDefault(clazz, Collections.emptyList());
+  }
+
+  public <T> void traverseNodeDepthFirst(
+      DexProgramClass clazz, T state, BiFunction<DexProgramClass, T, T> consumer) {
+    T newState = consumer.apply(clazz, state);
+    getSubtypesFor(clazz).forEach(subClazz -> traverseNodeDepthFirst(subClazz, newState, consumer));
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/TreeFixer.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/TreeFixer.java
index 7f12cc9..b102b69 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/TreeFixer.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/TreeFixer.java
@@ -34,7 +34,6 @@
 import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
-import java.util.function.Consumer;
 
 /**
  * The tree fixer traverses all program classes and finds and fixes references to old classes which
@@ -53,10 +52,6 @@
   private final BiMap<Wrapper<DexMethod>, Wrapper<DexMethod>> reservedInterfaceSignatures =
       HashBiMap.create();
 
-  // Store which methods have been renamed in parent classes.
-  private final Map<DexType, Map<Wrapper<DexMethod>, DexString>> renamedVirtualMethods =
-      new IdentityHashMap<>();
-
   public TreeFixer(
       AppView<AppInfoWithLiveness> appView,
       HorizontallyMergedClasses mergedClasses,
@@ -128,8 +123,13 @@
     Iterable<DexProgramClass> classes = appView.appInfo().classesWithDeterministicOrder();
     Iterables.filter(classes, DexProgramClass::isInterface).forEach(this::fixupInterfaceClass);
 
-    forEachClassTypeTraverseHierarchy(
-        Iterables.filter(classes, clazz -> !clazz.isInterface()), this::fixupProgramClass);
+    classes.forEach(this::fixupProgramClassSuperType);
+    SubtypingForrestForClasses subtypingForrest = new SubtypingForrestForClasses(appView);
+    // TODO(b/170078037): parallelize this code segment.
+    for (DexProgramClass root : subtypingForrest.getProgramRoots()) {
+      subtypingForrest.traverseNodeDepthFirst(
+          root, new IdentityHashMap<>(), this::fixupProgramClass);
+    }
 
     lensBuilder.remapMethods(movedMethods);
 
@@ -139,14 +139,19 @@
     return lens;
   }
 
-  private void fixupProgramClass(DexProgramClass clazz) {
+  private void fixupProgramClassSuperType(DexProgramClass clazz) {
+    clazz.superType = fixupType(clazz.superType);
+  }
+
+  private Map<Wrapper<DexMethod>, DexString> fixupProgramClass(
+      DexProgramClass clazz, Map<Wrapper<DexMethod>, DexString> remappedVirtualMethods) {
     assert !clazz.isInterface();
 
     // TODO(b/169395592): ensure merged classes have been removed using:
     //   assert !mergedClasses.hasBeenMergedIntoDifferentType(clazz.type);
 
-    Map<Wrapper<DexMethod>, DexString> renamedClassVirtualMethods =
-        new HashMap<>(renamedVirtualMethods.getOrDefault(clazz.superType, new HashMap<>()));
+    Map<Wrapper<DexMethod>, DexString> remappedClassVirtualMethods =
+        new HashMap<>(remappedVirtualMethods);
 
     Set<DexMethod> newDirectMethodReferences = new LinkedHashSet<>();
     Set<DexMethod> newVirtualMethodReferences = new LinkedHashSet<>();
@@ -155,39 +160,16 @@
         .getMethodCollection()
         .replaceVirtualMethods(
             method ->
-                fixupVirtualMethod(renamedClassVirtualMethods, newVirtualMethodReferences, method));
+                fixupVirtualMethod(
+                    remappedClassVirtualMethods, newVirtualMethodReferences, method));
     clazz
         .getMethodCollection()
         .replaceDirectMethods(method -> fixupDirectMethod(newDirectMethodReferences, method));
 
-    if (!renamedClassVirtualMethods.isEmpty()) {
-      renamedVirtualMethods.put(clazz.type, renamedClassVirtualMethods);
-    }
     fixupFields(clazz.staticFields(), clazz::setStaticField);
     fixupFields(clazz.instanceFields(), clazz::setInstanceField);
-  }
 
-  private void traverseUp(
-      DexProgramClass clazz, Set<DexProgramClass> seenClasses, Consumer<DexProgramClass> fn) {
-    if (clazz == null || !seenClasses.add(clazz)) {
-      return;
-    }
-
-    clazz.superType = mergedClasses.getMergeTargetOrDefault(clazz.superType);
-    if (clazz.superType != null) {
-      DexProgramClass superClass = appView.programDefinitionFor(clazz.superType, clazz);
-      traverseUp(superClass, seenClasses, fn);
-    }
-
-    fn.accept(clazz);
-  }
-
-  private void forEachClassTypeTraverseHierarchy(
-      Iterable<DexProgramClass> classes, Consumer<DexProgramClass> fn) {
-    Set<DexProgramClass> seenClasses = Sets.newIdentityHashSet();
-    for (DexProgramClass clazz : classes) {
-      traverseUp(clazz, seenClasses, fn);
-    }
+    return remappedClassVirtualMethods;
   }
 
   private DexEncodedMethod fixupVirtualInterfaceMethod(DexEncodedMethod method) {
diff --git a/src/main/java/com/android/tools/r8/utils/ConsumerUtils.java b/src/main/java/com/android/tools/r8/utils/ConsumerUtils.java
index 8d66524..7259af1 100644
--- a/src/main/java/com/android/tools/r8/utils/ConsumerUtils.java
+++ b/src/main/java/com/android/tools/r8/utils/ConsumerUtils.java
@@ -7,9 +7,18 @@
 import java.util.Set;
 import java.util.function.BiConsumer;
 import java.util.function.Consumer;
+import java.util.function.Function;
 
 public class ConsumerUtils {
 
+  public static <S, T> Function<S, Consumer<T>> curry(BiConsumer<S, T> function) {
+    return arg -> arg2 -> function.accept(arg, arg2);
+  }
+
+  public static <S, T> Consumer<T> apply(BiConsumer<S, T> function, S arg) {
+    return curry(function).apply(arg);
+  }
+
   public static <T> Consumer<T> acceptIfNotSeen(Consumer<T> consumer, Set<T> seen) {
     return element -> {
       if (seen.add(element)) {
diff --git a/src/main/java/com/android/tools/r8/utils/FunctionUtils.java b/src/main/java/com/android/tools/r8/utils/FunctionUtils.java
index 64e8537..b7c7a87 100644
--- a/src/main/java/com/android/tools/r8/utils/FunctionUtils.java
+++ b/src/main/java/com/android/tools/r8/utils/FunctionUtils.java
@@ -4,11 +4,20 @@
 
 package com.android.tools.r8.utils;
 
+import java.util.function.BiFunction;
 import java.util.function.Consumer;
 import java.util.function.Function;
 
 public class FunctionUtils {
 
+  public static <S, T, R> Function<S, Function<T, R>> curry(BiFunction<S, T, R> function) {
+    return arg -> arg2 -> function.apply(arg, arg2);
+  }
+
+  public static <S, T, R> Function<T, R> apply(BiFunction<S, T, R> function, S arg) {
+    return curry(function).apply(arg);
+  }
+
   public static <T, R> void forEachApply(
       Iterable<T> list, Function<T, Consumer<R>> func, R argument) {
     for (T t : list) {
diff --git a/src/test/java/com/android/tools/r8/classmerging/horizontal/InheritsFromLibraryClassTest.java b/src/test/java/com/android/tools/r8/classmerging/horizontal/InheritsFromLibraryClassTest.java
new file mode 100644
index 0000000..cfafb12
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/classmerging/horizontal/InheritsFromLibraryClassTest.java
@@ -0,0 +1,90 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.classmerging.horizontal;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static com.android.tools.r8.utils.codeinspector.Matchers.notIf;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+import com.android.tools.r8.NeverClassInline;
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.classmerging.horizontal.SuperConstructorCallsVirtualMethodTest.A;
+import com.android.tools.r8.classmerging.horizontal.SuperConstructorCallsVirtualMethodTest.B;
+import com.android.tools.r8.classmerging.horizontal.SuperConstructorCallsVirtualMethodTest.Main;
+import com.android.tools.r8.classmerging.horizontal.SuperConstructorCallsVirtualMethodTest.Parent;
+import java.util.ArrayList;
+import org.junit.Test;
+
+public class InheritsFromLibraryClassTest extends HorizontalClassMergingTestBase {
+  public InheritsFromLibraryClassTest(
+      TestParameters parameters, boolean enableHorizontalClassMerging) {
+    super(parameters, enableHorizontalClassMerging);
+  }
+
+  @Test
+  public void testR8() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepMainRule(Main.class)
+        .addOptionsModification(
+            options -> options.enableHorizontalClassMerging = enableHorizontalClassMerging)
+        .enableInliningAnnotations()
+        .enableNeverClassInliningAnnotations()
+        .setMinApi(parameters.getApiLevel())
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("a", "foo a", "b", "foo")
+        .inspect(
+            codeInspector -> {
+              assertThat(codeInspector.clazz(Parent.class), isPresent());
+              assertThat(codeInspector.clazz(A.class), isPresent());
+              assertThat(
+                  codeInspector.clazz(B.class), notIf(isPresent(), enableHorizontalClassMerging));
+              assertThat(codeInspector.clazz(C.class), isPresent());
+            });
+  }
+
+  public static class Parent {
+    @NeverInline
+    public void foo() {
+      System.out.println("foo");
+    }
+  }
+
+  @NeverClassInline
+  public static class A extends Parent {
+    public A() {
+      System.out.println("a");
+    }
+
+    @NeverInline
+    public void foo() {
+      System.out.println("foo a");
+    }
+  }
+
+  @NeverClassInline
+  public static class B extends Parent {
+    public B() {
+      System.out.println("b");
+    }
+  }
+
+  @NeverClassInline
+  public static class C extends ArrayList {
+    public C() {}
+
+    public void fooB(B b) {
+      b.foo();
+    }
+  }
+
+  public static class Main {
+    public static void main(String[] args) {
+      new A().foo();
+      new C().fooB(new B());
+    }
+  }
+}