Extend stack map validation to check for max stack height violation

Bug: 181185068
Change-Id: I8ecb63db3e0d740decbe3ffcecf8be0d94a7589e
diff --git a/src/main/java/com/android/tools/r8/cf/code/CfFrame.java b/src/main/java/com/android/tools/r8/cf/code/CfFrame.java
index b2b18b7..98bc101 100644
--- a/src/main/java/com/android/tools/r8/cf/code/CfFrame.java
+++ b/src/main/java/com/android/tools/r8/cf/code/CfFrame.java
@@ -371,6 +371,14 @@
     return stack.size();
   }
 
+  public int computeStackSize() {
+    int size = 0;
+    for (FrameType frameType : stack) {
+      size += frameType.isWide() ? 2 : 1;
+    }
+    return size;
+  }
+
   private Object[] computeStackTypes(int stackCount, GraphLens graphLens, NamingLens namingLens) {
     assert stackCount == stack.size();
     if (stackCount == 0) {
diff --git a/src/main/java/com/android/tools/r8/cf/code/CfFrameVerificationHelper.java b/src/main/java/com/android/tools/r8/cf/code/CfFrameVerificationHelper.java
index bc0da18..9272226 100644
--- a/src/main/java/com/android/tools/r8/cf/code/CfFrameVerificationHelper.java
+++ b/src/main/java/com/android/tools/r8/cf/code/CfFrameVerificationHelper.java
@@ -41,6 +41,7 @@
   private final DexItemFactory factory;
   private final List<CfTryCatch> tryCatchRanges;
   private final GraphLens graphLens;
+  private final int maxStackHeight;
 
   private final Deque<CfTryCatch> currentCatchRanges = new ArrayDeque<>();
   private final Set<CfLabel> tryCatchRangeLabels;
@@ -51,13 +52,15 @@
       List<CfTryCatch> tryCatchRanges,
       BiPredicate<DexType, DexType> isJavaAssignable,
       DexItemFactory factory,
-      GraphLens graphLens) {
+      GraphLens graphLens,
+      int maxStackHeight) {
     this.context = context;
     this.stateMap = stateMap;
     this.tryCatchRanges = tryCatchRanges;
     this.isJavaAssignable = isJavaAssignable;
     this.factory = factory;
     this.graphLens = graphLens;
+    this.maxStackHeight = maxStackHeight;
     throwStack = ImmutableDeque.of(FrameType.initialized(factory.throwableType));
     // Compute all labels that marks a start or end to catch ranges.
     tryCatchRangeLabels = Sets.newIdentityHashSet();
@@ -148,6 +151,15 @@
   public CfFrameVerificationHelper push(FrameType type) {
     checkFrameIsSet();
     currentFrame.getStack().addLast(type);
+    if (currentFrame.computeStackSize() > maxStackHeight) {
+      throw CfCodeStackMapValidatingException.error(
+          "The max stack height of "
+              + maxStackHeight
+              + " is violated when pushing type "
+              + type
+              + " to existing stack of size "
+              + currentFrame.getStack().size());
+    }
     return this;
   }
 
diff --git a/src/main/java/com/android/tools/r8/graph/CfCode.java b/src/main/java/com/android/tools/r8/graph/CfCode.java
index 025a20e..71b938f 100644
--- a/src/main/java/com/android/tools/r8/graph/CfCode.java
+++ b/src/main/java/com/android/tools/r8/graph/CfCode.java
@@ -67,8 +67,17 @@
 
   public enum StackMapStatus {
     NOT_VERIFIED,
-    INVALID_OR_NOT_PRESENT,
-    VALID
+    NOT_PRESENT,
+    INVALID,
+    VALID;
+
+    public boolean isValid() {
+      return this == VALID || this == NOT_PRESENT;
+    }
+
+    public boolean isInvalidOrNotPresent() {
+      return this == INVALID || this == NOT_PRESENT;
+    }
   }
 
   public static class LocalVariableInfo {
@@ -305,7 +314,7 @@
       LensCodeRewriterUtils rewriter,
       MethodVisitor visitor) {
     GraphLens graphLens = appView.graphLens();
-    assert verifyFrames(method.getDefinition(), appView, null, false)
+    assert verifyFrames(method.getDefinition(), appView, null, false).isValid()
         : "Could not validate stack map frames";
     DexItemFactory dexItemFactory = appView.dexItemFactory();
     InitClassLens initClassLens = appView.initClassLens();
@@ -433,7 +442,8 @@
       AppView<?> appView,
       Origin origin,
       boolean shouldApplyCodeRewritings) {
-    if (!verifyFrames(method, appView, origin, shouldApplyCodeRewritings)) {
+    stackMapStatus = verifyFrames(method, appView, origin, shouldApplyCodeRewritings);
+    if (!stackMapStatus.isValid()) {
       ArrayList<CfInstruction> copy = new ArrayList<>(instructions);
       copy.removeIf(CfInstruction::isFrame);
       setInstructions(copy);
@@ -706,16 +716,14 @@
             thisLocalInfo.index, debugLocalInfo, thisLocalInfo.start, thisLocalInfo.end));
   }
 
-  public boolean verifyFrames(
+  public StackMapStatus verifyFrames(
       DexEncodedMethod method, AppView<?> appView, Origin origin, boolean applyProtoTypeChanges) {
     if (!appView.options().canUseInputStackMaps()
         || appView.options().testing.disableStackMapVerification) {
-      stackMapStatus = StackMapStatus.INVALID_OR_NOT_PRESENT;
-      return true;
+      return StackMapStatus.NOT_PRESENT;
     }
     if (method.hasClassFileVersion() && method.getClassFileVersion().isLessThan(CfVersion.V1_7)) {
-      stackMapStatus = StackMapStatus.INVALID_OR_NOT_PRESENT;
-      return true;
+      return StackMapStatus.NOT_PRESENT;
     }
     if (!method.isInstanceInitializer()
         && appView
@@ -723,8 +731,7 @@
             .getOriginalMethodSignature(method.method)
             .isInstanceInitializer(appView.dexItemFactory())) {
       // We cannot verify instance initializers if they are moved.
-      stackMapStatus = StackMapStatus.INVALID_OR_NOT_PRESENT;
-      return true;
+      return StackMapStatus.NOT_PRESENT;
     }
     // Build a map from labels to frames.
     Map<CfLabel, CfFrame> stateMap = new IdentityHashMap<>();
@@ -791,7 +798,8 @@
             tryCatchRanges,
             isAssignablePredicate(appView),
             appView.dexItemFactory(),
-            appView.graphLens());
+            appView.graphLens(),
+            maxStack);
     if (stateMap.containsKey(null)) {
       assert !shouldComputeInitialFrame();
       builder.checkFrameAndSet(stateMap.get(null));
@@ -824,18 +832,16 @@
             appView);
       }
     }
