Revising class map full loading.

Revise implementation of ClassMap.forceLoad(), replacing
collectLoadedClasses() with getAllClasses(), removing
unnecessary allocations in fully loaded class maps.

Bug:
Change-Id: Ia1a74c283797dc90835d55af53064d628ada8717
diff --git a/src/main/java/com/android/tools/r8/graph/Code.java b/src/main/java/com/android/tools/r8/graph/Code.java
index dca5e02..84d6d5d 100644
--- a/src/main/java/com/android/tools/r8/graph/Code.java
+++ b/src/main/java/com/android/tools/r8/graph/Code.java
@@ -34,15 +34,15 @@
   }
 
   public DexCode asDexCode() {
-    throw new Unreachable();
+    throw new Unreachable(getClass().getCanonicalName() + ".asDexCode()");
   }
 
   public JarCode asJarCode() {
-    throw new Unreachable();
+    throw new Unreachable(getClass().getCanonicalName() + ".asJarCode()");
   }
 
   public OutlineCode asOutlineCode() {
-    throw new Unreachable();
+    throw new Unreachable(getClass().getCanonicalName() + ".asOutlineCode()");
   }
 
   @Override
diff --git a/src/main/java/com/android/tools/r8/graph/DexApplication.java b/src/main/java/com/android/tools/r8/graph/DexApplication.java
index ca702fe..d84dd8c 100644
--- a/src/main/java/com/android/tools/r8/graph/DexApplication.java
+++ b/src/main/java/com/android/tools/r8/graph/DexApplication.java
@@ -25,7 +25,6 @@
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Comparator;
-import java.util.Hashtable;
 import java.util.IdentityHashMap;
 import java.util.List;
 import java.util.Map;
@@ -87,7 +86,8 @@
   }
 
   public List<DexProgramClass> classes() {
-    List<DexProgramClass> classes = programClasses.collectLoadedClasses();
+    programClasses.forceLoad(type -> true);
+    List<DexProgramClass> classes = programClasses.getAllClasses();
     assert reorderClasses(classes);
     return classes;
   }
@@ -110,16 +110,16 @@
 
     // program classes are supposed to be loaded, but force-loading them is no-op.
     programClasses.forceLoad(type -> true);
-    programClasses.collectLoadedClasses().forEach(clazz -> loaded.put(clazz.type, clazz));
+    programClasses.getAllClasses().forEach(clazz -> loaded.put(clazz.type, clazz));
 
     if (classpathClasses != null) {
       classpathClasses.forceLoad(type -> !loaded.containsKey(type));
-      classpathClasses.collectLoadedClasses().forEach(clazz -> loaded.put(clazz.type, clazz));
+      classpathClasses.getAllClasses().forEach(clazz -> loaded.putIfAbsent(clazz.type, clazz));
     }
 
     if (libraryClasses != null) {
       libraryClasses.forceLoad(type -> !loaded.containsKey(type));
-      libraryClasses.collectLoadedClasses().forEach(clazz -> loaded.put(clazz.type, clazz));
+      libraryClasses.getAllClasses().forEach(clazz -> loaded.putIfAbsent(clazz.type, clazz));
     }
 
     return loaded;
