Remove redundant init class instructions

Change-Id: I462cd6e63fab090f3f059e886de49b96a5f068a4
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/RedundantFieldLoadAndStoreElimination.java b/src/main/java/com/android/tools/r8/ir/optimize/RedundantFieldLoadAndStoreElimination.java
index 0e3dfdb..50bea08 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/RedundantFieldLoadAndStoreElimination.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/RedundantFieldLoadAndStoreElimination.java
@@ -29,6 +29,7 @@
 import com.android.tools.r8.ir.code.Instruction;
 import com.android.tools.r8.ir.code.InstructionListIterator;
 import com.android.tools.r8.ir.code.InvokeDirect;
+import com.android.tools.r8.ir.code.InvokeStatic;
 import com.android.tools.r8.ir.code.NewInstance;
 import com.android.tools.r8.ir.code.Phi;
 import com.android.tools.r8.ir.code.StaticGet;
@@ -243,13 +244,12 @@
             }
           } else if (instruction.isInvokeDirect()) {
             handleInvokeDirect(instruction.asInvokeDirect());
+          } else if (instruction.isInvokeStatic()) {
+            handleInvokeStatic(instruction.asInvokeStatic());
           } else if (instruction.isInvokeMethod() || instruction.isInvokeCustom()) {
             killAllNonFinalActiveFields();
           } else if (instruction.isNewInstance()) {
-            NewInstance newInstance = instruction.asNewInstance();
-            if (newInstance.clazz.classInitializationMayHaveSideEffectsInContext(appView, method)) {
-              killAllNonFinalActiveFields();
-            }
+            handleNewInstance(instruction.asNewInstance());
           } else {
             // If the current instruction could trigger a method invocation, it could also cause
             // field values to change. In that case, it must be handled above.
@@ -258,6 +258,7 @@
             // Clear the field writes.
             if (instruction.instructionInstanceCanThrow(appView, method)) {
               activeState.clearMostRecentFieldWrites();
+              activeState.clearMostRecentInitClass();
             }
 
             // If this assertion fails for a new instruction we need to determine if that
@@ -394,6 +395,23 @@
         });
   }
 
+  private void handleInvokeStatic(InvokeStatic invoke) {
+    if (appView.hasClassHierarchy()) {
+      ProgramMethod resolvedMethod =
+          appView
+              .appInfo()
+              .withClassHierarchy()
+              .unsafeResolveMethodDueToDexFormat(invoke.getInvokedMethod())
+              .getResolvedProgramMethod();
+      if (resolvedMethod != null) {
+        markClassAsInitialized(resolvedMethod.getHolderType());
+        markMostRecentInitClassForRemoval(resolvedMethod.getHolderType());
+      }
+    }
+
+    killAllNonFinalActiveFields();
+  }
+
   private void handleInitClass(InstructionListIterator instructionIterator, InitClass initClass) {
     assert !initClass.outValue().hasAnyUsers();
 
@@ -406,11 +424,28 @@
     }
 
     DexType clazz = initClass.getClassValue();
-    if (!activeState.markClassAsInitialized(clazz)) {
+    if (markClassAsInitialized(clazz)) {
+      if (release) {
+        activeState.setMostRecentInitClass(initClass);
+      }
+    } else {
       instructionIterator.removeOrReplaceByDebugLocalRead();
     }
   }
 
+  private boolean markClassAsInitialized(DexType type) {
+    return activeState.markClassAsInitialized(type);
+  }
+
+  private void markMostRecentInitClassForRemoval(DexType initializedType) {
+    InitClass mostRecentInitClass = activeState.getMostRecentInitClass();
+    if (mostRecentInitClass != null && mostRecentInitClass.getClassValue() == initializedType) {
+      instructionsToRemove
+          .computeIfAbsent(mostRecentInitClass.getBlock(), ignoreKey(Sets::newIdentityHashSet))
+          .add(mostRecentInitClass);
+    }
+  }
+
   private void handleInstanceGet(
       InstructionListIterator it,
       InstanceGet instanceGet,
@@ -431,9 +466,18 @@
     }
 
     activeState.putNonFinalInstanceField(fieldAndObject, new ExistingValue(instanceGet.value()));
+    activeState.clearMostRecentInitClass();
     clearMostRecentInstanceFieldWrite(instanceGet, field);
   }
 
+  private void handleNewInstance(NewInstance newInstance) {
+    markClassAsInitialized(newInstance.getType());
+    markMostRecentInitClassForRemoval(newInstance.getType());
+    if (newInstance.getType().classInitializationMayHaveSideEffectsInContext(appView, method)) {
+      killAllNonFinalActiveFields();
+    }
+  }
+
   private void clearMostRecentInstanceFieldWrite(InstanceGet instanceGet, DexClassAndField field) {
     // If the instruction can throw, we need to clear all most-recent-writes, since subsequent field
     // writes (if any) are not guaranteed to be executed.
@@ -494,6 +538,8 @@
         }
       }
     }
