Perform depth first traversal during horizontal merging tree fixer
Bug: 169527976
Bug: 163311975
Change-Id: Ic69dc77d3b049558244c009a078cebbb11cfb906
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/SubtypingForrestForClasses.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/SubtypingForrestForClasses.java
new file mode 100644
index 0000000..402ddb1
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/SubtypingForrestForClasses.java
@@ -0,0 +1,77 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.horizontalclassmerging;
+
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.IdentityHashMap;
+import java.util.Map;
+import java.util.function.BiFunction;
+
+/**
+ * Calculates the subtyping forrest for all classes. Unlike {@link
+ * com.android.tools.r8.graph.SubtypingInfo}, interfaces are not included in this subtyping
+ * information and only the immediate parents are stored (i.e. the transitive parents are not
+ * calculated). In the following example graph, the roots are A, E and G, and each edge indicates an
+ * entry in {@link SubtypingForrestForClasses#subtypeMap} going from the parent to an entry in the
+ * collection of children. <code>
+ * A E G
+ * / \ |
+ * B C F
+ * |
+ * D
+ * </code>
+ */
+public class SubtypingForrestForClasses {
+ private final AppView<AppInfoWithLiveness> appView;
+
+ private final Collection<DexProgramClass> roots = new ArrayList<>();
+ private final Map<DexProgramClass, Collection<DexProgramClass>> subtypeMap =
+ new IdentityHashMap<>();
+
+ public SubtypingForrestForClasses(AppView<AppInfoWithLiveness> appView) {
+ this.appView = appView;
+
+ calculateSubtyping(appView.appInfo().classes());
+ }
+
+ private DexProgramClass superClass(DexProgramClass clazz) {
+ return appView.programDefinitionFor(clazz.superType, clazz);
+ }
+
+ private void calculateSubtyping(Iterable<DexProgramClass> classes) {
+ classes.forEach(this::calculateSubtyping);
+ }
+
+ private void calculateSubtyping(DexProgramClass clazz) {
+ if (clazz.isInterface()) {
+ return;
+ }
+ DexProgramClass superClass = superClass(clazz);
+ if (superClass == null) {
+ roots.add(clazz);
+ } else {
+ subtypeMap.computeIfAbsent(superClass, ignore -> new ArrayList<>()).add(clazz);
+ }
+ }
+
+ public Collection<DexProgramClass> getProgramRoots() {
+ return roots;
+ }
+
+ private Collection<DexProgramClass> getSubtypesFor(DexProgramClass clazz) {
+ return subtypeMap.getOrDefault(clazz, Collections.emptyList());
+ }
+
+ public <T> void traverseNodeDepthFirst(
+ DexProgramClass clazz, T state, BiFunction<DexProgramClass, T, T> consumer) {
+ T newState = consumer.apply(clazz, state);
+ getSubtypesFor(clazz).forEach(subClazz -> traverseNodeDepthFirst(subClazz, newState, consumer));
+ }
+}
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/TreeFixer.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/TreeFixer.java
index 7f12cc9..b102b69 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/TreeFixer.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/TreeFixer.java
@@ -34,7 +34,6 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
-import java.util.function.Consumer;
/**
* The tree fixer traverses all program classes and finds and fixes references to old classes which
@@ -53,10 +52,6 @@
private final BiMap<Wrapper<DexMethod>, Wrapper<DexMethod>> reservedInterfaceSignatures =
HashBiMap.create();
- // Store which methods have been renamed in parent classes.
- private final Map<DexType, Map<Wrapper<DexMethod>, DexString>> renamedVirtualMethods =
- new IdentityHashMap<>();
-
public TreeFixer(
AppView<AppInfoWithLiveness> appView,
HorizontallyMergedClasses mergedClasses,
@@ -128,8 +123,13 @@
Iterable<DexProgramClass> classes = appView.appInfo().classesWithDeterministicOrder();
Iterables.filter(classes, DexProgramClass::isInterface).forEach(this::fixupInterfaceClass);
- forEachClassTypeTraverseHierarchy(
- Iterables.filter(classes, clazz -> !clazz.isInterface()), this::fixupProgramClass);
+ classes.forEach(this::fixupProgramClassSuperType);
+ SubtypingForrestForClasses subtypingForrest = new SubtypingForrestForClasses(appView);
+ // TODO(b/170078037): parallelize this code segment.
+ for (DexProgramClass root : subtypingForrest.getProgramRoots()) {
+ subtypingForrest.traverseNodeDepthFirst(
+ root, new IdentityHashMap<>(), this::fixupProgramClass);
+ }
lensBuilder.remapMethods(movedMethods);
@@ -139,14 +139,19 @@
return lens;
}
- private void fixupProgramClass(DexProgramClass clazz) {
+ private void fixupProgramClassSuperType(DexProgramClass clazz) {
+ clazz.superType = fixupType(clazz.superType);
+ }
+
+ private Map<Wrapper<DexMethod>, DexString> fixupProgramClass(
+ DexProgramClass clazz, Map<Wrapper<DexMethod>, DexString> remappedVirtualMethods) {
assert !clazz.isInterface();
// TODO(b/169395592): ensure merged classes have been removed using:
// assert !mergedClasses.hasBeenMergedIntoDifferentType(clazz.type);
- Map<Wrapper<DexMethod>, DexString> renamedClassVirtualMethods =
- new HashMap<>(renamedVirtualMethods.getOrDefault(clazz.superType, new HashMap<>()));
+ Map<Wrapper<DexMethod>, DexString> remappedClassVirtualMethods =
+ new HashMap<>(remappedVirtualMethods);
Set<DexMethod> newDirectMethodReferences = new LinkedHashSet<>();
Set<DexMethod> newVirtualMethodReferences = new LinkedHashSet<>();
@@ -155,39 +160,16 @@
.getMethodCollection()
.replaceVirtualMethods(
method ->
- fixupVirtualMethod(renamedClassVirtualMethods, newVirtualMethodReferences, method));
+ fixupVirtualMethod(
+ remappedClassVirtualMethods, newVirtualMethodReferences, method));
clazz
.getMethodCollection()
.replaceDirectMethods(method -> fixupDirectMethod(newDirectMethodReferences, method));
- if (!renamedClassVirtualMethods.isEmpty()) {
- renamedVirtualMethods.put(clazz.type, renamedClassVirtualMethods);
- }
fixupFields(clazz.staticFields(), clazz::setStaticField);
fixupFields(clazz.instanceFields(), clazz::setInstanceField);
- }
- private void traverseUp(
- DexProgramClass clazz, Set<DexProgramClass> seenClasses, Consumer<DexProgramClass> fn) {
- if (clazz == null || !seenClasses.add(clazz)) {
- return;
- }
-
- clazz.superType = mergedClasses.getMergeTargetOrDefault(clazz.superType);
- if (clazz.superType != null) {
- DexProgramClass superClass = appView.programDefinitionFor(clazz.superType, clazz);
- traverseUp(superClass, seenClasses, fn);
- }
-
- fn.accept(clazz);
- }
-
- private void forEachClassTypeTraverseHierarchy(
- Iterable<DexProgramClass> classes, Consumer<DexProgramClass> fn) {
- Set<DexProgramClass> seenClasses = Sets.newIdentityHashSet();
- for (DexProgramClass clazz : classes) {
- traverseUp(clazz, seenClasses, fn);
- }
+ return remappedClassVirtualMethods;
}
private DexEncodedMethod fixupVirtualInterfaceMethod(DexEncodedMethod method) {
diff --git a/src/main/java/com/android/tools/r8/utils/ConsumerUtils.java b/src/main/java/com/android/tools/r8/utils/ConsumerUtils.java
index 8d66524..7259af1 100644
--- a/src/main/java/com/android/tools/r8/utils/ConsumerUtils.java
+++ b/src/main/java/com/android/tools/r8/utils/ConsumerUtils.java
@@ -7,9 +7,18 @@
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
+import java.util.function.Function;
public class ConsumerUtils {
+ public static <S, T> Function<S, Consumer<T>> curry(BiConsumer<S, T> function) {
+ return arg -> arg2 -> function.accept(arg, arg2);
+ }
+
+ public static <S, T> Consumer<T> apply(BiConsumer<S, T> function, S arg) {
+ return curry(function).apply(arg);
+ }
+
public static <T> Consumer<T> acceptIfNotSeen(Consumer<T> consumer, Set<T> seen) {
return element -> {
if (seen.add(element)) {
diff --git a/src/main/java/com/android/tools/r8/utils/FunctionUtils.java b/src/main/java/com/android/tools/r8/utils/FunctionUtils.java
index 64e8537..b7c7a87 100644
--- a/src/main/java/com/android/tools/r8/utils/FunctionUtils.java
+++ b/src/main/java/com/android/tools/r8/utils/FunctionUtils.java
@@ -4,11 +4,20 @@
package com.android.tools.r8.utils;
+import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
public class FunctionUtils {
+ public static <S, T, R> Function<S, Function<T, R>> curry(BiFunction<S, T, R> function) {
+ return arg -> arg2 -> function.apply(arg, arg2);
+ }
+
+ public static <S, T, R> Function<T, R> apply(BiFunction<S, T, R> function, S arg) {
+ return curry(function).apply(arg);
+ }
+
public static <T, R> void forEachApply(
Iterable<T> list, Function<T, Consumer<R>> func, R argument) {
for (T t : list) {
diff --git a/src/test/java/com/android/tools/r8/classmerging/horizontal/InheritsFromLibraryClassTest.java b/src/test/java/com/android/tools/r8/classmerging/horizontal/InheritsFromLibraryClassTest.java
new file mode 100644
index 0000000..cfafb12
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/classmerging/horizontal/InheritsFromLibraryClassTest.java
@@ -0,0 +1,90 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.classmerging.horizontal;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static com.android.tools.r8.utils.codeinspector.Matchers.notIf;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+import com.android.tools.r8.NeverClassInline;
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.classmerging.horizontal.SuperConstructorCallsVirtualMethodTest.A;
+import com.android.tools.r8.classmerging.horizontal.SuperConstructorCallsVirtualMethodTest.B;
+import com.android.tools.r8.classmerging.horizontal.SuperConstructorCallsVirtualMethodTest.Main;
+import com.android.tools.r8.classmerging.horizontal.SuperConstructorCallsVirtualMethodTest.Parent;
+import java.util.ArrayList;
+import org.junit.Test;
+
+public class InheritsFromLibraryClassTest extends HorizontalClassMergingTestBase {
+ public InheritsFromLibraryClassTest(
+ TestParameters parameters, boolean enableHorizontalClassMerging) {
+ super(parameters, enableHorizontalClassMerging);
+ }
+
+ @Test
+ public void testR8() throws Exception {
+ testForR8(parameters.getBackend())
+ .addInnerClasses(getClass())
+ .addKeepMainRule(Main.class)
+ .addOptionsModification(
+ options -> options.enableHorizontalClassMerging = enableHorizontalClassMerging)
+ .enableInliningAnnotations()
+ .enableNeverClassInliningAnnotations()
+ .setMinApi(parameters.getApiLevel())
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutputLines("a", "foo a", "b", "foo")
+ .inspect(
+ codeInspector -> {
+ assertThat(codeInspector.clazz(Parent.class), isPresent());
+ assertThat(codeInspector.clazz(A.class), isPresent());
+ assertThat(
+ codeInspector.clazz(B.class), notIf(isPresent(), enableHorizontalClassMerging));
+ assertThat(codeInspector.clazz(C.class), isPresent());
+ });
+ }
+
+ public static class Parent {
+ @NeverInline
+ public void foo() {
+ System.out.println("foo");
+ }
+ }
+
+ @NeverClassInline
+ public static class A extends Parent {
+ public A() {
+ System.out.println("a");
+ }
+
+ @NeverInline
+ public void foo() {
+ System.out.println("foo a");
+ }
+ }
+
+ @NeverClassInline
+ public static class B extends Parent {
+ public B() {
+ System.out.println("b");
+ }
+ }
+
+ @NeverClassInline
+ public static class C extends ArrayList {
+ public C() {}
+
+ public void fooB(B b) {
+ b.foo();
+ }
+ }
+
+ public static class Main {
+ public static void main(String[] args) {
+ new A().foo();
+ new C().fooB(new B());
+ }
+ }
+}