Automatically soft pin modeled proto methods

Change-Id: I398c1a16dc8f95818424af499d054fd55e9a1e70
diff --git a/src/main/java/com/android/tools/r8/graph/DexDefinitionSupplier.java b/src/main/java/com/android/tools/r8/graph/DexDefinitionSupplier.java
index d18db79..57212d4 100644
--- a/src/main/java/com/android/tools/r8/graph/DexDefinitionSupplier.java
+++ b/src/main/java/com/android/tools/r8/graph/DexDefinitionSupplier.java
@@ -76,6 +76,11 @@
   @Deprecated
   DexClass definitionFor(DexType type);
 
+  default DexClassAndMethod definitionFor(DexMethod method) {
+    DexClass holder = definitionFor(method.getHolderType());
+    return holder != null ? holder.lookupClassMethod(method) : null;
+  }
+
   // Use programDefinitionFor with a context.
   @Deprecated
   default DexProgramClass definitionForProgramType(DexType type) {
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteShrinker.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteShrinker.java
index 7596684..4d9b9a0 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteShrinker.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteShrinker.java
@@ -4,38 +4,49 @@
 
 package com.android.tools.r8.ir.analysis.proto;
 
+import static com.android.tools.r8.graph.DexClassAndMethod.asProgramMethodOrNull;
 import static com.android.tools.r8.graph.DexProgramClass.asProgramClassOrNull;
 import static com.android.tools.r8.ir.analysis.proto.ProtoUtils.getInfoValueFromMessageInfoConstructionInvoke;
 import static com.android.tools.r8.ir.analysis.proto.ProtoUtils.getObjectsValueFromMessageInfoConstructionInvoke;
 import static com.android.tools.r8.ir.analysis.proto.ProtoUtils.setObjectsValueForMessageInfoConstructionInvoke;
+import static com.android.tools.r8.ir.analysis.type.Nullability.definitelyNotNull;
 
+import com.android.tools.r8.graph.AccessControl;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.ir.analysis.proto.schema.ProtoMessageInfo;
 import com.android.tools.r8.ir.analysis.proto.schema.ProtoObject;
-import com.android.tools.r8.ir.analysis.type.Nullability;
+import com.android.tools.r8.ir.analysis.type.TypeAnalysis;
 import com.android.tools.r8.ir.analysis.type.TypeElement;
 import com.android.tools.r8.ir.code.ArrayPut;
 import com.android.tools.r8.ir.code.BasicBlock;
+import com.android.tools.r8.ir.code.BasicBlockIterator;
 import com.android.tools.r8.ir.code.ConstString;
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.code.IRCodeUtils;
 import com.android.tools.r8.ir.code.Instruction;
 import com.android.tools.r8.ir.code.InstructionListIterator;
+import com.android.tools.r8.ir.code.InvokeDirect;
 import com.android.tools.r8.ir.code.InvokeMethod;
+import com.android.tools.r8.ir.code.InvokeMethodWithReceiver;
 import com.android.tools.r8.ir.code.MemberType;
 import com.android.tools.r8.ir.code.NewArrayEmpty;
+import com.android.tools.r8.ir.code.NewInstance;
 import com.android.tools.r8.ir.code.Value;
 import com.android.tools.r8.ir.conversion.IRConverter;
 import com.android.tools.r8.ir.conversion.OneTimeMethodProcessor;
 import com.android.tools.r8.ir.optimize.info.OptimizationFeedbackIgnore;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.shaking.DependentMinimumKeepInfoCollection;
 import com.android.tools.r8.utils.Timing;
 import com.android.tools.r8.utils.collections.ProgramMethodSet;
+import com.google.common.collect.Sets;
 import java.util.List;
+import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.function.Consumer;
@@ -60,8 +71,8 @@
     // Types.
     this.objectArrayType =
         TypeElement.fromDexType(
-            appView.dexItemFactory().objectArrayType, Nullability.definitelyNotNull(), appView);
-    this.stringType = TypeElement.stringClassType(appView, Nullability.definitelyNotNull());
+            appView.dexItemFactory().objectArrayType, definitelyNotNull(), appView);
+    this.stringType = TypeElement.stringClassType(appView, definitelyNotNull());
   }
 
   public void extendRootSet(DependentMinimumKeepInfoCollection dependentMinimumKeepInfo) {
@@ -81,25 +92,17 @@
             .disallowOptimization();
       }
 