@@ -192,7 +192,7 @@
    * <p>If no directory is provided everything is written to System.out.
    */
   public void disassemble(Path outputDir, InternalOptions options) {
-    for (DexProgramClass clazz : programClasses.collectLoadedClasses()) {
+    for (DexProgramClass clazz : programClasses.getAllClasses()) {
       for (DexEncodedMethod method : clazz.virtualMethods()) {
         if (options.methodMatchesFilter(method)) {
           disassemble(method, getProguardMap(), outputDir);
@@ -246,7 +246,7 @@
    * Write smali source for the application code on the provided PrintStream.
    */
   public void smali(InternalOptions options, PrintStream ps) {
-    List<DexProgramClass> classes = programClasses.collectLoadedClasses();
+    List<DexProgramClass> classes = programClasses.getAllClasses();
     classes.sort(Comparator.comparing(DexProgramClass::toSourceString));
     boolean firstClass = true;
     for (DexClass clazz : classes) {
@@ -322,7 +322,7 @@
     }
 
     public Builder(DexApplication application) {
-      programClasses = application.programClasses.collectLoadedClasses();
+      programClasses = application.programClasses.getAllClasses();
       classpathClasses = application.classpathClasses;
       libraryClasses = application.libraryClasses;
       proguardMap = application.proguardMap;
diff --git a/src/main/java/com/android/tools/r8/graph/DexClasspathClass.java b/src/main/java/com/android/tools/r8/graph/DexClasspathClass.java
index dd5efbb..036b8bc 100644
--- a/src/main/java/com/android/tools/r8/graph/DexClasspathClass.java
+++ b/src/main/java/com/android/tools/r8/graph/DexClasspathClass.java
@@ -7,8 +7,9 @@
 import com.android.tools.r8.dex.IndexedItemCollection;
 import com.android.tools.r8.dex.MixedSectionCollection;
 import com.android.tools.r8.errors.Unreachable;
+import java.util.function.Supplier;
 
-public class DexClasspathClass extends DexClass {
+public class DexClasspathClass extends DexClass implements Supplier<DexClasspathClass> {
 
   public DexClasspathClass(DexType type, Resource.Kind origin, DexAccessFlags accessFlags,
       DexType superType, DexTypeList interfaces, DexString sourceFile, DexAnnotationSet annotations,
@@ -43,4 +44,9 @@
   public DexClasspathClass asClasspathClass() {
     return this;
   }
+
+  @Override
+  public DexClasspathClass get() {
+    return this;
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/graph/DexLibraryClass.java b/src/main/java/com/android/tools/r8/graph/DexLibraryClass.java
index c1b9205..0c8ce8c 100644
--- a/src/main/java/com/android/tools/r8/graph/DexLibraryClass.java
+++ b/src/main/java/com/android/tools/r8/graph/DexLibraryClass.java
@@ -7,8 +7,9 @@
 import com.android.tools.r8.dex.IndexedItemCollection;
 import com.android.tools.r8.dex.MixedSectionCollection;
 import com.android.tools.r8.errors.Unreachable;
+import java.util.function.Supplier;
 
-public class DexLibraryClass extends DexClass {
+public class DexLibraryClass extends DexClass implements Supplier<DexLibraryClass> {
 
   public DexLibraryClass(DexType type, Resource.Kind origin, DexAccessFlags accessFlags,
       DexType superType, DexTypeList interfaces, DexString sourceFile, DexAnnotationSet annotations,
@@ -48,4 +49,9 @@
   public DexLibraryClass asLibraryClass() {
     return this;
   }
+
+  @Override
+  public DexLibraryClass get() {
+    return this;
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/graph/DexProgramClass.java b/src/main/java/com/android/tools/r8/graph/DexProgramClass.java
index f93a7da..a38c9cb 100644
--- a/src/main/java/com/android/tools/r8/graph/DexProgramClass.java
+++ b/src/main/java/com/android/tools/r8/graph/DexProgramClass.java
@@ -7,8 +7,9 @@
 import com.android.tools.r8.dex.IndexedItemCollection;
 import com.android.tools.r8.dex.MixedSectionCollection;
 import java.util.Arrays;
+import java.util.function.Supplier;
 
-public class DexProgramClass extends DexClass {
+public class DexProgramClass extends DexClass implements Supplier<DexProgramClass> {
 
   private DexEncodedArray staticValues;
 
@@ -152,4 +153,9 @@
     directMethods = Arrays.copyOf(directMethods, directMethods.length + 1);
     directMethods[directMethods.length - 1] = staticMethod;
   }
+
+  @Override
+  public DexProgramClass get() {
+    return this;
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/utils/ClassMap.java b/src/main/java/com/android/tools/r8/utils/ClassMap.java
index 91adca4..9677473 100644
--- a/src/main/java/com/android/tools/r8/utils/ClassMap.java
+++ b/src/main/java/com/android/tools/r8/utils/ClassMap.java
@@ -4,17 +4,19 @@
 package com.android.tools.r8.utils;
 
 import com.android.tools.r8.errors.CompilationError;
+import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.graph.ClassKind;
 import com.android.tools.r8.graph.DexClass;
 import com.android.tools.r8.graph.DexType;
 import com.google.common.collect.Sets;
 import java.util.ArrayList;
-import java.util.Collection;
 import java.util.IdentityHashMap;
+import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.function.Predicate;
+import java.util.function.Supplier;
 
 /**
  * Represents a collection of classes. Collection can be fully loaded,
@@ -25,13 +27,17 @@
   // resources provided by different resource providers.
   //
   // NOTE: all access must be synchronized on `classes`.
-  private final Map<DexType, Value<T>> classes;
+  private final Map<DexType, Supplier<T>> classes;
 
-  // Class provider if available. In case it's `null`, all classes of
-  // the collection must be pre-populated in `classes`.
-  private final ClassProvider<T> classProvider;
+  // Class provider if available.
+  //
+  // If the class provider is `null` it indicates that all classes are already present
+  // in a map referenced by `classes` and thus the collection is fully loaded.
+  //
+  // NOTE: all access must be synchronized on `classes`.
+  private ClassProvider<T> classProvider;
 
-  ClassMap(Map<DexType, Value<T>> classes, ClassProvider<T> classProvider) {
+  ClassMap(Map<DexType, Supplier<T>> classes, ClassProvider<T> classProvider) {
     this.classes = classes == null ? new IdentityHashMap<>() : classes;
     this.classProvider = classProvider;
     assert this.classProvider == null || this.classProvider.getClassKind() == getClassKind();
@@ -40,6 +46,9 @@
   /** Resolves a class conflict by selecting a class, may generate compilation error. */
   abstract T resolveClassConflict(T a, T b);
 
+  /** Return supplier for preloaded class. */
+  abstract Supplier<T> getTransparentSupplier(T clazz);
+
   /** Kind of the classes supported by this collection. */
   abstract ClassKind getClassKind();
 
@@ -53,115 +62,175 @@
 
   /** Returns a definition for a class or `null` if there is no such class in the collection. */
   public T get(DexType type) {
-    final Value<T> value = getOrCreateValue(type);
+    Supplier<T> supplier;
 
-    if (value == null) {
-      return null;
-    }
-
-    if (!value.ready) {
-      // Load the value in a context synchronized on value instance. This way
-      // we avoid locking the whole collection during expensive resource loading
-      // and classes construction operations.
-      synchronized (value) {
-        if (!value.ready) {
-          assert classProvider != null : "getOrCreateValue() created "
-              + "Value for missing type when there is no classProvider.";
-          classProvider.collectClass(type, clazz -> {
-            assert clazz != null;
-            assert getClassKind().isOfKind(clazz);
-            assert !value.ready;
-
-            if (clazz.type != type) {
-              throw new CompilationError("Class content provided for type descriptor "
-                  + type.toSourceString() + " actually defines class " + clazz.type
-                  .toSourceString());
-            }
-
-            if (value.clazz == null) {
-              value.clazz = clazz;
-            } else {
-              // The class resolution *may* generate a compilation error as one of
-              // possible resolutions. In this case we leave `value` in (false, null)
-              // state so in rare case of another thread trying to get the same class
-              // before this error is propagated it will get the same conflict.
-              T oldClass = value.clazz;
-              value.clazz = null;
-              value.clazz = resolveClassConflict(oldClass, clazz);
-            }
-          });
-          value.ready = true;
-        }
-      }
-    }
-
-    assert value.ready;
-    return value.clazz;
-  }
-
-  private Value<T> getOrCreateValue(DexType type) {
     synchronized (classes) {
-      Value<T> value = classes.get(type);
-      if (value == null && classProvider != null) {
-        value = new Value<>();
-        classes.put(type, value);
+      supplier = classes.get(type);
+
+      // Get class supplier, create it if it does not
+      // exist and the collection is NOT fully loaded.
+      if (supplier == null) {
+        if (classProvider == null) {
+          // There is no supplier, but the collection is fully loaded.
+          return null;
+        }
+
+        supplier = new ConcurrentClassLoader<>(this, this.classProvider, type);
+        classes.put(type, supplier);
       }
-      return value;
     }
+
+    return supplier.get();
   }
 
-  /**
-   * Returns currently loaded classes.
-   *
-   * Method is assumed to be called when the collection is fully loaded,
-   * otherwise only a subset of potentially loaded classes may be returned.
-   */
-  public List<T> collectLoadedClasses() {
+  /** Returns all classes from the collection. The collection must be force-loaded. */
+  public List<T> getAllClasses() {
     List<T> loadedClasses = new ArrayList<>();
     synchronized (classes) {
-      for (Value<T> value : classes.values()) {
-        // NOTE: value mutations are NOT synchronized on `classes`, here we actually
-        // can see value which is not ready yet. Since everything that exists should
-        // be guaranteed by the caller to be loaded at this point, this can only happen
-        // if the code references classes that do not exist. Therefore, if the value is
-        // not ready here, we know that the loaded value will be 'null' once it is ready.
-        if (value.ready && value.clazz != null) {
-          loadedClasses.add(value.clazz);
-        }
+      if (classProvider != null) {
+        throw new Unreachable("Getting all classes from not fully loaded collection.");
+      }
+      for (Supplier<T> supplier : classes.values()) {
+        // Since the class map is fully loaded, all suppliers must be
+        // loaded and non-null.
+        T clazz = supplier.get();
+        assert clazz != null;
+        loadedClasses.add(clazz);
       }
     }
     return loadedClasses;
   }
 
-  /** Forces loading of all the classes satisfying the criteria specified. */
+  /**
+   * Forces loading of all the classes satisfying the criteria specified.
+   *
+   * NOTE: after this method finishes, the class map is considered to be fully-loaded
+   * and thus sealed. This has one side-effect: if we filter out some of the classes
+   * with `load` predicate, these classes will never be loaded.
+   */
   public void forceLoad(Predicate<DexType> load) {
-    if (classProvider != null) {
-      Set<DexType> loaded = Sets.newIdentityHashSet();
-      synchronized (classes) {
-        loaded.addAll(classes.keySet());
+    Set<DexType> knownClasses;
+    ClassProvider<T> classProvider;
+
+    synchronized (classes) {
+      classProvider = this.classProvider;
+      if (classProvider == null) {
+        return;
       }
-      Collection<DexType> types = classProvider.collectTypes();
-      for (DexType type : types) {
-        if (load.test(type) && !loaded.contains(type)) {
-          get(type); // force-load type.
+
+      // Collects the types which might be represented in fully loaded class map.
+      knownClasses = Sets.newIdentityHashSet();
+      knownClasses.addAll(classes.keySet());
+    }
+
+    // Add all types the class provider provides. Note that it may take time for class
+    // provider to collect these types, so ve do it outside synchronized context.
+    knownClasses.addAll(classProvider.collectTypes());
+
+    // Make sure all the types in `knownClasses` are loaded.
+    //
+    // We just go and touch every class, thus triggering their loading if they
+    // are not loaded so far. In case the class has already been loaded,
+    // touching the class will be a no-op with minimal overhead.
+    for (DexType type : knownClasses) {
+      if (load.test(type)) {
+        get(type);
+      }
+    }
+
+    synchronized (classes) {
+      if (this.classProvider == null) {
+        return; // Has been force-loaded concurrently.
+      }
+
+      // We avoid calling get() on a class supplier unless we know it was loaded.
+      // At this time `classes` may have more types then `knownClasses`, but for
+      // all extra classes we expect the supplier to return 'null' after loading.
+      Iterator<Map.Entry<DexType, Supplier<T>>> iterator = classes.entrySet().iterator();
+      while (iterator.hasNext()) {
+        Map.Entry<DexType, Supplier<T>> e = iterator.next();
+
+        if (knownClasses.contains(e.getKey())) {
+          // Get the class (it is expected to be loaded by this time).
+          T clazz = e.getValue().get();
+          if (clazz != null) {
+            // Since the class is already loaded, get rid of possible wrapping suppliers.
+            assert clazz.type == e.getKey();
+            e.setValue(getTransparentSupplier(clazz));
+            continue;
+          }
         }
+
+        // If the type is not in `knownClasses` or resolves to `null`,
+        // just remove the record from the map.
+        iterator.remove();
       }
+
+      // Mark the class map as fully loaded.
+      this.classProvider = null;
     }
   }
 
-  // Represents a value in the class map.
-  final static class Value<T> {
-    volatile boolean ready;
-    T clazz;
+  // Supplier implementing a thread-safe loader for a class loaded from a
+  // class provider. Helps avoid synchronizing on the whole class map
+  // when loading a class.
+  private static class ConcurrentClassLoader<T extends DexClass> implements Supplier<T> {
+    private ClassMap<T> classMap;
+    private ClassProvider<T> provider;
+    private DexType type;
 
-    Value() {
-      ready = false;
-      clazz = null;
+    private T clazz = null;
+    private volatile boolean ready = false;
+
+    ConcurrentClassLoader(ClassMap<T> classMap, ClassProvider<T> provider, DexType type) {
+      this.classMap = classMap;
+      this.provider = provider;
+      this.type = type;
     }
 
-    Value(T clazz) {
-      this.clazz = clazz;
-      this.ready = true;
+    @Override
+    public T get() {
+      if (ready) {
+        return clazz;
+      }
+
+      synchronized (this) {
+        if (!ready) {
+          assert classMap != null && provider != null && type != null;
+          provider.collectClass(type, createdClass -> {
+            assert createdClass != null;
+            assert classMap.getClassKind().isOfKind(createdClass);
+            assert !ready;
+
+            if (createdClass.type != type) {
+              throw new CompilationError(
+                  "Class content provided for type descriptor " + type.toSourceString() +
+                      " actually defines class " + createdClass.type.toSourceString());
+            }
+
+            if (clazz == null) {
+              clazz = createdClass;
+            } else {
+              // The class resolution *may* generate a compilation error as one of
+              // possible resolutions. In this case we leave `value` in (false, null)
+              // state so in rare case of another thread trying to get the same class
+              // before this error is propagated it will get the same conflict.
+              T oldClass = clazz;
+              clazz = null;
+              clazz = classMap.resolveClassConflict(oldClass, createdClass);
+            }
+          });
+
+          classMap = null;
+          provider = null;
+          type = null;
+          ready = true;
+        }
+      }
+
+      assert ready;
+      assert classMap == null && provider == null && type == null;
+      return clazz;
     }
   }
 }
diff --git a/src/main/java/com/android/tools/r8/utils/ClasspathClassCollection.java b/src/main/java/com/android/tools/r8/utils/ClasspathClassCollection.java
index ee00c39..45002d0 100644
--- a/src/main/java/com/android/tools/r8/utils/ClasspathClassCollection.java
+++ b/src/main/java/com/android/tools/r8/utils/ClasspathClassCollection.java
@@ -6,6 +6,7 @@
 import com.android.tools.r8.errors.CompilationError;
 import com.android.tools.r8.graph.ClassKind;
 import com.android.tools.r8.graph.DexClasspathClass;
+import java.util.function.Supplier;
 
 /** Represents a collection of classpath classes. */
 public class ClasspathClassCollection extends ClassMap<DexClasspathClass> {
@@ -19,6 +20,11 @@
   }
 
   @Override
+  Supplier<DexClasspathClass> getTransparentSupplier(DexClasspathClass clazz) {
+    return clazz;
+  }
+
+  @Override
   ClassKind getClassKind() {
     return ClassKind.CLASSPATH;
   }
diff --git a/src/main/java/com/android/tools/r8/utils/LibraryClassCollection.java b/src/main/java/com/android/tools/r8/utils/LibraryClassCollection.java
index e11e660..c94d7be 100644
--- a/src/main/java/com/android/tools/r8/utils/LibraryClassCollection.java
+++ b/src/main/java/com/android/tools/r8/utils/LibraryClassCollection.java
@@ -9,6 +9,7 @@
 import com.android.tools.r8.graph.DexApplication;
 import com.android.tools.r8.graph.DexLibraryClass;
 import com.android.tools.r8.logging.Log;
+import java.util.function.Supplier;
 
 /** Represents a collection of library classes. */
 public class LibraryClassCollection extends ClassMap<DexLibraryClass> {
@@ -31,6 +32,11 @@
   }
 
   @Override
+  Supplier<DexLibraryClass> getTransparentSupplier(DexLibraryClass clazz) {
+    return clazz;
+  }
+
+  @Override
   ClassKind getClassKind() {
     return ClassKind.LIBRARY;
   }
diff --git a/src/main/java/com/android/tools/r8/utils/ProgramClassCollection.java b/src/main/java/com/android/tools/r8/utils/ProgramClassCollection.java
index a37eb92..590ba83 100644
--- a/src/main/java/com/android/tools/r8/utils/ProgramClassCollection.java
+++ b/src/main/java/com/android/tools/r8/utils/ProgramClassCollection.java
@@ -11,25 +11,20 @@
 import com.android.tools.r8.ir.desugar.LambdaRewriter;
 import java.util.IdentityHashMap;
 import java.util.List;
+import java.util.function.Supplier;
 
 /** Represents a collection of library classes. */
 public class ProgramClassCollection extends ClassMap<DexProgramClass> {
   public static ProgramClassCollection create(List<DexProgramClass> classes) {
     // We have all classes preloaded, but not necessarily without conflicts.
-    IdentityHashMap<DexType, Value<DexProgramClass>> map = new IdentityHashMap<>();
+    IdentityHashMap<DexType, Supplier<DexProgramClass>> map = new IdentityHashMap<>();
     for (DexProgramClass clazz : classes) {
-      Value<DexProgramClass> value = map.get(clazz.type);
-      if (value == null) {
-        value = new Value<>(clazz);
-        map.put(clazz.type, value);
-      } else {
-        value.clazz = resolveClassConflictImpl(value.clazz, clazz);
-      }
+      map.merge(clazz.type, clazz, (a, b) -> resolveClassConflictImpl(a.get(), b.get()));
     }
     return new ProgramClassCollection(map);
   }
 
-  private ProgramClassCollection(IdentityHashMap<DexType, Value<DexProgramClass>> classes) {
+  private ProgramClassCollection(IdentityHashMap<DexType, Supplier<DexProgramClass>> classes) {
     super(classes, null);
   }
 
@@ -44,6 +39,11 @@
   }
 
   @Override
+  Supplier<DexProgramClass> getTransparentSupplier(DexProgramClass clazz) {
+    return clazz;
+  }
+
+  @Override
   ClassKind getClassKind() {
     return ClassKind.PROGRAM;
   }