Support for adding casts to invoke arguments

Change-Id: I7f4eb8d1ef1985ff19ac18523372a5dbe1a236ea
diff --git a/src/main/java/com/android/tools/r8/graph/RewrittenPrototypeDescription.java b/src/main/java/com/android/tools/r8/graph/RewrittenPrototypeDescription.java
index 0a60b01..1dd24ff 100644
--- a/src/main/java/com/android/tools/r8/graph/RewrittenPrototypeDescription.java
+++ b/src/main/java/com/android/tools/r8/graph/RewrittenPrototypeDescription.java
@@ -204,21 +204,28 @@
 
   public static class RewrittenTypeInfo extends ArgumentInfo {
 
+    private final DexType castType;
     private final DexType oldType;
     private final DexType newType;
     private final SingleValue singleValue;
 
+    public static Builder builder() {
+      return new Builder();
+    }
+
     public static RewrittenTypeInfo toVoid(
         DexType oldReturnType, DexItemFactory dexItemFactory, SingleValue singleValue) {
       assert singleValue != null;
-      return new RewrittenTypeInfo(oldReturnType, dexItemFactory.voidType, singleValue);
+      return new RewrittenTypeInfo(oldReturnType, dexItemFactory.voidType, null, singleValue);
     }
 
     public RewrittenTypeInfo(DexType oldType, DexType newType) {
-      this(oldType, newType, null);
+      this(oldType, newType, null, null);
     }
 
-    public RewrittenTypeInfo(DexType oldType, DexType newType, SingleValue singleValue) {
+    public RewrittenTypeInfo(
+        DexType oldType, DexType newType, DexType castType, SingleValue singleValue) {
+      this.castType = castType;
       this.oldType = oldType;
       this.newType = newType;
       this.singleValue = singleValue;
@@ -229,6 +236,10 @@
       return other.hasRewrittenReturnInfo() ? combine(other.getRewrittenReturnInfo()) : this;
     }
 
+    public DexType getCastType() {
+      return castType;
+    }
+
     public DexType getNewType() {
       return newType;
     }
@@ -245,6 +256,10 @@
       return newType.isVoidType();
     }
 
+    public boolean hasCastType() {
+      return castType != null;
+    }
+
     public boolean hasSingleValue() {
       return singleValue != null;
     }
@@ -271,18 +286,23 @@
     public RewrittenTypeInfo combine(RewrittenTypeInfo other) {
       assert !getNewType().isVoidType();
       assert getNewType() == other.getOldType();
-      return new RewrittenTypeInfo(getOldType(), other.getNewType(), other.getSingleValue());
+      return new RewrittenTypeInfo(
+          getOldType(), other.getNewType(), getCastType(), other.getSingleValue());
     }
 
     @Override
     public RewrittenTypeInfo rewrittenWithLens(
         AppView<AppInfoWithLiveness> appView, GraphLens graphLens) {
+      DexType rewrittenCastType = castType != null ? graphLens.lookupType(castType) : null;
       DexType rewrittenNewType = graphLens.lookupType(newType);
       SingleValue rewrittenSingleValue =
           hasSingleValue() ? getSingleValue().rewrittenWithLens(appView, graphLens) : null;
-      if (rewrittenNewType != newType || rewrittenSingleValue != singleValue) {
+      if (rewrittenCastType != castType
+          || rewrittenNewType != newType
+          || rewrittenSingleValue != singleValue) {
         // The old type is intentionally not rewritten.
-        return new RewrittenTypeInfo(oldType, rewrittenNewType, rewrittenSingleValue);
+        return new RewrittenTypeInfo(
+            oldType, rewrittenNewType, rewrittenCastType, rewrittenSingleValue);
       }
       return this;
     }
@@ -308,6 +328,38 @@
       assert newType.toBaseType(dexItemFactory).isIntType();
       return true;
     }
+
+    public static class Builder {
+
+      private DexType castType;
+      private DexType oldType;
+      private DexType newType;
+      private SingleValue singleValue;
+
+      public Builder setCastType(DexType castType) {
+        this.castType = castType;
+        return this;
+      }
+
+      public Builder setOldType(DexType oldType) {
+        this.oldType = oldType;
+        return this;
+      }
+
+      public Builder setNewType(DexType newType) {
+        this.newType = newType;
+        return this;
+      }
+
+      public Builder setSingleValue(SingleValue singleValue) {
+        this.singleValue = singleValue;
+        return this;
+      }
+
+      public RewrittenTypeInfo build() {
+        return new RewrittenTypeInfo(oldType, newType, castType, singleValue);
+      }
+    }
   }
 
   public static class ArgumentInfoCollection {
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/inlining/NullSimpleInliningConstraint.java b/src/main/java/com/android/tools/r8/ir/analysis/inlining/NullSimpleInliningConstraint.java
index bb4f505..10af841 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/inlining/NullSimpleInliningConstraint.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/inlining/NullSimpleInliningConstraint.java
@@ -52,12 +52,13 @@
           : NeverSimpleInliningConstraint.getInstance();
     } else if (argumentInfo.isRewrittenTypeInfo()) {
       RewrittenTypeInfo rewrittenTypeInfo = argumentInfo.asRewrittenTypeInfo();
-      // We should only get here as a result of enum unboxing.
-      assert rewrittenTypeInfo.verifyIsDueToUnboxing(appView.dexItemFactory());
-      // Rewrite definitely-null constraints to definitely-zero constraints.
-      return nullability.isDefinitelyNull()
-          ? factory.createEqualToNumberConstraint(getArgumentIndex(), 0)
-          : factory.createNotEqualToNumberConstraint(getArgumentIndex(), 0);
+      if (rewrittenTypeInfo.getNewType().isIntType()) {
+        // Rewrite definitely-null constraints to definitely-zero constraints.
+        return nullability.isDefinitelyNull()
+            ? factory.createEqualToNumberConstraint(getArgumentIndex(), 0)
+            : factory.createNotEqualToNumberConstraint(getArgumentIndex(), 0);
+      }
+      return this;
     }
     return withArgumentIndex(changes.getNewArgumentIndex(getArgumentIndex()), factory);
   }
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
index 5794f8d..3dc74b2 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
@@ -105,7 +105,6 @@
 import java.util.HashSet;
 import java.util.IdentityHashMap;
 import java.util.List;
