Extend profile flag propagation to cycles

Fixes: b/278826733
Change-Id: Icde823a7f84bb9f47c7008decdd3fe3c3e554c91
diff --git a/src/main/java/com/android/tools/r8/profile/AbstractProfileMethodRule.java b/src/main/java/com/android/tools/r8/profile/AbstractProfileMethodRule.java
index f39ea03..861083c 100644
--- a/src/main/java/com/android/tools/r8/profile/AbstractProfileMethodRule.java
+++ b/src/main/java/com/android/tools/r8/profile/AbstractProfileMethodRule.java
@@ -20,6 +20,8 @@
 
     MethodRuleBuilder join(MethodRuleBuilder methodRuleBuilder);
 
+    MethodRuleBuilder join(MethodRuleBuilder methodRuleBuilder, Runnable onChangedHandler);
+
     MethodRuleBuilder setIsStartup();
 
     MethodRuleBuilder setMethod(DexMethod method);
diff --git a/src/main/java/com/android/tools/r8/profile/art/ArtProfileMethodRule.java b/src/main/java/com/android/tools/r8/profile/art/ArtProfileMethodRule.java
index 5f4defc..f3032b2 100644
--- a/src/main/java/com/android/tools/r8/profile/art/ArtProfileMethodRule.java
+++ b/src/main/java/com/android/tools/r8/profile/art/ArtProfileMethodRule.java
@@ -136,6 +136,16 @@
     }
 
     @Override
+    public Builder join(Builder builder, Runnable onChangedHandler) {
+      int oldFlags = methodRuleInfoBuilder.getFlags();
+      join(builder);
+      if (methodRuleInfoBuilder.getFlags() != oldFlags) {
+        onChangedHandler.run();
+      }
+      return this;
+    }
+
+    @Override
     public Builder join(ArtProfileMethodRule methodRule) {
       methodRuleInfoBuilder.joinFlags(methodRule.getMethodRuleInfo());
       return this;
diff --git a/src/main/java/com/android/tools/r8/profile/rewriting/ProfileAdditions.java b/src/main/java/com/android/tools/r8/profile/rewriting/ProfileAdditions.java
index bf4b3eb..c21af30 100644
--- a/src/main/java/com/android/tools/r8/profile/rewriting/ProfileAdditions.java
+++ b/src/main/java/com/android/tools/r8/profile/rewriting/ProfileAdditions.java
@@ -17,7 +17,6 @@
 import com.android.tools.r8.profile.AbstractProfileClassRule;
 import com.android.tools.r8.profile.AbstractProfileMethodRule;
 import com.android.tools.r8.profile.AbstractProfileRule;
-import com.android.tools.r8.utils.SetUtils;
 import com.android.tools.r8.utils.WorkList;
 import com.google.common.collect.Sets;
 import java.util.ArrayList;
@@ -29,7 +28,6 @@
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.function.Consumer;
 import java.util.function.Function;
-import java.util.stream.Collectors;
 
 /** Mutable extension of an existing profile. */
 public abstract class ProfileAdditions<
@@ -84,8 +82,8 @@
   final Map<DexMethod, MethodRuleBuilder> methodRuleAdditions = new ConcurrentHashMap<>();
   private final Set<DexMethod> methodRuleRemovals = Sets.newConcurrentHashSet();
 
-  private final NestedMethodRuleAdditionsGraph nestedMethodRuleAdditionsGraph =
-      new NestedMethodRuleAdditionsGraph();
+  private final NestedMethodRuleAdditionsGraph<MethodRule, MethodRuleBuilder>
+      nestedMethodRuleAdditionsGraph = new NestedMethodRuleAdditionsGraph<>();
 
   protected ProfileAdditions(Profile profile) {
     this.profile = profile;
@@ -214,6 +212,11 @@
       return profile;
     }
 
+    // Assert that there are no cycles in the propagation graph. If there are any cycles, this
+    // likely means that we have mutually recursive synthetics, which could be unintentional.
+    // Note that this algorithm correctly deals with cycles, and thus this assertion can simply be
+    // disabled to allow cycles.
+    assert nestedMethodRuleAdditionsGraph.verifyNoCycles();
     nestedMethodRuleAdditionsGraph.propagateMethodRuleInfoFlags(methodRuleAdditions);
 
     // Add existing rules to new profile.
@@ -289,12 +292,14 @@
     this.profile = profile;
   }
 
