Add xml file tracing to enqueuer

The actual parsing and tracing of the xml is done in the
R8ResourceShrinkerState, with callbacks to the enqueuer.

Bug: b/287398085
Change-Id: I0b075ac158395830768fff654fd78e3fb381fbda
diff --git a/src/resourceshrinker/java/com/android/build/shrinker/r8integration/R8ResourceShrinkerState.java b/src/resourceshrinker/java/com/android/build/shrinker/r8integration/R8ResourceShrinkerState.java
index 2392218..d8c98c7 100644
--- a/src/resourceshrinker/java/com/android/build/shrinker/r8integration/R8ResourceShrinkerState.java
+++ b/src/resourceshrinker/java/com/android/build/shrinker/r8integration/R8ResourceShrinkerState.java
@@ -5,11 +5,16 @@
 
 import static com.android.build.shrinker.r8integration.LegacyResourceShrinker.getUtfReader;
 
+import com.android.aapt.Resources;
 import com.android.aapt.Resources.ConfigValue;
 import com.android.aapt.Resources.Entry;
+import com.android.aapt.Resources.FileReference;
 import com.android.aapt.Resources.Item;
+import com.android.aapt.Resources.Package;
 import com.android.aapt.Resources.ResourceTable;
 import com.android.aapt.Resources.Value;
+import com.android.aapt.Resources.XmlAttribute;
+import com.android.aapt.Resources.XmlElement;
 import com.android.aapt.Resources.XmlNode;
 import com.android.build.shrinker.NoDebugReporter;
 import com.android.build.shrinker.ResourceShrinkerImplKt;
@@ -26,15 +31,19 @@
 import com.android.ide.common.resources.usage.ResourceUsageModel.Resource;
 import com.android.resources.ResourceType;
 import com.android.tools.r8.FeatureSplit;
+import com.android.tools.r8.origin.Origin;
+import com.android.tools.r8.origin.PathOrigin;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Iterables;
 import java.io.IOException;
 import java.io.InputStream;
-import java.util.Collections;
+import java.nio.file.Paths;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.IdentityHashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.function.Function;
 import java.util.function.Supplier;
 import java.util.stream.Collectors;
@@ -48,18 +57,32 @@
   private Supplier<InputStream> manifestProvider;
   private final Map<String, Supplier<InputStream>> resfileProviders = new HashMap<>();
   private final Map<ResourceTable, FeatureSplit> resourceTables = new HashMap<>();
+  private ClassReferenceCallback enqueuerCallback;
+  private Map<Integer, String> resourceIdToXmlFiles;
+  private Set<String> packageNames;
+  private final Set<String> seenNoneClassValues = new HashSet<>();
+  private final Set<Integer> seenResourceIds = new HashSet<>();
+
+  @FunctionalInterface
+  public interface ClassReferenceCallback {
+    boolean tryClass(String possibleClass, Origin xmlFileOrigin);
+  }
 
   public R8ResourceShrinkerState(Function<Exception, RuntimeException> errorHandler) {
     r8ResourceShrinkerModel = new R8ResourceShrinkerModel(NoDebugReporter.INSTANCE, true);
     this.errorHandler = errorHandler;
   }
 