-      ProgramMethod newRepeatedGeneratedExtensionMethod =
-          generatedMessageLiteClass.lookupProgramMethod(
-              references.generatedMessageLiteMethods.newRepeatedGeneratedExtension);
-      if (newRepeatedGeneratedExtensionMethod != null) {
-        dependentMinimumKeepInfo
-            .getOrCreateUnconditionalMinimumKeepInfoFor(
-                newRepeatedGeneratedExtensionMethod.getReference())
-            .disallowOptimization();
-      }
-
-      ProgramMethod newSingularGeneratedExtensionMethod =
-          generatedMessageLiteClass.lookupProgramMethod(
-              references.generatedMessageLiteMethods.newSingularGeneratedExtension);
-      if (newSingularGeneratedExtensionMethod != null) {
-        dependentMinimumKeepInfo
-            .getOrCreateUnconditionalMinimumKeepInfoFor(
-                newSingularGeneratedExtensionMethod.getReference())
-            .disallowOptimization();
-      }
+      references.forEachMethodReference(
+          reference -> {
+            DexProgramClass holder =
+                asProgramClassOrNull(appView.definitionFor(reference.getHolderType()));
+            ProgramMethod method = reference.lookupOnProgramClass(holder);
+            if (method != null) {
+              dependentMinimumKeepInfo
+                  .getOrCreateUnconditionalMinimumKeepInfoFor(method.getReference())
+                  .disallowOptimization();
+            }
+          });
     }
   }
 
@@ -107,9 +110,98 @@
     ProgramMethod method = code.context();
     if (references.isDynamicMethod(method.getReference())) {
       rewriteDynamicMethod(method, code);
+    } else if (appView.hasLiveness()) {
+      optimizeNewMutableInstance(appView.withLiveness(), code);
     }
   }
 
+  private void optimizeNewMutableInstance(AppView<AppInfoWithLiveness> appView, IRCode code) {
+    Set<Value> affectedValues = Sets.newIdentityHashSet();
+    BasicBlockIterator blockIterator = code.listIterator();
+    while (blockIterator.hasNext()) {
+      BasicBlock block = blockIterator.next();
+      InstructionListIterator instructionIterator = block.listIterator(code);
+      while (instructionIterator.hasNext()) {
+        Instruction instruction = instructionIterator.next();
+        DexType newMutableInstanceType = getNewMutableInstanceType(appView, instruction);
+        if (newMutableInstanceType == null) {
+          continue;
+        }
+
+        DexMethod instanceInitializerReference =
+            appView.dexItemFactory().createInstanceInitializer(newMutableInstanceType);
+        ProgramMethod instanceInitializer =
+            asProgramMethodOrNull(appView.definitionFor(instanceInitializerReference));
+        if (instanceInitializer == null
+            || AccessControl.isMemberAccessible(
+                    instanceInitializer, instanceInitializer.getHolder(), code.context(), appView)
+                .isPossiblyFalse()) {
+          continue;
+        }
+
+        NewInstance newInstance =
+            NewInstance.builder()
+                .setType(newMutableInstanceType)
+                .setFreshOutValue(
+                    code, newMutableInstanceType.toTypeElement(appView, definitelyNotNull()))
+                .setPosition(instruction)
+                .build();
+        instructionIterator.replaceCurrentInstruction(newInstance, affectedValues);
+
+        InvokeDirect constructorInvoke =
+            InvokeDirect.builder()
+                .setMethod(instanceInitializerReference)
+                .setSingleArgument(newInstance.outValue())
+                .setPosition(instruction)
+                .build();
+
+        if (block.hasCatchHandlers()) {
+          // Split the block after the new-instance instruction and insert the constructor call in
+          // the split block.
+          BasicBlock splitBlock =
+              instructionIterator.splitCopyCatchHandlers(code, blockIterator, appView.options());
+          instructionIterator = splitBlock.listIterator(code);
+          instructionIterator.add(constructorInvoke);
+          BasicBlock previousBlock =
+              blockIterator.previousUntil(previous -> previous == splitBlock);
+          assert previousBlock != null;
+          blockIterator.next();
+        } else {
+          instructionIterator.add(constructorInvoke);
+        }
+      }
+    }
+    if (!affectedValues.isEmpty()) {
+      new TypeAnalysis(appView).narrowing(affectedValues);
+    }
+  }
+
+  private DexType getNewMutableInstanceType(
+      AppView<AppInfoWithLiveness> appView, Instruction instruction) {
+    if (!instruction.isInvokeMethodWithReceiver()) {
+      return null;
+    }
+    InvokeMethodWithReceiver invoke = instruction.asInvokeMethodWithReceiver();
+    DexMethod invokedMethod = invoke.getInvokedMethod();
+    if (!references.isDynamicMethod(invokedMethod)
+        && !references.isDynamicMethodBridge(invokedMethod)) {
+      return null;
+    }
+    assert invokedMethod.getParameter(0) == references.methodToInvokeType;
+    if (!references.methodToInvokeMembers.isNewMutableInstanceEnum(
+        invoke.getFirstNonReceiverArgument())) {
+      return null;
+    }
+    TypeElement receiverType = invoke.getReceiver().getDynamicUpperBoundType(appView);
+    if (!receiverType.isClassType()) {
+      return null;
+    }
+    DexType rawReceiverType = receiverType.asClassType().getClassType();
+    return appView.appInfo().isStrictSubtypeOf(rawReceiverType, references.generatedMessageLiteType)
+        ? rawReceiverType
+        : null;
+  }
+
   public void postOptimizeDynamicMethods(
       IRConverter converter, ExecutorService executorService, Timing timing)
       throws ExecutionException {
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoReferences.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoReferences.java
index 030fe3d..31633de 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoReferences.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoReferences.java
@@ -14,6 +14,7 @@
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.ProgramMethod;
 import com.android.tools.r8.ir.code.Value;
+import java.util.function.Consumer;
 
 public class ProtoReferences {
 
@@ -129,6 +130,17 @@
     methodToInvokeMembers = new MethodToInvokeMembers(factory);
   }
 
+  public void forEachMethodReference(Consumer<DexMethod> consumer) {
+    generatedExtensionMethods.forEachMethodReference(consumer);
+    generatedMessageLiteMethods.forEachMethodReference(consumer);
+    generatedMessageLiteBuilderMethods.forEachMethodReference(consumer);
+    generatedMessageLiteExtendableBuilderMethods.forEachMethodReference(consumer);
+    methodToInvokeMembers.forEachMethodReference(consumer);
+    consumer.accept(dynamicMethod);
+    consumer.accept(newMessageInfoMethod);
+    consumer.accept(rawMessageInfoConstructor);
+  }
+
   public DexField getDefaultInstanceField(DexProgramClass holder) {
     return dexItemFactory.createField(holder.type, holder.type, defaultInstanceFieldName);
   }
@@ -220,6 +232,11 @@
               dexItemFactory.constructorMethodName);
     }
 
