Generalize class staticizer to allow users of instantiation

Change-Id: Ib5368d40926bf8c3095a4c6ef83cae3269c07015
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/staticizer/ClassStaticizer.java b/src/main/java/com/android/tools/r8/ir/optimize/staticizer/ClassStaticizer.java
index 9680a8e..41df56a 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/staticizer/ClassStaticizer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/staticizer/ClassStaticizer.java
@@ -30,6 +30,7 @@
 import com.android.tools.r8.ir.optimize.info.OptimizationFeedback;
 import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.utils.ListUtils;
+import com.android.tools.r8.utils.SetUtils;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Sets;
 import java.util.ArrayList;
@@ -267,6 +268,7 @@
         NewInstance newInstance = instruction.asNewInstance();
         CandidateInfo candidateInfo = processInstantiation(method, iterator, newInstance);
         if (candidateInfo != null) {
+          alreadyProcessed.addAll(newInstance.outValue().aliasedUsers());
           // For host class initializers having eligible instantiation we also want to
           // ensure that the rest of the initializer consist of code w/o side effects.
           // This must guarantee that removing field access will not result in missing side
@@ -396,10 +398,10 @@
       return candidateInfo.invalidate();
     }
 
-    if (candidateValue.numberOfUsers() != 2) {
-      // We expect only two users for each instantiation: constructor call and
-      // static field write. We only check count here, since the exact instructions
-      // will be checked later.
+    if (candidateValue.numberOfUsers() < 2) {
+      // We expect two special users for each instantiation: constructor call and static field
+      // write. We allow the instance to have other users as well, as long as they are valid
+      // according to the user analysis.
       return candidateInfo.invalidate();
     }
 
@@ -411,7 +413,7 @@
     //        invoke-direct {v0, ...}, void <candidate-type>.<init>(...)
     //        sput-object v0, <instance-field>
     //        ...
-    //        ...
+    //        ... // other usages that are valid according to the user analysis.
     //
     // In case we guarantee candidate constructor does not access <instance-field>
     // directly or indirectly we can guarantee that all the potential reads get
@@ -422,26 +424,33 @@
       // Intentionally empty.
     }
     iterator.previous();
-
     if (!iterator.hasNext()) {
       return candidateInfo.invalidate();
     }
