Generalize Assume instruction

Bug: 129860265
Change-Id: Ib60c956721d95702b7c28ba05a78586760039f71
diff --git a/src/main/java/com/android/tools/r8/ir/code/Assume.java b/src/main/java/com/android/tools/r8/ir/code/Assume.java
index f5d95bf..7978468 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Assume.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Assume.java
@@ -4,27 +4,39 @@
 package com.android.tools.r8.ir.code;
 
 import com.android.tools.r8.cf.LoadStoreHelper;
+import com.android.tools.r8.errors.Unimplemented;
 import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.graph.AppInfo;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.ir.analysis.type.TypeLatticeElement;
+import com.android.tools.r8.ir.code.Assume.Assumption;
 import com.android.tools.r8.ir.conversion.CfBuilder;
 import com.android.tools.r8.ir.conversion.DexBuilder;
 import com.android.tools.r8.ir.optimize.Inliner.ConstraintWithTarget;
 import com.android.tools.r8.ir.optimize.InliningConstraints;
 
-public class Assume extends Instruction {
-  private final static String ERROR_MESSAGE = "This fake IR should be removed after inlining.";
+public class Assume<An extends Assumption> extends Instruction {
 
-  final Instruction origin;
+  private static final String ERROR_MESSAGE =
+      "Expected Assume instructions to be removed after IR processing.";
 
-  public Assume(Value dest, Value src, Instruction origin) {
+  private final An assumption;
+  private final Instruction origin;
+
+  private Assume(An assumption, Value dest, Value src, Instruction origin) {
     super(dest, src);
-    assert !src.isNeverNull();
+    assert assumption != null;
+    assert assumption.verifyCorrectnessOfValues(dest, src);
+    this.assumption = assumption;
     this.origin = origin;
   }
 
+  public static Assume<NonNullAssumption> createAssumeNonNullInstruction(
+      Value dest, Value src, Instruction origin) {
+    return new Assume<>(NonNullAssumption.get(), dest, src, origin);
+  }
+
   @Override
   public <T> T accept(InstructionVisitor<T> visitor) {
     return visitor.visit(this);
@@ -43,16 +55,29 @@
   }
 
   @Override
-  public boolean isNonNull() {
+  public boolean isAssume() {
     return true;
   }
 
   @Override
-  public Assume asNonNull() {
+  public Assume<An> asAssume() {
     return this;
   }
 
   @Override
+  public boolean isAssumeNonNull() {
+    return assumption.isAssumeNonNull();
+  }
+
+  @Override
+  public Assume<NonNullAssumption> asAssumeNonNull() {
+    assert isAssumeNonNull();
+    @SuppressWarnings("unchecked")
+    Assume<NonNullAssumption> self = (Assume<NonNullAssumption>) this;
+    return self;
+  }
+
+  @Override
   public boolean isIntroducingAnAlias() {
     return true;
   }
@@ -94,19 +119,26 @@
 
   @Override
   public boolean identicalNonValueNonPositionParts(Instruction other) {
-    return other.isNonNull();
+    if (!other.isAssume()) {
+      return false;
+    }
+    Assume<?> assumeInstruction = other.asAssume();
+    return assumption.equals(assumeInstruction.assumption);
   }
 
   @Override
   public ConstraintWithTarget inliningConstraint(
       InliningConstraints inliningConstraints, DexType invocationContext) {
-    return inliningConstraints.forNonNull();
+    return inliningConstraints.forAssume();
   }
 
   @Override
   public TypeLatticeElement evaluate(AppView<? extends AppInfo> appView) {
-    assert src().getTypeLattice().isReference();
-    return src().getTypeLattice().asReferenceTypeLatticeElement().asNotNull();
+    if (assumption.isAssumeNonNull()) {
+      assert src().getTypeLattice().isReference();
+      return src().getTypeLattice().asReferenceTypeLatticeElement().asNotNull();
+    }
+    throw new Unimplemented();
   }
 
   @Override
@@ -118,4 +150,37 @@
   public void insertLoadAndStores(InstructionListIterator it, LoadStoreHelper helper) {
     throw new Unreachable(ERROR_MESSAGE);
   }
+
+  abstract static class Assumption {
+
+    public boolean isAssumeNonNull() {
+      return false;
+    }
+
+    public boolean verifyCorrectnessOfValues(Value dest, Value src) {
+      return true;
+    }
+  }
+
+  public static class NonNullAssumption extends Assumption {
+
+    private static final NonNullAssumption instance = new NonNullAssumption();
+
+    private NonNullAssumption() {}
+
+    public static NonNullAssumption get() {
+      return instance;
+    }
+
+    @Override
+    public boolean isAssumeNonNull() {
+      return true;
+    }
+
+    @Override
+    public boolean verifyCorrectnessOfValues(Value dest, Value src) {
+      assert !src.isNeverNull();
+      return true;
+    }
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/code/DefaultInstructionVisitor.java b/src/main/java/com/android/tools/r8/ir/code/DefaultInstructionVisitor.java
index 5ec29e4..8a3fb3a 100644
--- a/src/main/java/com/android/tools/r8/ir/code/DefaultInstructionVisitor.java
+++ b/src/main/java/com/android/tools/r8/ir/code/DefaultInstructionVisitor.java
@@ -35,6 +35,11 @@
   }
 
   @Override
+  public T visit(Assume<?> instruction) {
+    return null;
+  }
+
+  @Override
   public T visit(And instruction) {
     return null;
   }
@@ -260,11 +265,6 @@
   }
 
   @Override
-  public T visit(Assume instruction) {
-    return null;
-  }
-
-  @Override
   public T visit(Not instruction) {
     return null;
   }
diff --git a/src/main/java/com/android/tools/r8/ir/code/Instruction.java b/src/main/java/com/android/tools/r8/ir/code/Instruction.java
index 7bf9eed..dcbef34 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Instruction.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Instruction.java
@@ -18,6 +18,7 @@
 import com.android.tools.r8.ir.analysis.constant.ConstRangeLatticeElement;
 import com.android.tools.r8.ir.analysis.constant.LatticeElement;
 import com.android.tools.r8.ir.analysis.type.TypeLatticeElement;
+import com.android.tools.r8.ir.code.Assume.NonNullAssumption;
 import com.android.tools.r8.ir.conversion.CfBuilder;
 import com.android.tools.r8.ir.conversion.DexBuilder;
 import com.android.tools.r8.ir.optimize.Inliner.ConstraintWithTarget;
@@ -604,6 +605,22 @@
     return null;
   }
 
+  public boolean isAssume() {
+    return false;
+  }
+
+  public Assume<?> asAssume() {
+    return null;
+  }
+
+  public boolean isAssumeNonNull() {
+    return false;
+  }
+
+  public Assume<NonNullAssumption> asAssumeNonNull() {
+    return null;
+  }
+
   public boolean isBinop() {
     return false;
   }
@@ -832,14 +849,6 @@
     return null;
   }
 
-  public boolean isNonNull() {
-    return false;
-  }
-
-  public Assume asNonNull() {
-    return null;
-  }
-
   public boolean isNot() {
     return false;
   }
diff --git a/src/main/java/com/android/tools/r8/ir/code/InstructionVisitor.java b/src/main/java/com/android/tools/r8/ir/code/InstructionVisitor.java
index d6160d7..d7f9ae3 100644
--- a/src/main/java/com/android/tools/r8/ir/code/InstructionVisitor.java
+++ b/src/main/java/com/android/tools/r8/ir/code/InstructionVisitor.java
@@ -24,6 +24,8 @@
 
   T visit(ArrayPut instruction);
 
+  T visit(Assume<?> instruction);
+
   T visit(CheckCast instruction);
 
   T visit(Cmp instruction);
@@ -104,8 +106,6 @@
 
   T visit(NewInstance instruction);
 
-  T visit(Assume instruction);
-
   T visit(Not instruction);
 
   T visit(NumberConversion instruction);
diff --git a/src/main/java/com/android/tools/r8/ir/code/Value.java b/src/main/java/com/android/tools/r8/ir/code/Value.java
index 7c9fda8..5138bc1 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Value.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Value.java
@@ -236,8 +236,9 @@
 
   /**
    * If this value is defined by an instruction that defines an alias of another value, such as the
-   * NonNull instruction, then the incoming value to the NonNull instruction is returned (if the
-   * incoming value is not itself defined by an instruction that introduces an alias).
+   * {@link Assume} instruction, then the incoming value to the {@link Assume} instruction is
+   * returned (if the incoming value is not itself defined by an instruction that introduces an
+   * alias).
    *
    * <p>If a phi value is found, then that phi value is returned.
    *
@@ -260,7 +261,7 @@
         assert visited.add(aliasedValue);
       }
     } while (aliasedValue != lastAliasedValue);
-    assert aliasedValue.isPhi() || !aliasedValue.definition.isNonNull();
+    assert aliasedValue.isPhi() || !aliasedValue.definition.isAssume();
     return aliasedValue;
   }
 
@@ -783,7 +784,7 @@
    */
   public boolean isNeverNull() {
     assert typeLattice.isReference();
-    return (definition != null && definition.isNonNull())
+    return (definition != null && definition.isAssumeNonNull())
         || typeLattice.nullability().isDefinitelyNotNull();
   }
 
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/Devirtualizer.java b/src/main/java/com/android/tools/r8/ir/optimize/Devirtualizer.java
index 38f7d27..af2b34e 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/Devirtualizer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/Devirtualizer.java
@@ -10,6 +10,7 @@
 import com.android.tools.r8.ir.analysis.type.TypeAnalysis;
 import com.android.tools.r8.ir.analysis.type.TypeLatticeElement;
 import com.android.tools.r8.ir.code.Assume;
+import com.android.tools.r8.ir.code.Assume.NonNullAssumption;
 import com.android.tools.r8.ir.code.BasicBlock;
 import com.android.tools.r8.ir.code.CheckCast;
 import com.android.tools.r8.ir.code.DominatorTree;
@@ -65,8 +66,8 @@
         // (out <-) invoke-virtual rcv_c, ... C#foo
         // ...
         // non_null_rcv <- non-null rcv_c  // <- Update the input rcv to the non-null, too.
-        if (current.isNonNull()) {
-          Assume nonNull = current.asNonNull();
+        if (current.isAssumeNonNull()) {
+          Assume<NonNullAssumption> nonNull = current.asAssumeNonNull();
           Instruction origin = nonNull.origin();
           if (origin.isInvokeInterface()
               && !origin.asInvokeInterface().getReceiver().hasLocalInfo()
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/InliningConstraints.java b/src/main/java/com/android/tools/r8/ir/optimize/InliningConstraints.java
index 2a66866..b5f8538 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/InliningConstraints.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/InliningConstraints.java
@@ -221,7 +221,7 @@
     return ConstraintWithTarget.classIsVisible(invocationContext, type, appView);
   }
 
-  public ConstraintWithTarget forNonNull() {
+  public ConstraintWithTarget forAssume() {
     return ConstraintWithTarget.ALWAYS;
   }
 
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/NonNullTracker.java b/src/main/java/com/android/tools/r8/ir/optimize/NonNullTracker.java
index 94bf914..fffd3da 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/NonNullTracker.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/NonNullTracker.java
@@ -13,6 +13,7 @@
 import com.android.tools.r8.ir.analysis.type.TypeAnalysis;
 import com.android.tools.r8.ir.analysis.type.TypeLatticeElement;
 import com.android.tools.r8.ir.code.Assume;
+import com.android.tools.r8.ir.code.Assume.NonNullAssumption;
 import com.android.tools.r8.ir.code.BasicBlock;
 import com.android.tools.r8.ir.code.DominatorTree;
 import com.android.tools.r8.ir.code.IRCode;
@@ -225,7 +226,9 @@
                         typeLattice.asReferenceTypeLatticeElement().asNotNull(),
                         knownToBeNonNullValue.getLocalInfo());
                 affectedValues.addAll(knownToBeNonNullValue.affectedValues());
-                Assume nonNull = new Assume(nonNullValue, knownToBeNonNullValue, theIf);
+                Assume<NonNullAssumption> nonNull =
+                    Assume.createAssumeNonNullInstruction(
+                        nonNullValue, knownToBeNonNullValue, theIf);
                 InstructionListIterator targetIterator = target.listIterator();
                 nonNull.setPosition(targetIterator.next().getPosition());
                 targetIterator.previous();
@@ -318,7 +321,8 @@
                 typeLattice.asReferenceTypeLatticeElement().asNotNull(),
                 knownToBeNonNullValue.getLocalInfo());
         affectedValues.addAll(knownToBeNonNullValue.affectedValues());
-        Assume nonNull = new Assume(nonNullValue, knownToBeNonNullValue, current);
+        Assume<NonNullAssumption> nonNull =
+            Assume.createAssumeNonNullInstruction(nonNullValue, knownToBeNonNullValue, current);
         nonNull.setPosition(current.getPosition());
         if (blockWithNonNullInstruction != block) {
           // If we split, add non-null IR on top of the new split block.
@@ -387,8 +391,8 @@
       // Collect basic blocks that check nullability of the parameter.
       nullCheckedBlocks.clear();
       for (Instruction user : argument.uniqueUsers()) {
-        if (user.isNonNull()) {
-          nullCheckedBlocks.add(user.asNonNull().getBlock());
+        if (user.isAssumeNonNull()) {
+          nullCheckedBlocks.add(user.asAssumeNonNull().getBlock());
         }
         if (user.isIf()
             && user.asIf().isZeroTest()
@@ -483,8 +487,8 @@
       //  ~>
       //
       // rcv#foo
-      if (instruction.isNonNull()) {
-        Assume nonNull = instruction.asNonNull();
+      if (instruction.isAssumeNonNull()) {
+        Assume<NonNullAssumption> nonNull = instruction.asAssumeNonNull();
         Value src = nonNull.src();
         Value dest = nonNull.dest();
         affectedValues.addAll(dest.affectedValues());
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/RedundantFieldLoadElimination.java b/src/main/java/com/android/tools/r8/ir/optimize/RedundantFieldLoadElimination.java
index 5c55d31..7cd0a5b 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/RedundantFieldLoadElimination.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/RedundantFieldLoadElimination.java
@@ -111,10 +111,10 @@
             assert !couldBeVolatile(field);
             if (instruction.isInstanceGet() && !instruction.outValue().hasLocalInfo()) {
               Value object = instruction.asInstanceGet().object();
-              // Values from NonNull instructions will always be replaced with their original
-              // value before code is generated.
-              if (!object.isPhi() && object.definition.isNonNull()) {
-                object = object.definition.asNonNull().src();
+              // Values from Assume instructions will always be replaced with their original value
+              // before code is generated.
+              if (!object.isPhi() && object.definition.isAssume()) {
+                object = object.definition.asAssume().src();
               }
               FieldAndObject fieldAndObject = new FieldAndObject(field, object);
               if (activeInstanceFields.containsKey(fieldAndObject)) {
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/NonNullTrackerTest.java b/src/test/java/com/android/tools/r8/ir/optimize/NonNullTrackerTest.java
index e971cc1..3f3ccd2 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/NonNullTrackerTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/NonNullTrackerTest.java
@@ -61,15 +61,15 @@
     while (it.hasNext()) {
       prev = curr != null && !curr.isGoto() ? curr : prev;
       curr = it.next();
-      if (curr.isNonNull()) {
+      if (curr.isAssumeNonNull()) {
         // Make sure non-null is added to the right place.
         assertTrue(prev == null
             || NonNullTracker.throwsOnNullInput(prev)
             || (prev.isIf() && prev.asIf().isZeroTest())
             || !curr.getBlock().getPredecessors().contains(prev.getBlock()));
         // Make sure non-null is used or inserted for arguments.
-        assertTrue(curr.outValue().numberOfAllUsers() > 0
-            || curr.asNonNull().src().isArgument());
+        assertTrue(
+            curr.outValue().numberOfAllUsers() > 0 || curr.asAssumeNonNull().src().isArgument());
         count++;
       }
     }
@@ -143,28 +143,33 @@
   public void avoidRedundantNonNull() throws Exception {
     MethodSignature signature = new MethodSignature("foo2", "int",
         new String[]{FieldAccessTest.class.getCanonicalName()});
-    buildAndTest(NonNullAfterFieldAccess.class, signature, 1, ircode -> {
-      // There are two InstancePut instructions of interest.
-      int count = 0;
-      InstructionIterator it = ircode.instructionIterator();
-      while (it.hasNext()) {
-        Instruction instruction = it.nextUntil(Instruction::isInstancePut);
-        if (instruction == null) {
-          break;
-        }
-        InstancePut iput = instruction.asInstancePut();
-        if (count == 0) {
-          // First one in the very first line: its value should not be replaced by NonNullMarker
-          // because this instruction will happen _before_ non-null.
-          assertFalse(iput.value().definition.isNonNull());
-        } else if (count == 1) {
-          // Second one after a safe invocation, which should use the value added by NonNullMarker.
-          assertTrue(iput.object().definition.isNonNull());
-        }
-        count++;
-      }
-      assertEquals(2, count);
-    });
+    buildAndTest(
+        NonNullAfterFieldAccess.class,
+        signature,
+        1,
+        ircode -> {
+          // There are two InstancePut instructions of interest.
+          int count = 0;
+          InstructionIterator it = ircode.instructionIterator();
+          while (it.hasNext()) {
+            Instruction instruction = it.nextUntil(Instruction::isInstancePut);
+            if (instruction == null) {
+              break;
+            }
+            InstancePut iput = instruction.asInstancePut();
+            if (count == 0) {
+              // First one in the very first line: its value should not be replaced by NonNullMarker
+              // because this instruction will happen _before_ non-null.
+              assertFalse(iput.value().definition.isAssumeNonNull());
+            } else if (count == 1) {
+              // Second one after a safe invocation, which should use the value added by
+              // NonNullMarker.
+              assertTrue(iput.object().definition.isAssumeNonNull());
+            }
+            count++;
+          }
+          assertEquals(2, count);
+        });
   }
 
   @Test