Break from use registry in argument propagation

Fixes: 150583533
Change-Id: I4ff35745829e7cb2200daea0509de761de271e83
diff --git a/src/main/java/com/android/tools/r8/graph/CfCode.java b/src/main/java/com/android/tools/r8/graph/CfCode.java
index 0a7c0e5..fd550c1 100644
--- a/src/main/java/com/android/tools/r8/graph/CfCode.java
+++ b/src/main/java/com/android/tools/r8/graph/CfCode.java
@@ -551,12 +551,21 @@
 
   @Override
   public void registerCodeReferences(ProgramMethod method, UseRegistry registry) {
+    assert registry.getTraversalContinuation().shouldContinue();
     ListIterator<CfInstruction> iterator = instructions.listIterator();
     while (iterator.hasNext()) {
       CfInstruction instruction = iterator.next();
       instruction.registerUse(registry, method, iterator);
+      if (registry.getTraversalContinuation().shouldBreak()) {
+        return;
+      }
     }
-    tryCatchRanges.forEach(tryCatch -> tryCatch.internalRegisterUse(registry, method));
+    for (CfTryCatch tryCatch : tryCatchRanges) {
+      tryCatch.internalRegisterUse(registry, method);
+      if (registry.getTraversalContinuation().shouldBreak()) {
+        return;
+      }
+    }
   }
 
   @Override
diff --git a/src/main/java/com/android/tools/r8/graph/DexCode.java b/src/main/java/com/android/tools/r8/graph/DexCode.java
index 589ae3a..2c36882 100644
--- a/src/main/java/com/android/tools/r8/graph/DexCode.java
+++ b/src/main/java/com/android/tools/r8/graph/DexCode.java
@@ -301,12 +301,19 @@
   }
 
   private void internalRegisterCodeReferences(DexClassAndMethod method, UseRegistry registry) {
+    assert registry.getTraversalContinuation().shouldContinue();
     for (Instruction insn : instructions) {
       insn.registerUse(registry);
+      if (registry.getTraversalContinuation().shouldBreak()) {
+        return;
+      }
     }
     for (TryHandler handler : handlers) {
       for (TypeAddrPair pair : handler.pairs) {
         registry.registerExceptionGuard(pair.type);
+        if (registry.getTraversalContinuation().shouldBreak()) {
+          return;
+        }
       }
     }
   }
diff --git a/src/main/java/com/android/tools/r8/graph/ProgramMethod.java b/src/main/java/com/android/tools/r8/graph/ProgramMethod.java
index 5d5a755..4d1790e 100644
--- a/src/main/java/com/android/tools/r8/graph/ProgramMethod.java
+++ b/src/main/java/com/android/tools/r8/graph/ProgramMethod.java
@@ -62,6 +62,11 @@
     }
   }
 
