Register field accesses to stateless lambda singleton fields

Bug: 165229577
Change-Id: I559538c6c058760094bea441f88d91f0cfc5ee16
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/LambdaRewriter.java b/src/main/java/com/android/tools/r8/ir/desugar/LambdaRewriter.java
index fd4eede..d27fb16 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/LambdaRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/LambdaRewriter.java
@@ -18,7 +18,6 @@
 import com.android.tools.r8.graph.CfCode;
 import com.android.tools.r8.graph.DexApplication.Builder;
 import com.android.tools.r8.graph.DexCallSite;
-import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexItemFactory;
 import com.android.tools.r8.graph.DexMethod;
@@ -133,7 +132,7 @@
    */
   public int desugarLambdas(ProgramMethod method, AppInfoWithClassHierarchy appInfo) {
     return desugarLambdas(
-        method.getDefinition(),
+        method,
         callsite -> {
           LambdaDescriptor descriptor = LambdaDescriptor.tryInfer(callsite, appInfo, method);
           if (descriptor == null) {
@@ -145,8 +144,8 @@
 
   // Same as above, but where lambdas are always known to exist for the call sites.
   public static int desugarLambdas(
-      DexEncodedMethod method, Function<DexCallSite, LambdaClass> callSites) {
-    CfCode code = method.getCode().asCfCode();
+      ProgramMethod method, Function<DexCallSite, LambdaClass> callSites) {
+    CfCode code = method.getDefinition().getCode().asCfCode();
     List<CfInstruction> instructions = code.getInstructions();
     Supplier<List<CfInstruction>> lazyNewInstructions =
         Suppliers.memoize(() -> new ArrayList<>(instructions));
diff --git a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
index f6b5aa2..dd08886 100644
--- a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
+++ b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
@@ -115,6 +115,7 @@
 import com.android.tools.r8.utils.Visibility;
 import com.android.tools.r8.utils.WorkList;
 import com.android.tools.r8.utils.collections.ProgramFieldSet;
+import com.android.tools.r8.utils.collections.ProgramMethodMap;
 import com.android.tools.r8.utils.collections.ProgramMethodSet;
 import com.google.common.base.Equivalence.Wrapper;
 import com.google.common.collect.ImmutableSet;
@@ -364,8 +365,8 @@
   private final DesugaredLibraryConversionWrapperAnalysis desugaredLibraryWrapperAnalysis;
   private final Map<DexType, Pair<LambdaClass, ProgramMethod>> lambdaClasses =
       new IdentityHashMap<>();
-  private final Map<DexEncodedMethod, Map<DexCallSite, LambdaClass>> lambdaCallSites =
-      new IdentityHashMap<>();
+  private final ProgramMethodMap<Map<DexCallSite, LambdaClass>> lambdaCallSites =
+      ProgramMethodMap.create();
   private final Map<DexMethod, ProgramMethod> methodsWithBackports = new IdentityHashMap<>();
   private final Set<DexProgramClass> classesWithSerializableLambdas = Sets.newIdentityHashSet();
 
@@ -929,7 +930,7 @@
         LambdaClass lambdaClass = lambdaRewriter.getOrCreateLambdaClass(descriptor, context);
         lambdaClasses.put(lambdaClass.type, new Pair<>(lambdaClass, context));
         lambdaCallSites
-            .computeIfAbsent(contextMethod, k -> new IdentityHashMap<>())
+            .computeIfAbsent(context, k -> new IdentityHashMap<>())
             .put(callSite, lambdaClass);
         if (lambdaClass.descriptor.interfaces.contains(appView.dexItemFactory().serializableType)) {
           classesWithSerializableLambdas.add(context.getHolder());
@@ -3023,6 +3024,9 @@
     Map<DexType, Pair<DexProgramClass, ProgramMethod>> syntheticInstantiations =
         new IdentityHashMap<>();
 
+    ProgramMethodMap<Set<DexField>> syntheticStaticFieldReadsByContext =
+        ProgramMethodMap.createLinked();
+
     Map<DexMethod, ProgramMethod> liveMethods = new IdentityHashMap<>();
 
     Map<DexType, DexClasspathClass> syntheticClasspathClasses = new IdentityHashMap<>();
@@ -3099,6 +3103,10 @@
             InstantiationReason.SYNTHESIZED_CLASS,
             fakeReason);
       }
+      syntheticStaticFieldReadsByContext.forEach(
+          (context, fields) ->
+              fields.forEach(
+                  field -> enqueuer.workList.enqueueTraceStaticFieldRead(field, context)));
       for (ProgramMethod liveMethod : liveMethods.values()) {
         assert !enqueuer.targetedMethods.contains(liveMethod.getDefinition());
         enqueuer.markMethodAsTargeted(liveMethod, fakeReason);
@@ -3106,6 +3114,26 @@
       }
       enqueuer.liveNonProgramTypes.addAll(syntheticClasspathClasses.values());
     }
+
+    void registerStatelessLambdaInstanceFieldReads(
+        ProgramMethodMap<Map<DexCallSite, LambdaClass>> lambdaCallSites) {
+      lambdaCallSites.forEach(this::registerStatelessLambdaInstanceFieldReads);
+    }
+
+    private void registerStatelessLambdaInstanceFieldReads(
+        ProgramMethod context, Map<DexCallSite, LambdaClass> callSites) {
+      Set<DexField> syntheticStaticFieldReadsInContext = null;
+      for (LambdaClass lambdaClass : callSites.values()) {
+        if (lambdaClass.isStateless()) {
+          if (syntheticStaticFieldReadsInContext == null) {
+            syntheticStaticFieldReadsInContext =
+                syntheticStaticFieldReadsByContext.computeIfAbsent(
+                    context, ignore -> Sets.newLinkedHashSet());
+          }
+          syntheticStaticFieldReadsInContext.add(lambdaClass.lambdaField);
+        }
+      }
+    }
   }
 
   private void synthesize() {
@@ -3179,6 +3207,7 @@
 
     // Rewrite all of the invoke-dynamic instructions to lambda class instantiations.
     lambdaCallSites.forEach(this::rewriteLambdaCallSites);
+    additions.registerStatelessLambdaInstanceFieldReads(lambdaCallSites);
 
     // Remove all '$deserializeLambda$' methods which are not supported by desugaring.
     for (DexProgramClass clazz : classesWithSerializableLambdas) {
@@ -3441,9 +3470,9 @@
   }
 
   private void rewriteLambdaCallSites(
-      DexEncodedMethod method, Map<DexCallSite, LambdaClass> callSites) {
+      ProgramMethod context, Map<DexCallSite, LambdaClass> callSites) {
     assert !callSites.isEmpty();
-    int replaced = LambdaRewriter.desugarLambdas(method, callSites::get);
+    int replaced = LambdaRewriter.desugarLambdas(context, callSites::get);
     assert replaced == callSites.size();
   }
 
diff --git a/src/main/java/com/android/tools/r8/utils/collections/ProgramMethodMap.java b/src/main/java/com/android/tools/r8/utils/collections/ProgramMethodMap.java
new file mode 100644
index 0000000..d49928a
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/utils/collections/ProgramMethodMap.java
@@ -0,0 +1,57 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.utils.collections;
+
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.utils.ProgramMethodEquivalence;
+import com.google.common.base.Equivalence.Wrapper;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.function.BiConsumer;
+import java.util.function.Function;
+import java.util.function.Supplier;
+
+public class ProgramMethodMap<V> {
+
+  private final Map<Wrapper<ProgramMethod>, V> backing;
+
+  private ProgramMethodMap(Supplier<Map<Wrapper<ProgramMethod>, V>> backingFactory) {
+    backing = backingFactory.get();
+  }
+
+  public static <V> ProgramMethodMap<V> create() {
+    return new ProgramMethodMap<>(HashMap::new);
+  }
+
+  public static <V> ProgramMethodMap<V> createLinked() {
+    return new ProgramMethodMap<>(LinkedHashMap::new);
+  }
+
+  public void clear() {
+    backing.clear();
+  }
+
+  public V computeIfAbsent(ProgramMethod method, Function<ProgramMethod, V> fn) {
+    return backing.computeIfAbsent(wrap(method), key -> fn.apply(key.get()));
+  }
+
+  public void forEach(BiConsumer<ProgramMethod, V> consumer) {
+    backing.forEach((wrapper, value) -> consumer.accept(wrapper.get(), value));
+  }
+
+  public boolean isEmpty() {
+    return backing.isEmpty();
+  }
+
+  public V put(ProgramMethod method, V value) {
+    Wrapper<ProgramMethod> wrapper = ProgramMethodEquivalence.get().wrap(method);
+    return backing.put(wrapper, value);
+  }
+
+  private static Wrapper<ProgramMethod> wrap(ProgramMethod method) {
+    return ProgramMethodEquivalence.get().wrap(method);
+  }
+}