-import java.util.ListIterator;
 import java.util.Map;
 import java.util.Set;
 import java.util.function.BiFunction;
@@ -256,6 +255,9 @@
               DexMethod actualTarget = lensLookup.getReference();
               Invoke.Type actualInvokeType = lensLookup.getType();
 
+              iterator =
+                  insertCastsForInvokeArgumentsIfNeeded(code, blocks, iterator, invoke, lensLookup);
+
               RewrittenPrototypeDescription prototypeChanges = lensLookup.getPrototypeChanges();
               if (prototypeChanges.requiresRewritingAtCallSite()
                   || invoke.getType() != actualInvokeType
@@ -756,7 +758,7 @@
 
   private InstructionListIterator insertCastForFieldAssignmentIfNeeded(
       IRCode code,
-      ListIterator<BasicBlock> blocks,
+      BasicBlockIterator blocks,
       InstructionListIterator iterator,
       FieldPut fieldPut,
       FieldLookupResult lookup) {
@@ -770,15 +772,72 @@
               .setPosition(fieldPut.getPosition())
               .build();
       iterator.add(checkCast);
-      iterator =
-          iterator.splitCopyCatchHandlers(code, blocks, appView.options()).listIterator(code);
       fieldPut.setValue(checkCast.outValue());
+
+      if (checkCast.getBlock().hasCatchHandlers()) {
+        // Split the block and reset the block iterator.
+        BasicBlock splitBlock = iterator.splitCopyCatchHandlers(code, blocks, appView.options());
+        BasicBlock previousBlock = blocks.previousUntil(block -> block == splitBlock);
+        assert previousBlock == splitBlock;
+        blocks.next();
+        iterator = splitBlock.listIterator(code);
+      }
+
       Instruction next = iterator.next();
       assert next == fieldPut;
     }
     return iterator;
   }
 
