Extend profile flag propagation to cycles
Fixes: b/278826733
Change-Id: Icde823a7f84bb9f47c7008decdd3fe3c3e554c91
diff --git a/src/main/java/com/android/tools/r8/profile/AbstractProfileMethodRule.java b/src/main/java/com/android/tools/r8/profile/AbstractProfileMethodRule.java
index f39ea03..861083c 100644
--- a/src/main/java/com/android/tools/r8/profile/AbstractProfileMethodRule.java
+++ b/src/main/java/com/android/tools/r8/profile/AbstractProfileMethodRule.java
@@ -20,6 +20,8 @@
MethodRuleBuilder join(MethodRuleBuilder methodRuleBuilder);
+ MethodRuleBuilder join(MethodRuleBuilder methodRuleBuilder, Runnable onChangedHandler);
+
MethodRuleBuilder setIsStartup();
MethodRuleBuilder setMethod(DexMethod method);
diff --git a/src/main/java/com/android/tools/r8/profile/art/ArtProfileMethodRule.java b/src/main/java/com/android/tools/r8/profile/art/ArtProfileMethodRule.java
index 5f4defc..f3032b2 100644
--- a/src/main/java/com/android/tools/r8/profile/art/ArtProfileMethodRule.java
+++ b/src/main/java/com/android/tools/r8/profile/art/ArtProfileMethodRule.java
@@ -136,6 +136,16 @@
}
@Override
+ public Builder join(Builder builder, Runnable onChangedHandler) {
+ int oldFlags = methodRuleInfoBuilder.getFlags();
+ join(builder);
+ if (methodRuleInfoBuilder.getFlags() != oldFlags) {
+ onChangedHandler.run();
+ }
+ return this;
+ }
+
+ @Override
public Builder join(ArtProfileMethodRule methodRule) {
methodRuleInfoBuilder.joinFlags(methodRule.getMethodRuleInfo());
return this;
diff --git a/src/main/java/com/android/tools/r8/profile/rewriting/ProfileAdditions.java b/src/main/java/com/android/tools/r8/profile/rewriting/ProfileAdditions.java
index bf4b3eb..c21af30 100644
--- a/src/main/java/com/android/tools/r8/profile/rewriting/ProfileAdditions.java
+++ b/src/main/java/com/android/tools/r8/profile/rewriting/ProfileAdditions.java
@@ -17,7 +17,6 @@
import com.android.tools.r8.profile.AbstractProfileClassRule;
import com.android.tools.r8.profile.AbstractProfileMethodRule;
import com.android.tools.r8.profile.AbstractProfileRule;
-import com.android.tools.r8.utils.SetUtils;
import com.android.tools.r8.utils.WorkList;
import com.google.common.collect.Sets;
import java.util.ArrayList;
@@ -29,7 +28,6 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.function.Function;
-import java.util.stream.Collectors;
/** Mutable extension of an existing profile. */
public abstract class ProfileAdditions<
@@ -84,8 +82,8 @@
final Map<DexMethod, MethodRuleBuilder> methodRuleAdditions = new ConcurrentHashMap<>();
private final Set<DexMethod> methodRuleRemovals = Sets.newConcurrentHashSet();
- private final NestedMethodRuleAdditionsGraph nestedMethodRuleAdditionsGraph =
- new NestedMethodRuleAdditionsGraph();
+ private final NestedMethodRuleAdditionsGraph<MethodRule, MethodRuleBuilder>
+ nestedMethodRuleAdditionsGraph = new NestedMethodRuleAdditionsGraph<>();
protected ProfileAdditions(Profile profile) {
this.profile = profile;
@@ -214,6 +212,11 @@
return profile;
}
+ // Assert that there are no cycles in the propagation graph. If there are any cycles, this
+ // likely means that we have mutually recursive synthetics, which could be unintentional.
+ // Note that this algorithm correctly deals with cycles, and thus this assertion can simply be
+ // disabled to allow cycles.
+ assert nestedMethodRuleAdditionsGraph.verifyNoCycles();
nestedMethodRuleAdditionsGraph.propagateMethodRuleInfoFlags(methodRuleAdditions);
// Add existing rules to new profile.
@@ -289,12 +292,14 @@
this.profile = profile;
}
- private class NestedMethodRuleAdditionsGraph {
+ public static class NestedMethodRuleAdditionsGraph<
+ MethodRule extends AbstractProfileMethodRule,
+ MethodRuleBuilder extends AbstractProfileMethodRule.Builder<MethodRule, MethodRuleBuilder>> {
private final Map<DexMethod, Set<DexMethod>> successors = new ConcurrentHashMap<>();
private final Map<DexMethod, Set<DexMethod>> predecessors = new ConcurrentHashMap<>();
- void recordMethodRuleInfoFlagsLargerThan(DexMethod largerFlags, DexMethod smallerFlags) {
+ public void recordMethodRuleInfoFlagsLargerThan(DexMethod largerFlags, DexMethod smallerFlags) {
predecessors
.computeIfAbsent(largerFlags, ignoreKey(Sets::newConcurrentHashSet))
.add(smallerFlags);
@@ -303,59 +308,58 @@
.add(largerFlags);
}
- void propagateMethodRuleInfoFlags(Map<DexMethod, MethodRuleBuilder> methodRuleAdditions) {
- List<DexMethod> leaves =
- successors.keySet().stream()
- .filter(method -> predecessors.getOrDefault(method, Collections.emptySet()).isEmpty())
- .collect(Collectors.toList());
- WorkList<DexMethod> worklist = WorkList.newIdentityWorkList(leaves);
- while (worklist.hasNext()) {
- DexMethod method = worklist.next();
- MethodRuleBuilder methodRuleBuilder = methodRuleAdditions.get(method);
- for (DexMethod successor : successors.getOrDefault(method, Collections.emptySet())) {
- MethodRuleBuilder successorMethodRuleBuilder = methodRuleAdditions.get(successor);
- // If this assertion fails, that means we have synthetics with multiple
- // synthesizing contexts, which are not guaranteed to be processed before the
- // synthetic itself. In that case this assertion should simply be removed.
- assert successorMethodRuleBuilder.isGreaterThanOrEqualTo(methodRuleBuilder)
- : getGraphString(methodRuleAdditions, method, successor);
- successorMethodRuleBuilder.join(methodRuleBuilder);
- // Note: no need to addIgnoringSeenSet() since the graph will not have cycles. Indeed, it
- // should never be the case that a method m2(), which is synthesized from method context
- // m1(), would itself be a synthesizing context for m1().
- worklist.addIfNotSeen(successor);
- }
- }
+ public void propagateMethodRuleInfoFlags(
+ Map<DexMethod, MethodRuleBuilder> methodRuleAdditions) {
+ WorkList.newIdentityWorkList(successors.keySet())
+ .process(
+ (method, worklist) -> {
+ MethodRuleBuilder methodRuleBuilder = methodRuleAdditions.get(method);
+ for (DexMethod successor :
+ successors.getOrDefault(method, Collections.emptySet())) {
+ MethodRuleBuilder successorMethodRuleBuilder = methodRuleAdditions.get(successor);
+ successorMethodRuleBuilder.join(
+ methodRuleBuilder,
+ // If the successor's flags changed, then reprocess the successor to propagate
+ // its flags to the successors of the successor.
+ () -> worklist.addIgnoringSeenSet(successor));
+ }
+ });
}
- // Return a string representation of the graph for diagnosing b/278524993.
- private String getGraphString(
- Map<DexMethod, MethodRuleBuilder> methodRuleAdditions,
- DexMethod context,
- DexMethod method) {
- StringBuilder builder =
- new StringBuilder("Error at edge: ")
- .append(context.toSourceString())
- .append(" -> ")
- .append(method.toSourceString());
- Set<DexMethod> nodes =
- SetUtils.unionIdentityHashSet(predecessors.keySet(), successors.keySet());
- for (DexMethod node : nodes) {
- builder
- .append(System.lineSeparator())
- .append(System.lineSeparator())
- .append(node.toSourceString());
- for (DexMethod predecessor : predecessors.getOrDefault(node, Collections.emptySet())) {
- builder
- .append(System.lineSeparator())
- .append(" <- ")
- .append(predecessor.toSourceString());
- }
- for (DexMethod successor : successors.getOrDefault(node, Collections.emptySet())) {
- builder.append(System.lineSeparator()).append(" -> ").append(successor.toSourceString());
+ public boolean verifyNoCycles() {
+ Set<DexMethod> seen = Sets.newIdentityHashSet();
+ for (DexMethod method : successors.keySet()) {
+ if (seen.add(method)) {
+ seen.addAll(verifyNoCyclesStartingFrom(method));
}
}
- return builder.toString();
+ return true;
+ }
+
+ public Set<DexMethod> verifyNoCyclesStartingFrom(DexMethod root) {
+ Set<DexMethod> seen = Sets.newIdentityHashSet();
+ Set<DexMethod> stack = Sets.newIdentityHashSet();
+ WorkList<DexMethod> worklist = WorkList.newIdentityWorkList(root);
+ worklist.process(
+ current -> {
+ if (seen.add(current)) {
+ // Seen for the first time, append to stack and continue the search for a cycle from
+ // the successors.
+ stack.add(current);
+ worklist.addFirstIgnoringSeenSet(current);
+ for (DexMethod successor : successors.getOrDefault(current, Collections.emptySet())) {
+ assert !stack.contains(successor) : "Found a cycle";
+ worklist.addFirstIfNotSeen(successor);
+ }
+ } else {
+ // Backtracking, remove current method from stack since we are done exploring the
+ // (transitive) successors.
+ boolean removed = stack.remove(current);
+ assert removed;
+ }
+ });
+ assert stack.isEmpty();
+ return worklist.getSeenSet();
}
}
}
diff --git a/src/main/java/com/android/tools/r8/profile/startup/profile/StartupProfileMethodRule.java b/src/main/java/com/android/tools/r8/profile/startup/profile/StartupProfileMethodRule.java
index 9c2afae..bc5e098 100644
--- a/src/main/java/com/android/tools/r8/profile/startup/profile/StartupProfileMethodRule.java
+++ b/src/main/java/com/android/tools/r8/profile/startup/profile/StartupProfileMethodRule.java
@@ -100,6 +100,11 @@
}
@Override
+ public Builder join(Builder builder, Runnable onChangedHandler) {
+ return this;
+ }
+
+ @Override
public Builder join(StartupProfileMethodRule methodRule) {
return this;
}
diff --git a/src/test/java/com/android/tools/r8/profile/art/flagpropagation/CyclicFlagPropagationTest.java b/src/test/java/com/android/tools/r8/profile/art/flagpropagation/CyclicFlagPropagationTest.java
new file mode 100644
index 0000000..7c4fba3
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/profile/art/flagpropagation/CyclicFlagPropagationTest.java
@@ -0,0 +1,86 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.profile.art.flagpropagation;
+
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.graph.DexItemFactory.ObjectMembers;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.profile.art.ArtProfileMethodRule;
+import com.android.tools.r8.profile.art.ArtProfileMethodRuleInfoImpl.Builder;
+import com.android.tools.r8.profile.rewriting.ProfileAdditions.NestedMethodRuleAdditionsGraph;
+import com.google.common.collect.ImmutableList;
+import java.util.IdentityHashMap;
+import java.util.List;
+import java.util.Map;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class CyclicFlagPropagationTest extends TestBase {
+
+ @Parameter(0)
+ public TestParameters parameters;
+
+ @Parameters(name = "{0}")
+ public static TestParametersCollection data() {
+ return getTestParameters().withNoneRuntime().build();
+ }
+
+ @Test
+ public void test() throws Exception {
+ NestedMethodRuleAdditionsGraph<ArtProfileMethodRule, ArtProfileMethodRule.Builder> graph =
+ new NestedMethodRuleAdditionsGraph<>();
+ ObjectMembers objectMembers = new DexItemFactory().objectMembers;
+
+ // Add an edge from equals() -> hashCode().
+ graph.recordMethodRuleInfoFlagsLargerThan(objectMembers.hashCode, objectMembers.equals);
+
+ // Add an edge from hashCode() -> finalize().
+ graph.recordMethodRuleInfoFlagsLargerThan(objectMembers.finalize, objectMembers.hashCode);
+
+ // Add an edge from hashCode() -> toString().
+ graph.recordMethodRuleInfoFlagsLargerThan(objectMembers.toString, objectMembers.hashCode);
+
+ // Add an edge from toString() -> equals().
+ graph.recordMethodRuleInfoFlagsLargerThan(objectMembers.equals, objectMembers.toString);
+
+ // Verify we detect the cycle.
+ assertThrows(AssertionError.class, graph::verifyNoCycles);
+
+ // Verify we detect the cycle when starting from equals(), hashCode(), and toString().
+ assertThrows(
+ AssertionError.class, () -> graph.verifyNoCyclesStartingFrom(objectMembers.equals));
+ assertThrows(
+ AssertionError.class, () -> graph.verifyNoCyclesStartingFrom(objectMembers.hashCode));
+ assertThrows(
+ AssertionError.class, () -> graph.verifyNoCyclesStartingFrom(objectMembers.toString));
+
+ // Verify that flag propagation works.
+ List<DexMethod> methods =
+ ImmutableList.of(
+ objectMembers.equals,
+ objectMembers.finalize,
+ objectMembers.hashCode,
+ objectMembers.toString);
+ Map<DexMethod, ArtProfileMethodRule.Builder> methodRuleBuilders = new IdentityHashMap<>();
+ for (DexMethod method : methods) {
+ methodRuleBuilders.put(method, ArtProfileMethodRule.builder());
+ }
+ methodRuleBuilders.get(objectMembers.equals).acceptMethodRuleInfoBuilder(Builder::setIsStartup);
+ graph.propagateMethodRuleInfoFlags(methodRuleBuilders);
+ for (DexMethod method : methods) {
+ assertTrue(methodRuleBuilders.get(method).build().getMethodRuleInfo().isStartup());
+ }
+ }
+}