Change the rewriting of the two args AssertionConstructor

The rewriting now use an outline which will perform a runtime
reflection check to see of the two args AssertionError constructor
is present.

Bug: b/244473445
Change-Id: I74214feda4130da0072539d4f6e6ae151293aa03
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index 17ce101..9967ae3 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -49,6 +49,7 @@
 import com.android.tools.r8.ir.desugar.itf.L8InnerOuterAttributeEraser;
 import com.android.tools.r8.ir.desugar.lambda.LambdaDeserializationMethodRemover;
 import com.android.tools.r8.ir.desugar.nest.D8NestBasedAccessDesugaring;
+import com.android.tools.r8.ir.optimize.AssertionErrorTwoArgsConstructorRewriter;
 import com.android.tools.r8.ir.optimize.AssertionsRewriter;
 import com.android.tools.r8.ir.optimize.AssumeInserter;
 import com.android.tools.r8.ir.optimize.CheckNotNullConverter;
@@ -132,6 +133,7 @@
   private final InternalOptions options;
   private final CfgPrinter printer;
   public final CodeRewriter codeRewriter;
+  public final AssertionErrorTwoArgsConstructorRewriter assertionErrorTwoArgsConstructorRewriter;
   private final NaturalIntLoopRemover naturalIntLoopRemover = new NaturalIntLoopRemover();
   public final MemberValuePropagation<?> memberValuePropagation;
   private final LensCodeRewriter lensCodeRewriter;
@@ -181,6 +183,8 @@
     this.options = appView.options();
     this.printer = printer;
     this.codeRewriter = new CodeRewriter(appView);
+    this.assertionErrorTwoArgsConstructorRewriter =
+        new AssertionErrorTwoArgsConstructorRewriter(appView);
     this.classInitializerDefaultsOptimization =
         new ClassInitializerDefaultsOptimization(appView, this);
     this.stringOptimizer = new StringOptimizer(appView);
@@ -370,6 +374,10 @@
     if (instanceInitializerOutliner != null) {
       processSimpleSynthesizeMethods(instanceInitializerOutliner.getSynthesizedMethods(), executor);
     }
+    if (assertionErrorTwoArgsConstructorRewriter != null) {
+      processSimpleSynthesizeMethods(
+          assertionErrorTwoArgsConstructorRewriter.getSynthesizedMethods(), executor);
+    }
 
     application = commitPendingSyntheticItemsD8(appView, application);
 
@@ -801,6 +809,10 @@
       processSimpleSynthesizeMethods(
           instanceInitializerOutliner.getSynthesizedMethods(), executorService);
     }
+    if (assertionErrorTwoArgsConstructorRewriter != null) {
+      processSimpleSynthesizeMethods(
+          assertionErrorTwoArgsConstructorRewriter.getSynthesizedMethods(), executorService);
+    }
 
     // Update optimization info for all synthesized methods at once.
     feedback.updateVisibleOptimizationInfo();
@@ -1333,7 +1345,7 @@
     naturalIntLoopRemover.run(appView, code);
     timing.end();
     timing.begin("Rewrite AssertionError");
-    codeRewriter.rewriteAssertionErrorTwoArgumentConstructor(code, options);
+    assertionErrorTwoArgsConstructorRewriter.rewrite(code, methodProcessingContext);
     timing.end();
     timing.begin("Run CSE");
     codeRewriter.commonSubexpressionElimination(code);
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/backports/BackportedMethods.java b/src/main/java/com/android/tools/r8/ir/desugar/backports/BackportedMethods.java
index bf643ae..480ac1d 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/backports/BackportedMethods.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/backports/BackportedMethods.java
@@ -74,6 +74,7 @@
     factory.createSynthesizedType("Ljava/lang/OutOfMemoryError;");
     factory.createSynthesizedType("Ljava/lang/Runnable;");
     factory.createSynthesizedType("Ljava/lang/SecurityException;");
+    factory.createSynthesizedType("Ljava/lang/reflect/Constructor;");
     factory.createSynthesizedType("Ljava/lang/reflect/InvocationTargetException;");
     factory.createSynthesizedType("Ljava/lang/reflect/Method;");
     factory.createSynthesizedType("Ljava/util/AbstractMap$SimpleImmutableEntry;");