+  private InstructionListIterator insertCastsForInvokeArgumentsIfNeeded(
+      IRCode code,
+      BasicBlockIterator blocks,
+      InstructionListIterator iterator,
+      InvokeMethod invoke,
+      MethodLookupResult lookup) {
+    RewrittenPrototypeDescription prototypeChanges = lookup.getPrototypeChanges();
+    if (prototypeChanges.isEmpty()) {
+      return iterator;
+    }
+    for (int argumentIndex = 0; argumentIndex < invoke.arguments().size(); argumentIndex++) {
+      RewrittenTypeInfo rewrittenTypeInfo =
+          prototypeChanges
+              .getArgumentInfoCollection()
+              .getArgumentInfo(argumentIndex)
+              .asRewrittenTypeInfo();
+      if (rewrittenTypeInfo != null && rewrittenTypeInfo.hasCastType()) {
+        iterator.previous();
+        Value object = invoke.getArgument(argumentIndex);
+        CheckCast checkCast =
+            SafeCheckCast.builder()
+                .setObject(object)
+                .setFreshOutValue(
+                    code,
+                    rewrittenTypeInfo
+                        .getCastType()
+                        .toTypeElement(appView, object.getType().nullability()))
+                .setCastType(rewrittenTypeInfo.getCastType())
+                .setPosition(invoke.getPosition())
+                .build();
+        iterator.add(checkCast);
+        invoke.replaceValue(argumentIndex, checkCast.outValue());
+
+        if (checkCast.getBlock().hasCatchHandlers()) {
+          // Split the block and reset the block iterator.
+          BasicBlock splitBlock = iterator.splitCopyCatchHandlers(code, blocks, appView.options());
+          BasicBlock previousBlock = blocks.previousUntil(block -> block == splitBlock);
+          assert previousBlock == splitBlock;
+          blocks.next();
+          iterator = splitBlock.listIterator(code);
+        }
+
+        Instruction next = iterator.next();
+        assert next == invoke;
+      }
+    }
+    return iterator;
+  }
+
   private DexField rewriteFieldReference(FieldLookupResult lookup, ProgramMethod context) {
     if (lookup.hasReboundReference()) {
       DexClass holder = appView.definitionFor(lookup.getReboundReference().getHolderType());
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/classinliner/constraint/ConditionalClassInlinerMethodConstraint.java b/src/main/java/com/android/tools/r8/ir/optimize/classinliner/constraint/ConditionalClassInlinerMethodConstraint.java
index 5403b86..4aceb17 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/classinliner/constraint/ConditionalClassInlinerMethodConstraint.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/classinliner/constraint/ConditionalClassInlinerMethodConstraint.java
@@ -11,7 +11,6 @@
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.graph.RewrittenPrototypeDescription.ArgumentInfo;
 import com.android.tools.r8.graph.RewrittenPrototypeDescription.ArgumentInfoCollection;
-import com.android.tools.r8.graph.RewrittenPrototypeDescription.RewrittenTypeInfo;
 import com.android.tools.r8.ir.analysis.value.AbstractValue;
 import com.android.tools.r8.ir.analysis.value.SingleConstValue;
 import com.android.tools.r8.ir.analysis.value.objectstate.ObjectState;
@@ -51,11 +50,10 @@
                 // usages of that parameter for class inlining.
                 return;
               }
-              if (argumentInfo.isRewrittenTypeInfo()) {
+              if (argumentInfo.isRewrittenTypeInfo()
+                  && argumentInfo.asRewrittenTypeInfo().getNewType().isIntType()) {
                 // This is due to enum unboxing. After enum unboxing, we no longer need information
                 // about the usages of this parameter for class inlining.
-                RewrittenTypeInfo rewrittenTypeInfo = argumentInfo.asRewrittenTypeInfo();
-                assert rewrittenTypeInfo.verifyIsDueToUnboxing(appView.dexItemFactory());
                 return;
               }
               backing.put(changes.getNewArgumentIndex(argumentIndex), usagePerContext);
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingLens.java b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingLens.java
index 1fe7a27..e055f70 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingLens.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/enums/EnumUnboxingLens.java
@@ -63,11 +63,14 @@
           if (unboxedEnums.hasUnboxedValueFor(singleFieldValue.getField())) {
             prototypeChanges =
                 prototypeChanges.withRewrittenReturnInfo(
-                    new RewrittenTypeInfo(
-                        rewrittenTypeInfo.getOldType(),
-                        rewrittenTypeInfo.getNewType(),
-                        abstractValueFactory.createSingleNumberValue(
-                            unboxedEnums.getUnboxedValue(singleFieldValue.getField()))));
+                    RewrittenTypeInfo.builder()
+                        .setCastType(rewrittenTypeInfo.getCastType())
+                        .setOldType(rewrittenTypeInfo.getOldType())
+                        .setNewType(rewrittenTypeInfo.getNewType())
+                        .setSingleValue(
+                            abstractValueFactory.createSingleNumberValue(
+                                unboxedEnums.getUnboxedValue(singleFieldValue.getField())))
+                        .build());
           }
         }
       }