[Partition] Track package names for fast lookup before key request

Change-Id: I31c73dd31e79e27eac6385d29976e13cbfc784b4
diff --git a/src/main/java/com/android/tools/r8/naming/ClassNameMapper.java b/src/main/java/com/android/tools/r8/naming/ClassNameMapper.java
index cfd7b32..ad33448 100644
--- a/src/main/java/com/android/tools/r8/naming/ClassNameMapper.java
+++ b/src/main/java/com/android/tools/r8/naming/ClassNameMapper.java
@@ -18,6 +18,7 @@
 import com.android.tools.r8.position.Position;
 import com.android.tools.r8.utils.BiMapContainer;
 import com.android.tools.r8.utils.ChainableStringConsumer;
+import com.android.tools.r8.utils.DescriptorUtils;
 import com.android.tools.r8.utils.Reporter;
 import com.google.common.collect.BiMap;
 import com.google.common.collect.ImmutableBiMap;
@@ -32,6 +33,7 @@
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.Iterator;
 import java.util.LinkedHashSet;
 import java.util.List;
@@ -247,6 +249,14 @@
     return preamble;
   }
 
+  public Set<String> getObfuscatedPackages() {
+    Set<String> packages = new HashSet<>();
+    classNameMappings.forEach(
+        (s, classNamingForNameMapper) ->
+            packages.add(DescriptorUtils.getPackageNameFromTypeName(s)));
+    return packages;
+  }
+
   public void setPreamble(List<String> preamble) {
     this.preamble = preamble;
   }
diff --git a/src/main/java/com/android/tools/r8/retrace/internal/MappingPartitionMetadataInternal.java b/src/main/java/com/android/tools/r8/retrace/internal/MappingPartitionMetadataInternal.java
index 9896856..b5438c6 100644
--- a/src/main/java/com/android/tools/r8/retrace/internal/MappingPartitionMetadataInternal.java
+++ b/src/main/java/com/android/tools/r8/retrace/internal/MappingPartitionMetadataInternal.java
@@ -43,7 +43,7 @@
   }
 
   default MetadataAdditionalInfo getAdditionalInfo() {
-    return MetadataAdditionalInfo.create(null);
+    return MetadataAdditionalInfo.create(null, null);
   }
 
   // Magic byte put into the metadata
diff --git a/src/main/java/com/android/tools/r8/retrace/internal/MetadataAdditionalInfo.java b/src/main/java/com/android/tools/r8/retrace/internal/MetadataAdditionalInfo.java
index 706584b..7e0e22d 100644
--- a/src/main/java/com/android/tools/r8/retrace/internal/MetadataAdditionalInfo.java
+++ b/src/main/java/com/android/tools/r8/retrace/internal/MetadataAdditionalInfo.java
@@ -5,20 +5,30 @@
 package com.android.tools.r8.retrace.internal;
 
 import com.android.tools.r8.dex.CompatByteBuffer;