+    public void forEachMethodReference(Consumer<DexMethod> consumer) {
+      consumer.accept(constructor);
+      consumer.accept(constructorWithClass);
+    }
+
     public boolean isConstructor(DexMethod method) {
       return method == constructor || method == constructorWithClass;
     }
@@ -230,7 +247,6 @@
     public final DexMethod createBuilderMethod;
     public final DexMethod dynamicMethodBridgeMethod;
     public final DexMethod dynamicMethodBridgeMethodWithObject;
-    public final DexMethod isInitializedMethod;
     public final DexMethod newRepeatedGeneratedExtension;
     public final DexMethod newSingularGeneratedExtension;
 
@@ -251,11 +267,6 @@
               dexItemFactory.createProto(
                   dexItemFactory.objectType, methodToInvokeType, dexItemFactory.objectType),
               "dynamicMethod");
-      isInitializedMethod =
-          dexItemFactory.createMethod(
-              generatedMessageLiteType,
-              dexItemFactory.createProto(dexItemFactory.booleanType),
-              "isInitialized");
       newRepeatedGeneratedExtension =
           dexItemFactory.createMethod(
               generatedMessageLiteType,
@@ -283,25 +294,31 @@
                   dexItemFactory.classType),
               "newSingularGeneratedExtension");
     }
