Handle non-linear control flow in redundant field load elimination

Bug: 152280793, 111380066
Change-Id: I9cc6254e624ea93859f50be068faaaf429c5b1ae
diff --git a/src/main/java/com/android/tools/r8/ir/code/BasicBlock.java b/src/main/java/com/android/tools/r8/ir/code/BasicBlock.java
index e36adab..52aed30 100644
--- a/src/main/java/com/android/tools/r8/ir/code/BasicBlock.java
+++ b/src/main/java/com/android/tools/r8/ir/code/BasicBlock.java
@@ -540,6 +540,10 @@
     return phis;
   }
 
+  public boolean isEntry() {
+    return getPredecessors().isEmpty();
+  }
+
   public boolean isFilled() {
     return filled;
   }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/RedundantFieldLoadElimination.java b/src/main/java/com/android/tools/r8/ir/optimize/RedundantFieldLoadElimination.java
index a5501a6..81b32e8 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/RedundantFieldLoadElimination.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/RedundantFieldLoadElimination.java
@@ -4,6 +4,8 @@
 
 package com.android.tools.r8.ir.optimize;
 
+import static com.android.tools.r8.utils.PredicateUtils.not;
+
 import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexEncodedField;
@@ -30,8 +32,9 @@
 import com.android.tools.r8.ir.code.Value;
 import com.android.tools.r8.ir.optimize.info.field.InstanceFieldInitializationInfoCollection;
 import com.android.tools.r8.ir.optimize.info.initializer.InstanceInitializerInfo;
-import com.android.tools.r8.utils.SetUtils;
 import com.google.common.collect.Sets;
+import java.util.ArrayDeque;
+import java.util.Deque;
 import java.util.HashMap;
 import java.util.IdentityHashMap;
 import java.util.Map;
@@ -55,14 +58,11 @@
   private final Set<Value> affectedValues = Sets.newIdentityHashSet();
 
   // Maps keeping track of fields that have an already loaded value at basic block entry.
-  private final Map<BasicBlock, Set<DexType>> activeInitializedClassesAtEntry =
-      new IdentityHashMap<>();
-  private final Map<BasicBlock, FieldValuesMap> activeFieldsAtEntry = new IdentityHashMap<>();
+  private final Map<BasicBlock, State> activeStateAtExit = new IdentityHashMap<>();
 
   // Maps keeping track of fields with already loaded values for the current block during
   // elimination.