-    if (!isValidInitCall(candidateInfo, iterator.next(), candidateValue, candidateType)) {
+    Set<Instruction> users = SetUtils.newIdentityHashSet(candidateValue.uniqueUsers());
+    Instruction constructorCall = iterator.next();
+    if (!isValidInitCall(candidateInfo, constructorCall, candidateValue, candidateType)) {
       iterator.previous();
       return candidateInfo.invalidate();
     }
-
+    boolean removedConstructorCall = users.remove(constructorCall);
+    assert removedConstructorCall;
     if (!iterator.hasNext()) {
       return candidateInfo.invalidate();
     }
-    if (!isValidStaticPut(candidateInfo, iterator.next())) {
+    Instruction staticPut = iterator.next();
+    if (!isValidStaticPut(candidateInfo, staticPut)) {
       iterator.previous();
       return candidateInfo.invalidate();
     }
+    boolean removedStaticPut = users.remove(staticPut);
+    assert removedStaticPut;
     if (candidateInfo.fieldWrites.incrementAndGet() > 1) {
       return candidateInfo.invalidate();
     }
-
+    if (!isSelectedValueUsersValid(candidateInfo, candidateValue, false, users)) {
+      return candidateInfo.invalidate();
+    }
     return candidateInfo;
   }
 
@@ -567,59 +576,79 @@
   private CandidateInfo analyzeAllValueUsers(
       CandidateInfo candidateInfo, Value value, boolean ignoreSuperClassInitInvoke) {
     assert value != null && value == value.getAliasedValue();
-
     if (value.numberOfPhiUsers() > 0) {
       return candidateInfo.invalidate();
     }
+    if (!isSelectedValueUsersValid(
+        candidateInfo, value, ignoreSuperClassInitInvoke, value.uniqueUsers())) {
+      return candidateInfo.invalidate();
+    }
+    return candidateInfo;
+  }
 
-    Set<Instruction> currentUsers = value.uniqueUsers();
+  private boolean isSelectedValueUsersValid(
+      CandidateInfo candidateInfo,
+      Value value,
+      boolean ignoreSuperClassInitInvoke,
+      Set<Instruction> currentUsers) {
     while (!currentUsers.isEmpty()) {
       Set<Instruction> indirectUsers = Sets.newIdentityHashSet();
       for (Instruction user : currentUsers) {
-        if (user.isAssume()) {
-          if (user.outValue().numberOfPhiUsers() > 0) {
-            return candidateInfo.invalidate();
-          }
-          indirectUsers.addAll(user.outValue().uniqueUsers());
-          continue;
+        if (!isValidValueUser(
+            candidateInfo, value, ignoreSuperClassInitInvoke, indirectUsers, user)) {
+          return false;
         }
-        if (user.isInvokeVirtual() || user.isInvokeDirect() /* private methods */) {
-          InvokeMethodWithReceiver invoke = user.asInvokeMethodWithReceiver();
-          Predicate<Value> isAliasedValue = v -> v.getAliasedValue() == value;
-          DexMethod methodReferenced = invoke.getInvokedMethod();
-          if (factory.isConstructor(methodReferenced)) {
-            assert user.isInvokeDirect();
-            if (ignoreSuperClassInitInvoke
-                && ListUtils.lastIndexMatching(invoke.inValues(), isAliasedValue) == 0
-                && methodReferenced == factory.objectMembers.constructor) {
-              // If we are inside candidate constructor and analyzing usages
-              // of the receiver, we want to ignore invocations of superclass
-              // constructor which will be removed after staticizing.
-              continue;
-            }
-            return candidateInfo.invalidate();
-          }
-          AppInfoWithLiveness appInfo = appView.appInfo();
-          ResolutionResult resolutionResult =
-              appInfo.resolveMethod(methodReferenced.holder, methodReferenced);
-          DexEncodedMethod methodInvoked =
-              user.isInvokeDirect()
-                  ? resolutionResult.lookupInvokeDirectTarget(candidateInfo.candidate, appInfo)
-                  : resolutionResult.isVirtualTarget() ? resolutionResult.getSingleTarget() : null;
-          if (ListUtils.lastIndexMatching(invoke.inValues(), isAliasedValue) == 0
-              && methodInvoked != null
-              && methodInvoked.holder() == candidateInfo.candidate.type) {
-            continue;
-          }
-        }
-
-        // All other users are not allowed.
-        return candidateInfo.invalidate();
       }
       currentUsers = indirectUsers;
     }
+    return true;
+  }
 
-    return candidateInfo;
+  private boolean isValidValueUser(
+      CandidateInfo candidateInfo,
+      Value value,
+      boolean ignoreSuperClassInitInvoke,
+      Set<Instruction> indirectUsers,
+      Instruction user) {
+    if (user.isAssume()) {
+      if (user.outValue().numberOfPhiUsers() > 0) {
+        return false;
+      }
+      indirectUsers.addAll(user.outValue().uniqueUsers());
+      return true;
+    }
+    if (user.isInvokeVirtual() || user.isInvokeDirect() /* private methods */) {
+      InvokeMethodWithReceiver invoke = user.asInvokeMethodWithReceiver();
+      Predicate<Value> isAliasedValue = v -> v.getAliasedValue() == value;
+      DexMethod methodReferenced = invoke.getInvokedMethod();
+      if (factory.isConstructor(methodReferenced)) {
+        assert user.isInvokeDirect();
+        if (ignoreSuperClassInitInvoke
+            && ListUtils.lastIndexMatching(invoke.inValues(), isAliasedValue) == 0
+            && methodReferenced == factory.objectMembers.constructor) {
+          // If we are inside candidate constructor and analyzing usages
+          // of the receiver, we want to ignore invocations of superclass
+          // constructor which will be removed after staticizing.
+          return true;
+        }
+        return false;
+      }
+      AppInfoWithLiveness appInfo = appView.appInfo();
+      ResolutionResult resolutionResult =
+          appInfo.resolveMethod(methodReferenced.holder, methodReferenced);
+      DexEncodedMethod methodInvoked =
+          user.isInvokeDirect()
+              ? resolutionResult.lookupInvokeDirectTarget(candidateInfo.candidate, appInfo)
+              : resolutionResult.isVirtualTarget() ? resolutionResult.getSingleTarget() : null;
+      if (ListUtils.lastIndexMatching(invoke.inValues(), isAliasedValue) == 0
+          && methodInvoked != null
+          && methodInvoked.holder() == candidateInfo.candidate.type) {
+        return true;
+      }
+    }
+
+    // All other users are not allowed.
+    return false;
   }
 
   // Perform staticizing candidates:
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/staticizer/StaticizingProcessor.java b/src/main/java/com/android/tools/r8/ir/optimize/staticizer/StaticizingProcessor.java
index ecdc979..4a1924b 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/staticizer/StaticizingProcessor.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/staticizer/StaticizingProcessor.java
@@ -4,6 +4,8 @@
 
 package com.android.tools.r8.ir.optimize.staticizer;
 
+import static com.android.tools.r8.ir.analysis.type.Nullability.maybeNull;
+
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DebugLocalInfo;
 import com.android.tools.r8.graph.DexClass;
@@ -18,8 +20,10 @@
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.code.Instruction;
 import com.android.tools.r8.ir.code.InstructionListIterator;
+import com.android.tools.r8.ir.code.InvokeDirect;
 import com.android.tools.r8.ir.code.InvokeMethodWithReceiver;
 import com.android.tools.r8.ir.code.InvokeStatic;
+import com.android.tools.r8.ir.code.NewInstance;
 import com.android.tools.r8.ir.code.Phi;
 import com.android.tools.r8.ir.code.StaticGet;
 import com.android.tools.r8.ir.code.StaticPut;
@@ -329,9 +333,8 @@
     assert candidateInfo != null;
 
     // Find and remove instantiation and its users.
-    for (Instruction instruction : code.instructions()) {
-      if (instruction.isNewInstance()
-          && instruction.asNewInstance().clazz == candidateInfo.candidate.type) {
+    for (NewInstance newInstance : code.<NewInstance>instructions(Instruction::isNewInstance)) {
+      if (newInstance.clazz == candidateInfo.candidate.type) {
         // Remove all usages
         // NOTE: requiring (a) the instance initializer to be trivial, (b) not allowing
         //       candidates with instance fields and (c) requiring candidate to directly
@@ -340,10 +343,31 @@
         assert candidateInfo.candidate.superType == factory().objectType;
         assert candidateInfo.candidate.instanceFields().size() == 0;
 
-        Value singletonValue = instruction.outValue();
+        Value singletonValue = newInstance.outValue();
         assert singletonValue != null;
-        singletonValue.uniqueUsers().forEach(user -> user.removeOrReplaceByDebugLocalRead(code));
-        instruction.removeOrReplaceByDebugLocalRead(code);
+
+        InvokeDirect uniqueConstructorInvoke =
+            newInstance.getUniqueConstructorInvoke(appView.dexItemFactory());
+        assert uniqueConstructorInvoke != null;
+        uniqueConstructorInvoke.removeOrReplaceByDebugLocalRead(code);
+
+        StaticPut uniqueStaticPut = null;
+        for (Instruction user : singletonValue.uniqueUsers()) {
+          if (user.isStaticPut()) {
+            assert uniqueStaticPut == null;
+            uniqueStaticPut = user.asStaticPut();
+          }
+        }
+        assert uniqueStaticPut != null;
+        uniqueStaticPut.removeOrReplaceByDebugLocalRead(code);
+
+        if (newInstance.outValue().hasAnyUsers()) {
+          TypeElement type = TypeElement.fromDexType(newInstance.clazz, maybeNull(), appView);
+          newInstance.replace(
+              new StaticGet(code.createValue(type), candidateInfo.singletonField.field), code);
+        } else {
+          newInstance.removeOrReplaceByDebugLocalRead(code);
+        }
         return;
       }
     }
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/staticizer/CompanionClassWithNewInstanceUserTest.java b/src/test/java/com/android/tools/r8/ir/optimize/staticizer/CompanionClassWithNewInstanceUserTest.java
index 5aeb159..42586c7 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/staticizer/CompanionClassWithNewInstanceUserTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/staticizer/CompanionClassWithNewInstanceUserTest.java
@@ -1,12 +1,14 @@
 package com.android.tools.r8.ir.optimize.staticizer;
 
 import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.CoreMatchers.not;
 import static org.hamcrest.MatcherAssert.assertThat;
 
 import com.android.tools.r8.NeverInline;
 import com.android.tools.r8.TestBase;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
 import com.android.tools.r8.utils.codeinspector.CodeInspector;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -41,8 +43,13 @@
   }
 
   private void inspect(CodeInspector inspector) {
-    // The companion class has not been removed.
-    assertThat(inspector.clazz(Companion.class), isPresent());
+    // The companion class has been removed.
+    assertThat(inspector.clazz(Companion.class), not(isPresent()));
+
+    // The companion method has been moved to the companion host class.
+    ClassSubject hostClassSubject = inspector.clazz(CompanionHost.class);
+    assertThat(hostClassSubject, isPresent());
+    assertThat(hostClassSubject.uniqueMethodWithName("method"), isPresent());
   }
 
   static class TestClass {
diff --git a/src/test/java/com/android/tools/r8/kotlin/KotlinClassStaticizerTest.java b/src/test/java/com/android/tools/r8/kotlin/KotlinClassStaticizerTest.java
index 6087795..6fb984f 100644
--- a/src/test/java/com/android/tools/r8/kotlin/KotlinClassStaticizerTest.java
+++ b/src/test/java/com/android/tools/r8/kotlin/KotlinClassStaticizerTest.java
@@ -53,7 +53,6 @@
           // The Util class is there, but its instance methods have been inlined.
           ClassSubject utilClass = inspector.clazz("class_staticizer.Util");
           assertThat(utilClass, isPresent());
-          AtomicInteger nonStaticMethodCount = new AtomicInteger();
           assertTrue(
               utilClass.allMethods().stream()
                   .filter(Predicates.not(FoundMethodSubject::isStatic))