+  public <T> T registerCodeReferencesWithResult(UseRegistryWithResult<T> registry) {
+    registerCodeReferences(registry);
+    return registry.getResult();
+  }
+
   @Override
   public ProgramMethod getContext() {
     return this;
diff --git a/src/main/java/com/android/tools/r8/graph/UseRegistry.java b/src/main/java/com/android/tools/r8/graph/UseRegistry.java
index c9c0bc9..3913383 100644
--- a/src/main/java/com/android/tools/r8/graph/UseRegistry.java
+++ b/src/main/java/com/android/tools/r8/graph/UseRegistry.java
@@ -4,11 +4,13 @@
 package com.android.tools.r8.graph;
 
 import com.android.tools.r8.code.CfOrDexInstruction;
+import com.android.tools.r8.utils.TraversalContinuation;
 import java.util.ListIterator;
 
 public abstract class UseRegistry {
 
   private DexItemFactory factory;
+  private TraversalContinuation continuation = TraversalContinuation.CONTINUE;
 
   public enum MethodHandleUse {
     ARGUMENT_TO_LAMBDA_METAFACTORY,
@@ -23,6 +25,15 @@
     method.registerCodeReferences(this);
   }
 
+  public void doBreak() {
+    assert continuation.shouldContinue();
+    continuation = TraversalContinuation.BREAK;
+  }
+
+  public TraversalContinuation getTraversalContinuation() {
+    return continuation;
+  }
+
   public abstract void registerInitClass(DexType type);
 
   public abstract void registerInvokeVirtual(DexMethod method);
diff --git a/src/main/java/com/android/tools/r8/graph/UseRegistryWithResult.java b/src/main/java/com/android/tools/r8/graph/UseRegistryWithResult.java
new file mode 100644
index 0000000..916a2f9
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/graph/UseRegistryWithResult.java
@@ -0,0 +1,28 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.graph;
+
+public abstract class UseRegistryWithResult<T> extends UseRegistry {
+
+  private T result;
+
+  public UseRegistryWithResult(DexItemFactory factory) {
+    super(factory);
+  }
+
+  public UseRegistryWithResult(DexItemFactory factory, T defaultResult) {
+    super(factory);
+    this.result = defaultResult;
+  }
+
+  public T getResult() {
+    return result;
+  }
+
+  public void setResult(T result) {
+    this.result = result;
+    doBreak();
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/code/ConstructorEntryPointSynthesizedCode.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/code/ConstructorEntryPointSynthesizedCode.java
index da12b74..036a381 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/code/ConstructorEntryPointSynthesizedCode.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/code/ConstructorEntryPointSynthesizedCode.java
@@ -42,8 +42,12 @@
   }
 
   private void registerReachableDefinitions(UseRegistry registry) {
+    assert registry.getTraversalContinuation().shouldContinue();
     for (DexMethod typeConstructor : typeConstructors.values()) {
       registry.registerInvokeDirect(typeConstructor);
+      if (registry.getTraversalContinuation().shouldBreak()) {
+        return;
+      }
     }
   }
 
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/code/VirtualMethodEntryPointSynthesizedCode.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/code/VirtualMethodEntryPointSynthesizedCode.java
index 4402a83..c258558 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/code/VirtualMethodEntryPointSynthesizedCode.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/code/VirtualMethodEntryPointSynthesizedCode.java
@@ -17,8 +17,6 @@
 public class VirtualMethodEntryPointSynthesizedCode extends SynthesizedCode {
   private final Int2ReferenceSortedMap<DexMethod> mappedMethods;
 
-  private final DexItemFactory factory;
-
   public VirtualMethodEntryPointSynthesizedCode(
       Int2ReferenceSortedMap<DexMethod> mappedMethods,
       DexField classIdField,
@@ -35,7 +33,6 @@
                 method,
                 position,
                 originalMethod));
-    this.factory = factory;
     this.mappedMethods = mappedMethods;
   }
 
@@ -55,8 +52,12 @@
   }
 
   private void registerReachableDefinitions(UseRegistry registry) {
+    assert registry.getTraversalContinuation().shouldContinue();
     for (DexMethod mappedMethod : mappedMethods.values()) {
       registry.registerInvokeDirect(mappedMethod);
+      if (registry.getTraversalContinuation().shouldBreak()) {
+        return;
+      }
     }
   }
 
diff --git a/src/main/java/com/android/tools/r8/ir/synthetic/SynthesizedCode.java b/src/main/java/com/android/tools/r8/ir/synthetic/SynthesizedCode.java
index eb56742..6bd2080 100644
--- a/src/main/java/com/android/tools/r8/ir/synthetic/SynthesizedCode.java
+++ b/src/main/java/com/android/tools/r8/ir/synthetic/SynthesizedCode.java
@@ -4,22 +4,15 @@
 
 package com.android.tools.r8.ir.synthetic;
 
-import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.graph.UseRegistry;
 import java.util.function.Consumer;
 
-public class SynthesizedCode extends AbstractSynthesizedCode {
+public abstract class SynthesizedCode extends AbstractSynthesizedCode {
 
   private final SourceCodeProvider sourceCodeProvider;
-  private final Consumer<UseRegistry> registryCallback;
 
   public SynthesizedCode(SourceCodeProvider sourceCodeProvider) {
-    this(sourceCodeProvider, SynthesizedCode::registerReachableDefinitionsDefault);
-  }
-
-  private SynthesizedCode(SourceCodeProvider sourceCodeProvider, Consumer<UseRegistry> callback) {
     this.sourceCodeProvider = sourceCodeProvider;
-    this.registryCallback = callback;
   }
 
   @Override
@@ -28,11 +21,5 @@
   }
 
   @Override
-  public Consumer<UseRegistry> getRegistryCallback() {
-    return registryCallback;
-  }
-
-  private static void registerReachableDefinitionsDefault(UseRegistry registry) {
-    throw new Unreachable();
-  }
+  public abstract Consumer<UseRegistry> getRegistryCallback();
 }
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorMethodReprocessingEnqueuer.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorMethodReprocessingEnqueuer.java
index 3c732df..37adcfa 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorMethodReprocessingEnqueuer.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorMethodReprocessingEnqueuer.java
@@ -13,7 +13,7 @@
 import com.android.tools.r8.graph.GraphLens;
 import com.android.tools.r8.graph.MethodResolutionResult.SingleResolutionResult;
 import com.android.tools.r8.graph.ProgramMethod;
-import com.android.tools.r8.graph.UseRegistry;
+import com.android.tools.r8.graph.UseRegistryWithResult;
 import com.android.tools.r8.ir.conversion.IRConverter;
 import com.android.tools.r8.ir.conversion.PostMethodProcessor;
 import com.android.tools.r8.ir.optimize.info.CallSiteOptimizationInfo;
@@ -93,8 +93,7 @@
                   method -> {
                     AffectedMethodUseRegistry registry =
                         new AffectedMethodUseRegistry(appView, graphLens);
-                    method.registerCodeReferences(registry);
-                    if (registry.isAffected()) {
+                    if (method.registerCodeReferencesWithResult(registry)) {
                       methodsToReprocessInClass.add(method);
                     }
                   });
@@ -107,24 +106,20 @@
             methodsToReprocessBuilder.addAll(methodsToReprocessForClass, currentGraphLens));
   }
 
