Remove references to dead builders

Change-Id: I15da928efe73e0ab6b39e55aecdda05bde8578df
Bug: 149363884
diff --git a/src/main/java/com/android/tools/r8/R8.java b/src/main/java/com/android/tools/r8/R8.java
index c036f5f..8817968 100644
--- a/src/main/java/com/android/tools/r8/R8.java
+++ b/src/main/java/com/android/tools/r8/R8.java
@@ -25,6 +25,7 @@
 import com.android.tools.r8.graph.DexReference;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.DirectMappedDexApplication;
+import com.android.tools.r8.graph.EnumValueInfoMapCollection;
 import com.android.tools.r8.graph.GraphLense;
 import com.android.tools.r8.graph.GraphLense.NestedGraphLense;
 import com.android.tools.r8.graph.analysis.ClassInitializerAssertionEnablingAnalysis;
@@ -561,6 +562,10 @@
       // Collect the already pruned types before creating a new app info without liveness.
       Set<DexType> prunedTypes = appView.withLiveness().appInfo().getPrunedTypes();
 
+      // TODO: move to appview.
+      EnumValueInfoMapCollection enumValueInfoMapCollection =
+          appViewWithLiveness.appInfo().getEnumValueInfoMapCollection();
+
       if (!options.mainDexKeepRules.isEmpty()) {
         appView.setAppInfo(new AppInfoWithSubtyping(application));
         // No need to build a new main dex root set
@@ -625,11 +630,18 @@
 
           Enqueuer enqueuer = EnqueuerFactory.createForFinalTreeShaking(appView, keptGraphConsumer);
           appView.setAppInfo(
-              enqueuer.traceApplication(
-                  appView.rootSet(),
-                  options.getProguardConfiguration().getDontWarnPatterns(),
-                  executorService,
-                  timing));
+              enqueuer
+                  .traceApplication(
+                      appView.rootSet(),
+                      options.getProguardConfiguration().getDontWarnPatterns(),
+                      executorService,
+                      timing)
+                  .withEnumValueInfoMaps(enumValueInfoMapCollection));
+
+          appView.withGeneratedMessageLiteBuilderShrinker(
+              shrinker ->
+                  shrinker.removeDeadBuilderReferencesFromDynamicMethods(
+                      appViewWithLiveness, executorService, timing));
 
           if (Log.ENABLED && Log.isLoggingEnabledFor(GeneratedExtensionRegistryShrinker.class)) {
             appView.withGeneratedExtensionRegistryShrinker(
diff --git a/src/main/java/com/android/tools/r8/graph/AppView.java b/src/main/java/com/android/tools/r8/graph/AppView.java
index 401173c..9484c21 100644
--- a/src/main/java/com/android/tools/r8/graph/AppView.java
+++ b/src/main/java/com/android/tools/r8/graph/AppView.java
@@ -22,6 +22,7 @@
 import com.android.tools.r8.shaking.RootSetBuilder.RootSet;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.OptionalBool;
+import com.android.tools.r8.utils.ThrowingConsumer;
 import com.google.common.base.Predicates;
 import com.google.common.collect.ImmutableSet;
 import java.util.IdentityHashMap;
@@ -267,8 +268,8 @@
     }
   }
 