-  private Set<DexType> activeInitializedClasses;
-  private FieldValuesMap activeFieldValues;
+  private State activeState;
 
   public RedundantFieldLoadElimination(AppView<?> appView, IRCode code) {
     this.appView = appView;
@@ -153,14 +153,7 @@
   public void run() {
     DexType context = method.holder();
     for (BasicBlock block : dominatorTree.getSortedBlocks()) {
-      activeInitializedClasses =
-          activeInitializedClassesAtEntry.containsKey(block)
-              ? activeInitializedClassesAtEntry.get(block)
-              : Sets.newIdentityHashSet();
-      activeFieldValues =
-          activeFieldsAtEntry.containsKey(block)
-              ? activeFieldsAtEntry.get(block)
-              : new FieldValuesMap();
+      computeActiveStateOnBlockEntry(block);
       InstructionListIterator it = block.listIterator(code);
       while (it.hasNext()) {
         Instruction instruction = it.next();
@@ -179,11 +172,11 @@
             }
             Value object = instanceGet.object().getAliasedValue();
             FieldAndObject fieldAndObject = new FieldAndObject(field, object);
-            FieldValue replacement = activeFieldValues.getInstanceFieldValue(fieldAndObject);
+            FieldValue replacement = activeState.getInstanceFieldValue(fieldAndObject);
             if (replacement != null) {
               replacement.eliminateRedundantRead(it, instanceGet);
             } else {
-              activeFieldValues.putNonFinalInstanceField(
+              activeState.putNonFinalInstanceField(
                   fieldAndObject, new ExistingValue(instanceGet.value()));
             }
           } else if (instruction.isInstancePut()) {
@@ -197,23 +190,23 @@
             ExistingValue value = new ExistingValue(instancePut.value());
             if (definition.isFinal()) {
               assert method.isInstanceInitializer() || verifyWasInstanceInitializer();
-              activeFieldValues.putFinalInstanceField(fieldAndObject, value);
+              activeState.putFinalInstanceField(fieldAndObject, value);
             } else {
-              activeFieldValues.putNonFinalInstanceField(fieldAndObject, value);
+              activeState.putNonFinalInstanceField(fieldAndObject, value);
             }
           } else if (instruction.isStaticGet()) {
             StaticGet staticGet = instruction.asStaticGet();
             if (staticGet.outValue().hasLocalInfo()) {
               continue;
             }
-            FieldValue replacement = activeFieldValues.getStaticFieldValue(field);
+            FieldValue replacement = activeState.getStaticFieldValue(field);
             if (replacement != null) {
               replacement.eliminateRedundantRead(it, staticGet);
             } else {
               // A field get on a different class can cause <clinit> to run and change static
               // field values.
               killNonFinalActiveFields(staticGet);
-              activeFieldValues.putNonFinalStaticField(field, new ExistingValue(staticGet.value()));
+              activeState.putNonFinalStaticField(field, new ExistingValue(staticGet.value()));
             }
           } else if (instruction.isStaticPut()) {
             StaticPut staticPut = instruction.asStaticPut();
@@ -223,17 +216,19 @@
             ExistingValue value = new ExistingValue(staticPut.value());
             if (definition.isFinal()) {
               assert method.isClassInitializer();
-              activeFieldValues.putFinalStaticField(field, value);
+              activeState.putFinalStaticField(field, value);
             } else {
-              activeFieldValues.putNonFinalStaticField(field, value);
+              activeState.putNonFinalStaticField(field, value);
             }
           }
         } else if (instruction.isInitClass()) {
           InitClass initClass = instruction.asInitClass();
           assert !initClass.outValue().hasAnyUsers();
-          if (activeInitializedClasses.contains(initClass.getClassValue())) {
+          DexType clazz = initClass.getClassValue();
+          if (activeState.isClassInitialized(clazz)) {
             it.removeOrReplaceByDebugLocalRead();
           }
+          activeState.markClassAsInitialized(clazz);
         } else if (instruction.isMonitor()) {
           if (instruction.asMonitor().isEnter()) {
             killAllNonFinalActiveFields();
@@ -288,7 +283,7 @@
               : "Unexpected instruction of type " + instruction.getClass().getTypeName();
         }
       }
-      propagateActiveStateFrom(block);
+      recordActiveStateOnBlockExit(block);
     }
     if (!affectedValues.isEmpty()) {
       new TypeAnalysis(appView).narrowing(affectedValues);
@@ -338,14 +333,13 @@
                 invoke.getArgument(info.asArgumentInitializationInfo().getArgumentIndex());
             Value object = invoke.getReceiver().getAliasedValue();
             FieldAndObject fieldAndObject = new FieldAndObject(field.field, object);
-            activeFieldValues.putNonFinalInstanceField(fieldAndObject, new ExistingValue(value));
+            activeState.putNonFinalInstanceField(fieldAndObject, new ExistingValue(value));
           } else if (info.isSingleValue()) {
             SingleValue value = info.asSingleValue();
             if (value.isMaterializableInContext(appView, method.holder())) {
               Value object = invoke.getReceiver().getAliasedValue();
               FieldAndObject fieldAndObject = new FieldAndObject(field.field, object);
-              activeFieldValues.putNonFinalInstanceField(
-                  fieldAndObject, new MaterializableValue(value));
+              activeState.putNonFinalInstanceField(fieldAndObject, new MaterializableValue(value));
             }
           } else {
             assert info.isTypeInitializationInfo();
@@ -353,33 +347,49 @@
         });
   }
 
