Forcefully remove assignments to dead proto extension fields

Bug: 143588134
Change-Id: Ib7f275b066bf9cfdc742c8e19170dd209f386a0c
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedExtensionRegistryShrinker.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedExtensionRegistryShrinker.java
index 624cf90..e84cd11 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedExtensionRegistryShrinker.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedExtensionRegistryShrinker.java
@@ -14,8 +14,13 @@
 import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.FieldAccessInfo;
 import com.android.tools.r8.graph.FieldAccessInfoCollection;
+import com.android.tools.r8.ir.code.IRCode;
+import com.android.tools.r8.ir.code.IRCodeUtils;
+import com.android.tools.r8.ir.code.Instruction;
+import com.android.tools.r8.ir.code.StaticPut;
 import com.android.tools.r8.ir.conversion.IRConverter;
 import com.android.tools.r8.ir.conversion.OneTimeMethodProcessor;
 import com.android.tools.r8.ir.optimize.info.OptimizationFeedbackIgnore;
@@ -27,6 +32,8 @@
 import com.google.common.collect.Sets;
 import java.io.IOException;
 import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Set;
 import java.util.function.Consumer;
 import java.util.function.Predicate;
@@ -66,6 +73,7 @@
   private final AppView<AppInfoWithLiveness> appView;
   private final ProtoReferences references;
 
