Handle check-cast users in the class inliner

Change-Id: I089b63316044d44059417057d88a6b1d0e78b7d4
diff --git a/src/main/java/com/android/tools/r8/ir/code/AliasedValueConfiguration.java b/src/main/java/com/android/tools/r8/ir/code/AliasedValueConfiguration.java
new file mode 100644
index 0000000..008d363
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/code/AliasedValueConfiguration.java
@@ -0,0 +1,12 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.code;
+
+public interface AliasedValueConfiguration {
+
+  boolean isIntroducingAnAlias(Instruction instruction);
+
+  Value getAliasForOutValue(Instruction instruction);
+}
diff --git a/src/main/java/com/android/tools/r8/ir/code/AssumeAndCheckCastAliasedValueConfiguration.java b/src/main/java/com/android/tools/r8/ir/code/AssumeAndCheckCastAliasedValueConfiguration.java
new file mode 100644
index 0000000..8a6517b
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/code/AssumeAndCheckCastAliasedValueConfiguration.java
@@ -0,0 +1,30 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.code;
+
+import com.android.tools.r8.utils.ListUtils;
+
+public class AssumeAndCheckCastAliasedValueConfiguration implements AliasedValueConfiguration {
+
+  private static final AssumeAndCheckCastAliasedValueConfiguration INSTANCE =
+      new AssumeAndCheckCastAliasedValueConfiguration();
+
+  private AssumeAndCheckCastAliasedValueConfiguration() {}
+
+  public static AssumeAndCheckCastAliasedValueConfiguration getInstance() {
+    return INSTANCE;
+  }
+
+  @Override
+  public boolean isIntroducingAnAlias(Instruction instruction) {
+    return instruction.isAssume() || instruction.isCheckCast();
+  }
+
+  @Override
+  public Value getAliasForOutValue(Instruction instruction) {
+    assert instruction.isAssume() || instruction.isCheckCast();
+    return ListUtils.first(instruction.inValues());
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/ir/code/DefaultAliasedValueConfiguration.java b/src/main/java/com/android/tools/r8/ir/code/DefaultAliasedValueConfiguration.java
new file mode 100644
index 0000000..b147deb
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/code/DefaultAliasedValueConfiguration.java
@@ -0,0 +1,28 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.code;
+
+public class DefaultAliasedValueConfiguration implements AliasedValueConfiguration {
+
+  private static final DefaultAliasedValueConfiguration INSTANCE =
+      new DefaultAliasedValueConfiguration();
+
+  private DefaultAliasedValueConfiguration() {}
+
+  public static DefaultAliasedValueConfiguration getInstance() {
+    return INSTANCE;
+  }
+
+  @Override
+  public boolean isIntroducingAnAlias(Instruction instruction) {
+    return instruction.isAssume();
+  }
+
+  @Override
+  public Value getAliasForOutValue(Instruction instruction) {
+    assert instruction.isAssume();
+    return instruction.asAssume().src();
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/ir/code/Value.java b/src/main/java/com/android/tools/r8/ir/code/Value.java
index d6372b5..3079026 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Value.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Value.java
@@ -264,25 +264,31 @@
    * <p>This method is useful to find the "true" definition of a value inside the current method.
    */
   public Value getAliasedValue() {
-    return getAliasedValue(Predicates.alwaysFalse());
+    return getAliasedValue(
+        DefaultAliasedValueConfiguration.getInstance(), Predicates.alwaysFalse());
   }
 
-  public Value getAliasedValue(Predicate<Value> stoppingCriterion) {
+  public Value getAliasedValue(AliasedValueConfiguration configuration) {
+    return getAliasedValue(configuration, Predicates.alwaysFalse());
+  }
+
+  public Value getAliasedValue(
+      AliasedValueConfiguration configuration, Predicate<Value> stoppingCriterion) {
     assert stoppingCriterion != null;
     Set<Value> visited = Sets.newIdentityHashSet();
     Value lastAliasedValue;
     Value aliasedValue = this;
     do {
-      if (stoppingCriterion.test(aliasedValue)) {
-        return aliasedValue;
-      }
       lastAliasedValue = aliasedValue;
       if (aliasedValue.isPhi()) {
         return aliasedValue;
       }
+      if (stoppingCriterion.test(aliasedValue)) {
+        return aliasedValue;
+      }
       Instruction definitionOfAliasedValue = aliasedValue.definition;
-      if (definitionOfAliasedValue.isIntroducingAnAlias()) {
-        aliasedValue = definitionOfAliasedValue.getAliasForOutValue();
+      if (configuration.isIntroducingAnAlias(definitionOfAliasedValue)) {
+        aliasedValue = configuration.getAliasForOutValue(definitionOfAliasedValue);
 
         // There shouldn't be a cycle.
         assert visited.add(aliasedValue);
@@ -293,7 +299,8 @@
   }
 
   public Value getSpecificAliasedValue(Predicate<Value> stoppingCriterion) {
-    Value aliasedValue = getAliasedValue(stoppingCriterion);
+    Value aliasedValue =
+        getAliasedValue(DefaultAliasedValueConfiguration.getInstance(), stoppingCriterion);
     return stoppingCriterion.test(aliasedValue) ? aliasedValue : null;
   }
 
@@ -440,21 +447,29 @@
   }
 
   public Set<Instruction> aliasedUsers() {
+    return aliasedUsers(DefaultAliasedValueConfiguration.getInstance());
+  }
+
+  public Set<Instruction> aliasedUsers(AliasedValueConfiguration configuration) {
     Set<Instruction> users = SetUtils.newIdentityHashSet(uniqueUsers());
     Set<Instruction> visited = Sets.newIdentityHashSet();
-    collectAliasedUsersViaAssume(visited, uniqueUsers(), users);
+    collectAliasedUsersViaAssume(configuration, visited, uniqueUsers(), users);
     return users;
   }
 
   private static void collectAliasedUsersViaAssume(
-      Set<Instruction> visited, Set<Instruction> usersToTest, Set<Instruction> collectedUsers) {
+      AliasedValueConfiguration configuration,
+      Set<Instruction> visited,
+      Set<Instruction> usersToTest,
+      Set<Instruction> collectedUsers) {
     for (Instruction user : usersToTest) {
       if (!visited.add(user)) {
         continue;
       }
-      if (user.isAssume()) {
+      if (configuration.isIntroducingAnAlias(user)) {
         collectedUsers.addAll(user.outValue().uniqueUsers());
-        collectAliasedUsersViaAssume(visited, user.outValue().uniqueUsers(), collectedUsers);
+        collectAliasedUsersViaAssume(
+            configuration, visited, user.outValue().uniqueUsers(), collectedUsers);
       }
     }
   }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/classinliner/InlineCandidateProcessor.java b/src/main/java/com/android/tools/r8/ir/optimize/classinliner/InlineCandidateProcessor.java
index 17510c2..9bb2ba5 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/classinliner/InlineCandidateProcessor.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/classinliner/InlineCandidateProcessor.java
@@ -20,8 +20,10 @@
 import com.android.tools.r8.ir.analysis.ClassInitializationAnalysis;
 import com.android.tools.r8.ir.analysis.type.ClassTypeLatticeElement;
 import com.android.tools.r8.ir.analysis.type.TypeAnalysis;
-import com.android.tools.r8.ir.code.Assume;
+import com.android.tools.r8.ir.code.AliasedValueConfiguration;
+import com.android.tools.r8.ir.code.AssumeAndCheckCastAliasedValueConfiguration;
 import com.android.tools.r8.ir.code.BasicBlock;
+import com.android.tools.r8.ir.code.CheckCast;
 import com.android.tools.r8.ir.code.ConstNumber;
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.code.If;
@@ -50,6 +52,7 @@
 import com.android.tools.r8.ir.optimize.inliner.NopWhyAreYouNotInliningReporter;
 import com.android.tools.r8.kotlin.KotlinInfo;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.ListUtils;
 import com.android.tools.r8.utils.Pair;
 import com.android.tools.r8.utils.StringUtils;
 import com.google.common.collect.ImmutableSet;
@@ -76,6 +79,8 @@
 
   private static final ImmutableSet<If.Type> ALLOWED_ZERO_TEST_TYPES =
       ImmutableSet.of(If.Type.EQ, If.Type.NE);
+  private static final AliasedValueConfiguration aliasesThroughAssumeAndCheckCasts =
+      AssumeAndCheckCastAliasedValueConfiguration.getInstance();
 
   private final AppView<AppInfoWithLiveness> appView;
   private final Inliner inliner;
@@ -355,7 +360,7 @@
     }
 
     anyInlinedMethods |= forceInlineDirectMethodInvocations(code, inliningIRProvider);
-    removeAssumeInstructionsLinkedToEligibleInstance();
+    removeAliasIntroducingInstructionsLinkedToEligibleInstance();
     removeMiscUsages(code);
     removeFieldReads(code);
     removeFieldWrites();
@@ -448,21 +453,26 @@
     return true;
   }
 
-  private void removeAssumeInstructionsLinkedToEligibleInstance() {
-    for (Instruction user : eligibleInstance.aliasedUsers()) {
-      if (!user.isAssume()) {
-        continue;
+  private void removeAliasIntroducingInstructionsLinkedToEligibleInstance() {
+    Set<Instruction> currentUsers = eligibleInstance.uniqueUsers();
+    while (!currentUsers.isEmpty()) {
+      Set<Instruction> indirectOutValueUsers = Sets.newIdentityHashSet();
+      for (Instruction instruction : currentUsers) {
+        if (instruction.isAssume() || instruction.isCheckCast()) {
+          Value src = ListUtils.first(instruction.inValues());
+          Value dest = instruction.outValue();
+          indirectOutValueUsers.addAll(dest.uniqueUsers());
+          assert !dest.hasPhiUsers();
+          dest.replaceUsers(src);
+          removeInstruction(instruction);
+        }
       }
-      Assume<?> assumeInstruction = user.asAssume();
-      Value src = assumeInstruction.src();
-      Value dest = assumeInstruction.outValue();
-      assert receivers.isReceiverAlias(dest);
-      assert !dest.hasPhiUsers();
-      dest.replaceUsers(src);
-      removeInstruction(user);
+      currentUsers = indirectOutValueUsers;
     }
-    // Verify that no more assume instructions are left as users.
+
+    // Verify that no more assume or check-cast instructions are left as users.
     assert eligibleInstance.aliasedUsers().stream().noneMatch(Instruction::isAssume);
+    assert eligibleInstance.aliasedUsers().stream().noneMatch(Instruction::isCheckCast);
   }
 
   // Remove miscellaneous users before handling field reads.
@@ -700,7 +710,13 @@
     while (!currentUsers.isEmpty()) {
       Set<Instruction> indirectOutValueUsers = Sets.newIdentityHashSet();
       for (Instruction instruction : currentUsers) {
-        if (instruction.isAssume()) {
+        if (instruction.isAssume() || instruction.isCheckCast()) {
+          if (instruction.isCheckCast()) {
+            CheckCast checkCast = instruction.asCheckCast();
+            if (!appView.appInfo().isSubtype(eligibleClass.type, checkCast.getType())) {
+              return false; // Unsafe cast.
+            }
+          }
           Value outValueAlias = instruction.outValue();
           if (outValueAlias.hasPhiUsers() || outValueAlias.hasDebugUsers()) {
             return false;
@@ -714,11 +730,12 @@
 
         if (instruction.isInvokeMethodWithReceiver()) {
           InvokeMethodWithReceiver user = instruction.asInvokeMethodWithReceiver();
-          if (user.getReceiver().getAliasedValue() != outValue) {
+          if (user.getReceiver().getAliasedValue(aliasesThroughAssumeAndCheckCasts) != outValue) {
             return false;
           }
           for (int i = 1; i < user.inValues().size(); i++) {
-            if (user.inValues().get(i).getAliasedValue() == outValue) {
+            Value inValue = user.inValues().get(i);
+            if (inValue.getAliasedValue(aliasesThroughAssumeAndCheckCasts) == outValue) {
               return false;
             }
           }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/info/MethodOptimizationInfoCollector.java b/src/main/java/com/android/tools/r8/ir/optimize/info/MethodOptimizationInfoCollector.java
index 8238723..b97cc0d 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/info/MethodOptimizationInfoCollector.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/info/MethodOptimizationInfoCollector.java
@@ -27,6 +27,7 @@
 import static com.android.tools.r8.ir.code.Opcodes.INVOKE_NEW_ARRAY;
 import static com.android.tools.r8.ir.code.Opcodes.INVOKE_STATIC;
 import static com.android.tools.r8.ir.code.Opcodes.INVOKE_VIRTUAL;
+import static com.android.tools.r8.ir.code.Opcodes.MONITOR;
 import static com.android.tools.r8.ir.code.Opcodes.MUL;
 import static com.android.tools.r8.ir.code.Opcodes.NEW_ARRAY_EMPTY;
 import static com.android.tools.r8.ir.code.Opcodes.NEW_INSTANCE;
@@ -59,6 +60,8 @@
 import com.android.tools.r8.ir.analysis.type.Nullability;
 import com.android.tools.r8.ir.analysis.type.TypeLatticeElement;
 import com.android.tools.r8.ir.analysis.value.AbstractValue;
+import com.android.tools.r8.ir.code.AliasedValueConfiguration;
+import com.android.tools.r8.ir.code.AssumeAndCheckCastAliasedValueConfiguration;
 import com.android.tools.r8.ir.code.BasicBlock;
 import com.android.tools.r8.ir.code.DominatorTree;
 import com.android.tools.r8.ir.code.FieldInstruction;
@@ -99,6 +102,7 @@
 import java.util.List;
 import java.util.Set;
 import java.util.function.BiFunction;
+import java.util.function.Predicate;
 
 public class MethodOptimizationInfoCollector {
   private final AppView<AppInfoWithLiveness> appView;
@@ -167,80 +171,79 @@
     List<Pair<Invoke.Type, DexMethod>> callsReceiver = new ArrayList<>();
     boolean seenSuperInitCall = false;
     boolean seenMonitor = false;
-    for (Instruction insn : receiver.aliasedUsers()) {
-      if (insn.isAssume()) {
-        continue;
-      }
 
-      if (insn.isMonitor()) {
-        seenMonitor = true;
-        continue;
-      }
+    AliasedValueConfiguration configuration =
+        AssumeAndCheckCastAliasedValueConfiguration.getInstance();
+    Predicate<Value> isReceiverAlias = value -> value.getAliasedValue(configuration) == receiver;
+    for (Instruction insn : receiver.aliasedUsers(configuration)) {
+      switch (insn.opcode()) {
+        case ASSUME:
+        case CHECK_CAST:
+        case RETURN:
+          break;
 
-      if (insn.isInstanceGet() || insn.isInstancePut()) {
-        if (insn.isInstancePut()) {
-          InstancePut instancePutInstruction = insn.asInstancePut();
-          // Only allow field writes to the receiver.
-          if (instancePutInstruction.object().getAliasedValue() != receiver) {
+        case MONITOR:
+          seenMonitor = true;
+          break;
+
+        case INSTANCE_GET:
+        case INSTANCE_PUT:
+          {
+            if (insn.isInstancePut()) {
+              InstancePut instancePutInstruction = insn.asInstancePut();
+              // Only allow field writes to the receiver.
+              if (!isReceiverAlias.test(instancePutInstruction.object())) {
+                return;
+              }
+              // Do not allow the receiver to escape via a field write.
+              if (isReceiverAlias.test(instancePutInstruction.value())) {
+                return;
+              }
+            }
+            DexField field = insn.asFieldInstruction().getField();
+            if (appView.appInfo().resolveFieldOn(clazz, field) != null) {
+              // Require only accessing direct or indirect instance fields of the current class.
+              break;
+            }
             return;
           }
-          // Do not allow the receiver to escape via a field write.
-          if (instancePutInstruction.value().getAliasedValue() == receiver) {
+
+        case INVOKE_DIRECT:
+          {
+            InvokeDirect invoke = insn.asInvokeDirect();
+            DexMethod invokedMethod = invoke.getInvokedMethod();
+            if (dexItemFactory.isConstructor(invokedMethod)
+                && invokedMethod.holder == clazz.superType
+                && ListUtils.lastIndexMatching(invoke.arguments(), isReceiverAlias) == 0
+                && !seenSuperInitCall
+                && instanceInitializer) {
+              seenSuperInitCall = true;
+              break;
+            }
+            // We don't support other direct calls yet.
             return;
           }
-        }
-        DexField field = insn.asFieldInstruction().getField();
-        if (appView.appInfo().resolveFieldOn(clazz, field) != null) {
-          // Require only accessing direct or indirect instance fields of the current class.
-          continue;
-        }
-        return;
-      }
 
-      // If this is an instance initializer allow one call to superclass instance initializer.
-      if (insn.isInvokeDirect()) {
-        InvokeDirect invokedDirect = insn.asInvokeDirect();
-        DexMethod invokedMethod = invokedDirect.getInvokedMethod();
-        if (dexItemFactory.isConstructor(invokedMethod)
-            && invokedMethod.holder == clazz.superType
-            && ListUtils.lastIndexMatching(
-                invokedDirect.inValues(), v -> v.getAliasedValue() == receiver) == 0
-            && !seenSuperInitCall
-            && instanceInitializer) {
-          seenSuperInitCall = true;
-          continue;
-        }
-        // We don't support other direct calls yet.
-        return;
-      }
-
-      if (insn.isInvokeVirtual()) {
-        InvokeVirtual invoke = insn.asInvokeVirtual();
-        if (invoke.getReceiver().getAliasedValue() != receiver) {
-          return; // Not allowed.
-        }
-        for (int i = 1; i < invoke.arguments().size(); i++) {
-          Value argument = invoke.arguments().get(i);
-          if (argument.getAliasedValue() == receiver) {
-            return; // Not allowed.
+        case INVOKE_VIRTUAL:
+          {
+            InvokeVirtual invoke = insn.asInvokeVirtual();
+            if (ListUtils.lastIndexMatching(invoke.arguments(), isReceiverAlias) != 0) {
+              return; // Not allowed.
+            }
+            DexMethod invokedMethod = invoke.getInvokedMethod();
+            DexType returnType = invokedMethod.proto.returnType;
+            if (returnType.isClassType()
+                && appView.appInfo().isRelatedBySubtyping(returnType, method.method.holder)) {
+              return; // Not allowed, could introduce an alias of the receiver.
+            }
+            callsReceiver.add(new Pair<>(Invoke.Type.VIRTUAL, invokedMethod));
           }
-        }
-        DexMethod invokedMethod = invoke.getInvokedMethod();
-        DexType returnType = invokedMethod.proto.returnType;
-        if (returnType.isClassType()
-            && appView.appInfo().isRelatedBySubtyping(returnType, method.method.holder)) {
-          return; // Not allowed, could introduce an alias of the receiver.
-        }
-        callsReceiver.add(new Pair<>(Invoke.Type.VIRTUAL, invokedMethod));
-        continue;
-      }
+          break;
 
-      if (insn.isReturn()) {
-        continue;
+        default:
+          // Other receiver usages make the method not eligible.
+          return;
       }
-
-      // Other receiver usages make the method not eligible.
-      return;
     }
 
     if (instanceInitializer && !seenSuperInitCall) {