Strengthen method return types

Bug: 214329925
Change-Id: Ibda77542b4f79e9e9e30f2dc8f23fd8ed1140a8a
diff --git a/src/main/java/com/android/tools/r8/graph/RewrittenPrototypeDescription.java b/src/main/java/com/android/tools/r8/graph/RewrittenPrototypeDescription.java
index 1dd24ff..dda7746 100644
--- a/src/main/java/com/android/tools/r8/graph/RewrittenPrototypeDescription.java
+++ b/src/main/java/com/android/tools/r8/graph/RewrittenPrototypeDescription.java
@@ -336,6 +336,13 @@
       private DexType newType;
       private SingleValue singleValue;
 
+      public Builder applyIf(boolean condition, Consumer<Builder> consumer) {
+        if (condition) {
+          consumer.accept(this);
+        }
+        return this;
+      }
+
       public Builder setCastType(DexType castType) {
         this.castType = castType;
         return this;
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
index 31f99e5..f6c957c 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
@@ -34,7 +34,6 @@
 import static com.android.tools.r8.utils.ObjectUtils.getBooleanOrElse;
 
 import com.android.tools.r8.errors.CompilationError;
-import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.graph.AccessControl;
 import com.android.tools.r8.graph.AppInfoWithClassHierarchy;
 import com.android.tools.r8.graph.AppView;
@@ -338,10 +337,22 @@
                   }
                 }
 
-                Value newOutValue =
-                    prototypeChanges.hasBeenChangedToReturnVoid()
-                        ? null
-                        : makeOutValue(invoke, code);
+                Value newOutValue;
+                if (prototypeChanges.hasRewrittenReturnInfo()) {
+                  if (invoke.hasOutValue() && !prototypeChanges.hasBeenChangedToReturnVoid()) {
+                    TypeElement newReturnType =
+                        prototypeChanges
+                            .getRewrittenReturnInfo()
+                            .getNewType()
+                            .toTypeElement(appView);
+                    newOutValue = code.createValue(newReturnType, invoke.getLocalInfo());
+                    affectedPhis.addAll(invoke.outValue().uniquePhiUsers());
+                  } else {
+                    newOutValue = null;
+                  }
+                } else {
+                  newOutValue = makeOutValue(invoke, code);
+                }
 
                 Map<SingleNumberValue, Map<DexType, Value>> parameterMap = new IdentityHashMap<>();
 
@@ -418,19 +429,6 @@
                     iterator.add(constantReturnMaterializingInstruction);
                   }
                 }
-
-                DexType actualReturnType = actualTarget.proto.returnType;
-                DexType expectedReturnType = graphLens.lookupType(invokedMethod.proto.returnType);
-                if (newInvoke.hasOutValue() && actualReturnType != expectedReturnType) {
-                  throw new Unreachable(
-                      "Unexpected need to insert a cast. Possibly related to resolving"
-                          + " b/79143143.\n"
-                          + invokedMethod
-                          + " type changed from "
-                          + expectedReturnType
-                          + " to "
-                          + actualReturnType);
-                }
               }
             }
             break;
@@ -663,6 +661,9 @@
               if (ret.isReturnVoid()) {
                 break;
               }
+
+              insertCastForReturnIfNeeded(code, blocks, iterator, ret);
+
               DexType returnType = code.context().getReturnType();
               Value retValue = ret.returnValue();
               DexType initialType =
@@ -820,6 +821,46 @@
     return iterator;
   }
 