+
+    public void forEachMethodReference(Consumer<DexMethod> consumer) {
+      consumer.accept(createBuilderMethod);
+      consumer.accept(dynamicMethodBridgeMethod);
+      consumer.accept(dynamicMethodBridgeMethodWithObject);
+      consumer.accept(newRepeatedGeneratedExtension);
+      consumer.accept(newSingularGeneratedExtension);
+    }
   }
 
   public class GeneratedMessageLiteBuilderMethods {
 
-    public final DexMethod buildPartialMethod;
     public final DexMethod constructorMethod;
 
     private GeneratedMessageLiteBuilderMethods(DexItemFactory dexItemFactory) {
-      buildPartialMethod =
-          dexItemFactory.createMethod(
-              generatedMessageLiteBuilderType,
-              dexItemFactory.createProto(generatedMessageLiteType),
-              "buildPartial");
       constructorMethod =
           dexItemFactory.createMethod(
               generatedMessageLiteBuilderType,
               dexItemFactory.createProto(dexItemFactory.voidType, generatedMessageLiteType),
               dexItemFactory.constructorMethodName);
     }
+
+    public void forEachMethodReference(Consumer<DexMethod> consumer) {
+      consumer.accept(constructorMethod);
+    }
   }
 
   public class GeneratedMessageLiteExtendableBuilderMethods {
@@ -322,6 +339,11 @@
                   dexItemFactory.voidType, generatedMessageLiteExtendableMessageType),
               dexItemFactory.constructorMethodName);
     }
+
+    public void forEachMethodReference(Consumer<DexMethod> consumer) {
+      consumer.accept(buildPartialMethod);
+      consumer.accept(constructorMethod);
+    }
   }
 
   public class MethodToInvokeMembers {
@@ -355,6 +377,10 @@
               methodToInvokeType, methodToInvokeType, "SET_MEMOIZED_IS_INITIALIZED");
     }
 
+    public void forEachMethodReference(Consumer<DexMethod> consumer) {
+      // Intentionally empty.
+    }
+
     public boolean isNewMutableInstanceEnum(DexField field) {
       return field == newMutableInstanceField;
     }
diff --git a/src/main/java/com/android/tools/r8/ir/code/BasicBlockIterator.java b/src/main/java/com/android/tools/r8/ir/code/BasicBlockIterator.java
index 523ca73..4de1c80 100644
--- a/src/main/java/com/android/tools/r8/ir/code/BasicBlockIterator.java
+++ b/src/main/java/com/android/tools/r8/ir/code/BasicBlockIterator.java
@@ -6,6 +6,7 @@
 
 import com.android.tools.r8.utils.IteratorUtils;
 import java.util.ListIterator;
+import java.util.function.Predicate;
 
 public class BasicBlockIterator implements ListIterator<BasicBlock> {
 
@@ -63,6 +64,10 @@
     return listIterator.previousIndex();
   }
 
+  public BasicBlock previousUntil(Predicate<BasicBlock> predicate) {
+    return IteratorUtils.previousUntil(this, predicate);
+  }
+
   @Override
   public void add(BasicBlock block) {
     listIterator.add(block);
diff --git a/src/main/java/com/android/tools/r8/ir/code/InvokeMethod.java b/src/main/java/com/android/tools/r8/ir/code/InvokeMethod.java
index efbcd4c..06173d4 100644
--- a/src/main/java/com/android/tools/r8/ir/code/InvokeMethod.java
+++ b/src/main/java/com/android/tools/r8/ir/code/InvokeMethod.java
@@ -71,6 +71,10 @@
     }
   }
 
+  public Value getFirstNonReceiverArgument() {
+    return getArgument(getFirstNonReceiverArgumentIndex());
+  }
+
   public int getFirstNonReceiverArgumentIndex() {
     return BooleanUtils.intValue(isInvokeMethodWithReceiver());
   }
diff --git a/src/main/java/com/android/tools/r8/ir/code/NewInstance.java b/src/main/java/com/android/tools/r8/ir/code/NewInstance.java
index 05fba04..024f6f2 100644
--- a/src/main/java/com/android/tools/r8/ir/code/NewInstance.java
+++ b/src/main/java/com/android/tools/r8/ir/code/NewInstance.java
@@ -38,6 +38,10 @@
     this.clazz = clazz;
   }
 
+  public static Builder builder() {
+    return new Builder();
+  }
+
   public DexType getType() {
     return clazz;
   }
@@ -228,4 +232,24 @@
     assert type.isDefinitelyNotNull();
     return true;
   }
+
+  public static class Builder extends BuilderBase<Builder, NewInstance> {
+
+    private DexType type;
+
+    public Builder setType(DexType type) {
+      this.type = type;
+      return this;
+    }
+
+    @Override
+    public NewInstance build() {
+      return amend(new NewInstance(type, outValue));
+    }
+
+    @Override
+    public Builder self() {
+      return this;
+    }
+  }
 }