-    stackMapStatus = StackMapStatus.VALID;
-    return true;
+    return StackMapStatus.VALID;
   }
 
-  private boolean reportStackMapError(CfCodeDiagnostics diagnostics, AppView<?> appView) {
+  private StackMapStatus reportStackMapError(CfCodeDiagnostics diagnostics, AppView<?> appView) {
     // Stack maps was required from version V1_6 (50), but the JVM gave a grace-period and only
     // started enforcing stack maps from 51 in JVM 8. As a consequence, we have different android
     // libraries that has V1_7 code but has no stack maps. To not fail on compilations we only
     // report a warning.
-    stackMapStatus = StackMapStatus.INVALID_OR_NOT_PRESENT;
     appView.options().reporter.warning(diagnostics);
-    return false;
+    return StackMapStatus.INVALID;
   }
 
   private boolean finalAndExitInstruction(CfInstruction instruction) {
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/CfBuilder.java b/src/main/java/com/android/tools/r8/ir/conversion/CfBuilder.java
index e0841f7..8a63d63 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/CfBuilder.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/CfBuilder.java
@@ -164,7 +164,7 @@
     DexBuilder.removeRedundantDebugPositions(code);
     CfCode code = buildCfCode();
     assert verifyInvokeInterface(code, appView);
-    assert code.verifyFrames(method, appView, this.code.origin, false);
+    assert code.verifyFrames(method, appView, this.code.origin, false).isValid();
     return code;
   }
 
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/CfSourceCode.java b/src/main/java/com/android/tools/r8/ir/conversion/CfSourceCode.java
index 50ba68b..1d7f619 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/CfSourceCode.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/CfSourceCode.java
@@ -671,7 +671,7 @@
   public DexType getPhiTypeForBlock(
       int register, int blockOffset, ValueTypeConstraint constraint, RegisterReadType readType) {
     assert code.getStackMapStatus() != StackMapStatus.NOT_VERIFIED;
-    if (code.getStackMapStatus() == StackMapStatus.INVALID_OR_NOT_PRESENT) {
+    if (code.getStackMapStatus().isInvalidOrNotPresent()) {
       return null;
     }
     // We should be able to find the a snapshot at the block-offset:
diff --git a/src/test/java/com/android/tools/r8/cf/stackmap/InvalidLongStackValueMaxHeightTest.java b/src/test/java/com/android/tools/r8/cf/stackmap/InvalidLongStackValueMaxHeightTest.java
new file mode 100644
index 0000000..ae574e9
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/cf/stackmap/InvalidLongStackValueMaxHeightTest.java
@@ -0,0 +1,91 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.cf.stackmap;
+
+import static com.android.tools.r8.DiagnosticsMatcher.diagnosticMessage;
+import static org.hamcrest.CoreMatchers.containsString;
+import static org.junit.Assume.assumeTrue;
+
+import com.android.tools.r8.CompilationFailedException;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.transformers.ClassFileTransformer.MethodPredicate;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class InvalidLongStackValueMaxHeightTest extends TestBase {
+
+  private final String[] EXPECTED = new String[] {"52"};
+  private final TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimes().withAllApiLevelsAlsoForCf().build();
+  }
+
+  public InvalidLongStackValueMaxHeightTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  @Test
+  public void smokeTest() throws Exception {
+    testForRuntime(parameters)
+        .addProgramClasses(Main.class, Tester.class)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines(EXPECTED);
+  }
+
+  @Test(expected = CompilationFailedException.class)
+  public void testD8Cf() throws Exception {
+    assumeTrue(parameters.isCfRuntime());
+    testForD8(parameters.getBackend())
+        .addProgramClasses(Tester.class)
+        .addProgramClassFileData(getMainWithChangedMaxStackHeight())
+        .setMinApi(parameters.getApiLevel())
+        .compileWithExpectedDiagnostics(
+            diagnostics -> {
+              diagnostics.assertWarningThatMatches(
+                  diagnosticMessage(containsString("The max stack height of 2 is violated")));
+            });
+  }
+
+  @Test()
+  public void testD8Dex() throws Exception {
+    assumeTrue(parameters.isDexRuntime());
+    testForD8(parameters.getBackend())
+        .addProgramClasses(Tester.class)
+        .addProgramClassFileData(getMainWithChangedMaxStackHeight())
+        .setMinApi(parameters.getApiLevel())
+        .compileWithExpectedDiagnostics(
+            diagnostics -> {
+              diagnostics.assertWarningThatMatches(
+                  diagnosticMessage(containsString("The max stack height of 2 is violated")));
+            })
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines(EXPECTED);
+  }
+
+  public byte[] getMainWithChangedMaxStackHeight() throws Exception {
+    return transformer(Main.class).setMaxStackHeight(MethodPredicate.onName("main"), 2).transform();
+  }
+
+  public static class Tester {
+
+    public static void test(long x, int y) {
+      System.out.println(x + y);
+    }
+  }
+
+  public static class Main {
+
+    public static void main(String[] args) {
+      Tester.test(10L, 42);
+    }
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/cf/stackmap/InvalidStackHeightTest.java b/src/test/java/com/android/tools/r8/cf/stackmap/InvalidStackHeightTest.java
new file mode 100644
index 0000000..71b6637
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/cf/stackmap/InvalidStackHeightTest.java
@@ -0,0 +1,106 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.cf.stackmap;
+
+import static com.android.tools.r8.DiagnosticsMatcher.diagnosticMessage;
+import static org.hamcrest.CoreMatchers.containsString;
+import static org.junit.Assume.assumeTrue;
+
+import com.android.tools.r8.CompilationFailedException;
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.transformers.ClassFileTransformer.MethodPredicate;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class InvalidStackHeightTest extends TestBase {
+
+  private final String[] EXPECTED = new String[] {"42"};
+
+  private final TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimes().withAllApiLevelsAlsoForCf().build();
+  }
+
+  public InvalidStackHeightTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  @Test
+  public void smokeTest() throws Exception {
+    testForRuntime(parameters)
+        .addProgramClasses(Main.class)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines(EXPECTED);
+  }
+
+  @Test(expected = CompilationFailedException.class)
+  public void testD8Cf() throws Exception {
+    assumeTrue(parameters.isCfRuntime());
+    testForD8(parameters.getBackend())
+        .addProgramClassFileData(getMainWithChangedMaxStackHeight())
+        .setMinApi(parameters.getApiLevel())
+        .compileWithExpectedDiagnostics(
+            diagnostics -> {
+              diagnostics.assertWarningMessageThatMatches(
+                  containsString("The max stack height of 1 is violated"));
+            });
+  }
+
+  @Test
+  public void testD8Dex() throws Exception {
+    assumeTrue(parameters.isDexRuntime());
+    testForD8(parameters.getBackend())
+        .addProgramClassFileData(getMainWithChangedMaxStackHeight())
+        .setMinApi(parameters.getApiLevel())
+        .compileWithExpectedDiagnostics(
+            diagnostics -> {
+              diagnostics.assertWarningMessageThatMatches(
+                  containsString("The max stack height of 1 is violated"));
+            })
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines(EXPECTED);
+  }
+
+  @Test()
+  public void testR8() throws Exception {
+    testForR8(parameters.getBackend())
+        .addProgramClassFileData(getMainWithChangedMaxStackHeight())
+        .enableInliningAnnotations()
+        .addKeepMainRule(Main.class)
+        .setMinApi(parameters.getApiLevel())
+        .allowDiagnosticWarningMessages()
+        .compileWithExpectedDiagnostics(
+            diagnostics -> {
+              diagnostics.assertWarningsMatch(
+                  diagnosticMessage(containsString("The max stack height of 1 is violated")));
+            })
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines(EXPECTED);
+  }
+
+  public byte[] getMainWithChangedMaxStackHeight() throws Exception {
+    return transformer(Main.class).setMaxStackHeight(MethodPredicate.onName("main"), 1).transform();
+  }
+
+  public static class Main {
+
+    @NeverInline
+    private void test(int x, int y) {
+      System.out.println(x + y);
+    }
+
+    public static void main(String[] args) {
+      new Main().test(args.length, 42);
+    }
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/cf/stackmap/LongStackValuesInFramesTest.java b/src/test/java/com/android/tools/r8/cf/stackmap/LongStackValuesInFramesTest.java
new file mode 100644
index 0000000..742472b
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/cf/stackmap/LongStackValuesInFramesTest.java
@@ -0,0 +1,151 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.cf.stackmap;
+
+import static com.android.tools.r8.cf.stackmap.LongStackValuesInFramesTest.LongStackValuesInFramesTest$MainDump.dump;
+import static org.junit.Assume.assumeTrue;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+import org.objectweb.asm.ClassWriter;
+import org.objectweb.asm.Label;
+import org.objectweb.asm.MethodVisitor;
+import org.objectweb.asm.Opcodes;
+
+@RunWith(Parameterized.class)
+public class LongStackValuesInFramesTest extends TestBase {
+
+  private final String[] EXPECTED = new String[] {"52"};
+
+  private final TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimes().withAllApiLevelsAlsoForCf().build();
+  }
+
+  public LongStackValuesInFramesTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  @Test
+  public void testJvm() throws Exception {
+    assumeTrue(parameters.isCfRuntime());
+    testForJvm()
+        .addProgramClasses(Tester.class)
+        .addProgramClassFileData(dump())
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines(EXPECTED);
+  }
+
+  @Test
+  public void testD8() throws Exception {
+    testForD8(parameters.getBackend())
+        .addProgramClasses(Tester.class)
+        .addProgramClassFileData(dump())
+        .setMinApi(parameters.getApiLevel())
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines(EXPECTED);
+  }
+
+  public static class Tester {
+
+    @NeverInline
+    public static void test(long x, int y) {
+      System.out.println(x + y);
+    }
+  }
+
+  public static class Main {
+
+    // This code will be rewritten to:
+    // ldc_w 10
+    // bipush 10
+    // if (args.length == 0) {
+    //   invoke Tester.test(JI)V
+    // }
+    // pop
+    // pop2
+    public static void main(String[] args) {
+      long x = 10L;
+      int y = 42;
+      if (args.length == 0) {
+        Tester.test(x, y);
+      }
+    }
+  }
+
+  public static class LongStackValuesInFramesTest$MainDump implements Opcodes {
+
+    public static byte[] dump() {
+
+      ClassWriter classWriter = new ClassWriter(0);
+      MethodVisitor methodVisitor;
+
+      classWriter.visit(
+          V1_8,
+          ACC_PUBLIC | ACC_SUPER,
+          "com/android/tools/r8/cf/stackmap/LongStackValuesInFramesTest$Main",
+          null,
+          "java/lang/Object",
+          null);
+      classWriter.visitSource("LongStackValuesInFramesTest.java", null);
+      classWriter.visitInnerClass(
+          "com/android/tools/r8/cf/stackmap/LongStackValuesInFramesTest$Main",
+          "com/android/tools/r8/cf/stackmap/LongStackValuesInFramesTest",
+          "Main",
+          ACC_PUBLIC | ACC_STATIC);
+      classWriter.visitInnerClass(
+          "com/android/tools/r8/cf/stackmap/LongStackValuesInFramesTest$Tester",
+          "com/android/tools/r8/cf/stackmap/LongStackValuesInFramesTest",
+          "Tester",
+          ACC_PUBLIC | ACC_STATIC);
+
+      {
+        methodVisitor =
+            classWriter.visitMethod(
+                ACC_PUBLIC | ACC_STATIC, "main", "([Ljava/lang/String;)V", null, null);
+        methodVisitor.visitCode();
+        methodVisitor.visitLdcInsn(10L);
+        methodVisitor.visitIntInsn(BIPUSH, 42);
+        methodVisitor.visitVarInsn(ALOAD, 0);
+        methodVisitor.visitInsn(ARRAYLENGTH);
+        Label label1 = new Label();
+        methodVisitor.visitJumpInsn(IFNE, label1);
+        methodVisitor.visitMethodInsn(
+            INVOKESTATIC,
+            "com/android/tools/r8/cf/stackmap/LongStackValuesInFramesTest$Tester",
+            "test",
+            "(JI)V",
+            false);
+        Label label2 = new Label();
+        methodVisitor.visitJumpInsn(Opcodes.GOTO, label2);
+        methodVisitor.visitLabel(label1);
+        methodVisitor.visitFrame(
+            Opcodes.F_FULL,
+            1,
+            new Object[] {"[Ljava/lang/String;"},
+            2,
+            new Object[] {Opcodes.LONG, Opcodes.INTEGER});
+        methodVisitor.visitInsn(Opcodes.POP);
+        methodVisitor.visitInsn(Opcodes.POP2);
+        methodVisitor.visitLabel(label2);
+        methodVisitor.visitFrame(Opcodes.F_FULL, 1, new Object[] {"[Ljava/lang/String;"}, 0, null);
+        methodVisitor.visitInsn(RETURN);
+        methodVisitor.visitMaxs(4, 3);
+        methodVisitor.visitEnd();
+      }
+      classWriter.visitEnd();
+
+      return classWriter.toByteArray();
+    }
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/transformers/ClassFileTransformer.java b/src/test/java/com/android/tools/r8/transformers/ClassFileTransformer.java
index f828cfc..0169604 100644
--- a/src/test/java/com/android/tools/r8/transformers/ClassFileTransformer.java
+++ b/src/test/java/com/android/tools/r8/transformers/ClassFileTransformer.java
@@ -517,6 +517,16 @@
     static MethodPredicate onName(String name) {
       return (access, otherName, descriptor, signature, exceptions) -> name.equals(otherName);
     }
+
+    static boolean testContext(MethodPredicate predicate, MethodContext context) {
+      MethodReference reference = context.getReference();
+      return predicate.test(
+          context.accessFlags,
+          reference.getMethodName(),
+          reference.getMethodDescriptor(),
+          null,
+          null);
+    }
   }
 
   @FunctionalInterface
@@ -970,4 +980,16 @@
           }
         });
   }
+
+  public ClassFileTransformer setMaxStackHeight(MethodPredicate predicate, int newMaxStack) {
+    return addMethodTransformer(
+        new MethodTransformer() {
+          @Override
+          public void visitMaxs(int maxStack, int maxLocals) {
+            super.visitMaxs(
+                MethodPredicate.testContext(predicate, getContext()) ? newMaxStack : maxStack,
+                maxLocals);
+          }
+        });
+  }
 }