-  private class NestedMethodRuleAdditionsGraph {
+  public static class NestedMethodRuleAdditionsGraph<
+      MethodRule extends AbstractProfileMethodRule,
+      MethodRuleBuilder extends AbstractProfileMethodRule.Builder<MethodRule, MethodRuleBuilder>> {
 
     private final Map<DexMethod, Set<DexMethod>> successors = new ConcurrentHashMap<>();
     private final Map<DexMethod, Set<DexMethod>> predecessors = new ConcurrentHashMap<>();
 
-    void recordMethodRuleInfoFlagsLargerThan(DexMethod largerFlags, DexMethod smallerFlags) {
+    public void recordMethodRuleInfoFlagsLargerThan(DexMethod largerFlags, DexMethod smallerFlags) {
       predecessors
           .computeIfAbsent(largerFlags, ignoreKey(Sets::newConcurrentHashSet))
           .add(smallerFlags);
@@ -303,59 +308,58 @@
           .add(largerFlags);
     }
 
-    void propagateMethodRuleInfoFlags(Map<DexMethod, MethodRuleBuilder> methodRuleAdditions) {
-      List<DexMethod> leaves =
-          successors.keySet().stream()
-              .filter(method -> predecessors.getOrDefault(method, Collections.emptySet()).isEmpty())
-              .collect(Collectors.toList());
-      WorkList<DexMethod> worklist = WorkList.newIdentityWorkList(leaves);
-      while (worklist.hasNext()) {
-        DexMethod method = worklist.next();
-        MethodRuleBuilder methodRuleBuilder = methodRuleAdditions.get(method);
-        for (DexMethod successor : successors.getOrDefault(method, Collections.emptySet())) {
-          MethodRuleBuilder successorMethodRuleBuilder = methodRuleAdditions.get(successor);
-          // If this assertion fails, that means we have synthetics with multiple
-          // synthesizing contexts, which are not guaranteed to be processed before the
-          // synthetic itself. In that case this assertion should simply be removed.
-          assert successorMethodRuleBuilder.isGreaterThanOrEqualTo(methodRuleBuilder)
-              : getGraphString(methodRuleAdditions, method, successor);
-          successorMethodRuleBuilder.join(methodRuleBuilder);
-          // Note: no need to addIgnoringSeenSet() since the graph will not have cycles. Indeed, it
-          // should never be the case that a method m2(), which is synthesized from method context
-          // m1(), would itself be a synthesizing context for m1().
-          worklist.addIfNotSeen(successor);
-        }
-      }
+    public void propagateMethodRuleInfoFlags(
+        Map<DexMethod, MethodRuleBuilder> methodRuleAdditions) {
+      WorkList.newIdentityWorkList(successors.keySet())
+          .process(
+              (method, worklist) -> {
+                MethodRuleBuilder methodRuleBuilder = methodRuleAdditions.get(method);
+                for (DexMethod successor :
+                    successors.getOrDefault(method, Collections.emptySet())) {
+                  MethodRuleBuilder successorMethodRuleBuilder = methodRuleAdditions.get(successor);
+                  successorMethodRuleBuilder.join(
+                      methodRuleBuilder,
+                      // If the successor's flags changed, then reprocess the successor to propagate
+                      // its flags to the successors of the successor.
+                      () -> worklist.addIgnoringSeenSet(successor));
+                }
+              });
     }
 
-    // Return a string representation of the graph for diagnosing b/278524993.
-    private String getGraphString(
-        Map<DexMethod, MethodRuleBuilder> methodRuleAdditions,
-        DexMethod context,
-        DexMethod method) {
-      StringBuilder builder =
-          new StringBuilder("Error at edge: ")
-              .append(context.toSourceString())
-              .append(" -> ")
-              .append(method.toSourceString());
-      Set<DexMethod> nodes =
-          SetUtils.unionIdentityHashSet(predecessors.keySet(), successors.keySet());
-      for (DexMethod node : nodes) {
-        builder
-            .append(System.lineSeparator())
-            .append(System.lineSeparator())
-            .append(node.toSourceString());
-        for (DexMethod predecessor : predecessors.getOrDefault(node, Collections.emptySet())) {
-          builder
-              .append(System.lineSeparator())
-              .append("  <- ")
-              .append(predecessor.toSourceString());
-        }
-        for (DexMethod successor : successors.getOrDefault(node, Collections.emptySet())) {
-          builder.append(System.lineSeparator()).append("  -> ").append(successor.toSourceString());
+    public boolean verifyNoCycles() {
+      Set<DexMethod> seen = Sets.newIdentityHashSet();
+      for (DexMethod method : successors.keySet()) {
+        if (seen.add(method)) {
+          seen.addAll(verifyNoCyclesStartingFrom(method));
         }
       }
-      return builder.toString();
+      return true;
+    }
+
+    public Set<DexMethod> verifyNoCyclesStartingFrom(DexMethod root) {
+      Set<DexMethod> seen = Sets.newIdentityHashSet();
+      Set<DexMethod> stack = Sets.newIdentityHashSet();
+      WorkList<DexMethod> worklist = WorkList.newIdentityWorkList(root);
+      worklist.process(
+          current -> {
+            if (seen.add(current)) {
+              // Seen for the first time, append to stack and continue the search for a cycle from
+              // the successors.
+              stack.add(current);
+              worklist.addFirstIgnoringSeenSet(current);
+              for (DexMethod successor : successors.getOrDefault(current, Collections.emptySet())) {
+                assert !stack.contains(successor) : "Found a cycle";
+                worklist.addFirstIfNotSeen(successor);
+              }
+            } else {
+              // Backtracking, remove current method from stack since we are done exploring the
+              // (transitive) successors.
+              boolean removed = stack.remove(current);
+              assert removed;
+            }
+          });
+      assert stack.isEmpty();
+      return worklist.getSeenSet();
     }
   }
 }
diff --git a/src/main/java/com/android/tools/r8/profile/startup/profile/StartupProfileMethodRule.java b/src/main/java/com/android/tools/r8/profile/startup/profile/StartupProfileMethodRule.java
index 9c2afae..bc5e098 100644
--- a/src/main/java/com/android/tools/r8/profile/startup/profile/StartupProfileMethodRule.java
+++ b/src/main/java/com/android/tools/r8/profile/startup/profile/StartupProfileMethodRule.java
@@ -100,6 +100,11 @@
     }
 
     @Override
+    public Builder join(Builder builder, Runnable onChangedHandler) {
+      return this;
+    }
+
+    @Override
     public Builder join(StartupProfileMethodRule methodRule) {
       return this;
     }
diff --git a/src/test/java/com/android/tools/r8/profile/art/flagpropagation/CyclicFlagPropagationTest.java b/src/test/java/com/android/tools/r8/profile/art/flagpropagation/CyclicFlagPropagationTest.java
new file mode 100644
index 0000000..7c4fba3
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/profile/art/flagpropagation/CyclicFlagPropagationTest.java
@@ -0,0 +1,86 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.profile.art.flagpropagation;
+
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.graph.DexItemFactory.ObjectMembers;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.profile.art.ArtProfileMethodRule;
+import com.android.tools.r8.profile.art.ArtProfileMethodRuleInfoImpl.Builder;
+import com.android.tools.r8.profile.rewriting.ProfileAdditions.NestedMethodRuleAdditionsGraph;
+import com.google.common.collect.ImmutableList;
+import java.util.IdentityHashMap;
+import java.util.List;
+import java.util.Map;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class CyclicFlagPropagationTest extends TestBase {
+
+  @Parameter(0)
+  public TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withNoneRuntime().build();
+  }
+
+  @Test
+  public void test() throws Exception {
+    NestedMethodRuleAdditionsGraph<ArtProfileMethodRule, ArtProfileMethodRule.Builder> graph =
+        new NestedMethodRuleAdditionsGraph<>();
+    ObjectMembers objectMembers = new DexItemFactory().objectMembers;
+
+    // Add an edge from equals() -> hashCode().
+    graph.recordMethodRuleInfoFlagsLargerThan(objectMembers.hashCode, objectMembers.equals);
+
+    // Add an edge from hashCode() -> finalize().
+    graph.recordMethodRuleInfoFlagsLargerThan(objectMembers.finalize, objectMembers.hashCode);
+
+    // Add an edge from hashCode() -> toString().
+    graph.recordMethodRuleInfoFlagsLargerThan(objectMembers.toString, objectMembers.hashCode);
+
+    // Add an edge from toString() -> equals().
+    graph.recordMethodRuleInfoFlagsLargerThan(objectMembers.equals, objectMembers.toString);
+
+    // Verify we detect the cycle.
+    assertThrows(AssertionError.class, graph::verifyNoCycles);
+
+    // Verify we detect the cycle when starting from equals(), hashCode(), and toString().
+    assertThrows(
+        AssertionError.class, () -> graph.verifyNoCyclesStartingFrom(objectMembers.equals));
+    assertThrows(
+        AssertionError.class, () -> graph.verifyNoCyclesStartingFrom(objectMembers.hashCode));
+    assertThrows(
+        AssertionError.class, () -> graph.verifyNoCyclesStartingFrom(objectMembers.toString));
+
+    // Verify that flag propagation works.
+    List<DexMethod> methods =
+        ImmutableList.of(
+            objectMembers.equals,
+            objectMembers.finalize,
+            objectMembers.hashCode,
+            objectMembers.toString);
+    Map<DexMethod, ArtProfileMethodRule.Builder> methodRuleBuilders = new IdentityHashMap<>();
+    for (DexMethod method : methods) {
+      methodRuleBuilders.put(method, ArtProfileMethodRule.builder());
+    }
+    methodRuleBuilders.get(objectMembers.equals).acceptMethodRuleInfoBuilder(Builder::setIsStartup);
+    graph.propagateMethodRuleInfoFlags(methodRuleBuilders);
+    for (DexMethod method : methods) {
+      assertTrue(methodRuleBuilders.get(method).build().getMethodRuleInfo().isStartup());
+    }
+  }
+}