-  private void propagateActiveStateFrom(BasicBlock block) {
-    for (BasicBlock successor : block.getSuccessors()) {
+  private void computeActiveStateOnBlockEntry(BasicBlock block) {
+    if (block.isEntry()) {
+      activeState = new State();
+      return;
+    }
+    Deque<State> predecessorExitStates = new ArrayDeque<>(block.getPredecessors().size());
+    for (BasicBlock predecessor : block.getPredecessors()) {
+      State predecessorExitState = activeStateAtExit.get(predecessor);
+      if (predecessorExitState == null) {
+        // Not processed yet.
+        activeState = new State();
+        return;
+      }
       // Allow propagation across exceptional edges, just be careful not to propagate if the
       // throwing instruction is a field instruction.
-      if (successor.getPredecessors().size() == 1) {
-        if (block.hasCatchSuccessor(successor)) {
-          Instruction exceptionalExit = block.exceptionalExit();
-          if (exceptionalExit != null) {
-            if (exceptionalExit.isFieldInstruction()) {
-              killActiveFieldsForExceptionalExit(exceptionalExit.asFieldInstruction());
-            } else if (exceptionalExit.isInitClass()) {
-              killActiveInitializedClassesForExceptionalExit(exceptionalExit.asInitClass());
-            }
+      if (predecessor.hasCatchSuccessor(block)) {
+        Instruction exceptionalExit = predecessor.exceptionalExit();
+        if (exceptionalExit != null) {
+          predecessorExitState = new State(predecessorExitState);
+          if (exceptionalExit.isFieldInstruction()) {
+            predecessorExitState.killActiveFieldsForExceptionalExit(
+                exceptionalExit.asFieldInstruction());
+          } else if (exceptionalExit.isInitClass()) {
+            predecessorExitState.killActiveInitializedClassesForExceptionalExit(
+                exceptionalExit.asInitClass());
           }
         }
-        assert !activeInitializedClassesAtEntry.containsKey(successor);
-        activeInitializedClassesAtEntry.put(
-            successor, SetUtils.newIdentityHashSet(activeInitializedClasses));
-        assert !activeFieldsAtEntry.containsKey(successor);
-        activeFieldsAtEntry.put(successor, new FieldValuesMap(activeFieldValues));
       }
+      predecessorExitStates.addLast(predecessorExitState);
     }
+    State state = new State(predecessorExitStates.removeFirst());
+    predecessorExitStates.forEach(state::intersect);
+    activeState = state;
+  }
+
+  private void recordActiveStateOnBlockExit(BasicBlock block) {
+    assert !activeStateAtExit.containsKey(block);
+    activeStateAtExit.put(block, activeState);
   }
 
   private void killAllNonFinalActiveFields() {
-    activeFieldValues.clearNonFinalInstanceFields();
-    activeFieldValues.clearNonFinalStaticFields();
+    activeState.clearNonFinalInstanceFields();
+    activeState.clearNonFinalStaticFields();
   }
 
   private void killNonFinalActiveFields(FieldInstruction instruction) {
@@ -387,62 +397,46 @@
     if (instruction.isInstancePut()) {
       // Remove all the field/object pairs that refer to this field to make sure
       // that we are conservative.
-      activeFieldValues.removeNonFinalInstanceFields(field);
+      activeState.removeNonFinalInstanceFields(field);
     } else if (instruction.isStaticPut()) {
       if (field.holder != code.method.holder()) {
         // Accessing a static field on a different object could cause <clinit> to run which
         // could modify any static field on any other object.
-        activeFieldValues.clearNonFinalStaticFields();
+        activeState.clearNonFinalStaticFields();
       } else {
-        activeFieldValues.removeNonFinalStaticField(field);
+        activeState.removeNonFinalStaticField(field);
       }
     } else if (instruction.isStaticGet()) {
       if (field.holder != code.method.holder()) {
         // Accessing a static field on a different object could cause <clinit> to run which
         // could modify any static field on any other object.
-        activeFieldValues.clearNonFinalStaticFields();
+        activeState.clearNonFinalStaticFields();
       }
     } else if (instruction.isInstanceGet()) {
       throw new Unreachable();
     }
   }
 
-  // If a field get instruction throws an exception it did not have an effect on the
-  // value of the field. Therefore, when propagating across exceptional edges for a
-  // field get instruction we have to exclude that field from the set of known
-  // field values.
-  private void killActiveFieldsForExceptionalExit(FieldInstruction instruction) {
-    DexField field = instruction.getField();
-    if (instruction.isInstanceGet()) {
-      Value object = instruction.asInstanceGet().object().getAliasedValue();
-      FieldAndObject fieldAndObject = new FieldAndObject(field, object);
-      activeFieldValues.removeInstanceField(fieldAndObject);
-    } else if (instruction.isStaticGet()) {
-      activeFieldValues.removeStaticField(field);
-    }
-  }
-
-  private void killActiveInitializedClassesForExceptionalExit(InitClass instruction) {
-    activeInitializedClasses.remove(instruction.getClassValue());
-  }
-
-  static class FieldValuesMap {
+  static class State {
 
     private final Map<FieldAndObject, FieldValue> finalInstanceFieldValues = new HashMap<>();
 
     private final Map<DexField, FieldValue> finalStaticFieldValues = new IdentityHashMap<>();
 
+    private final Set<DexType> initializedClasses = Sets.newIdentityHashSet();
+
     private final Map<FieldAndObject, FieldValue> nonFinalInstanceFieldValues = new HashMap<>();
 
     private final Map<DexField, FieldValue> nonFinalStaticFieldValues = new IdentityHashMap<>();
 
-    public FieldValuesMap() {}
+    public State() {}
 
-    public FieldValuesMap(FieldValuesMap map) {
-      finalInstanceFieldValues.putAll(map.finalInstanceFieldValues);
-      finalStaticFieldValues.putAll(map.finalStaticFieldValues);
-      nonFinalInstanceFieldValues.putAll(map.nonFinalInstanceFieldValues);
-      nonFinalStaticFieldValues.putAll(map.nonFinalStaticFieldValues);
+    public State(State state) {
+      finalInstanceFieldValues.putAll(state.finalInstanceFieldValues);
+      finalStaticFieldValues.putAll(state.finalStaticFieldValues);
+      initializedClasses.addAll(state.initializedClasses);
+      nonFinalInstanceFieldValues.putAll(state.nonFinalInstanceFieldValues);
+      nonFinalStaticFieldValues.putAll(state.nonFinalStaticFieldValues);
     }
 
     public void clearNonFinalInstanceFields() {
@@ -463,6 +457,50 @@
       return value != null ? value : finalStaticFieldValues.get(field);
     }
 
+    public void intersect(State state) {
+      intersectFieldValues(finalInstanceFieldValues, state.finalInstanceFieldValues);
+      intersectFieldValues(finalStaticFieldValues, state.finalStaticFieldValues);
+      intersectInitializedClasses(initializedClasses, state.initializedClasses);
+      intersectFieldValues(nonFinalInstanceFieldValues, state.nonFinalInstanceFieldValues);
+      intersectFieldValues(nonFinalStaticFieldValues, state.nonFinalStaticFieldValues);
+    }
+
+    private static <K> void intersectFieldValues(
+        Map<K, FieldValue> fieldValues, Map<K, FieldValue> other) {
+      fieldValues.entrySet().removeIf(entry -> other.get(entry.getKey()) != entry.getValue());
+    }
+
+    private static void intersectInitializedClasses(
+        Set<DexType> initializedClasses, Set<DexType> other) {
+      initializedClasses.removeIf(not(other::contains));
+    }
+
+    public boolean isClassInitialized(DexType clazz) {
+      return initializedClasses.contains(clazz);
+    }
+
+    // If a field get instruction throws an exception it did not have an effect on the value of the
+    // field. Therefore, when propagating across exceptional edges for a field get instruction we
+    // have to exclude that field from the set of known field values.
+    public void killActiveFieldsForExceptionalExit(FieldInstruction instruction) {
+      DexField field = instruction.getField();
+      if (instruction.isInstanceGet()) {
+        Value object = instruction.asInstanceGet().object().getAliasedValue();
+        FieldAndObject fieldAndObject = new FieldAndObject(field, object);
+        removeNonFinalInstanceField(fieldAndObject);
+      } else if (instruction.isStaticGet()) {
+        removeNonFinalStaticField(field);
+      }
+    }
+
+    private void killActiveInitializedClassesForExceptionalExit(InitClass instruction) {
+      initializedClasses.remove(instruction.getClassValue());
+    }
+
+    public void markClassAsInitialized(DexType clazz) {
+      initializedClasses.add(clazz);
+    }
+
     public void removeInstanceField(FieldAndObject field) {
       removeFinalInstanceField(field);
       removeNonFinalInstanceField(field);
diff --git a/src/main/java/com/android/tools/r8/utils/PredicateUtils.java b/src/main/java/com/android/tools/r8/utils/PredicateUtils.java
index 2c6fac0..880da5e 100644
--- a/src/main/java/com/android/tools/r8/utils/PredicateUtils.java
+++ b/src/main/java/com/android/tools/r8/utils/PredicateUtils.java
@@ -16,4 +16,8 @@
     }
     return null;
   }
+
+  public static <T> Predicate<T> not(Predicate<T> predicate) {
+    return t -> !predicate.test(t);
+  }
 }
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/redundantfieldloadelimination/RedundantFieldLoadEliminationMeetTest.java b/src/test/java/com/android/tools/r8/ir/optimize/redundantfieldloadelimination/RedundantFieldLoadEliminationMeetTest.java
index 4ee4380..776df6b 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/redundantfieldloadelimination/RedundantFieldLoadEliminationMeetTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/redundantfieldloadelimination/RedundantFieldLoadEliminationMeetTest.java
@@ -48,9 +48,8 @@
 
     MethodSubject mainMethodSubject = testClassSubject.mainMethod();
     assertThat(mainMethodSubject, isPresent());
-    // TODO(b/152280793): Should be 1.
     assertEquals(
-        2, mainMethodSubject.streamInstructions().filter(InstructionSubject::isStaticGet).count());
+        1, mainMethodSubject.streamInstructions().filter(InstructionSubject::isStaticGet).count());
   }
 
   static class TestClass {
diff --git a/src/test/java/com/android/tools/r8/regress/b111250398/B111250398.java b/src/test/java/com/android/tools/r8/regress/b111250398/B111250398.java
index 32c754e..0556439 100644
--- a/src/test/java/com/android/tools/r8/regress/b111250398/B111250398.java
+++ b/src/test/java/com/android/tools/r8/regress/b111250398/B111250398.java
@@ -276,12 +276,8 @@
     // compilation (R8) will eliminate field loads on non-volatile fields.
     assertEquals(1, countIget(mfOnA.getMethod().getCode().asDexCode(), fOnA.getField().field));
     assertEquals(1, countSget(msfOnA.getMethod().getCode().asDexCode(), sfOnA.getField().field));
-    // TODO(111380066). This could be 2 in stead of 4, but right now the optimization tracks the
-    // combined set of fields for all successors, and for synchronized code all blocks have
-    // exceptional edges for ensuring monitor exit causing the active load to be invalidated for
-    // both normal and exceptional successors.
-    assertEquals(4,
-        countIget(mfWithMonitorOnA.getMethod().getCode().asDexCode(), fOnA.getField().field));
+    assertEquals(
+        2, countIget(mfWithMonitorOnA.getMethod().getCode().asDexCode(), fOnA.getField().field));
 
     // For fields on other class both separate compilation (D8) and whole program
     // compilation (R8) will differ in the eliminated field loads of non-volatile fields.