-  public void withGeneratedMessageLiteBuilderShrinker(
-      Consumer<GeneratedMessageLiteBuilderShrinker> consumer) {
+  public <E extends Throwable> void withGeneratedMessageLiteBuilderShrinker(
+      ThrowingConsumer<GeneratedMessageLiteBuilderShrinker, E> consumer) throws E {
     if (protoShrinker != null && protoShrinker.generatedMessageLiteBuilderShrinker != null) {
       consumer.accept(protoShrinker.generatedMessageLiteBuilderShrinker);
     }
diff --git a/src/main/java/com/android/tools/r8/graph/DexProto.java b/src/main/java/com/android/tools/r8/graph/DexProto.java
index 5c05605..43d6a29 100644
--- a/src/main/java/com/android/tools/r8/graph/DexProto.java
+++ b/src/main/java/com/android/tools/r8/graph/DexProto.java
@@ -20,6 +20,10 @@
     this.parameters = parameters;
   }
 
+  public DexType getParameter(int index) {
+    return parameters.values[index];
+  }
+
   @Override
   public int computeHashCode() {
     return shorty.hashCode()
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteBuilderShrinker.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteBuilderShrinker.java
index 717df30..ed5de69 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteBuilderShrinker.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteBuilderShrinker.java
@@ -11,36 +11,51 @@
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.EnumValueInfoMapCollection.EnumValueInfo;
+import com.android.tools.r8.graph.EnumValueInfoMapCollection.EnumValueInfoMap;
 import com.android.tools.r8.ir.analysis.type.ClassTypeLatticeElement;
 import com.android.tools.r8.ir.analysis.type.TypeAnalysis;
+import com.android.tools.r8.ir.analysis.type.TypeLatticeElement;
 import com.android.tools.r8.ir.code.CheckCast;
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.code.Instruction;
 import com.android.tools.r8.ir.code.InstructionListIterator;
+import com.android.tools.r8.ir.code.IntSwitch;
 import com.android.tools.r8.ir.code.InvokeVirtual;
 import com.android.tools.r8.ir.code.Value;
 import com.android.tools.r8.ir.conversion.CallGraph.Node;
+import com.android.tools.r8.ir.conversion.IRConverter;
 import com.android.tools.r8.ir.conversion.MethodProcessor;
+import com.android.tools.r8.ir.optimize.CodeRewriter;
 import com.android.tools.r8.ir.optimize.Inliner;
 import com.android.tools.r8.ir.optimize.Inliner.Reason;
+import com.android.tools.r8.ir.optimize.controlflow.SwitchCaseAnalyzer;
 import com.android.tools.r8.ir.optimize.enums.EnumValueOptimizer;
 import com.android.tools.r8.ir.optimize.info.OptimizationFeedback;
+import com.android.tools.r8.ir.optimize.info.OptimizationFeedbackSimple;
 import com.android.tools.r8.ir.optimize.inliner.FixedInliningReasonStrategy;
+import com.android.tools.r8.origin.Origin;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.utils.PredicateSet;
+import com.android.tools.r8.utils.ThreadUtils;
+import com.android.tools.r8.utils.Timing;
 import com.google.common.collect.Sets;
 import java.util.ArrayList;
+import java.util.IdentityHashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
 import java.util.function.BooleanSupplier;
 
-// TODO(b/112437944): Should remove the new Builder() instructions from each dynamicMethod() that
-//  references a dead proto builder.
 public class GeneratedMessageLiteBuilderShrinker {
 
   private final AppView<? extends AppInfoWithSubtyping> appView;
   private final ProtoReferences references;
 
+  private final Map<DexProgramClass, DexEncodedMethod> builders = new IdentityHashMap<>();
+
   GeneratedMessageLiteBuilderShrinker(
       AppView<? extends AppInfoWithSubtyping> appView, ProtoReferences references) {
     this.appView = appView;
@@ -51,11 +66,55 @@
   public boolean deferDeadProtoBuilders(
       DexProgramClass clazz, DexEncodedMethod context, BooleanSupplier register) {
     if (references.isDynamicMethod(context) && references.isGeneratedMessageLiteBuilder(clazz)) {
-      return register.getAsBoolean();
+      if (register.getAsBoolean()) {
+        assert builders.getOrDefault(clazz, context) == context;
+        builders.put(clazz, context);
+        return true;
+      }
     }
     return false;
   }
 
+  /**
+   * Reprocesses each dynamicMethod() that references a dead builder to remove the dead builder
+   * references.
+   */
+  public void removeDeadBuilderReferencesFromDynamicMethods(
+      AppView<AppInfoWithLiveness> appView, ExecutorService executorService, Timing timing)
+      throws ExecutionException {
+    timing.begin("Remove dead builder references");
+    AppInfoWithLiveness appInfo = appView.appInfo();
+    IRConverter converter = new IRConverter(appView, Timing.empty());
+    CodeRewriter codeRewriter = new CodeRewriter(appView, converter);
+    MethodToInvokeSwitchCaseAnalyzer switchCaseAnalyzer = new MethodToInvokeSwitchCaseAnalyzer();
+    if (switchCaseAnalyzer.isInitialized()) {
+      ThreadUtils.processItems(
+          builders.entrySet(),
+          entry -> {
+            if (!appInfo.isLiveProgramClass(entry.getKey())) {
+              removeDeadBuilderReferencesFromDynamicMethod(
+                  appView, entry.getValue(), converter, codeRewriter, switchCaseAnalyzer);
+            }
+          },
+          executorService);
+    }
+    builders.clear();
+    timing.end(); // Remove dead builder references
+  }
+
+  private void removeDeadBuilderReferencesFromDynamicMethod(
+      AppView<AppInfoWithLiveness> appView,
+      DexEncodedMethod dynamicMethod,
+      IRConverter converter,
+      CodeRewriter codeRewriter,
+      SwitchCaseAnalyzer switchCaseAnalyzer) {
+    Origin origin = appView.appInfo().originFor(dynamicMethod.holder());
+    IRCode code = dynamicMethod.buildIR(appView, origin);
+    codeRewriter.rewriteSwitch(code, switchCaseAnalyzer);
+    converter.removeDeadCodeAndFinalizeIR(
+        dynamicMethod, code, OptimizationFeedbackSimple.getInstance(), Timing.empty());
+  }
+
   public static void addInliningHeuristicsForBuilderInlining(
       AppView<? extends AppInfoWithSubtyping> appView,
       PredicateSet<DexType> alwaysClassInline,
@@ -177,6 +236,47 @@
     }
   }
 
+  private class MethodToInvokeSwitchCaseAnalyzer extends SwitchCaseAnalyzer {
+
+    private final int newBuilderOrdinal;
+
+    private MethodToInvokeSwitchCaseAnalyzer() {
+      EnumValueInfoMap enumValueInfoMap =
+          appView.appInfo().withLiveness().getEnumValueInfoMap(references.methodToInvokeType);
+      if (enumValueInfoMap != null) {
+        EnumValueInfo newBuilderValueInfo =
+            enumValueInfoMap.getEnumValueInfo(references.methodToInvokeMembers.newBuilderField);
+        if (newBuilderValueInfo != null) {
+          newBuilderOrdinal = newBuilderValueInfo.ordinal;
+          return;
+        }
+      }
+      newBuilderOrdinal = -1;
+    }
+
+    public boolean isInitialized() {
+      return newBuilderOrdinal >= 0;
+    }
+
+    @Override
+    public boolean switchCaseIsUnreachable(IntSwitch theSwitch, int index) {
+      if (index != newBuilderOrdinal) {
+        return false;
+      }
+      Value switchValue = theSwitch.value();
+      if (!switchValue.isDefinedByInstructionSatisfying(Instruction::isInvokeVirtual)) {
+        return false;
+      }
+      InvokeVirtual definition = switchValue.definition.asInvokeVirtual();
+      if (definition.getInvokedMethod() != appView.dexItemFactory().enumMethods.ordinal) {
+        return false;
+      }
+      TypeLatticeElement enumType = definition.getReceiver().getTypeLattice();
+      return enumType.isClassType()
+          && enumType.asClassTypeLatticeElement().getClassType() == references.methodToInvokeType;
+    }
+  }
+
   private static class RootSetExtension {
 
     private final AppView<? extends AppInfoWithSubtyping> appView;
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
index d8d503d..f77eb9e 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
@@ -73,6 +73,7 @@
 import com.android.tools.r8.ir.code.ValueType;
 import com.android.tools.r8.ir.code.Xor;
 import com.android.tools.r8.ir.conversion.IRConverter;
+import com.android.tools.r8.ir.optimize.controlflow.SwitchCaseAnalyzer;
 import com.android.tools.r8.ir.regalloc.LinearScanRegisterAllocator;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.InternalOutputMode;
@@ -843,7 +844,11 @@
     return outliersAsIfSize;
   }
 
-  private boolean rewriteSwitch(IRCode code) {
+  public boolean rewriteSwitch(IRCode code) {
+    return rewriteSwitch(code, SwitchCaseAnalyzer.getInstance());
+  }
+
+  public boolean rewriteSwitch(IRCode code, SwitchCaseAnalyzer switchCaseAnalyzer) {
     if (!code.metadata().mayHaveIntSwitch()) {
       return false;
     }
@@ -859,7 +864,7 @@
           IntSwitch theSwitch = instruction.asIntSwitch();
           if (options.testing.enableDeadSwitchCaseElimination) {
             SwitchCaseEliminator eliminator =
-                removeUnnecessarySwitchCases(code, theSwitch, iterator);
+                removeUnnecessarySwitchCases(code, theSwitch, iterator, switchCaseAnalyzer);
             if (eliminator != null) {
               if (eliminator.mayHaveIntroducedUnreachableBlocks()) {
                 needToRemoveUnreachableBlocks = true;
@@ -1004,7 +1009,10 @@
   }
 
   private SwitchCaseEliminator removeUnnecessarySwitchCases(
-      IRCode code, IntSwitch theSwitch, InstructionListIterator iterator) {
+      IRCode code,
+      IntSwitch theSwitch,
+      InstructionListIterator iterator,
+      SwitchCaseAnalyzer switchCaseAnalyzer) {
     BasicBlock defaultTarget = theSwitch.fallthroughBlock();
     SwitchCaseEliminator eliminator = null;
     BasicBlockBehavioralSubsumption behavioralSubsumption =
@@ -1016,7 +1024,7 @@
 
       // This switch case can be removed if the behavior of the target block is equivalent to the
       // behavior of the default block, or if the switch case is unreachable.
-      if (switchCaseIsUnreachable(theSwitch, i)
+      if (switchCaseAnalyzer.switchCaseIsUnreachable(theSwitch, i)
           || behavioralSubsumption.isSubsumedBy(targetBlock, defaultTarget)) {
         if (eliminator == null) {
           eliminator = new SwitchCaseEliminator(theSwitch, iterator);
@@ -1030,12 +1038,6 @@
     return eliminator;
   }
 
-  private boolean switchCaseIsUnreachable(IntSwitch theSwitch, int index) {
-    Value switchValue = theSwitch.value();
-    return switchValue.hasValueRange()
-        && !switchValue.getValueRange().containsValue(theSwitch.getKey(index));
-  }
-
   /**
    * Rewrite all branch targets to the destination of trivial goto chains when possible. Does not
    * rewrite fallthrough targets as that would require block reordering and the transformation only
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/controlflow/SwitchCaseAnalyzer.java b/src/main/java/com/android/tools/r8/ir/optimize/controlflow/SwitchCaseAnalyzer.java
new file mode 100644
index 0000000..6c22a72
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/optimize/controlflow/SwitchCaseAnalyzer.java
@@ -0,0 +1,25 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize.controlflow;
+
+import com.android.tools.r8.ir.code.IntSwitch;
+import com.android.tools.r8.ir.code.Value;
+
+public class SwitchCaseAnalyzer {
+
+  private static final SwitchCaseAnalyzer INSTANCE = new SwitchCaseAnalyzer();
+
+  public SwitchCaseAnalyzer() {}
+
+  public static SwitchCaseAnalyzer getInstance() {
+    return INSTANCE;
+  }
+
+  public boolean switchCaseIsUnreachable(IntSwitch theSwitch, int index) {
+    Value switchValue = theSwitch.value();
+    return switchValue.hasValueRange()
+        && !switchValue.getValueRange().containsValue(theSwitch.getKey(index));
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
index 61ed827..4234043 100644
--- a/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
+++ b/src/main/java/com/android/tools/r8/shaking/AppInfoWithLiveness.java
@@ -698,6 +698,11 @@
     return result;
   }
 
+  public EnumValueInfoMapCollection getEnumValueInfoMapCollection() {
+    assert checkIfObsolete();
+    return enumValueInfoMaps;
+  }
+
   public EnumValueInfoMap getEnumValueInfoMap(DexType enumType) {
     assert checkIfObsolete();
     return enumValueInfoMaps.getEnumValueInfoMap(enumType);