Limit simple inlining constraint analysis

Bug: b/355238622
Change-Id: Ifd9460d53542d64cf87272ac1d5c61b0ee608e24
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/inlining/SimpleInliningConstraint.java b/src/main/java/com/android/tools/r8/ir/analysis/inlining/SimpleInliningConstraint.java
index cbeb329..e32cc71 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/inlining/SimpleInliningConstraint.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/inlining/SimpleInliningConstraint.java
@@ -61,7 +61,7 @@
     }
     assert isArgumentConstraint() || isDisjunction();
     assert other.isArgumentConstraint() || other.isDisjunction();
-    return new SimpleInliningConstraintConjunction(ImmutableList.of(this, other));
+    return SimpleInliningConstraintConjunction.create(ImmutableList.of(this, other));
   }
 
   public final SimpleInliningConstraintWithDepth lazyMeet(
@@ -90,7 +90,7 @@
     }
     assert isArgumentConstraint() || isConjunction();
     assert other.isArgumentConstraint() || other.isConjunction();
-    return new SimpleInliningConstraintDisjunction(ImmutableList.of(this, other));
+    return SimpleInliningConstraintDisjunction.create(ImmutableList.of(this, other));
   }
 
   public abstract SimpleInliningConstraint fixupAfterParametersChanged(
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/inlining/SimpleInliningConstraintAnalysis.java b/src/main/java/com/android/tools/r8/ir/analysis/inlining/SimpleInliningConstraintAnalysis.java
index 41e9bb0..de956ca 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/inlining/SimpleInliningConstraintAnalysis.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/inlining/SimpleInliningConstraintAnalysis.java
@@ -46,6 +46,8 @@
  */
 public class SimpleInliningConstraintAnalysis {
 
+  private static final int MAX_BRANCH_DEPTH = 3;
+
   private final SimpleInliningConstraintFactory constraintFactory;
   private final DexItemFactory dexItemFactory;
   private final ProgramMethod method;
@@ -78,18 +80,20 @@
     // returns.
     InstructionIterator instructionIterator =
         code.entryBlock().iterator(code.getNumberOfArguments());
-    return analyzeInstructionsInBlock(code.entryBlock(), 0, instructionIterator);
+    return analyzeInstructionsInBlock(code.entryBlock(), 0, 0, instructionIterator);
   }
 
   private SimpleInliningConstraintWithDepth analyzeInstructionsInBlock(
-      BasicBlock block, int depth) {
-    return analyzeInstructionsInBlock(block, depth, block.iterator());
+      BasicBlock block, int branchDepth, int instructionDepth) {
+    return analyzeInstructionsInBlock(block, branchDepth, instructionDepth, block.iterator());
   }
 
   private SimpleInliningConstraintWithDepth analyzeInstructionsInBlock(
-      BasicBlock block, int instructionDepth, InstructionIterator instructionIterator) {
-    // If we reach a block that has already been seen, or one that has catch handlers, then give up.
-    if (!seen.add(block) || block.hasCatchHandlers()) {
+      BasicBlock block,
+      int branchDepth,
+      int instructionDepth,
+      InstructionIterator instructionIterator) {
+    if (!seen.add(block) || block.hasCatchHandlers() || branchDepth > MAX_BRANCH_DEPTH) {
       return SimpleInliningConstraintWithDepth.getNever();
     }
 
@@ -119,7 +123,8 @@
     }
 
     SimpleInliningConstraintWithDepth jumpConstraint =
-        computeConstraintForJumpInstruction(instruction.asJumpInstruction(), instructionDepth);
+        computeConstraintForJumpInstruction(
+            instruction.asJumpInstruction(), branchDepth, instructionDepth);
     return jumpConstraint.meet(blockConstraint);
   }
 
@@ -154,79 +159,94 @@
   }
 
   private SimpleInliningConstraintWithDepth computeConstraintForJumpInstruction(
-      JumpInstruction instruction, int instructionDepth) {
+      JumpInstruction instruction, int branchDepth, int instructionDepth) {
     switch (instruction.opcode()) {
       case IF:
-        If ifInstruction = instruction.asIf();
-        Value singleArgumentOperand = getSingleArgumentOperand(ifInstruction);
-        if (singleArgumentOperand == null || singleArgumentOperand.isThis()) {
-          break;
+        {
+          If ifInstruction = instruction.asIf();
+          Value singleArgumentOperand = getSingleArgumentOperand(ifInstruction);
+          if (singleArgumentOperand == null || singleArgumentOperand.isThis()) {
+            break;
+          }
+
+          Value otherOperand =
+              ifInstruction.isZeroTest()
+                  ? null
+                  : ifInstruction.getOperand(
+                      1 - ifInstruction.inValues().indexOf(singleArgumentOperand));
+
+          int argumentIndex =
+              singleArgumentOperand.getAliasedValue().getDefinition().asArgument().getIndex();
+          DexType argumentType = method.getDefinition().getArgumentType(argumentIndex);
+          int currentBranchDepth = branchDepth;
+          int currentInstructionDepth = instructionDepth;
+
+          // Compute the constraint for which paths through the true target are guaranteed to exit
+          // early.
+          int newBranchDepth = currentBranchDepth + 1;
+          SimpleInliningConstraintWithDepth trueTargetConstraint =
+              computeConstraintFromIfTest(
+                      argumentIndex, argumentType, otherOperand, ifInstruction.getType())
+                  // Only recurse into the true target if the constraint from the if-instruction
+                  // is not 'never'.
+                  .lazyMeet(
+                      () ->
+                          analyzeInstructionsInBlock(
+                              ifInstruction.getTrueTarget(),
+                              newBranchDepth,
+                              currentInstructionDepth));
+
+          // Compute the constraint for which paths through the false target are guaranteed to
+          // exit early.
+          SimpleInliningConstraintWithDepth fallthroughTargetConstraint =
+              computeConstraintFromIfTest(
+                      argumentIndex, argumentType, otherOperand, ifInstruction.getType().inverted())
+                  // Only recurse into the false target if the constraint from the if-instruction
+                  // is not 'never'.
+                  .lazyMeet(
+                      () ->
+                          analyzeInstructionsInBlock(
+                              ifInstruction.fallthroughBlock(),
+                              newBranchDepth,
+                              currentInstructionDepth));
+
+          // Paths going through this basic block are guaranteed to exit early if the true target
+          // is guaranteed to exit early or the false target is.
+          return trueTargetConstraint.join(fallthroughTargetConstraint);
         }
 
-        Value otherOperand =
-            ifInstruction.isZeroTest()
-                ? null
-                : ifInstruction.getOperand(
-                    1 - ifInstruction.inValues().indexOf(singleArgumentOperand));
-
-        int argumentIndex =
-            singleArgumentOperand.getAliasedValue().getDefinition().asArgument().getIndex();
-        DexType argumentType = method.getDefinition().getArgumentType(argumentIndex);
-        int currentDepth = instructionDepth;
-
-        // Compute the constraint for which paths through the true target are guaranteed to exit
-        // early.
-        SimpleInliningConstraintWithDepth trueTargetConstraint =
-            computeConstraintFromIfTest(
-                    argumentIndex, argumentType, otherOperand, ifInstruction.getType())
-                // Only recurse into the true target if the constraint from the if-instruction
-                // is not 'never'.
-                .lazyMeet(
-                    () -> analyzeInstructionsInBlock(ifInstruction.getTrueTarget(), currentDepth));
-
-        // Compute the constraint for which paths through the false target are guaranteed to
-        // exit early.
-        SimpleInliningConstraintWithDepth fallthroughTargetConstraint =
-            computeConstraintFromIfTest(
-                    argumentIndex, argumentType, otherOperand, ifInstruction.getType().inverted())
-                // Only recurse into the false target if the constraint from the if-instruction
-                // is not 'never'.
-                .lazyMeet(
-                    () ->
-                        analyzeInstructionsInBlock(ifInstruction.fallthroughBlock(), currentDepth));
-
-        // Paths going through this basic block are guaranteed to exit early if the true target
-        // is guaranteed to exit early or the false target is.
-        return trueTargetConstraint.join(fallthroughTargetConstraint);
-
       case GOTO:
-        return analyzeInstructionsInBlock(instruction.asGoto().getTarget(), instructionDepth);
+        return analyzeInstructionsInBlock(
+            instruction.asGoto().getTarget(), branchDepth, instructionDepth);
 
       case RETURN:
         return AlwaysSimpleInliningConstraint.getInstance().withDepth(instructionDepth);
 
       case STRING_SWITCH:
-        // Require that all cases including the default case are simple. In that case we can
-        // guarantee simpleness by requiring that the switch value is constant.
-        StringSwitch stringSwitch = instruction.asStringSwitch();
-        Value valueRoot = stringSwitch.value().getAliasedValue();
-        if (!valueRoot.isDefinedByInstructionSatisfying(Instruction::isArgument)) {
-          return SimpleInliningConstraintWithDepth.getNever();
-        }
-        int maxInstructionDepth = instructionDepth;
-        for (BasicBlock successor : stringSwitch.getBlock().getNormalSuccessors()) {
-          SimpleInliningConstraintWithDepth successorConstraintWithDepth =
-              analyzeInstructionsInBlock(successor, instructionDepth);
-          if (!successorConstraintWithDepth.getConstraint().isAlways()) {
+        {
+          // Require that all cases including the default case are simple. In that case we can
+          // guarantee simpleness by requiring that the switch value is constant.
+          StringSwitch stringSwitch = instruction.asStringSwitch();
+          Value valueRoot = stringSwitch.value().getAliasedValue();
+          if (!valueRoot.isDefinedByInstructionSatisfying(Instruction::isArgument)) {
             return SimpleInliningConstraintWithDepth.getNever();
           }
-          maxInstructionDepth =
-              Math.max(maxInstructionDepth, successorConstraintWithDepth.getInstructionDepth());
+          int newBranchDepth = branchDepth + 1;
+          int maxInstructionDepth = instructionDepth;
+          for (BasicBlock successor : stringSwitch.getBlock().getNormalSuccessors()) {
+            SimpleInliningConstraintWithDepth successorConstraintWithDepth =
+                analyzeInstructionsInBlock(successor, newBranchDepth, instructionDepth);
+            if (!successorConstraintWithDepth.getConstraint().isAlways()) {
+              return SimpleInliningConstraintWithDepth.getNever();
+            }
+            maxInstructionDepth =
+                Math.max(maxInstructionDepth, successorConstraintWithDepth.getInstructionDepth());
+          }
+          Argument argument = valueRoot.getDefinition().asArgument();
+          ConstSimpleInliningConstraint simpleConstraint =
+              constraintFactory.createConstConstraint(argument.getIndex());
+          return simpleConstraint.withDepth(maxInstructionDepth);
         }
-        Argument argument = valueRoot.getDefinition().asArgument();
-        ConstSimpleInliningConstraint simpleConstraint =
-            constraintFactory.createConstConstraint(argument.getIndex());
-        return simpleConstraint.withDepth(maxInstructionDepth);
 
       case THROW:
         return instruction.getBlock().hasCatchHandlers()
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/inlining/SimpleInliningConstraintConjunction.java b/src/main/java/com/android/tools/r8/ir/analysis/inlining/SimpleInliningConstraintConjunction.java
index 6147cd8..a282039 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/inlining/SimpleInliningConstraintConjunction.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/inlining/SimpleInliningConstraintConjunction.java
@@ -15,9 +15,11 @@
 
 public class SimpleInliningConstraintConjunction extends SimpleInliningConstraint {
 
+  private static final int MAX_SIZE = 3;
+
   private final List<SimpleInliningConstraint> constraints;
 
-  public SimpleInliningConstraintConjunction(List<SimpleInliningConstraint> constraints) {
+  private SimpleInliningConstraintConjunction(List<SimpleInliningConstraint> constraints) {
     assert constraints.size() > 1;
     assert constraints.stream().noneMatch(SimpleInliningConstraint::isAlways);
     assert constraints.stream().noneMatch(SimpleInliningConstraint::isConjunction);
@@ -25,6 +27,12 @@
     this.constraints = constraints;
   }
 
+  public static SimpleInliningConstraint create(List<SimpleInliningConstraint> constraints) {
+    return constraints.size() <= MAX_SIZE
+        ? new SimpleInliningConstraintConjunction(constraints)
+        : NeverSimpleInliningConstraint.getInstance();
+  }
+
   SimpleInliningConstraint add(SimpleInliningConstraint constraint) {
     assert !constraint.isAlways();
     assert !constraint.isNever();
@@ -32,16 +40,15 @@
       return addAll(constraint.asConjunction());
     }
     assert constraint.isArgumentConstraint() || constraint.isDisjunction();
-    return new SimpleInliningConstraintConjunction(
+    return create(
         ImmutableList.<SimpleInliningConstraint>builder()
             .addAll(constraints)
             .add(constraint)
             .build());
   }
 
-  public SimpleInliningConstraintConjunction addAll(
-      SimpleInliningConstraintConjunction conjunction) {
-    return new SimpleInliningConstraintConjunction(
+  public SimpleInliningConstraint addAll(SimpleInliningConstraintConjunction conjunction) {
+    return create(
         ImmutableList.<SimpleInliningConstraint>builder()
             .addAll(constraints)
             .addAll(conjunction.constraints)
@@ -103,6 +110,6 @@
       return NeverSimpleInliningConstraint.getInstance();
     }
 
-    return new SimpleInliningConstraintConjunction(rewrittenConstraints);
+    return create(rewrittenConstraints);
   }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/inlining/SimpleInliningConstraintDisjunction.java b/src/main/java/com/android/tools/r8/ir/analysis/inlining/SimpleInliningConstraintDisjunction.java
index 02c0b01..660aa3b 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/inlining/SimpleInliningConstraintDisjunction.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/inlining/SimpleInliningConstraintDisjunction.java
@@ -15,9 +15,11 @@
 
 public class SimpleInliningConstraintDisjunction extends SimpleInliningConstraint {
 
+  private static final int MAX_SIZE = 3;
+
   private final List<SimpleInliningConstraint> constraints;
 
-  public SimpleInliningConstraintDisjunction(List<SimpleInliningConstraint> constraints) {
+  private SimpleInliningConstraintDisjunction(List<SimpleInliningConstraint> constraints) {
     assert constraints.size() > 1;
     assert constraints.stream().noneMatch(SimpleInliningConstraint::isAlways);
     assert constraints.stream().noneMatch(SimpleInliningConstraint::isDisjunction);
@@ -25,6 +27,12 @@
     this.constraints = constraints;
   }
 
+  public static SimpleInliningConstraint create(List<SimpleInliningConstraint> constraints) {
+    return constraints.size() <= MAX_SIZE
+        ? new SimpleInliningConstraintDisjunction(constraints)
+        : NeverSimpleInliningConstraint.getInstance();
+  }
+
   SimpleInliningConstraint add(SimpleInliningConstraint constraint) {
     assert !constraint.isAlways();
     assert !constraint.isNever();
@@ -32,16 +40,15 @@
       return addAll(constraint.asDisjunction());
     }
     assert constraint.isArgumentConstraint() || constraint.isConjunction();
-    return new SimpleInliningConstraintDisjunction(
+    return create(
         ImmutableList.<SimpleInliningConstraint>builder()
             .addAll(constraints)
             .add(constraint)
             .build());
   }
 
-  public SimpleInliningConstraintDisjunction addAll(
-      SimpleInliningConstraintDisjunction disjunction) {
-    return new SimpleInliningConstraintDisjunction(
+  public SimpleInliningConstraint addAll(SimpleInliningConstraintDisjunction disjunction) {
+    return create(
         ImmutableList.<SimpleInliningConstraint>builder()
             .addAll(constraints)
             .addAll(disjunction.constraints)
@@ -103,6 +110,6 @@
       return AlwaysSimpleInliningConstraint.getInstance();
     }
 
-    return new SimpleInliningConstraintDisjunction(rewrittenConstraints);
+    return create(rewrittenConstraints);
   }
 }