+  private InstructionListIterator insertCastForReturnIfNeeded(
+      IRCode code, BasicBlockIterator blocks, InstructionListIterator iterator, Return ret) {
+    RewrittenPrototypeDescription prototypeChanges =
+        appView
+            .graphLens()
+            .lookupPrototypeChangesForMethodDefinition(code.context().getReference());
+    if (!prototypeChanges.hasRewrittenReturnInfo()
+        || !prototypeChanges.getRewrittenReturnInfo().hasCastType()) {
+      return iterator;
+    }
+
+    iterator.previous();
+
+    // Split the block and reset the block iterator.
+    if (ret.getBlock().hasCatchHandlers()) {
+      BasicBlock splitBlock = iterator.splitCopyCatchHandlers(code, blocks, options);
+      BasicBlock previousBlock = blocks.previousUntil(block -> block == splitBlock);
+      assert previousBlock != null;
+      blocks.next();
+      iterator = splitBlock.listIterator(code);
+    }
+
+    DexType castType = prototypeChanges.getRewrittenReturnInfo().getCastType();
+    Value returnValue = ret.returnValue();
+    CheckCast checkCast =
+        SafeCheckCast.builder()
+            .setObject(returnValue)
+            .setFreshOutValue(
+                code, castType.toTypeElement(appView, returnValue.getType().nullability()))
+            .setCastType(castType)
+            .setPosition(ret.getPosition())
+            .build();
+    iterator.add(checkCast);
+    ret.replaceValue(0, checkCast.outValue());
+
+    Instruction next = iterator.next();
+    assert next == ret;
+    return iterator;
+  }
+
   private DexField rewriteFieldReference(FieldLookupResult lookup, ProgramMethod context) {
     if (lookup.hasReboundReference()) {
       DexClass holder = appView.definitionFor(lookup.getReboundReference().getHolderType());
diff --git a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorProgramOptimizer.java b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorProgramOptimizer.java
index 0264a50..2a5a4cd 100644
--- a/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorProgramOptimizer.java
+++ b/src/main/java/com/android/tools/r8/optimize/argumentpropagation/ArgumentPropagatorProgramOptimizer.java
@@ -36,7 +36,6 @@
 import com.android.tools.r8.shaking.KeepFieldInfo;
 import com.android.tools.r8.utils.AccessUtils;
 import com.android.tools.r8.utils.BooleanBox;
-import com.android.tools.r8.utils.BooleanUtils;
 import com.android.tools.r8.utils.IntBox;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.Pair;
@@ -61,6 +60,7 @@
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ExecutionException;
@@ -74,17 +74,17 @@
   static class AllowedPrototypeChanges {
 
     private static final AllowedPrototypeChanges EMPTY =
-        new AllowedPrototypeChanges(false, Int2ReferenceMaps.emptyMap(), IntSets.EMPTY_SET);
+        new AllowedPrototypeChanges(null, Int2ReferenceMaps.emptyMap(), IntSets.EMPTY_SET);
 
-    boolean canRewriteToVoid;
+    DexType newReturnType;
     Int2ReferenceMap<DexType> newParameterTypes;
     IntSet removableParameterIndices;
 
     AllowedPrototypeChanges(
-        boolean canRewriteToVoid,
+        DexType newReturnType,
         Int2ReferenceMap<DexType> newParameterTypes,
         IntSet removableParameterIndices) {
-      this.canRewriteToVoid = canRewriteToVoid;
+      this.newReturnType = newReturnType;
       this.newParameterTypes = newParameterTypes;
       this.removableParameterIndices = removableParameterIndices;
     }
@@ -93,6 +93,10 @@
       if (prototypeChanges.isEmpty()) {
         return empty();
       }
+      DexType newReturnType =
+          prototypeChanges.hasRewrittenReturnInfo()
+              ? prototypeChanges.getRewrittenReturnInfo().getNewType()
+              : null;
       Int2ReferenceMap<DexType> newParameterTypes = new Int2ReferenceOpenHashMap<>();
       IntSet removableParameterIndices = new IntOpenHashSet();
       prototypeChanges
@@ -108,9 +112,7 @@
                 }
               });
       return new AllowedPrototypeChanges(
-          prototypeChanges.hasBeenChangedToReturnVoid(),
-          newParameterTypes,
-          removableParameterIndices);
+          newReturnType, newParameterTypes, removableParameterIndices);
     }
 
     public static AllowedPrototypeChanges empty() {
@@ -119,7 +121,7 @@
 
     @Override
     public int hashCode() {
-      return BooleanUtils.intValue(canRewriteToVoid) | (removableParameterIndices.hashCode() << 1);
+      return Objects.hash(newReturnType, newParameterTypes, removableParameterIndices);
     }
 
     @Override
@@ -128,7 +130,7 @@
         return false;
       }
       AllowedPrototypeChanges other = (AllowedPrototypeChanges) obj;
-      return canRewriteToVoid == other.canRewriteToVoid
+      return newReturnType == other.newReturnType
           && newParameterTypes.equals(other.newParameterTypes)
           && removableParameterIndices.equals(other.removableParameterIndices);
     }
@@ -335,6 +337,10 @@
               return;
             }
 
+            if (containsImmediateInterfaceOfInstantiatedLambda(methods)) {
+              return;
+            }
+
             // Find the parameters that are either (i) the same constant, (ii) all unused, or (iii)
             // all possible to strengthen to the same stronger type, in all methods.
             Int2ReferenceMap<DexType> newParameterTypes = new Int2ReferenceOpenHashMap<>();