-  static class AffectedMethodUseRegistry extends UseRegistry {
+  static class AffectedMethodUseRegistry extends UseRegistryWithResult<Boolean> {
 
     private final AppView<AppInfoWithLiveness> appView;
     private final ArgumentPropagatorGraphLens graphLens;
 
-    // Set to true if the given piece of code resolves to a method that needs rewriting according to
-    // the graph lens.
-    private boolean affected;
-
     AffectedMethodUseRegistry(
         AppView<AppInfoWithLiveness> appView, ArgumentPropagatorGraphLens graphLens) {
-      super(appView.dexItemFactory());
+      super(appView.dexItemFactory(), false);
       this.appView = appView;
       this.graphLens = graphLens;
     }
 
-    boolean isAffected() {
-      return affected;
+    private void markAffected() {
+      setResult(Boolean.TRUE);
     }
 
     @Override
@@ -163,8 +158,7 @@
       DexMethod rewrittenMethodReference =
           graphLens.internalGetNextMethodSignature(resolvedMethod.getReference());
       if (rewrittenMethodReference != resolvedMethod.getReference()) {
-        affected = true;
-        // TODO(b/150583533): break/abort!
+        markAffected();
       }
     }
 
diff --git a/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java b/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
index 6a89e9b..6e06948 100644
--- a/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
+++ b/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
@@ -2186,6 +2186,7 @@
     @Override
     public Consumer<UseRegistry> getRegistryCallback() {
       return registry -> {
+        assert registry.getTraversalContinuation().shouldContinue();
         switch (type) {
           case DIRECT:
             registry.registerInvokeDirect(invocationTarget);
diff --git a/src/test/java/com/android/tools/r8/maindexlist/MainDexListTests.java b/src/test/java/com/android/tools/r8/maindexlist/MainDexListTests.java
index c6182b5..be4f6de 100644
--- a/src/test/java/com/android/tools/r8/maindexlist/MainDexListTests.java
+++ b/src/test/java/com/android/tools/r8/maindexlist/MainDexListTests.java
@@ -51,6 +51,7 @@
 import com.android.tools.r8.graph.InitClassLens;
 import com.android.tools.r8.graph.MethodAccessFlags;
 import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.graph.UseRegistry;
 import com.android.tools.r8.ir.code.CatchHandlers;
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.code.Phi.RegisterReadType;
@@ -852,7 +853,12 @@
                 DexString.EMPTY_ARRAY);
         Code code =
             new SynthesizedCode(
-                (ignored, callerPosition) -> new ReturnVoidCode(voidReturnMethod, callerPosition));
+                (ignored, callerPosition) -> new ReturnVoidCode(voidReturnMethod, callerPosition)) {
+              @Override
+              public Consumer<UseRegistry> getRegistryCallback() {
+                throw new Unreachable();
+              }
+            };
         DexEncodedMethod method =
             DexEncodedMethod.builder()
                 .setMethod(voidReturnMethod)