-  public List<String> trace(int id) {
+  public void trace(int id) {
+    if (!seenResourceIds.add(id)) {
+      return;
+    }
     Resource resource = r8ResourceShrinkerModel.getResourceStore().getResource(id);
     if (resource == null) {
-      return Collections.emptyList();
+      return;
     }
     ResourceUsageModel.markReachable(resource);
+    traceXml(id);
     if (resource.references != null) {
       for (Resource reference : resource.references) {
         if (!reference.isReachable()) {
@@ -67,7 +90,25 @@
         }
       }
     }
-    return Collections.emptyList();
+  }
+
+  public void setEnqueuerCallback(ClassReferenceCallback enqueuerCallback) {
+    assert this.enqueuerCallback == null;
+    this.enqueuerCallback = enqueuerCallback;
+  }
+
+  private synchronized Set<String> getPackageNames() {
+    // TODO(b/325888516): Consider only doing this for the package corresponding to the current
+    // feature.
+    if (packageNames == null) {
+      packageNames = new HashSet<>();
+      for (ResourceTable resourceTable : resourceTables.keySet()) {
+        for (Package aPackage : resourceTable.getPackageList()) {
+          packageNames.add(aPackage.getPackageName());
+        }
+      }
+    }
+    return packageNames;
   }
 
   public void setManifestProvider(Supplier<InputStream> manifestProvider) {
@@ -144,6 +185,74 @@
     return resEntriesToKeep.build();
   }
 
+  private void traceXml(int id) {
+    String xmlFile = getResourceIdToXmlFiles().get(id);
+    if (xmlFile != null) {
+      InputStream inputStream = xmlFileProviders.get(xmlFile).get();
+      try {
+        XmlNode xmlNode = XmlNode.parseFrom(inputStream);
+        visitNode(xmlNode, xmlFile);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+  }
+
+  private void tryEnqueuerOnString(String possibleClass, String xmlName) {
+    // There are a lot of xml tags and attributes that are evaluated over and over, if it is
+    // not a class, ignore it.
+    if (seenNoneClassValues.contains(possibleClass)) {
+      return;
+    }
+    if (!enqueuerCallback.tryClass(possibleClass, new PathOrigin(Paths.get(xmlName)))) {
+      seenNoneClassValues.add(possibleClass);
+    }
+  }
+
+  private void visitNode(XmlNode xmlNode, String xmlName) {
+    XmlElement element = xmlNode.getElement();
+    tryEnqueuerOnString(element.getName(), xmlName);
+    for (XmlAttribute xmlAttribute : element.getAttributeList()) {
+      String value = xmlAttribute.getValue();
+      tryEnqueuerOnString(value, xmlName);
+      if (value.startsWith(".")) {
+        // package specific names, e.g. context
+        getPackageNames().forEach(s -> tryEnqueuerOnString(s + value, xmlName));
+      }
+    }
+    element.getChildList().forEach(e -> visitNode(e, xmlName));
+  }
+
+  public Map<Integer, String> getResourceIdToXmlFiles() {
+    if (resourceIdToXmlFiles == null) {
+      resourceIdToXmlFiles = new HashMap<>();
+      for (ResourceTable resourceTable : resourceTables.keySet()) {
+        for (Package packageEntry : resourceTable.getPackageList()) {
+          for (Resources.Type type : packageEntry.getTypeList()) {
+            for (Entry entry : type.getEntryList()) {
+              for (ConfigValue configValue : entry.getConfigValueList()) {
+                if (configValue.hasValue()) {
+                  Value value = configValue.getValue();
+                  if (value.hasItem()) {
+                    Item item = value.getItem();
+                    if (item.hasFile()) {
+                      FileReference file = item.getFile();
+                      if (file.getType() == FileReference.Type.PROTO_XML) {
+                        int id = ResourceTableUtilKt.toIdentifier(packageEntry, type, entry);
+                        resourceIdToXmlFiles.put(id, file.getPath());
+                      }
+                    }
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+    return resourceIdToXmlFiles;
+  }
+
   private List<Integer> getResourcesToRemove() {
     return r8ResourceShrinkerModel.getResourceStore().getResources().stream()
         .filter(r -> !r.isReachable() && !r.isPublic())
@@ -171,6 +280,16 @@
     }
   }
 
+  public void enqueuerDone(boolean isFinalTreeshaking) {
+    enqueuerCallback = null;
+    seenResourceIds.clear();
+    if (!isFinalTreeshaking) {
+      // After final tree shaking we will need the reachability bits to decide what to write out
+      // from the model.
+      clearReachableBits();
+    }
+  }
+
   public void clearReachableBits() {
     for (Resource resource : r8ResourceShrinkerModel.getResourceStore().getResources()) {
       resource.setReachable(false);