Resolve instance initializer collisions in proto normalization
Bug: 195112263
Change-Id: Ib3e8172db63123acd0ce77f23a04654f59b2b3c9
diff --git a/src/main/java/com/android/tools/r8/optimize/proto/ProtoNormalizer.java b/src/main/java/com/android/tools/r8/optimize/proto/ProtoNormalizer.java
index 29ed273..e28c76b 100644
--- a/src/main/java/com/android/tools/r8/optimize/proto/ProtoNormalizer.java
+++ b/src/main/java/com/android/tools/r8/optimize/proto/ProtoNormalizer.java
@@ -6,6 +6,7 @@
import static com.android.tools.r8.utils.MapUtils.ignoreKey;
+import com.android.tools.r8.errors.Unreachable;
import com.android.tools.r8.graph.AppView;
import com.android.tools.r8.graph.DexItemFactory;
import com.android.tools.r8.graph.DexMethod;
@@ -16,11 +17,15 @@
import com.android.tools.r8.graph.ProgramMethod;
import com.android.tools.r8.shaking.AppInfoWithLiveness;
import com.android.tools.r8.utils.InternalOptions;
+import com.android.tools.r8.utils.IterableUtils;
+import com.android.tools.r8.utils.MapUtils;
import com.android.tools.r8.utils.ThreadUtils;
import com.android.tools.r8.utils.Timing;
+import com.android.tools.r8.utils.WorkList;
import com.android.tools.r8.utils.collections.BidirectionalOneToOneHashMap;
import com.android.tools.r8.utils.collections.DexMethodSignatureSet;
import com.android.tools.r8.utils.collections.MutableBidirectionalOneToOneMap;
+import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;
import java.util.Collections;
import java.util.HashMap;
@@ -59,14 +64,19 @@
LocalReservationState localReservationState = new LocalReservationState();
ProtoNormalizerGraphLens.Builder lensBuilder = ProtoNormalizerGraphLens.builder(appView);
for (DexProgramClass clazz : appView.appInfo().classesWithDeterministicOrder()) {
+ Map<DexMethodSignature, DexMethodSignature> newInstanceInitializerSignatures =
+ computeNewInstanceInitializerSignatures(
+ clazz, localReservationState, globalReservationState);
clazz
.getMethodCollection()
.replaceMethods(
method -> {
DexMethodSignature methodSignature = method.getSignature();
DexMethodSignature newMethodSignature =
- localReservationState.getNewMethodSignature(
- methodSignature, dexItemFactory, globalReservationState);
+ method.isInstanceInitializer()
+ ? newInstanceInitializerSignatures.get(methodSignature)
+ : localReservationState.getAndReserveNewMethodSignature(
+ methodSignature, dexItemFactory, globalReservationState);
if (methodSignature.equals(newMethodSignature)) {
return method;
}
@@ -170,6 +180,90 @@
}
}
+ Map<DexMethodSignature, DexMethodSignature> computeNewInstanceInitializerSignatures(
+ DexProgramClass clazz,
+ LocalReservationState localReservationState,
+ GlobalReservationState globalReservationState) {
+ // Create a map from new method signatures to old method signatures. This produces a one-to-many
+ // mapping since multiple instance initializers may normalize to the same signature.
+ Map<DexMethodSignature, DexMethodSignatureSet> instanceInitializerCollisions =
+ computeInstanceInitializerCollisions(clazz, localReservationState, globalReservationState);
+
+ // Resolve each collision to ensure that the mapping is one-to-one.
+ resolveInstanceInitializerCollisions(instanceInitializerCollisions);
+
+ // Inverse the one-to-one map to produce a mapping from old method signatures to new method
+ // signatures.
+ return MapUtils.transform(
+ instanceInitializerCollisions,
+ HashMap::new,
+ (newMethodSignature, methodSignatures) -> Iterables.getFirst(methodSignatures, null),
+ (newMethodSignature, methodSignatures) -> newMethodSignature,
+ (newMethodSignature, methodSignature, otherMethodSignature) -> {
+ throw new Unreachable();
+ });
+ }
+
+ private Map<DexMethodSignature, DexMethodSignatureSet> computeInstanceInitializerCollisions(
+ DexProgramClass clazz,
+ LocalReservationState localReservationState,
+ GlobalReservationState globalReservationState) {
+ Map<DexMethodSignature, DexMethodSignatureSet> instanceInitializerCollisions = new HashMap<>();
+ clazz.forEachProgramInstanceInitializer(
+ method -> {
+ DexMethodSignature methodSignature = method.getMethodSignature();
+ DexMethodSignature newMethodSignature =
+ localReservationState.getNewMethodSignature(
+ methodSignature, dexItemFactory, globalReservationState);
+ instanceInitializerCollisions
+ .computeIfAbsent(newMethodSignature, ignoreKey(DexMethodSignatureSet::create))
+ .add(methodSignature);
+ });
+ return instanceInitializerCollisions;
+ }
+
+ private void resolveInstanceInitializerCollisions(
+ Map<DexMethodSignature, DexMethodSignatureSet> instanceInitializerCollisions) {
+ WorkList<DexMethodSignature> worklist = WorkList.newEqualityWorkList();
+ instanceInitializerCollisions.forEach(
+ (newMethodSignature, methodSignatures) -> {
+ if (methodSignatures.size() > 1) {
+ worklist.addIfNotSeen(newMethodSignature);
+ }
+ });
+
+ while (worklist.hasNext()) {
+ DexMethodSignature newMethodSignature = worklist.removeSeen();
+ DexMethodSignatureSet methodSignatures =
+ instanceInitializerCollisions.get(newMethodSignature);
+ assert methodSignatures.size() > 1;
+
+ // Resolve this conflict in a deterministic way.
+ DexMethodSignature survivor =
+ methodSignatures.contains(newMethodSignature)
+ ? newMethodSignature
+ : IterableUtils.min(methodSignatures, DexMethodSignature::compareTo);
+
+ // Disallow optimizations of all other methods than the `survivor`.
+ for (DexMethodSignature methodSignature : methodSignatures) {
+ if (!methodSignature.equals(survivor)) {
+ DexMethodSignatureSet originalMethodSignaturesForMethodSignature =
+ instanceInitializerCollisions.computeIfAbsent(
+ methodSignature, ignoreKey(DexMethodSignatureSet::create));
+ originalMethodSignaturesForMethodSignature.add(methodSignature);
+ if (originalMethodSignaturesForMethodSignature.size() > 1) {
+ worklist.addIfNotSeen(methodSignature);
+ }
+ }
+ }
+
+ // Remove all pinned methods from the set of original method signatures stored at
+ // instanceInitializerCollisions.get(newMethodSignature).
+ methodSignatures.clear();
+ methodSignatures.add(survivor);
+ }
+ }
+
private boolean isUnoptimizable(ProgramMethod method) {
// TODO(b/195112263): This is incomplete.
return appView.getKeepInfo(method).isPinned(options)
@@ -226,11 +320,27 @@
MutableBidirectionalOneToOneMap<DexMethodSignature, DexMethodSignature> newMethodSignatures =
new BidirectionalOneToOneHashMap<>();
- // TODO: avoid sorting multiple times.
DexMethodSignature getNewMethodSignature(
DexMethodSignature methodSignature,
DexItemFactory dexItemFactory,
GlobalReservationState globalReservationState) {
+ return internalGetAndReserveNewMethodSignature(
+ methodSignature, dexItemFactory, globalReservationState, false);
+ }
+
+ DexMethodSignature getAndReserveNewMethodSignature(
+ DexMethodSignature methodSignature,
+ DexItemFactory dexItemFactory,
+ GlobalReservationState globalReservationState) {
+ return internalGetAndReserveNewMethodSignature(
+ methodSignature, dexItemFactory, globalReservationState, true);
+ }
+
+ private DexMethodSignature internalGetAndReserveNewMethodSignature(
+ DexMethodSignature methodSignature,
+ DexItemFactory dexItemFactory,
+ GlobalReservationState globalReservationState,
+ boolean reserve) {
if (globalReservationState.isUnoptimizable(methodSignature)) {
assert !newMethodSignatures.containsKey(methodSignature);
return methodSignature;
@@ -254,7 +364,9 @@
newMethodSignature = newMethodSignature.withName(newMethodName);
} while (newMethodSignatures.containsValue(newMethodSignature));
}
- newMethodSignatures.put(methodSignature, newMethodSignature);
+ if (reserve) {
+ newMethodSignatures.put(methodSignature, newMethodSignature);
+ }
return newMethodSignature;
}
}
diff --git a/src/main/java/com/android/tools/r8/utils/MapUtils.java b/src/main/java/com/android/tools/r8/utils/MapUtils.java
index a29af0b..2f55680 100644
--- a/src/main/java/com/android/tools/r8/utils/MapUtils.java
+++ b/src/main/java/com/android/tools/r8/utils/MapUtils.java
@@ -9,6 +9,7 @@
import it.unimi.dsi.fastutil.ints.Int2ReferenceMaps;
import java.util.IdentityHashMap;
import java.util.Map;
+import java.util.function.BiFunction;
import java.util.function.BiPredicate;
import java.util.function.Function;
import java.util.function.IntFunction;
@@ -63,14 +64,28 @@
Function<K1, K2> keyMapping,
Function<V1, V2> valueMapping,
TriFunction<K2, V2, V2, V2> valueMerger) {
+ return transform(
+ map,
+ factory,
+ (key, value) -> keyMapping.apply(key),
+ (key, value) -> valueMapping.apply(value),
+ valueMerger);
+ }
+
+ public static <K1, V1, K2, V2> Map<K2, V2> transform(
+ Map<K1, V1> map,
+ IntFunction<Map<K2, V2>> factory,
+ BiFunction<K1, V1, K2> keyMapping,
+ BiFunction<K1, V1, V2> valueMapping,
+ TriFunction<K2, V2, V2, V2> valueMerger) {
Map<K2, V2> result = factory.apply(map.size());
map.forEach(
(key, value) -> {
- K2 newKey = keyMapping.apply(key);
+ K2 newKey = keyMapping.apply(key, value);
if (newKey == null) {
return;
}
- V2 newValue = valueMapping.apply(value);
+ V2 newValue = valueMapping.apply(key, value);
V2 existingValue = result.put(newKey, newValue);
if (existingValue != null) {
result.put(newKey, valueMerger.apply(newKey, existingValue, newValue));
diff --git a/src/test/java/com/android/tools/r8/graph/genericsignature/GenericSignaturePartialTypeArgumentApplierTest.java b/src/test/java/com/android/tools/r8/graph/genericsignature/GenericSignaturePartialTypeArgumentApplierTest.java
index da89df6..4b04321 100644
--- a/src/test/java/com/android/tools/r8/graph/genericsignature/GenericSignaturePartialTypeArgumentApplierTest.java
+++ b/src/test/java/com/android/tools/r8/graph/genericsignature/GenericSignaturePartialTypeArgumentApplierTest.java
@@ -33,6 +33,7 @@
import java.util.Map;
import java.util.Set;
import java.util.function.BiPredicate;
+import java.util.function.Function;
import java.util.function.Predicate;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -121,7 +122,7 @@
MapUtils.transform(
substitutions,
HashMap::new,
- s -> s,
+ Function.identity(),
ClassTypeSignature::new,
(key, val1, val2) -> {
throw new Unreachable("No keys should be merged");
diff --git a/src/test/java/com/android/tools/r8/optimize/proto/ProtoNormalizationWithInstanceInitializerCollisionTest.java b/src/test/java/com/android/tools/r8/optimize/proto/ProtoNormalizationWithInstanceInitializerCollisionTest.java
new file mode 100644
index 0000000..66a4cdc
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/optimize/proto/ProtoNormalizationWithInstanceInitializerCollisionTest.java
@@ -0,0 +1,107 @@
+// Copyright (c) 2022, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.optimize.proto;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+import com.android.tools.r8.NoHorizontalClassMerging;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import com.android.tools.r8.utils.codeinspector.TypeSubject;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class ProtoNormalizationWithInstanceInitializerCollisionTest extends TestBase {
+
+ @Parameter(0)
+ public TestParameters parameters;
+
+ @Parameters(name = "{0}")
+ public static TestParametersCollection data() {
+ return getTestParameters().withAllRuntimesAndApiLevels().build();
+ }
+
+ @Test
+ public void test() throws Exception {
+ testForR8(parameters.getBackend())
+ .addInnerClasses(getClass())
+ .addKeepMainRule(Main.class)
+ .addOptionsModification(
+ options -> options.testing.enableExperimentalProtoNormalization = true)
+ .enableNoHorizontalClassMergingAnnotations()
+ .setMinApi(parameters.getApiLevel())
+ .compile()
+ .inspect(
+ inspector -> {
+ ClassSubject mainClassSubject = inspector.clazz(Main.class);
+ assertThat(mainClassSubject, isPresent());
+
+ ClassSubject aClassSubject = inspector.clazz(A.class);
+ assertThat(aClassSubject, isPresent());
+
+ ClassSubject bClassSubject = inspector.clazz(B.class);
+ assertThat(bClassSubject, isPresent());
+
+ TypeSubject aTypeSubject = aClassSubject.asTypeSubject();
+ TypeSubject bTypeSubject = bClassSubject.asTypeSubject();
+
+ // Main.<init>(A, B) is unchanged.
+ MethodSubject initMethodSubject =
+ mainClassSubject.initFromTypes(aTypeSubject, bTypeSubject);
+ assertThat(initMethodSubject, isPresent());
+
+ // Main.<init>(B, A) is unchanged.
+ MethodSubject otherInitMethodSubject =
+ mainClassSubject.initFromTypes(bTypeSubject, aTypeSubject);
+ assertThat(otherInitMethodSubject, isPresent());
+ })
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutputLines("A", "B", "A", "B");
+ }
+
+ static class Main {
+
+ public static void main(String[] args) {
+ new Main(new A(), new B());
+ new Main(new B(), new A());
+ }
+
+ Main(A a, B b) {
+ System.out.println(a);
+ System.out.println(b);
+ }
+
+ Main(B b, A a) {
+ System.out.println(a);
+ System.out.println(b);
+ }
+ }
+
+ @NoHorizontalClassMerging
+ static class A {
+
+ @Override
+ public String toString() {
+ return "A";
+ }
+ }
+
+ @NoHorizontalClassMerging
+ static class B {
+
+ @Override
+ public String toString() {
+ return "B";
+ }
+ }
+}
diff --git a/src/test/java/com/android/tools/r8/utils/codeinspector/ClassSubject.java b/src/test/java/com/android/tools/r8/utils/codeinspector/ClassSubject.java
index 30b3fb0..d604156 100644
--- a/src/test/java/com/android/tools/r8/utils/codeinspector/ClassSubject.java
+++ b/src/test/java/com/android/tools/r8/utils/codeinspector/ClassSubject.java
@@ -26,6 +26,7 @@
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Predicate;
+import java.util.stream.Collectors;
import kotlinx.metadata.jvm.KotlinClassMetadata;
import org.junit.rules.TemporaryFolder;
@@ -128,6 +129,14 @@
return init(Arrays.asList(parameters));
}
+ public MethodSubject initFromTypes(List<TypeSubject> parameters) {
+ return init(parameters.stream().map(TypeSubject::getTypeName).collect(Collectors.toList()));
+ }
+
+ public MethodSubject initFromTypes(TypeSubject... parameters) {
+ return initFromTypes(Arrays.asList(parameters));
+ }
+
public MethodSubject method(MethodSignature signature) {
return method(signature.type, signature.name, ImmutableList.copyOf(signature.parameters));
}