+
+    activeState.clearMostRecentInitClass();
   }
 
   private void handleStaticGet(
@@ -501,6 +547,8 @@
       StaticGet staticGet,
       DexClassAndField field,
       AssumeRemover assumeRemover) {
+    markClassAsInitialized(field.getHolderType());
+
     if (staticGet.outValue().hasLocalInfo()) {
       killNonFinalActiveFields(staticGet);
       clearMostRecentStaticFieldWrite(staticGet, field);
@@ -532,6 +580,9 @@
         applyObjectState(staticGet.outValue(), singleFieldValue.getObjectState());
       }
     }
+
+    markMostRecentInitClassForRemoval(field.getHolderType());
+    activeState.clearMostRecentInitClass();
   }
 
   private void clearMostRecentStaticFieldWrite(StaticGet staticGet, DexClassAndField field) {
@@ -545,6 +596,8 @@
   }
 
   private void handleStaticPut(StaticPut staticPut, DexClassAndField field) {
+    markClassAsInitialized(field.getHolderType());
+
     // A field put on a different class can cause <clinit> to run and change static field values.
     killNonFinalActiveFields(staticPut);
 
@@ -573,6 +626,9 @@
         }
       }
     }
+
+    markMostRecentInitClassForRemoval(field.getHolderType());
+    activeState.clearMostRecentInitClass();
   }
 
   private void applyObjectState(Value value, ObjectState objectState) {
@@ -593,6 +649,7 @@
     activeState.clearNonFinalInstanceFields();
     activeState.clearNonFinalStaticFields();
     activeState.clearMostRecentFieldWrites();
+    activeState.clearMostRecentInitClass();
   }
 
   private void killNonFinalActiveFields(Instruction instruction) {
@@ -714,6 +771,7 @@
       }
       if (!block.hasUniqueSuccessorWithUniquePredecessor()) {
         state.clearMostRecentFieldWrites();
+        state.clearMostRecentInitClass();
       }
       ensureCapacity(state);
       activeStateAtExit.put(block, state);
@@ -753,6 +811,8 @@
 
     private LinkedHashMap<DexField, FieldValue> nonFinalStaticFieldValues;
 
+    private InitClass mostRecentInitClass;
+
     private LinkedHashMap<FieldAndObject, InstancePut> mostRecentInstanceFieldWrites;
 
     private LinkedHashMap<DexField, StaticPut> mostRecentStaticFieldWrites;
@@ -787,6 +847,7 @@
           nonFinalStaticFieldValues = new LinkedHashMap<>();
           nonFinalStaticFieldValues.putAll(state.nonFinalStaticFieldValues);
         }
+        mostRecentInitClass = state.mostRecentInitClass;
         if (state.mostRecentInstanceFieldWrites != null
             && !state.mostRecentInstanceFieldWrites.isEmpty()) {
           mostRecentInstanceFieldWrites = new LinkedHashMap<>();
@@ -885,6 +946,7 @@
       } else {
         nonFinalStaticFieldValues = null;
       }
+      assert mostRecentInitClass == null;
       assert mostRecentInstanceFieldWrites == null;
       assert mostRecentStaticFieldWrites == null;
     }
@@ -899,10 +961,6 @@
       initializedClasses.removeIf(not(other::contains));
     }
 
-    public boolean isClassInitialized(DexType clazz) {
-      return initializedClasses != null && initializedClasses.contains(clazz);
-    }
-
     public boolean isEmpty() {
       return isEmpty(finalInstanceFieldValues)
           && isEmpty(finalStaticFieldValues)
@@ -1082,6 +1140,20 @@
       nonFinalStaticFieldValues.put(field, value);
     }
 
+    public InitClass getMostRecentInitClass() {
+      return mostRecentInitClass;
+    }
+
+    public void setMostRecentInitClass(InitClass initClass) {
+      mostRecentInitClass = initClass;
+    }
+
+    public InitClass clearMostRecentInitClass() {
+      InitClass result = mostRecentInitClass;
+      mostRecentInitClass = null;
+      return result;
+    }
+
     public int size() {
       return size(finalInstanceFieldValues)
           + size(finalStaticFieldValues)
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/redundantfieldloadelimination/RedundantInitClassBeforeInvokeStaticTest.java b/src/test/java/com/android/tools/r8/ir/optimize/redundantfieldloadelimination/RedundantInitClassBeforeInvokeStaticTest.java
index b2ea2ce..d7e4296 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/redundantfieldloadelimination/RedundantInitClassBeforeInvokeStaticTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/redundantfieldloadelimination/RedundantInitClassBeforeInvokeStaticTest.java
@@ -45,7 +45,7 @@
               assertThat(greeterClassSubject, isPresent());
               assertThat(greeterClassSubject.uniqueMethodWithName("hello"), isAbsent());
               assertThat(greeterClassSubject.uniqueMethodWithName("world"), isPresent());
-              assertEquals(1, greeterClassSubject.allFields().size());
+              assertEquals(0, greeterClassSubject.allFields().size());
             })
         .run(parameters.getRuntime(), Main.class)
         .assertSuccessWithOutputLines("Hello world!");