+import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.retrace.RetracePartitionException;
 import com.android.tools.r8.utils.SerializationUtils;
 import com.android.tools.r8.utils.StringUtils;
 import java.io.ByteArrayOutputStream;
 import java.io.DataOutputStream;
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.Predicate;
 
 public class MetadataAdditionalInfo {
 
+  private static final int NUMBER_OF_ELEMENTS = 2;
+
   public enum AdditionalInfoTypes {
     UNKNOWN(-1),
-    PREAMBLE(0);
+    PREAMBLE(0),
+    OBFUSCATED_PACKAGES(1);
 
     private final int serializedKey;
 
@@ -29,15 +39,19 @@
     static AdditionalInfoTypes getByKey(int serializedKey) {
       if (serializedKey == 0) {
         return PREAMBLE;
+      } else if (serializedKey == 1) {
+        return OBFUSCATED_PACKAGES;
       }
       return UNKNOWN;
     }
   }
 
   protected final List<String> preamble;
+  protected final Set<String> obfuscatedPackages;
 
-  private MetadataAdditionalInfo(List<String> preamble) {
+  private MetadataAdditionalInfo(List<String> preamble, Set<String> obfuscatedPackages) {
     this.preamble = preamble;
+    this.obfuscatedPackages = obfuscatedPackages;
   }
 
   public boolean hasPreamble() {
@@ -48,70 +62,116 @@
     return preamble;
   }
 
+  public boolean hasObfuscatedPackages() {
+    return obfuscatedPackages != null;
+  }
+
+  public Set<String> getObfuscatedPackages() {
+    return obfuscatedPackages;
+  }
+
   // The serialized format is an extensible list where we first record the offsets for each data
   // section and then emit the data.
   // <total-size:int><number-of-elements:short>[<type-i:short><length-i:int><data-i>]
   public void serialize(DataOutputStream dataOutputStream) throws IOException {
     ByteArrayOutputStream temp = new ByteArrayOutputStream();
     DataOutputStream additionalInfoStream = new DataOutputStream(temp);
-    additionalInfoStream.writeShort(1);
+    additionalInfoStream.writeShort(NUMBER_OF_ELEMENTS);
     additionalInfoStream.writeShort(AdditionalInfoTypes.PREAMBLE.serializedKey);
     SerializationUtils.writeUTFOfIntSize(additionalInfoStream, StringUtils.unixLines(preamble));
+    additionalInfoStream.writeShort(AdditionalInfoTypes.OBFUSCATED_PACKAGES.serializedKey);
+    List<String> sortedPackages = new ArrayList<>(obfuscatedPackages);
+    Collections.sort(sortedPackages);
+    SerializationUtils.writeUTFOfIntSize(
+        additionalInfoStream, StringUtils.unixLines(sortedPackages));
     byte[] payload = temp.toByteArray();
     dataOutputStream.writeInt(payload.length);
     dataOutputStream.write(payload);
   }
 
-  private static MetadataAdditionalInfo deserialize(byte[] bytes) {
+  private static MetadataAdditionalInfo deserialize(
+      byte[] bytes, Predicate<AdditionalInfoTypes> serializeSection) {
     CompatByteBuffer compatByteBuffer = CompatByteBuffer.wrap(bytes);
     int numberOfElements = compatByteBuffer.getShort();
     List<String> preamble = null;
+    Set<String> packages = null;
     for (int i = 0; i < numberOfElements; i++) {
       // We are parsing <type:short><length:int><bytes>
       int additionInfoTypeKey = compatByteBuffer.getShort();
       AdditionalInfoTypes additionalInfoType = AdditionalInfoTypes.getByKey(additionInfoTypeKey);
-      if (additionalInfoType == AdditionalInfoTypes.PREAMBLE) {
-        preamble = StringUtils.splitLines(compatByteBuffer.getUTFOfIntSize());
-      } else {
+      if (additionalInfoType == AdditionalInfoTypes.UNKNOWN) {
         throw new RetracePartitionException(
             "Could not additional info from key: " + additionInfoTypeKey);
       }
+      if (serializeSection.test(additionalInfoType)) {
+        switch (additionalInfoType) {
+          case PREAMBLE:
+            preamble = StringUtils.splitLines(compatByteBuffer.getUTFOfIntSize());
+            break;
+          case OBFUSCATED_PACKAGES:
+            packages = StringUtils.splitLinesIntoSet(compatByteBuffer.getUTFOfIntSize());
+            break;
+          default:
+            throw new Unreachable("Unreachable since we already checked for UNKNOWN");
+        }
+      } else {
+        int length = compatByteBuffer.getInt();
+        compatByteBuffer.position(compatByteBuffer.position() + length);
+      }
     }
-    return new MetadataAdditionalInfo(preamble);
+    return new MetadataAdditionalInfo(preamble, packages);
   }
 
-  public static MetadataAdditionalInfo create(List<String> preamble) {
-    return new MetadataAdditionalInfo(preamble);
+  public static MetadataAdditionalInfo create(
+      List<String> preamble, Set<String> obfuscatedPackages) {
+    return new MetadataAdditionalInfo(preamble, obfuscatedPackages);
   }
 
   public static class LazyMetadataAdditionalInfo extends MetadataAdditionalInfo {
 
-    private byte[] bytes;
-    private MetadataAdditionalInfo metadataAdditionalInfo = null;
+    private final byte[] bytes;
+    private final Map<Integer, MetadataAdditionalInfo> metadataAdditionalInfo =
+        new ConcurrentHashMap<>();
 
     public LazyMetadataAdditionalInfo(byte[] bytes) {
-      super(null);
+      super(null, null);
       this.bytes = bytes;
     }
 
     @Override
     public boolean hasPreamble() {
-      MetadataAdditionalInfo metadataAdditionalInfo = getMetadataAdditionalInfo();
+      MetadataAdditionalInfo metadataAdditionalInfo =
+          getMetadataAdditionalInfo(AdditionalInfoTypes.PREAMBLE);
       return metadataAdditionalInfo != null && metadataAdditionalInfo.hasPreamble();
     }
 
     @Override
     public Collection<String> getPreamble() {
-      MetadataAdditionalInfo metadataAdditionalInfo = getMetadataAdditionalInfo();
+      MetadataAdditionalInfo metadataAdditionalInfo =
+          getMetadataAdditionalInfo(AdditionalInfoTypes.PREAMBLE);
       return metadataAdditionalInfo == null ? null : metadataAdditionalInfo.getPreamble();
     }
 
-    private MetadataAdditionalInfo getMetadataAdditionalInfo() {
-      if (metadataAdditionalInfo == null) {
-        metadataAdditionalInfo = MetadataAdditionalInfo.deserialize(bytes);
-        bytes = null;
-      }
-      return metadataAdditionalInfo;
+    @Override
+    public boolean hasObfuscatedPackages() {
+      MetadataAdditionalInfo metadataAdditionalInfo =
+          getMetadataAdditionalInfo(AdditionalInfoTypes.OBFUSCATED_PACKAGES);
+      return metadataAdditionalInfo != null && metadataAdditionalInfo.hasObfuscatedPackages();
+    }
+
+    @Override
+    public Set<String> getObfuscatedPackages() {
+      MetadataAdditionalInfo metadataAdditionalInfo =
+          getMetadataAdditionalInfo(AdditionalInfoTypes.OBFUSCATED_PACKAGES);
+      return metadataAdditionalInfo == null ? null : metadataAdditionalInfo.getObfuscatedPackages();
+    }
+
+    private MetadataAdditionalInfo getMetadataAdditionalInfo(AdditionalInfoTypes infoType) {
+      return metadataAdditionalInfo.computeIfAbsent(
+          infoType.serializedKey,
+          ignored ->
+              MetadataAdditionalInfo.deserialize(
+                  bytes, deserializeType -> deserializeType == infoType));
     }
 
     public static LazyMetadataAdditionalInfo create(CompatByteBuffer buffer) {
diff --git a/src/main/java/com/android/tools/r8/retrace/internal/PartitionMappingSupplierBase.java b/src/main/java/com/android/tools/r8/retrace/internal/PartitionMappingSupplierBase.java
index 012fb94..a15f89f 100644
--- a/src/main/java/com/android/tools/r8/retrace/internal/PartitionMappingSupplierBase.java
+++ b/src/main/java/com/android/tools/r8/retrace/internal/PartitionMappingSupplierBase.java
@@ -23,6 +23,8 @@
 import com.android.tools.r8.retrace.PrepareMappingPartitionsCallback;
 import com.android.tools.r8.retrace.RegisterMappingPartitionCallback;
 import com.android.tools.r8.retrace.internal.ProguardMapReaderWithFiltering.ProguardMapReaderWithFilteringInputBuffer;
+import com.android.tools.r8.utils.Box;
+import com.android.tools.r8.utils.DescriptorUtils;
 import com.android.tools.r8.utils.StringDiagnostic;
 import java.io.ByteArrayInputStream;
 import java.io.IOException;
@@ -30,6 +32,7 @@
 import java.util.HashSet;
 import java.util.LinkedHashSet;
 import java.util.Set;
+import java.util.function.Predicate;
 
 public abstract class PartitionMappingSupplierBase<T extends PartitionMappingSupplierBase<T>>
     implements Finishable {
@@ -45,7 +48,8 @@
   private final Set<String> pendingKeys = new LinkedHashSet<>();
   private final Set<String> builtKeys = new HashSet<>();
 
-  private MappingPartitionMetadataInternal mappingPartitionMetadataCache;
+  private final Box<MappingPartitionMetadataInternal> mappingPartitionMetadataCache = new Box<>();
+  private final Box<Predicate<String>> typeNameCouldHavePartitionCache = new Box<>();
 
   protected PartitionMappingSupplierBase(
       RegisterMappingPartitionCallback registerCallback,
@@ -63,16 +67,49 @@
   }
 
   public MappingPartitionMetadataInternal getMetadata(DiagnosticsHandler diagnosticsHandler) {
-    if (mappingPartitionMetadataCache != null) {
-      return mappingPartitionMetadataCache;
+    if (mappingPartitionMetadataCache.isSet()) {
+      return mappingPartitionMetadataCache.get();
     }
-    return mappingPartitionMetadataCache =
-        MappingPartitionMetadataInternal.deserialize(
-            CompatByteBuffer.wrapOrNull(metadata), fallbackMapVersion, diagnosticsHandler);
+    synchronized (mappingPartitionMetadataCache) {
+      if (mappingPartitionMetadataCache.isSet()) {
+        return mappingPartitionMetadataCache.get();
+      }
+      MappingPartitionMetadataInternal data =
+          MappingPartitionMetadataInternal.deserialize(
+              CompatByteBuffer.wrapOrNull(metadata), fallbackMapVersion, diagnosticsHandler);
+      mappingPartitionMetadataCache.set(data);
+      return data;
+    }
   }
 
   public T registerClassUse(DiagnosticsHandler diagnosticsHandler, ClassReference classReference) {
-    return registerKeyUse(classReference.getTypeName());
+    // Check if the package name is registered before requesting the bytes for a partition.
+    String typeName = classReference.getTypeName();
+    if (isPotentialRetraceClass(diagnosticsHandler, typeName)) {
+      return registerKeyUse(typeName);
+    }
+    return self();
+  }
+
+  private boolean isPotentialRetraceClass(DiagnosticsHandler diagnosticsHandler, String typeName) {
+    if (typeNameCouldHavePartitionCache.isSet()) {
+      return typeNameCouldHavePartitionCache.get().test(typeName);
+    }
+    synchronized (typeNameCouldHavePartitionCache) {
+      if (typeNameCouldHavePartitionCache.isSet()) {
+        return typeNameCouldHavePartitionCache.get().test(typeName);
+      }
+      Predicate<String> typeNameCouldHavePartitionPredicate =
+          getPartitionPredicate(getPackagesWithClasses(diagnosticsHandler));
+      typeNameCouldHavePartitionCache.set(typeNameCouldHavePartitionPredicate);
+      return typeNameCouldHavePartitionPredicate.test(typeName);
+    }
+  }
+
+  private Predicate<String> getPartitionPredicate(Set<String> packagesWithClasses) {
+    return name ->
+        packagesWithClasses == null
+            || packagesWithClasses.contains(DescriptorUtils.getPackageNameFromTypeName(name));
   }
 
   public T registerMethodUse(
@@ -85,13 +122,24 @@
   }
 
   public T registerKeyUse(String key) {
-    // TODO(b/274735214): only call the register partition if we have a partition for it.
     if (!builtKeys.contains(key) && pendingKeys.add(key)) {
       registerCallback.register(key);
     }
     return self();
   }
 
+  private Set<String> getPackagesWithClasses(DiagnosticsHandler diagnosticsHandler) {
+    MappingPartitionMetadataInternal metadata = getMetadata(diagnosticsHandler);
+    if (metadata == null || !metadata.canGetAdditionalInfo()) {
+      return null;
+    }
+    MetadataAdditionalInfo additionalInfo = metadata.getAdditionalInfo();
+    if (!additionalInfo.hasObfuscatedPackages()) {
+      return null;
+    }
+    return additionalInfo.getObfuscatedPackages();
+  }
+
   public void verifyMappingFileHash(DiagnosticsHandler diagnosticsHandler) {
     String errorMessage = "Cannot verify map file hash for partitions";
     diagnosticsHandler.error(new StringDiagnostic(errorMessage));
@@ -115,7 +163,6 @@
     for (String pendingKey : pendingKeys) {
       try {
         byte[] suppliedPartition = partitionSupplier.get(pendingKey);
-        // TODO(b/274735214): only expect a partition if have generated one for the key.
         if (suppliedPartition == null) {
           continue;
         }
diff --git a/src/main/java/com/android/tools/r8/retrace/internal/ProguardMapPartitionerOnClassNameToText.java b/src/main/java/com/android/tools/r8/retrace/internal/ProguardMapPartitionerOnClassNameToText.java
index 3edde9a..3af7dff 100644
--- a/src/main/java/com/android/tools/r8/retrace/internal/ProguardMapPartitionerOnClassNameToText.java
+++ b/src/main/java/com/android/tools/r8/retrace/internal/ProguardMapPartitionerOnClassNameToText.java
@@ -168,7 +168,8 @@
       return ObfuscatedTypeNameAsKeyMetadataWithPartitionNames.create(
           mapVersion,
           MetadataPartitionCollection.create(keys),
-          MetadataAdditionalInfo.create(classMapper.getPreamble()));
+          MetadataAdditionalInfo.create(
+              classMapper.getPreamble(), classMapper.getObfuscatedPackages()));
     } else {
       RetracePartitionException retraceError =
           new RetracePartitionException("Unknown mapping partitioning strategy");
diff --git a/src/main/java/com/android/tools/r8/utils/PartitionMapZipContainer.java b/src/main/java/com/android/tools/r8/utils/PartitionMapZipContainer.java
index 242fbcc..e2af968 100644
--- a/src/main/java/com/android/tools/r8/utils/PartitionMapZipContainer.java
+++ b/src/main/java/com/android/tools/r8/utils/PartitionMapZipContainer.java
@@ -32,7 +32,6 @@
         .setMappingPartitionFromKeySupplier(
             key -> {
               try {
-                // TODO(b/274735214): The key should exist.
                 ZipEntry entry = zipFile.getEntry(key);
                 return entry == null
                     ? EMPTY_RESULT
diff --git a/src/main/java/com/android/tools/r8/utils/StringUtils.java b/src/main/java/com/android/tools/r8/utils/StringUtils.java
index 9885ab6..af857c4 100644
--- a/src/main/java/com/android/tools/r8/utils/StringUtils.java
+++ b/src/main/java/com/android/tools/r8/utils/StringUtils.java
@@ -11,9 +11,11 @@
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
+import java.util.Set;
 import java.util.function.Consumer;
 import java.util.function.Function;
 import java.util.regex.Matcher;
@@ -273,14 +275,25 @@
     return join(LINE_SEPARATOR, collection, BraceType.NONE);
   }
 
-
   public static List<String> splitLines(String content) {
     return splitLines(content, false);
   }
 
+  public static Set<String> splitLinesIntoSet(String content) {
+    Set<String> set = new HashSet<>();
+    splitLines(content, false, set::add);
+    return set;
+  }
+
   public static List<String> splitLines(String content, boolean includeTrailingEmptyLine) {
+    List<String> list = new ArrayList<>();
+    splitLines(content, includeTrailingEmptyLine, list::add);
+    return list;
+  }
+
+  private static void splitLines(
+      String content, boolean includeTrailingEmptyLine, Consumer<String> consumer) {
     int length = content.length();
-    List<String> lines = new ArrayList<>();
     int start = 0;
     for (int i = 0; i < length; i++) {
       char c = content.charAt(i);
@@ -290,16 +303,15 @@
       } else if (c != '\n') {
         continue;
       }
-      lines.add(content.substring(start, end));
+      consumer.accept(content.substring(start, end));
       start = i + 1;
     }
     if (start < length) {
       String line = content.substring(start);
       if (includeTrailingEmptyLine || !line.isEmpty()) {
-        lines.add(line);
+        consumer.accept(line);
       }
     }
-    return lines;
   }
 
   public static String zeroPrefix(int i, int width) {
diff --git a/src/test/java/com/android/tools/r8/retrace/partition/R8ZipContainerMappingFileTest.java b/src/test/java/com/android/tools/r8/retrace/partition/R8ZipContainerMappingFileTest.java
index 7c64527..255b693 100644
--- a/src/test/java/com/android/tools/r8/retrace/partition/R8ZipContainerMappingFileTest.java
+++ b/src/test/java/com/android/tools/r8/retrace/partition/R8ZipContainerMappingFileTest.java
@@ -26,7 +26,6 @@
 import java.io.IOException;
 import java.nio.file.Files;
 import java.nio.file.Path;
-import java.util.zip.ZipEntry;
 import java.util.zip.ZipFile;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -140,11 +139,7 @@
         .setMappingPartitionFromKeySupplier(
             key -> {
               try {
-                // TODO(b/274735214): The key should exist.
-                ZipEntry entry = zipFile.getEntry(key);
-                return entry == null
-                    ? null
-                    : ByteStreams.toByteArray(zipFile.getInputStream(entry));
+                return ByteStreams.toByteArray(zipFile.getInputStream(zipFile.getEntry(key)));
               } catch (IOException e) {
                 throw new RuntimeException(e);
               }