@@ -117,6 +118,101 @@
     factory.createSynthesizedType("[Ljava/util/Map$Entry;");
   }
 
+  public static CfCode AssertionErrorMethods_createAssertionError(
+      DexItemFactory factory, DexMethod method) {
+    CfLabel label0 = new CfLabel();
+    CfLabel label1 = new CfLabel();
+    CfLabel label2 = new CfLabel();
+    CfLabel label3 = new CfLabel();
+    CfLabel label4 = new CfLabel();
+    CfLabel label5 = new CfLabel();
+    CfLabel label6 = new CfLabel();
+    return new CfCode(
+        method.holder,
+        5,
+        3,
+        ImmutableList.of(
+            label0,
+            new CfConstClass(factory.createType("Ljava/lang/AssertionError;")),
+            new CfConstNumber(2, ValueType.INT),
+            new CfNewArray(factory.createType("[Ljava/lang/Class;")),
+            new CfStackInstruction(CfStackInstruction.Opcode.Dup),
+            new CfConstNumber(0, ValueType.INT),
+            new CfConstClass(factory.stringType),
+            new CfArrayStore(MemberType.OBJECT),
+            new CfStackInstruction(CfStackInstruction.Opcode.Dup),
+            new CfConstNumber(1, ValueType.INT),
+            new CfConstClass(factory.throwableType),
+            new CfArrayStore(MemberType.OBJECT),
+            label1,
+            new CfInvoke(
+                182,
+                factory.createMethod(
+                    factory.classType,
+                    factory.createProto(
+                        factory.createType("Ljava/lang/reflect/Constructor;"),
+                        factory.createType("[Ljava/lang/Class;")),
+                    factory.createString("getDeclaredConstructor")),
+                false),
+            new CfStore(ValueType.OBJECT, 2),
+            label2,
+            new CfLoad(ValueType.OBJECT, 2),
+            new CfConstNumber(2, ValueType.INT),
+            new CfNewArray(factory.createType("[Ljava/lang/Object;")),
+            new CfStackInstruction(CfStackInstruction.Opcode.Dup),
+            new CfConstNumber(0, ValueType.INT),
+            new CfLoad(ValueType.OBJECT, 0),
+            new CfArrayStore(MemberType.OBJECT),
+            new CfStackInstruction(CfStackInstruction.Opcode.Dup),
+            new CfConstNumber(1, ValueType.INT),
+            new CfLoad(ValueType.OBJECT, 1),
+            new CfArrayStore(MemberType.OBJECT),
+            new CfInvoke(
+                182,
+                factory.createMethod(
+                    factory.createType("Ljava/lang/reflect/Constructor;"),
+                    factory.createProto(
+                        factory.objectType, factory.createType("[Ljava/lang/Object;")),
+                    factory.createString("newInstance")),
+                false),
+            new CfCheckCast(factory.createType("Ljava/lang/AssertionError;")),
+            label3,
+            new CfReturn(ValueType.OBJECT),
+            label4,
+            new CfFrame(
+                new Int2ObjectAVLTreeMap<>(
+                    new int[] {0, 1},
+                    new FrameType[] {
+                      FrameType.initializedNonNullReference(factory.stringType),
+                      FrameType.initializedNonNullReference(factory.throwableType)
+                    }),
+                new ArrayDeque<>(
+                    Arrays.asList(
+                        FrameType.initializedNonNullReference(
+                            factory.createType("Ljava/lang/Exception;"))))),
+            new CfStore(ValueType.OBJECT, 2),
+            label5,
+            new CfNew(factory.createType("Ljava/lang/AssertionError;")),
+            new CfStackInstruction(CfStackInstruction.Opcode.Dup),
+            new CfLoad(ValueType.OBJECT, 0),
+            new CfInvoke(
+                183,
+                factory.createMethod(
+                    factory.createType("Ljava/lang/AssertionError;"),
+                    factory.createProto(factory.voidType, factory.objectType),
+                    factory.createString("<init>")),
+                false),
+            new CfReturn(ValueType.OBJECT),
+            label6),
+        ImmutableList.of(
+            new CfTryCatch(
+                label0,
+                label3,
+                ImmutableList.of(factory.createType("Ljava/lang/Exception;")),
+                ImmutableList.of(label4))),
+        ImmutableList.of());
+  }
+
   public static CfCode AtomicReferenceArrayMethods_compareAndSet(
       DexItemFactory factory, DexMethod method) {
     CfLabel label0 = new CfLabel();
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/AssertionErrorTwoArgsConstructorRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/AssertionErrorTwoArgsConstructorRewriter.java
new file mode 100644
index 0000000..06d9d8f
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/optimize/AssertionErrorTwoArgsConstructorRewriter.java
@@ -0,0 +1,109 @@
+// Copyright (c) 2022, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize;
+
+import com.android.tools.r8.contexts.CompilationContext.MethodProcessingContext;
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexProto;
+import com.android.tools.r8.graph.MethodAccessFlags;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.ir.analysis.type.Nullability;
+import com.android.tools.r8.ir.analysis.type.TypeElement;
+import com.android.tools.r8.ir.code.BasicBlock;
+import com.android.tools.r8.ir.code.IRCode;
+import com.android.tools.r8.ir.code.Instruction;
+import com.android.tools.r8.ir.code.InstructionListIterator;
+import com.android.tools.r8.ir.code.InvokeStatic;
+import com.android.tools.r8.ir.code.Value;
+import com.android.tools.r8.ir.desugar.backports.BackportedMethods;
+import com.android.tools.r8.utils.InternalOptions;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.ListIterator;
+
+public class AssertionErrorTwoArgsConstructorRewriter {
+
+  private final AppView<?> appView;
+  private final DexItemFactory dexItemFactory;
+  private final InternalOptions options;
+
+  public AssertionErrorTwoArgsConstructorRewriter(AppView<?> appView) {
+    this.appView = appView;
+    this.options = appView.options();
+    this.dexItemFactory = appView.dexItemFactory();
+  }
+
+  public void rewrite(IRCode code, MethodProcessingContext methodProcessingContext) {
+    if (options.canUseAssertionErrorTwoArgumentConstructor()) {
+      return;
+    }
+
+    ListIterator<BasicBlock> blockIterator = code.listIterator();
+    while (blockIterator.hasNext()) {
+      BasicBlock block = blockIterator.next();
+      InstructionListIterator insnIterator = block.listIterator(code);
+      while (insnIterator.hasNext()) {
+        Instruction current = insnIterator.next();
+        if (current.isInvokeMethod()) {
+          DexMethod invokedMethod = current.asInvokeMethod().getInvokedMethod();
+          if (invokedMethod == dexItemFactory.assertionErrorMethods.initMessageAndCause) {
+            List<Value> inValues = current.inValues();
+            assert inValues.size() == 3; // receiver, message, cause
+
+            Value assertionError =
+                code.createValue(
+                    TypeElement.fromDexType(
+                        dexItemFactory.assertionErrorType,
+                        Nullability.definitelyNotNull(),
+                        appView));
+            Instruction invoke =
+                new InvokeStatic(
+                    createSynthetic(methodProcessingContext).getReference(),
+                    assertionError,
+                    inValues.subList(1, 3));
+            insnIterator.replaceCurrentInstruction(invoke);
+            inValues.get(0).replaceUsers(assertionError);
+            inValues.get(0).definition.removeOrReplaceByDebugLocalRead(code);
+          }
+        }
+      }
+    }
+    assert code.isConsistentSSA(appView);
+  }
+
+  private final List<ProgramMethod> synthesizedMethods = new ArrayList<>();
+
+  public List<ProgramMethod> getSynthesizedMethods() {
+    return synthesizedMethods;
+  }
+
+  private ProgramMethod createSynthetic(MethodProcessingContext methodProcessingContext) {
+    DexItemFactory factory = appView.dexItemFactory();
+    DexProto proto =
+        factory.createProto(factory.assertionErrorType, factory.stringType, factory.throwableType);
+    ProgramMethod method =
+        appView
+            .getSyntheticItems()
+            .createMethod(
+                kinds -> kinds.BACKPORT,
+                methodProcessingContext.createUniqueContext(),
+                appView,
+                builder ->
+                    builder
+                        .setApiLevelForCode(appView.computedMinApiLevel())
+                        .setProto(proto)
+                        .setAccessFlags(MethodAccessFlags.createPublicStaticSynthetic())
+                        .setCode(
+                            methodSig ->
+                                BackportedMethods.AssertionErrorMethods_createAssertionError(
+                                    factory, methodSig)));
+    synchronized (synthesizedMethods) {
+      synthesizedMethods.add(method);
+    }
+    return method;
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
index 36b96dc..4d31397 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
@@ -138,7 +138,6 @@
 import it.unimi.dsi.fastutil.objects.Reference2IntMap;
 import it.unimi.dsi.fastutil.objects.Reference2IntOpenHashMap;
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.BitSet;
 import java.util.Collection;
 import java.util.Collections;
@@ -3424,55 +3423,6 @@
     assert code.isConsistentSSA(appView);
   }
 
-  public void rewriteAssertionErrorTwoArgumentConstructor(IRCode code, InternalOptions options) {
-    if (options.canUseAssertionErrorTwoArgumentConstructor()) {
-      return;
-    }
-
-    ListIterator<BasicBlock> blockIterator = code.listIterator();
-    while (blockIterator.hasNext()) {
-      BasicBlock block = blockIterator.next();
-      InstructionListIterator insnIterator = block.listIterator(code);
-      while (insnIterator.hasNext()) {
-        Instruction current = insnIterator.next();
-        if (current.isInvokeMethod()) {
-          DexMethod invokedMethod = current.asInvokeMethod().getInvokedMethod();
-          if (invokedMethod == dexItemFactory.assertionErrorMethods.initMessageAndCause) {
-            // Rewrite calls to new AssertionError(message, cause) to new AssertionError(message)
-            // and then initCause(cause).
-            List<Value> inValues = current.inValues();
-            assert inValues.size() == 3; // receiver, message, cause
-
-            // Remove cause from the constructor call
-            List<Value> newInitInValues = inValues.subList(0, 2);
-            insnIterator.replaceCurrentInstruction(
-                new InvokeDirect(
-                    dexItemFactory.assertionErrorMethods.initMessage, null, newInitInValues));
-
-            // On API 15 and older we cannot use initCause because of a bug in AssertionError.
-            if (options.canInitCauseAfterAssertionErrorObjectConstructor()) {
-              // Add a call to Throwable.initCause(cause)
-              if (block.hasCatchHandlers()) {
-                insnIterator = insnIterator.split(code, blockIterator).listIterator(code);
-              }
-              List<Value> initCauseArguments = Arrays.asList(inValues.get(0), inValues.get(2));
-              InvokeVirtual initCause =
-                  new InvokeVirtual(
-                      dexItemFactory.throwableMethods.initCause,
-                      code.createValue(
-                          TypeElement.fromDexType(
-                              dexItemFactory.throwableType, maybeNull(), appView)),
-                      initCauseArguments);
-              initCause.setPosition(current.getPosition());
-              insnIterator.add(initCause);
-            }
-          }
-        }
-      }
-    }
-    assert code.isConsistentSSA(appView);
-  }
-
   /**
    * Remove moves that are not actually used by instructions in exiting paths. These moves can arise
    * due to debug local info needing a particular value and the live-interval for it then moves it
diff --git a/src/test/java/com/android/tools/r8/ir/desugar/backports/AssertionErrorMethods.java b/src/test/java/com/android/tools/r8/ir/desugar/backports/AssertionErrorMethods.java
new file mode 100644
index 0000000..7b8c096
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/desugar/backports/AssertionErrorMethods.java
@@ -0,0 +1,20 @@
+// Copyright (c) 2022, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.desugar.backports;
+
+import java.lang.reflect.Constructor;
+
+public final class AssertionErrorMethods {
+
+  public static AssertionError createAssertionError(String message, Throwable cause) {
+    try {
+      Constructor<AssertionError> twoArgsConstructor =
+          AssertionError.class.getDeclaredConstructor(String.class, Throwable.class);
+      return twoArgsConstructor.newInstance(message, cause);
+    } catch (Exception e) {
+      return new AssertionError(message);
+    }
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/ir/desugar/backports/GenerateBackportMethods.java b/src/test/java/com/android/tools/r8/ir/desugar/backports/GenerateBackportMethods.java
index 81f92c3..429f2be 100644
--- a/src/test/java/com/android/tools/r8/ir/desugar/backports/GenerateBackportMethods.java
+++ b/src/test/java/com/android/tools/r8/ir/desugar/backports/GenerateBackportMethods.java
@@ -34,6 +34,7 @@
       factory.createType("Lcom/android/tools/r8/ir/desugar/backports/BackportedMethods;");
   private final List<Class<?>> METHOD_TEMPLATE_CLASSES =
       ImmutableList.of(
+          AssertionErrorMethods.class,
           AtomicReferenceArrayMethods.class,
           AtomicReferenceFieldUpdaterMethods.class,
           AtomicReferenceMethods.class,
diff --git a/src/test/java/com/android/tools/r8/rewrite/assertionerror/AssertionErrorRewriteApi16Test.java b/src/test/java/com/android/tools/r8/rewrite/assertionerror/AssertionErrorRewriteApi16Test.java
index 4fbb7c3..e8ca0e3 100644
--- a/src/test/java/com/android/tools/r8/rewrite/assertionerror/AssertionErrorRewriteApi16Test.java
+++ b/src/test/java/com/android/tools/r8/rewrite/assertionerror/AssertionErrorRewriteApi16Test.java
@@ -44,6 +44,6 @@
         .enableInliningAnnotations()
         .setMinApi(AndroidApiLevel.J)
         .run(parameters.getRuntime(), Main.class, String.valueOf(false))
-        .assertSuccessWithOutputLines("OK", "OK");
+        .assertSuccessWithOutputLines("message", "java.lang.RuntimeException: cause message");
   }
 }
diff --git a/src/test/java/com/android/tools/r8/rewrite/assertionerror/AssertionErrorRewriteTest.java b/src/test/java/com/android/tools/r8/rewrite/assertionerror/AssertionErrorRewriteTest.java
index 2318414..f22208c 100644
--- a/src/test/java/com/android/tools/r8/rewrite/assertionerror/AssertionErrorRewriteTest.java
+++ b/src/test/java/com/android/tools/r8/rewrite/assertionerror/AssertionErrorRewriteTest.java
@@ -3,7 +3,6 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.rewrite.assertionerror;
 
-import static com.android.tools.r8.ToolHelper.getDefaultAndroidJar;
 import static org.junit.Assume.assumeTrue;
 
 import com.android.tools.r8.NeverInline;
@@ -38,46 +37,33 @@
   @Test public void d8() throws Exception {
     assumeTrue(parameters.isDexRuntime());
     testForD8()
-        .addLibraryFiles(getDefaultAndroidJar())
         .addProgramClasses(Main.class)
         .setMinApi(parameters.getApiLevel())
-        .run(parameters.getRuntime(), Main.class, String.valueOf(expectCause))
-        .assertSuccessWithOutputLines("OK", "OK");
+        .run(parameters.getRuntime(), Main.class)
+        // None of the VMs we have for testing is missing the two args constructor.
+        .assertSuccessWithOutputLines("message", "java.lang.RuntimeException: cause message");
   }
 
   @Test public void r8() throws Exception {
     testForR8(parameters.getBackend())
-        .addLibraryFiles(getDefaultAndroidJar())
         .addProgramClasses(Main.class)
         .addKeepMainRule(Main.class)
         .enableInliningAnnotations()
         .setMinApi(parameters.getApiLevel())
         .run(parameters.getRuntime(), Main.class, String.valueOf(expectCause))
-        .assertSuccessWithOutputLines("OK", "OK");
+        // None of the VMs we have for testing is missing the two args constructor.
+        .assertSuccessWithOutputLines("message", "java.lang.RuntimeException: cause message");
   }
 
   public static final class Main {
     public static void main(String[] args) {
-      boolean expectCause = Boolean.parseBoolean(args[0]);
-
       Throwable expectedCause = new RuntimeException("cause message");
       try {
         throwAssertionError(expectedCause);
         System.out.println("unreachable");
       } catch (AssertionError e) {
-        String message = e.getMessage();
-        if (!message.equals("message")) {
-          throw new RuntimeException("Incorrect AssertionError message: " + message);
-        } else {
-          System.out.println("OK");
-        }
-
-        Throwable cause = e.getCause();
-        if (expectCause && cause != expectedCause) {
-          throw new RuntimeException("Incorrect AssertionError cause", cause);
-        } else {
-          System.out.println("OK");
-        }
+        System.out.println(e.getMessage());
+        System.out.println(e.getCause());
       }
     }