+  private final Set<DexType> classesWithRemovedExtensionFields = Sets.newIdentityHashSet();
   private final Set<DexField> removedExtensionFields = Sets.newIdentityHashSet();
 
   GeneratedExtensionRegistryShrinker(
@@ -81,7 +89,41 @@
    * const-null.
    */
   public void run() {
-    forEachDeadProtoExtensionField(removedExtensionFields::add);
+    forEachDeadProtoExtensionField(this::recordDeadProtoExtensionField);
+  }
+
+  private void recordDeadProtoExtensionField(DexField field) {
+    classesWithRemovedExtensionFields.add(field.holder);
+    removedExtensionFields.add(field);
+  }
+
+  /**
+   * If {@param method} is a class initializer that initializes a dead proto extension field, then
+   * forcefully remove the field assignment and all the code that contributes to the initialization
+   * of the value of the field assignment.
+   */
+  public void rewriteCode(DexEncodedMethod method, IRCode code) {
+    if (method.isClassInitializer()
+        && classesWithRemovedExtensionFields.contains(method.method.holder)
+        && code.metadata().mayHaveStaticPut()) {
+      rewriteClassInitializer(code);
+    }
+  }
+
+  private void rewriteClassInitializer(IRCode code) {
+    List<StaticPut> toBeRemoved = new ArrayList<>();
+    for (StaticPut staticPut : code.<StaticPut>instructions(Instruction::isStaticPut)) {
+      if (removedExtensionFields.contains(staticPut.getField())) {
+        toBeRemoved.add(staticPut);
+      }
+    }
+    for (StaticPut instruction : toBeRemoved) {
+      if (!instruction.hasBlock()) {
+        // Already removed.
+        continue;
+      }
+      IRCodeUtils.removeInstructionAndTransitiveInputsIfNotUsed(code, instruction);
+    }
   }
 
   public boolean wasRemoved(DexField field) {
diff --git a/src/main/java/com/android/tools/r8/ir/code/IRCodeUtils.java b/src/main/java/com/android/tools/r8/ir/code/IRCodeUtils.java
new file mode 100644
index 0000000..dfc3f0e
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/code/IRCodeUtils.java
@@ -0,0 +1,60 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.code;
+
+import com.android.tools.r8.utils.DequeUtils;
+import com.google.common.collect.Sets;
+import java.util.Deque;
+import java.util.Set;
+
+public class IRCodeUtils {
+
+  /**
+   * Removes the given instruction and all the instructions that are used to define the in-values of
+   * the given instruction, even if the instructions may have side effects (!).
+   *
+   * <p>Use with caution!
+   */
+  public static void removeInstructionAndTransitiveInputsIfNotUsed(
+      IRCode code, Instruction instruction) {
+    Set<InstructionOrPhi> removed = Sets.newIdentityHashSet();
+    Deque<InstructionOrPhi> worklist = DequeUtils.newArrayDeque(instruction);
+    while (!worklist.isEmpty()) {
+      InstructionOrPhi instructionOrPhi = worklist.removeFirst();
+      if (removed.contains(instructionOrPhi)) {
+        // Already removed.
+        continue;
+      }
+      if (instructionOrPhi.isPhi()) {
+        Phi current = instructionOrPhi.asPhi();
+        if (!current.hasUsers() && !current.hasDebugUsers()) {
+          boolean hasOtherPhiUserThanSelf = false;
+          for (Phi phiUser : current.uniquePhiUsers()) {
+            if (phiUser != current) {
+              hasOtherPhiUserThanSelf = true;
+              break;
+            }
+          }
+          if (!hasOtherPhiUserThanSelf) {
+            current.removeDeadPhi();
+            for (Value operand : current.getOperands()) {
+              worklist.add(operand.isPhi() ? operand.asPhi() : operand.definition);
+            }
+            removed.add(current);
+          }
+        }
+      } else {
+        Instruction current = instructionOrPhi.asInstruction();
+        if (!current.hasOutValue() || !current.outValue().hasAnyUsers()) {
+          current.getBlock().listIterator(code, current).removeOrReplaceByDebugLocalRead();
+          for (Value inValue : current.inValues()) {
+            worklist.add(inValue.isPhi() ? inValue.asPhi() : inValue.definition);
+          }
+          removed.add(current);
+        }
+      }
+    }
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/ir/code/Instruction.java b/src/main/java/com/android/tools/r8/ir/code/Instruction.java
index 1129616..8836361 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Instruction.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Instruction.java
@@ -604,6 +604,11 @@
     return true;
   }
 
+  @Override
+  public Instruction asInstruction() {
+    return this;
+  }
+
   public boolean isArrayGet() {
     return false;
   }
diff --git a/src/main/java/com/android/tools/r8/ir/code/InstructionOrPhi.java b/src/main/java/com/android/tools/r8/ir/code/InstructionOrPhi.java
index 10110a0..e9a5347 100644
--- a/src/main/java/com/android/tools/r8/ir/code/InstructionOrPhi.java
+++ b/src/main/java/com/android/tools/r8/ir/code/InstructionOrPhi.java
@@ -10,7 +10,15 @@
     return false;
   }
 
+  default Instruction asInstruction() {
+    return null;
+  }
+
   default boolean isPhi() {
     return false;
   }
+
+  default Phi asPhi() {
+    return null;
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/code/Phi.java b/src/main/java/com/android/tools/r8/ir/code/Phi.java
index 47eb9d4..3252d67 100644
--- a/src/main/java/com/android/tools/r8/ir/code/Phi.java
+++ b/src/main/java/com/android/tools/r8/ir/code/Phi.java
@@ -224,11 +224,11 @@
     return true;
   }
 
-  public void removeTrivialPhi() {
-    removeTrivialPhi(null, null);
+  public boolean removeTrivialPhi() {
+    return removeTrivialPhi(null, null);
   }
 
-  public void removeTrivialPhi(IRBuilder builder, Set<Value> affectedValues) {
+  public boolean removeTrivialPhi(IRBuilder builder, Set<Value> affectedValues) {
     Value same = null;
     for (Value op : operands) {
       if (op == same || op == this) {
@@ -238,7 +238,7 @@
       if (same != null) {
         // Merged at least two values and is therefore not trivial.
         assert !isTrivialPhi();
-        return;
+        return false;
       }
       same = op;
     }
@@ -247,7 +247,7 @@
       // When doing if-simplification we remove blocks and we can end up with cyclic phis
       // of the form v1 = phi(v1, v1) in dead blocks. If we encounter that case we just
       // leave the phi in there and check at the end that there are no trivial phis.
-      return;
+      return false;
     }
     // Ensure that the value that replaces this phi is constrained to the type of this phi.
     if (builder != null && typeLattice.isPreciseType() && !typeLattice.isBottom()) {
@@ -286,6 +286,7 @@
     }
     // Get rid of the phi itself.
     block.removePhi(this);
+    return true;
   }
 
   public void removeDeadPhi() {
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index 358a7fe..8a93298 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -1154,6 +1154,10 @@
 
     previous = printMethod(code, "IR after inserting assume instructions (SSA)", previous);
 
+    appView.withGeneratedExtensionRegistryShrinker(shrinker -> shrinker.rewriteCode(method, code));
+
+    previous = printMethod(code, "IR after generated extension registry shrinking (SSA)", previous);
+
     appView.withGeneratedMessageLiteShrinker(shrinker -> shrinker.run(method, code));
 
     previous = printMethod(code, "IR after generated message lite shrinking (SSA)", previous);
diff --git a/src/test/java/com/android/tools/r8/internal/proto/Proto2ShrinkingTest.java b/src/test/java/com/android/tools/r8/internal/proto/Proto2ShrinkingTest.java
index dcba861..bd9c192 100644
--- a/src/test/java/com/android/tools/r8/internal/proto/Proto2ShrinkingTest.java
+++ b/src/test/java/com/android/tools/r8/internal/proto/Proto2ShrinkingTest.java
@@ -59,7 +59,7 @@
     return buildParameters(
         BooleanUtils.values(),
         BooleanUtils.values(),
-        getTestParameters().withAllRuntimes().build());
+        getTestParameters().withAllRuntimesAndApiLevels().build());
   }
 
   public Proto2ShrinkingTest(
@@ -78,12 +78,6 @@
             .addKeepMainRule("proto2.TestClass")
             .addKeepRuleFiles(PROTOBUF_LITE_PROGUARD_RULES)
             .addKeepRules(alwaysInlineNewSingularGeneratedExtensionRule())
-            // TODO(b/112437944): Attempt to prove that DEFAULT_INSTANCE is non-null, such that the
-            //  following "assumenotnull" rule can be omitted.
-            .addKeepRules(
-                "-assumenosideeffects class " + EXT_C + " {",
-                "  private static final " + EXT_C + " DEFAULT_INSTANCE return 1..42;",
-                "}")
             .addOptionsModification(
                 options -> {
                   options.enableFieldBitAccessAnalysis = true;
@@ -94,7 +88,7 @@
             .allowAccessModification(allowAccessModification)
             .allowUnusedProguardConfigurationRules()
             .minification(enableMinification)
-            .setMinApi(parameters.getRuntime())
+            .setMinApi(parameters.getApiLevel())
             .compile()
             .inspect(
                 outputInspector -> {
@@ -245,8 +239,8 @@
           ImmutableList.of(FLAGGED_OFF_EXTENSION, HAS_NO_USED_EXTENSIONS, EXT_B, EXT_C);
       for (String unusedExtensionName : unusedExtensionNames) {
         assertThat(inputInspector.clazz(unusedExtensionName), isPresent());
-        // TODO(b/143588134): Re-enable this assertion.
-        // assertThat(outputInspector.clazz(unusedExtensionName), not(isPresent()));
+        assertThat(
+            unusedExtensionName, outputInspector.clazz(unusedExtensionName), not(isPresent()));
       }
     }
   }
@@ -362,7 +356,7 @@
         .allowAccessModification(allowAccessModification)
         .allowUnusedProguardConfigurationRules()
         .minification(enableMinification)
-        .setMinApi(parameters.getRuntime())
+        .setMinApi(parameters.getApiLevel())
         .compile()
         .inspect(
             inspector ->
diff --git a/src/test/java/com/android/tools/r8/utils/codeinspector/AbsentClassSubject.java b/src/test/java/com/android/tools/r8/utils/codeinspector/AbsentClassSubject.java
index e769a54..8e2f554 100644
--- a/src/test/java/com/android/tools/r8/utils/codeinspector/AbsentClassSubject.java
+++ b/src/test/java/com/android/tools/r8/utils/codeinspector/AbsentClassSubject.java
@@ -49,6 +49,11 @@
   }
 
   @Override
+  public FieldSubject uniqueFieldWithFinalName(String name) {
+    return new AbsentFieldSubject();
+  }
+
+  @Override
   public boolean isAbstract() {
     throw new Unreachable("Cannot determine if an absent class is abstract");
   }
diff --git a/src/test/java/com/android/tools/r8/utils/codeinspector/ClassSubject.java b/src/test/java/com/android/tools/r8/utils/codeinspector/ClassSubject.java
index 4ed548e..8b8f519 100644
--- a/src/test/java/com/android/tools/r8/utils/codeinspector/ClassSubject.java
+++ b/src/test/java/com/android/tools/r8/utils/codeinspector/ClassSubject.java
@@ -123,6 +123,8 @@
 
   public abstract FieldSubject uniqueFieldWithName(String name);
 
+  public abstract FieldSubject uniqueFieldWithFinalName(String name);
+
   public FoundClassSubject asFoundClassSubject() {
     return null;
   }
diff --git a/src/test/java/com/android/tools/r8/utils/codeinspector/FieldSubject.java b/src/test/java/com/android/tools/r8/utils/codeinspector/FieldSubject.java
index c182f74..dc9940f 100644
--- a/src/test/java/com/android/tools/r8/utils/codeinspector/FieldSubject.java
+++ b/src/test/java/com/android/tools/r8/utils/codeinspector/FieldSubject.java
@@ -26,6 +26,10 @@
     return this;
   }
 
+  public FoundFieldSubject asFoundFieldSubject() {
+    return null;
+  }
+
   @Override
   public boolean isFieldSubject() {
     return true;
diff --git a/src/test/java/com/android/tools/r8/utils/codeinspector/FoundClassSubject.java b/src/test/java/com/android/tools/r8/utils/codeinspector/FoundClassSubject.java
index d7e46b1..45276f4 100644
--- a/src/test/java/com/android/tools/r8/utils/codeinspector/FoundClassSubject.java
+++ b/src/test/java/com/android/tools/r8/utils/codeinspector/FoundClassSubject.java
@@ -171,6 +171,18 @@
   }
 
   @Override
+  public FieldSubject uniqueFieldWithFinalName(String name) {
+    FieldSubject fieldSubject = null;
+    for (FoundFieldSubject candidate : allFields()) {
+      if (candidate.getFinalName().equals(name)) {
+        assert fieldSubject == null;
+        fieldSubject = candidate;
+      }
+    }
+    return fieldSubject != null ? fieldSubject : new AbsentFieldSubject();
+  }
+
+  @Override
   public FoundClassSubject asFoundClassSubject() {
     return this;
   }
diff --git a/src/test/java/com/android/tools/r8/utils/codeinspector/FoundFieldSubject.java b/src/test/java/com/android/tools/r8/utils/codeinspector/FoundFieldSubject.java
index a4d8210..f659abe 100644
--- a/src/test/java/com/android/tools/r8/utils/codeinspector/FoundFieldSubject.java
+++ b/src/test/java/com/android/tools/r8/utils/codeinspector/FoundFieldSubject.java
@@ -27,6 +27,11 @@
   }
 
   @Override
+  public FoundFieldSubject asFoundFieldSubject() {
+    return this;
+  }
+
+  @Override
   public boolean isPublic() {
     return dexField.accessFlags.isPublic();
   }
diff --git a/src/test/java/com/android/tools/r8/utils/codeinspector/analysis/ProtoApplicationStats.java b/src/test/java/com/android/tools/r8/utils/codeinspector/analysis/ProtoApplicationStats.java
index 4a0f63d..8e65495 100644
--- a/src/test/java/com/android/tools/r8/utils/codeinspector/analysis/ProtoApplicationStats.java
+++ b/src/test/java/com/android/tools/r8/utils/codeinspector/analysis/ProtoApplicationStats.java
@@ -11,6 +11,7 @@
 import com.android.tools.r8.utils.StringUtils;
 import com.android.tools.r8.utils.codeinspector.ClassSubject;
 import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.FieldSubject;
 import com.android.tools.r8.utils.codeinspector.FoundClassSubject;
 import com.android.tools.r8.utils.codeinspector.FoundFieldSubject;
 import com.android.tools.r8.utils.codeinspector.FoundMethodSubject;
@@ -99,8 +100,8 @@
   class GeneratedExtensionRegistryStats extends Stats {
 
     final Set<DexMethod> findLiteExtensionByNumberMethods = Sets.newIdentityHashSet();
-    final Set<DexType> retainedExtensions = Sets.newIdentityHashSet();
-    final Set<DexType> spuriouslyRetainedExtensions = Sets.newIdentityHashSet();
+    final Set<DexField> retainedExtensionFields = Sets.newIdentityHashSet();
+    final Set<DexField> spuriouslyRetainedExtensionFields = Sets.newIdentityHashSet();
 
     String getStats(
         GeneratedExtensionRegistryStats baseline, GeneratedExtensionRegistryStats original) {
@@ -109,8 +110,8 @@
           "  # findLiteExtensionByNumber() methods: "
               + progress(this, baseline, original, x -> x.findLiteExtensionByNumberMethods),
           "  # retained extensions: "
-              + progress(this, baseline, original, x -> x.retainedExtensions),
-          "  # spuriously retained extensions: " + spuriouslyRetainedExtensions.size());
+              + progress(this, baseline, original, x -> x.retainedExtensionFields),
+          "  # spuriously retained extension fields: " + spuriouslyRetainedExtensionFields.size());
     }
   }
 
@@ -181,8 +182,12 @@
                 }
                 FoundClassSubject extensionClassSubject =
                     inspector.clazz(field.holder.toSourceString()).asFoundClassSubject();
-                generatedExtensionRegistryStats.retainedExtensions.add(
-                    extensionClassSubject.getOriginalDexType(dexItemFactory));
+                FoundFieldSubject extensionFieldSubject =
+                    extensionClassSubject
+                        .uniqueFieldWithFinalName(field.name.toSourceString())
+                        .asFoundFieldSubject();
+                generatedExtensionRegistryStats.retainedExtensionFields.add(
+                    extensionFieldSubject.getOriginalDexField(dexItemFactory));
               }
             }
           }
@@ -203,10 +208,19 @@
     }
 
     if (original != null) {
-      for (DexType extensionType : original.generatedExtensionRegistryStats.retainedExtensions) {
-        if (!generatedExtensionRegistryStats.retainedExtensions.contains(extensionType)
-            && inspector.clazz(extensionType.toSourceString()).isPresent()) {
-          generatedExtensionRegistryStats.spuriouslyRetainedExtensions.add(extensionType);
+      for (DexField extensionField :
+          original.generatedExtensionRegistryStats.retainedExtensionFields) {
+        if (generatedExtensionRegistryStats.retainedExtensionFields.contains(extensionField)) {
+          continue;
+        }
+        ClassSubject classSubject = inspector.clazz(extensionField.holder.toSourceString());
+        if (!classSubject.isPresent()) {
+          continue;
+        }
+        FieldSubject fieldSubject =
+            classSubject.uniqueFieldWithName(extensionField.name.toSourceString());
+        if (fieldSubject.isPresent()) {
+          generatedExtensionRegistryStats.spuriouslyRetainedExtensionFields.add(extensionField);
         }
       }
     }