@@ -342,36 +348,35 @@
             for (int parameterIndex = 1;
                 parameterIndex < signature.getProto().getArity() + 1;
                 parameterIndex++) {
-              if (!containsImmediateInterfaceOfInstantiatedLambda(methods)) {
-                if (canRemoveParameterFromVirtualMethods(methods, parameterIndex)) {
-                  removableVirtualMethodParametersInAllMethods.add(parameterIndex);
-                } else {
-                  DexType newParameterType =
-                      getNewParameterTypeForVirtualMethods(methods, parameterIndex);
-                  if (newParameterType != null) {
-                    newParameterTypes.put(parameterIndex, newParameterType);
-                  }
+              if (canRemoveParameterFromVirtualMethods(methods, parameterIndex)) {
+                removableVirtualMethodParametersInAllMethods.add(parameterIndex);
+              } else {
+                DexType newParameterType =
+                    getNewParameterTypeForVirtualMethods(methods, parameterIndex);
+                if (newParameterType != null) {
+                  newParameterTypes.put(parameterIndex, newParameterType);
                 }
               }
             }
 
             // If any prototype changes can be made, record it.
             SingleValue returnValueForVirtualMethods =
-                getReturnValueForVirtualMethods(signature, methods);
-            boolean canRewriteVirtualMethodsToVoid = returnValueForVirtualMethods != null;
-            if (canRewriteVirtualMethodsToVoid
+                getReturnValueForVirtualMethods(methods, signature);
+            DexType newReturnType =
+                getNewReturnTypeForVirtualMethods(methods, returnValueForVirtualMethods);
+            if (newReturnType != null
                 || !newParameterTypes.isEmpty()
                 || !removableVirtualMethodParametersInAllMethods.isEmpty()) {
               allowedPrototypeChangesForVirtualMethods.put(
                   signature,
                   new AllowedPrototypeChanges(
-                      canRewriteVirtualMethodsToVoid,
+                      newReturnType,
                       newParameterTypes,
                       removableVirtualMethodParametersInAllMethods));
             }
 
             // Also record the found return value for abstract virtual methods.
-            if (canRewriteVirtualMethodsToVoid) {
+            if (newReturnType == dexItemFactory.voidType) {
               for (ProgramMethod method : methods) {
                 if (method.getAccessFlags().isAbstract()) {
                   returnValuesForVirtualMethods.put(method, returnValueForVirtualMethods);
@@ -409,7 +414,7 @@
     }
 
     private SingleValue getReturnValueForVirtualMethods(
-        DexMethodSignature signature, ProgramMethodSet methods) {
+        ProgramMethodSet methods, DexMethodSignature signature) {
       if (signature.getReturnType().isVoidType()) {
         return null;
       }
@@ -417,14 +422,6 @@
       SingleValue returnValue = null;
       for (ProgramMethod method : methods) {
         if (method.getDefinition().isAbstract()) {
-          DexProgramClass holder = method.getHolder();
-          if (holder.isInterface()) {
-            ObjectAllocationInfoCollection objectAllocationInfoCollection =
-                appView.appInfo().getObjectAllocationInfoCollection();
-            if (objectAllocationInfoCollection.isImmediateInterfaceOfInstantiatedLambda(holder)) {
-              return null;
-            }
-          }
           // OK, this can be rewritten to have void return type.
           continue;
         }
@@ -480,6 +477,28 @@
       return true;
     }
 
+    private DexType getNewReturnTypeForVirtualMethods(
+        ProgramMethodSet methods, SingleValue returnValue) {
+      if (returnValue != null) {
+        return dexItemFactory.voidType;
+      }
+      DexType newReturnType = null;
+      for (ProgramMethod method : methods) {
+        if (method.getDefinition().isAbstract()) {
+          // OK, this method can have any return type.
+          continue;
+        }
+        DexType newReturnTypeForMethod = getNewReturnType(method, null);
+        if (newReturnTypeForMethod == null
+            || (newReturnType != null && newReturnType != newReturnTypeForMethod)) {
+          return null;
+        }
+        newReturnType = newReturnTypeForMethod;
+      }
+      assert newReturnType == null || newReturnType != methods.getFirst().getReturnType();
+      return newReturnType;
+    }
+
     private DexType getNewParameterTypeForVirtualMethods(
         ProgramMethodSet methods, int parameterIndex) {
       DexType newParameterType = null;
@@ -723,7 +742,7 @@
       if (method.getAccessFlags().isAbstract()) {
         return computePrototypeChangesForAbstractVirtualMethod(
             method,
-            allowedPrototypeChanges.canRewriteToVoid,
+            allowedPrototypeChanges.newReturnType,
             newParameterTypes,
             removableParameterIndices);
       }
@@ -731,7 +750,7 @@
       RewrittenPrototypeDescription prototypeChanges =
           computePrototypeChangesForMethod(
               method,
-              allowedPrototypeChanges.canRewriteToVoid,
+              allowedPrototypeChanges.newReturnType,
               newParameterTypes::get,
               removableParameterIndices::contains);
       assert prototypeChanges.getArgumentInfoCollection().numberOfRemovedArguments()
@@ -741,10 +760,9 @@
 
     private RewrittenPrototypeDescription computePrototypeChangesForAbstractVirtualMethod(
         ProgramMethod method,
-        boolean canRewriteToVoid,
+        DexType newReturnType,
         Int2ReferenceMap<DexType> newParameterTypes,
         IntSet removableParameterIndices) {
-
       // Treat the parameters as unused.
       ArgumentInfoCollection.Builder argumentInfoCollectionBuilder =
           ArgumentInfoCollection.builder();
@@ -768,7 +786,7 @@
       }
       return RewrittenPrototypeDescription.create(
           Collections.emptyList(),
-          computeReturnChangesForMethod(method, canRewriteToVoid),
+          computeReturnChangesForMethod(method, newReturnType),
           argumentInfoCollectionBuilder.build());
     }
 
@@ -778,7 +796,62 @@
               ? parameterIndex -> getNewParameterType(method, parameterIndex)
               : parameterIndex -> null;
       return computePrototypeChangesForMethod(
-          method, true, parameterIndexToParameterType, parameterIndex -> true);
+          method, getNewReturnType(method), parameterIndexToParameterType, parameterIndex -> true);
+    }
+
+    private DexType getNewReturnType(ProgramMethod method) {
+      return getNewReturnType(method, getReturnValue(method));
+    }
+
+    private DexType getNewReturnType(ProgramMethod method, SingleValue returnValue) {
+      DexType staticType = method.getReturnType();
+      if (staticType.isVoidType()
+          || !appView.getKeepInfo(method).isReturnTypeStrengtheningAllowed(options)) {
+        return null;
+      }
+      if (returnValue != null) {
+        return dexItemFactory.voidType;
+      }
+      TypeElement newReturnTypeElement =
+          method
+              .getOptimizationInfo()
+              .getDynamicType()
+              .getDynamicUpperBoundType(staticType.toTypeElement(appView));
+      assert newReturnTypeElement.isTop()
+          || newReturnTypeElement.lessThanOrEqual(staticType.toTypeElement(appView), appView);
+      if (!newReturnTypeElement.isClassType()) {
+        assert newReturnTypeElement.isArrayType() || newReturnTypeElement.isTop();
+        return null;
+      }
+      DexType newReturnType = newReturnTypeElement.asClassType().toDexType(dexItemFactory);
+      if (newReturnType == staticType) {
+        return null;
+      }
+      if (!appView.appInfo().isSubtype(newReturnType, staticType)) {
+        return null;
+      }
+      return AccessUtils.isAccessibleInSameContextsAs(
+              newReturnType, method.getReturnType(), appView)
+          ? newReturnType
+          : null;
+    }
+
+    private SingleValue getReturnValue(ProgramMethod method) {
+      AbstractValue returnValue;
+      if (method.getReturnType().isAlwaysNull(appView)) {
+        returnValue = appView.abstractValueFactory().createNullValue();
+      } else if (method.getDefinition().belongsToVirtualPool()
+          && returnValuesForVirtualMethods.containsKey(method)) {
+        assert method.getAccessFlags().isAbstract();
+        returnValue = returnValuesForVirtualMethods.get(method);
+      } else {
+        returnValue = method.getOptimizationInfo().getAbstractReturnValue();
+      }
+
+      return returnValue.isSingleValue()
+              && returnValue.asSingleValue().isMaterializableInAllContexts(appView)
+          ? returnValue.asSingleValue()
+          : null;
     }
 
     private DexType getNewParameterType(ProgramMethod method, int parameterIndex) {
@@ -809,6 +882,9 @@
       if (newParameterType == staticType) {
         return null;
       }
+      if (!appView.appInfo().isSubtype(newParameterType, staticType)) {
+        return null;
+      }
       return AccessUtils.isAccessibleInSameContextsAs(newParameterType, staticType, appView)
           ? newParameterType
           : null;
@@ -816,12 +892,12 @@
 
     private RewrittenPrototypeDescription computePrototypeChangesForMethod(
         ProgramMethod method,
-        boolean allowToVoidRewriting,
+        DexType newReturnType,
         IntFunction<DexType> newParameterTypes,
         IntPredicate removableParameterIndices) {
       return RewrittenPrototypeDescription.create(
           Collections.emptyList(),
-          computeReturnChangesForMethod(method, allowToVoidRewriting),
+          computeReturnChangesForMethod(method, newReturnType),
           computeParameterChangesForMethod(method, newParameterTypes, removableParameterIndices));
     }
 
@@ -870,30 +946,20 @@
     }
 
     private RewrittenTypeInfo computeReturnChangesForMethod(
-        ProgramMethod method, boolean allowToVoidRewriting) {
-      if (!allowToVoidRewriting) {
+        ProgramMethod method, DexType newReturnType) {
+      if (newReturnType == null) {
         assert !returnValuesForVirtualMethods.containsKey(method);
         return null;
       }
-
-      AbstractValue returnValue;
-      if (method.getReturnType().isAlwaysNull(appView)) {
-        returnValue = appView.abstractValueFactory().createNullValue();
-      } else if (method.getDefinition().belongsToVirtualPool()
-          && returnValuesForVirtualMethods.containsKey(method)) {
-        assert method.getAccessFlags().isAbstract();
-        returnValue = returnValuesForVirtualMethods.get(method);
-      } else {
-        returnValue = method.getOptimizationInfo().getAbstractReturnValue();
-      }
-
-      if (!returnValue.isSingleValue()
-          || !returnValue.asSingleValue().isMaterializableInAllContexts(appView)) {
-        return null;
-      }
-
-      SingleValue singleValue = returnValue.asSingleValue();
-      return RewrittenTypeInfo.toVoid(method.getReturnType(), dexItemFactory, singleValue);
+      assert newReturnType != method.getReturnType();
+      return RewrittenTypeInfo.builder()
+          .applyIf(
+              newReturnType == dexItemFactory.voidType,
+              builder -> builder.setSingleValue(getReturnValue(method)))
+          .setCastType(newReturnType)
+          .setOldType(method.getReturnType())
+          .setNewType(newReturnType)
+          .build();
     }
   }
 }
diff --git a/src/test/java/com/android/tools/r8/apimodel/ApiModelNoOutlineForFullyMockedTest.java b/src/test/java/com/android/tools/r8/apimodel/ApiModelNoOutlineForFullyMockedTest.java
index 91c55e4..76e8b57 100644
--- a/src/test/java/com/android/tools/r8/apimodel/ApiModelNoOutlineForFullyMockedTest.java
+++ b/src/test/java/com/android/tools/r8/apimodel/ApiModelNoOutlineForFullyMockedTest.java
@@ -13,6 +13,7 @@
 import static org.junit.Assume.assumeFalse;
 
 import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.NoReturnTypeStrengthening;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
@@ -60,6 +61,7 @@
         .apply(ApiModelingTestHelper::enableOutliningOfMethods)
         .apply(ApiModelingTestHelper::enableStubbingOfClasses)
         .enableInliningAnnotations()
+        .enableNoReturnTypeStrengtheningAnnotations()
         .compile()
         .applyIf(
             parameters.isDexRuntime()
@@ -93,6 +95,8 @@
   public static class Main {
 
     @NeverInline
+    // TODO(b/214329925): Type strengthening should consult API database.
+    @NoReturnTypeStrengthening
     public static Object create() {
       return AndroidBuildVersion.VERSION >= 23 ? new LibraryClass() : null;
     }
diff --git a/src/test/java/com/android/tools/r8/bridgeremoval/hoisting/BridgeHoistingAccessibilityTest.java b/src/test/java/com/android/tools/r8/bridgeremoval/hoisting/BridgeHoistingAccessibilityTest.java
index 45855d8..bf166ba 100644
--- a/src/test/java/com/android/tools/r8/bridgeremoval/hoisting/BridgeHoistingAccessibilityTest.java
+++ b/src/test/java/com/android/tools/r8/bridgeremoval/hoisting/BridgeHoistingAccessibilityTest.java
@@ -77,6 +77,8 @@
         .enableNoHorizontalClassMergingAnnotations()
         .enableNoVerticalClassMergingAnnotations()
         .enableNeverClassInliningAnnotations()
+        // TODO(b/173398086): uniqueMethodWithName() does not work with argument changes.
+        .noMinification()
         .setMinApi(parameters.getApiLevel())
         .compile()
         .inspect(this::inspect)
diff --git a/src/test/java/com/android/tools/r8/bridgeremoval/hoisting/KeptBridgeHoistingTest.java b/src/test/java/com/android/tools/r8/bridgeremoval/hoisting/KeptBridgeHoistingTest.java
index fc6ac6b..e45bdcc 100644
--- a/src/test/java/com/android/tools/r8/bridgeremoval/hoisting/KeptBridgeHoistingTest.java
+++ b/src/test/java/com/android/tools/r8/bridgeremoval/hoisting/KeptBridgeHoistingTest.java
@@ -45,6 +45,8 @@
         .enableInliningAnnotations()
         .enableNoVerticalClassMergingAnnotations()
         .enableNeverClassInliningAnnotations()
+        // TODO(b/173398086): uniqueMethodWithName() does not work with signature changes.
+        .noMinification()
         .setMinApi(parameters.getApiLevel())
         .compile()
         .inspect(this::inspect)
diff --git a/src/test/java/com/android/tools/r8/bridgeremoval/hoisting/NonSuperclassBridgeHoistingTest.java b/src/test/java/com/android/tools/r8/bridgeremoval/hoisting/NonSuperclassBridgeHoistingTest.java
index e987a43..2819311 100644
--- a/src/test/java/com/android/tools/r8/bridgeremoval/hoisting/NonSuperclassBridgeHoistingTest.java
+++ b/src/test/java/com/android/tools/r8/bridgeremoval/hoisting/NonSuperclassBridgeHoistingTest.java
@@ -47,6 +47,8 @@
         .enableInliningAnnotations()
         .enableNoVerticalClassMergingAnnotations()
         .enableNeverClassInliningAnnotations()
+        // TODO(b/173398086): uniqueMethodWithName() does not work with signature changes.
+        .noMinification()
         .setMinApi(parameters.getApiLevel())
         .compile()
         .inspect(this::inspect)
diff --git a/src/test/java/com/android/tools/r8/bridgeremoval/hoisting/PositiveBridgeHoistingTest.java b/src/test/java/com/android/tools/r8/bridgeremoval/hoisting/PositiveBridgeHoistingTest.java
index ae77a45..bdd55fc 100644
--- a/src/test/java/com/android/tools/r8/bridgeremoval/hoisting/PositiveBridgeHoistingTest.java
+++ b/src/test/java/com/android/tools/r8/bridgeremoval/hoisting/PositiveBridgeHoistingTest.java
@@ -56,6 +56,8 @@
         .enableInliningAnnotations()
         .enableNeverClassInliningAnnotations()
         .enableNoHorizontalClassMergingAnnotations()
+        // TODO(b/173398086): uniqueMethodWithName() does not work with argument changes.
+        .noMinification()
         .setMinApi(parameters.getApiLevel())
         .compile()
         .inspect(this::inspect)
diff --git a/src/test/java/com/android/tools/r8/bridgeremoval/hoisting/testclasses/BridgeHoistingAccessibilityTestClasses.java b/src/test/java/com/android/tools/r8/bridgeremoval/hoisting/testclasses/BridgeHoistingAccessibilityTestClasses.java
index 64c19e5..36cc8be 100644
--- a/src/test/java/com/android/tools/r8/bridgeremoval/hoisting/testclasses/BridgeHoistingAccessibilityTestClasses.java
+++ b/src/test/java/com/android/tools/r8/bridgeremoval/hoisting/testclasses/BridgeHoistingAccessibilityTestClasses.java
@@ -31,6 +31,7 @@
     }
   }
 
+  @NoHorizontalClassMerging
   @NoVerticalClassMerging
   public static class AWithRangedInvoke {
 
diff --git a/src/test/java/com/android/tools/r8/enumunboxing/FailingMethodEnumUnboxingTest.java b/src/test/java/com/android/tools/r8/enumunboxing/FailingMethodEnumUnboxingTest.java
index ce3e3f9..0a00994 100644
--- a/src/test/java/com/android/tools/r8/enumunboxing/FailingMethodEnumUnboxingTest.java
+++ b/src/test/java/com/android/tools/r8/enumunboxing/FailingMethodEnumUnboxingTest.java
@@ -64,6 +64,8 @@
                         FailingParameterType.MyEnum.class))
             .enableInliningAnnotations()
             .enableNeverClassInliningAnnotations()
+            // TODO(b/173398086): uniqueMethodWithName() does not work with signature changes.
+            .noMinification()
             .setMinApi(parameters.getApiLevel())
             .compile()
             .inspect(this::assertEnumsAsExpected);
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/dynamictype/DynamicTypeOptimizationTest.java b/src/test/java/com/android/tools/r8/ir/optimize/dynamictype/DynamicTypeOptimizationTest.java
index 52736da..f3db3a7 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/dynamictype/DynamicTypeOptimizationTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/dynamictype/DynamicTypeOptimizationTest.java
@@ -10,6 +10,7 @@
 import static org.junit.Assert.assertTrue;
 
 import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.NoReturnTypeStrengthening;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.utils.StringUtils;
@@ -43,6 +44,7 @@
         // Keep B to ensure that we will treat it as being instantiated.
         .addKeepClassRulesWithAllowObfuscation(B.class)
         .enableInliningAnnotations()
+        .enableNoReturnTypeStrengtheningAnnotations()
         .setMinApi(parameters.getApiLevel())
         .compile()
         .inspect(this::inspect)
@@ -134,6 +136,7 @@
     }
 
     @NeverInline
+    @NoReturnTypeStrengthening
     private static I get() {
       return new A();
     }
diff --git a/src/test/java/com/android/tools/r8/kotlin/metadata/MetadataRewriteInCompanionTest.java b/src/test/java/com/android/tools/r8/kotlin/metadata/MetadataRewriteInCompanionTest.java
index eae75ac..472435a 100644
--- a/src/test/java/com/android/tools/r8/kotlin/metadata/MetadataRewriteInCompanionTest.java
+++ b/src/test/java/com/android/tools/r8/kotlin/metadata/MetadataRewriteInCompanionTest.java
@@ -128,6 +128,8 @@
             .addKeepRules("-keepclassmembers class **.B { *** Companion; }")
             // Keep the class of the companion class.
             .addKeepRules("-keep class **.*$Companion")
+            // TODO(b/173398086): uniqueMethodWithName() does not work with signature changes.
+            .addKeepRules("-noreturntypestrengthening class **.B { *** access$getElt1$cp(...); }")
             // No rule for Super, but will be kept and renamed.
             .addKeepAttributes(ProguardKeepAttributes.RUNTIME_VISIBLE_ANNOTATIONS)
             // To keep @JvmField annotation
@@ -135,6 +137,7 @@
             // To keep ...$Companion structure
             .addKeepAttributes(ProguardKeepAttributes.INNER_CLASSES)
             .addKeepAttributes(ProguardKeepAttributes.ENCLOSING_METHOD)
+            .enableProguardTestOptions()
             .compile()
             .inspect(codeInspector -> inspect(codeInspector, false))
             .writeToZip();
diff --git a/src/test/java/com/android/tools/r8/maindexlist/MainDexDevirtualizerTest.java b/src/test/java/com/android/tools/r8/maindexlist/MainDexDevirtualizerTest.java
index ad0cf84..783b9f1 100644
--- a/src/test/java/com/android/tools/r8/maindexlist/MainDexDevirtualizerTest.java
+++ b/src/test/java/com/android/tools/r8/maindexlist/MainDexDevirtualizerTest.java
@@ -16,6 +16,7 @@
 
 import com.android.tools.r8.NeverClassInline;
 import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.NoReturnTypeStrengthening;
 import com.android.tools.r8.NoVerticalClassMerging;
 import com.android.tools.r8.R8FullTestBuilder;
 import com.android.tools.r8.TestBase;
@@ -108,6 +109,7 @@
     Box<String> mainDexStringList = new Box<>("");
     testForR8(parameters.getBackend())
         .addProgramClasses(I.class, Provider.class, A.class, Main.class)
+        .enableNoReturnTypeStrengtheningAnnotations()
         .enableNoVerticalClassMergingAnnotations()
         .enableInliningAnnotations()
         .enableNeverClassInliningAnnotations()
@@ -156,6 +158,7 @@
 
   public static class Provider {
     @NeverInline
+    @NoReturnTypeStrengthening
     public static I getImpl() {
       return new A(); // <-- We will call-site optimize getImpl() to always return A.
     }
diff --git a/src/test/java/com/android/tools/r8/naming/overloadaggressively/B.java b/src/test/java/com/android/tools/r8/naming/overloadaggressively/B.java
index 298c597..c18b3d7 100644
--- a/src/test/java/com/android/tools/r8/naming/overloadaggressively/B.java
+++ b/src/test/java/com/android/tools/r8/naming/overloadaggressively/B.java
@@ -4,6 +4,7 @@
 package com.android.tools.r8.naming.overloadaggressively;
 
 import com.android.tools.r8.NeverPropagateValue;
+import com.android.tools.r8.NoReturnTypeStrengthening;
 
 public class B {
   volatile int f1 = 8;
@@ -16,6 +17,7 @@
   }
 
   @NeverPropagateValue
+  @NoReturnTypeStrengthening
   public Object getF2() {
     return f2;
   }
diff --git a/src/test/java/com/android/tools/r8/naming/overloadaggressively/OverloadAggressivelyTest.java b/src/test/java/com/android/tools/r8/naming/overloadaggressively/OverloadAggressivelyTest.java
index cbd22a0..c54b51b 100644
--- a/src/test/java/com/android/tools/r8/naming/overloadaggressively/OverloadAggressivelyTest.java
+++ b/src/test/java/com/android/tools/r8/naming/overloadaggressively/OverloadAggressivelyTest.java
@@ -11,6 +11,8 @@
 
 import com.android.tools.r8.R8Command;
 import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.ToolHelper;
 import com.android.tools.r8.ToolHelper.ProcessResult;
 import com.android.tools.r8.graph.DexEncodedField;
@@ -27,19 +29,18 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
 
 @RunWith(Parameterized.class)
 public class OverloadAggressivelyTest extends TestBase {
 
-  private Backend backend;
+  @Parameter(0)
+  public TestParameters parameters;
 
-  @Parameterized.Parameters(name = "Backend: {0}")
-  public static Backend[] data() {
-    return ToolHelper.getBackends();
-  }
-
-  public OverloadAggressivelyTest(Backend backend) {
-    this.backend = backend;
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withCfRuntimes().build();
   }
 
   private AndroidApp runR8(AndroidApp app, Class<?> main, Path out, boolean overloadaggressively)
@@ -56,8 +57,8 @@
                     keepMainProguardConfiguration(main),
                     overloadaggressively ? "-overloadaggressively" : ""),
                 Origin.unknown())
-            .setOutput(out, outputMode(backend))
-            .addLibraryFiles(TestBase.runtimeJar(backend))
+            .setOutput(out, outputMode(parameters.getBackend()))
+            .addLibraryFiles(parameters.getDefaultRuntimeLibrary())
             .build();
     return ToolHelper.runR8(
         command,
@@ -68,10 +69,10 @@
   }
 
   private ProcessResult runRaw(AndroidApp app, String main) throws IOException {
-    if (backend == Backend.DEX) {
+    if (parameters.isDexRuntime()) {
       return runOnArtRaw(app, main);
     } else {
-      assert backend == Backend.CF;
+      assert parameters.isCfRuntime();
       return runOnJavaRaw(app, main, Collections.emptyList());
     }
   }
@@ -178,19 +179,20 @@
     String expected = StringUtils.lines("diff: 0", "d8 v.s. d8", "r8 v.s. r8");
     String expectedOverloadAggressively = StringUtils.lines("diff: 0", "d8 v.s. 8", "r8 v.s. 8");
 
-    if (backend.isCf()) {
+    if (parameters.isCfRuntime()) {
       testForJvm().addTestClasspath().run(MethodResolution.class).assertSuccessWithOutput(expected);
     }
 
-    testForR8Compat(backend)
+    testForR8Compat(parameters.getBackend())
         .addProgramClasses(MethodResolution.class, B.class)
         .addKeepMainRule(MethodResolution.class)
         .addOptionsModification(options -> options.inlinerOptions().enableInlining = false)
         .applyIf(overloadaggressively, builder -> builder.addKeepRules("-overloadaggressively"))
         .enableMemberValuePropagationAnnotations()
+        .enableNoReturnTypeStrengtheningAnnotations()
         .compile()
         .inspect(inspector -> inspect(inspector, overloadaggressively))
-        .run(MethodResolution.class)
+        .run(parameters.getRuntime(), MethodResolution.class)
         .applyIf(
             overloadaggressively,
             runResult -> runResult.assertSuccessWithOutput(expectedOverloadAggressively),
@@ -218,7 +220,7 @@
 
   @Test
   public void testMethodResolution_aggressively() throws Exception {
-    assumeTrue(backend == Backend.CF);
+    assumeTrue(parameters.isCfRuntime());
     methodResolution(true);
   }
 
diff --git a/src/test/java/com/android/tools/r8/optimize/argumentpropagation/ReturnTypeStrengtheningTest.java b/src/test/java/com/android/tools/r8/optimize/argumentpropagation/ReturnTypeStrengtheningTest.java
new file mode 100644
index 0000000..89dccdb
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/optimize/argumentpropagation/ReturnTypeStrengtheningTest.java
@@ -0,0 +1,116 @@
+// Copyright (c) 2022, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.optimize.argumentpropagation;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.NoVerticalClassMerging;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import com.android.tools.r8.utils.codeinspector.FoundClassSubject;
+import com.android.tools.r8.utils.codeinspector.FoundMethodSubject;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class ReturnTypeStrengtheningTest extends TestBase {
+
+  @Parameter(0)
+  public TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimesAndApiLevels().build();
+  }
+
+  @Test
+  public void test() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepMainRule(Main.class)
+        .enableInliningAnnotations()
+        .enableNoVerticalClassMergingAnnotations()
+        // TODO(b/173398086): uniqueMethodWithName() does not work with argument changes.
+        .noMinification()
+        .setMinApi(parameters.getApiLevel())
+        .compile()
+        .inspect(
+            inspector -> {
+              ClassSubject mainClassSubject = inspector.clazz(Main.class);
+              assertThat(mainClassSubject, isPresent());
+
+              ClassSubject aClassSubject = inspector.clazz(A.class);
+              assertThat(aClassSubject, isPresent());
+
+              // Return type of get() should be strengthened to A.
+              MethodSubject getMethodSubject = mainClassSubject.uniqueMethodWithName("get");
+              assertThat(getMethodSubject, isPresent());
+              assertEquals(
+                  aClassSubject.getFinalName(),
+                  getMethodSubject.getProgramMethod().getReturnType().getTypeName());
+
+              // Method consume(I) should be rewritten to consume(A).
+              MethodSubject testBMethodSubject = mainClassSubject.uniqueMethodWithName("consume");
+              assertThat(testBMethodSubject, isPresent());
+              assertEquals(
+                  aClassSubject.getFinalName(),
+                  testBMethodSubject.getProgramMethod().getParameter(0).getTypeName());
+
+              // There should be no casts in the application.
+              for (FoundClassSubject classSubject : inspector.allClasses()) {
+                for (FoundMethodSubject methodSubject : classSubject.allMethods()) {
+                  assertTrue(
+                      methodSubject
+                          .streamInstructions()
+                          .noneMatch(InstructionSubject::isCheckCast));
+                }
+              }
+            })
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("A");
+  }
+
+  static class Main {
+
+    public static void main(String[] args) {
+      consume(get());
+    }
+
+    @NeverInline
+    static I get() {
+      return new A();
+    }
+
+    @NeverInline
+    static void consume(I i) {
+      i.m();
+    }
+  }
+
+  @NoVerticalClassMerging
+  interface I {
+
+    void m();
+  }
+
+  static class A implements I {
+
+    @Override
+    public void m() {
+      System.out.println("A");
+    }
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/shaking/keptgraph/KeptByReachableSubclassTest.java b/src/test/java/com/android/tools/r8/shaking/keptgraph/KeptByReachableSubclassTest.java
index 528a4ac..b848b1c 100644
--- a/src/test/java/com/android/tools/r8/shaking/keptgraph/KeptByReachableSubclassTest.java
+++ b/src/test/java/com/android/tools/r8/shaking/keptgraph/KeptByReachableSubclassTest.java
@@ -7,6 +7,7 @@
 import static org.junit.Assert.assertEquals;
 
 import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.NoReturnTypeStrengthening;
 import com.android.tools.r8.NoVerticalClassMerging;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
@@ -44,6 +45,7 @@
     GraphInspector inspector =
         testForR8(parameters.getBackend())
             .enableGraphInspector()
+            .enableNoReturnTypeStrengtheningAnnotations()
             .enableNoVerticalClassMergingAnnotations()
             .enableInliningAnnotations()
             .addProgramClasses(CLASS, A.class, B.class)
@@ -95,6 +97,7 @@
   public static class TestClass {
 
     @NeverInline
+    @NoReturnTypeStrengthening
     static A create() {
       return new B();
     }