Support passing class/field/method of interest to disasm.py

Change-Id: I372c7bb4ea4c37cad8282ba55f7fdad5f358c92b
diff --git a/src/main/java/com/android/tools/r8/Disassemble.java b/src/main/java/com/android/tools/r8/Disassemble.java
index a47366e..7ddbafb 100644
--- a/src/main/java/com/android/tools/r8/Disassemble.java
+++ b/src/main/java/com/android/tools/r8/Disassemble.java
@@ -11,9 +11,15 @@
 import com.android.tools.r8.graph.SmaliWriter;
 import com.android.tools.r8.naming.ClassNameMapper;
 import com.android.tools.r8.origin.CommandLineOrigin;
+import com.android.tools.r8.references.ClassReference;
+import com.android.tools.r8.references.FieldReference;
+import com.android.tools.r8.references.MethodReference;
+import com.android.tools.r8.references.Reference;
 import com.android.tools.r8.utils.AndroidApp;
 import com.android.tools.r8.utils.ConsumerUtils;
+import com.android.tools.r8.utils.FieldReferenceUtils;
 import com.android.tools.r8.utils.InternalOptions;
+import com.android.tools.r8.utils.MethodReferenceUtils;
 import com.android.tools.r8.utils.StringDiagnostic;
 import com.android.tools.r8.utils.ThreadUtils;
 import com.android.tools.r8.utils.timing.Timing;
@@ -26,6 +32,8 @@
 import java.nio.file.Path;
 import java.nio.file.Paths;
 import java.util.Collections;
+import java.util.HashSet;
+import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.function.Consumer;
@@ -45,6 +53,10 @@
       private boolean noCode = false;
       private boolean useIr;
 
+      private Set<ClassReference> classReferences = null;
+      private Set<FieldReference> fieldReferences = null;
+      private Set<MethodReference> methodReferences = null;
+
       @Override
       Builder self() {
         return this;
@@ -84,6 +96,30 @@
         return this;
       }
 
+      public Builder addClassReference(ClassReference classReference) {
+        if (classReferences == null) {
+          classReferences = new HashSet<>();
+        }
+        classReferences.add(classReference);
+        return this;
+      }
+
+      public Builder addFieldReference(FieldReference fieldReference) {
+        if (fieldReferences == null) {
+          fieldReferences = new HashSet<>();
+        }
+        fieldReferences.add(fieldReference);
+        return this;
+      }
+
+      public Builder addMethodReference(MethodReference methodReference) {
+        if (methodReferences == null) {
+          methodReferences = new HashSet<>();
+        }
+        methodReferences.add(methodReference);
+        return this;
+      }
+
       @Override
       protected DisassembleCommand makeCommand() {
         // If printing versions ignore everything else.
@@ -97,7 +133,10 @@
             allInfo,
             useSmali,
             useIr,
-            noCode);
+            noCode,
+            classReferences,
+            fieldReferences,
+            methodReferences);
       }
     }
 
@@ -112,6 +151,12 @@
             + "  --pg-map <file>             # Proguard map <file> for mapping names.\n"
             + "  --pg-map-charset <charset>  # Charset for Proguard map file.\n"
             + "  --output                    # Specify a file or directory to write to.\n"
+            + "  --class <descriptor>        # Only disassemble the given class "
+            + "(e.g., Lcom/example/Class;).\n"
+            + "  --field <descriptor>        # Only disassemble the given field "
+            + "(e.g., Lcom/example/Class;->method()V).\n"
+            + "  --method <descriptor>       # Only disassemble the given method "
+            + "(e.g., Lcom/example/Class;->field:I).\n"
             + "  --version                   # Print the version of r8.\n"
             + "  --help                      # Print this message.";
 
@@ -119,6 +164,9 @@
     private final boolean useSmali;
     private final boolean useIr;
     private final boolean noCode;
+    private final Set<ClassReference> classReferences;
+    private final Set<FieldReference> fieldReferences;
+    private final Set<MethodReference> methodReferences;
 
     public static Builder builder() {
       return new Builder();
@@ -164,6 +212,15 @@
         } else if (arg.equals("--output")) {
           String outputPath = args[++i];
           builder.setOutputPath(Paths.get(outputPath));
+        } else if (arg.equals("--class")) {
+          String nextArg = args[++i];
+          builder.addClassReference(Reference.classFromDescriptor(nextArg));
+        } else if (arg.equals("--field")) {
+          String nextArg = args[++i];
+          builder.addFieldReference(FieldReferenceUtils.parseSmaliString(nextArg));
+        } else if (arg.equals("--method")) {
+          String nextArg = args[++i];
+          builder.addMethodReference(MethodReferenceUtils.parseSmaliString(nextArg));
         } else {
           if (arg.startsWith("--")) {
             builder.getReporter().error(new StringDiagnostic("Unknown option: " + arg,
@@ -181,7 +238,10 @@
         boolean allInfo,
         boolean useSmali,
         boolean useIr,
-        boolean noCode) {
+        boolean noCode,
+        Set<ClassReference> classReferences,
+        Set<FieldReference> fieldReferences,
+        Set<MethodReference> methodReferences) {
       super(inputApp);
       this.outputPath = outputPath;
       this.proguardMap = proguardMap;
@@ -189,6 +249,9 @@
       this.useSmali = useSmali;
       this.useIr = useIr;
       this.noCode = noCode;
+      this.classReferences = classReferences;
+      this.fieldReferences = fieldReferences;
+      this.methodReferences = methodReferences;
     }
 
     private DisassembleCommand(boolean printHelp, boolean printVersion) {
@@ -199,6 +262,9 @@
       useSmali = false;
       useIr = false;
       noCode = false;
+      classReferences = null;
+      fieldReferences = null;
+      methodReferences = null;
     }
 
     public Path getOutputPath() {
@@ -217,6 +283,18 @@
       return noCode;
     }
 
+    public Set<ClassReference> getClassReferences() {
+      return classReferences;
+    }
+
+    public Set<FieldReference> getFieldReferences() {
+      return fieldReferences;
+    }
+
+    public Set<MethodReference> getMethodReferences() {
+      return methodReferences;
+    }
+
     @Override
     InternalOptions getInternalOptions() {
       InternalOptions internal = new InternalOptions();
@@ -283,9 +361,21 @@
               .read(command.proguardMap, executor);
       DexByteCodeWriter writer =
           command.useSmali()
-              ? new SmaliWriter(application, options)
+              ? new SmaliWriter(
+                  application,
+                  options,
+                  command.getClassReferences(),
+                  command.getFieldReferences(),
+                  command.getMethodReferences())
               : new AssemblyWriter(
-                  application, options, command.allInfo, command.useIr(), !command.noCode());
+                  application,
+                  options,
+                  command.allInfo,
+                  command.useIr(),
+                  !command.noCode(),
+                  command.getClassReferences(),
+                  command.getFieldReferences(),
+                  command.getMethodReferences());
       if (outputWriter.extractMarkers()) {
         writer.writeMarkers(
             outputWriter.outputStreamProvider(application.getProguardMap()).get(null));
diff --git a/src/main/java/com/android/tools/r8/graph/AssemblyWriter.java b/src/main/java/com/android/tools/r8/graph/AssemblyWriter.java
index e64c151..40c2ff0 100644
--- a/src/main/java/com/android/tools/r8/graph/AssemblyWriter.java
+++ b/src/main/java/com/android/tools/r8/graph/AssemblyWriter.java
@@ -15,6 +15,9 @@
 import com.android.tools.r8.ir.optimize.info.OptimizationFeedbackIgnore;
 import com.android.tools.r8.kotlin.Kotlin;
 import com.android.tools.r8.kotlin.KotlinMetadataWriter;
+import com.android.tools.r8.references.ClassReference;
+import com.android.tools.r8.references.FieldReference;
+import com.android.tools.r8.references.MethodReference;
 import com.android.tools.r8.synthesis.SyntheticItems.GlobalSyntheticsStrategy;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.RetracerForCodePrinting;
@@ -24,6 +27,7 @@
 import java.io.BufferedReader;
 import java.io.PrintStream;
 import java.io.StringReader;
+import java.util.Set;
 import java.util.stream.Collectors;
 
 public class AssemblyWriter extends DexByteCodeWriter {
@@ -44,7 +48,19 @@
       boolean allInfo,
       boolean writeIR,
       boolean writeCode) {
-    super(application, options);
+    this(application, options, allInfo, writeIR, writeCode, null, null, null);
+  }
+
+  public AssemblyWriter(
+      DexApplication application,
+      InternalOptions options,
+      boolean allInfo,
+      boolean writeIR,
+      boolean writeCode,
+      Set<ClassReference> classReferences,
+      Set<FieldReference> fieldReferences,
+      Set<MethodReference> methodReferences) {
+    super(application, options, classReferences, fieldReferences, methodReferences);
     this.compilationContext = CompilationContext.createInitialContext(options);
     this.writeAllClassInfo = allInfo;
     this.writeFields = allInfo;
@@ -136,7 +152,7 @@
 
   @Override
   void writeField(DexEncodedField field, PrintStream ps) {
-    if (writeFields) {
+    if (writeFields && shouldWriteField(field)) {
       writeAnnotations(null, field.annotations(), ps);
       ps.print(field.accessFlags + " ");
       ps.print(retracer.toSourceString(field.getReference()));
@@ -160,6 +176,9 @@
 
   @Override
   void writeMethod(ProgramMethod method, PrintStream ps) {
+    if (!shouldWriteMethod(method.getDefinition())) {
+      return;
+    }
     DexEncodedMethod definition = method.getDefinition();
     ps.println("#");
     ps.println("# Method: '" + retracer.toSourceString(definition.getReference()) + "':");
diff --git a/src/main/java/com/android/tools/r8/graph/DexByteCodeWriter.java b/src/main/java/com/android/tools/r8/graph/DexByteCodeWriter.java
index be0c645..cd57175 100644
--- a/src/main/java/com/android/tools/r8/graph/DexByteCodeWriter.java
+++ b/src/main/java/com/android/tools/r8/graph/DexByteCodeWriter.java
@@ -5,6 +5,9 @@
 
 import com.android.tools.r8.dex.Marker;
 import com.android.tools.r8.naming.ClassNameMapper;
+import com.android.tools.r8.references.ClassReference;
+import com.android.tools.r8.references.FieldReference;
+import com.android.tools.r8.references.MethodReference;
 import com.android.tools.r8.utils.DescriptorUtils;
 import com.android.tools.r8.utils.InternalOptions;
 import java.io.File;
@@ -13,6 +16,7 @@
 import java.nio.file.Files;
 import java.nio.file.Path;
 import java.util.Collection;
+import java.util.Set;
 import java.util.function.Consumer;
 
 public abstract class DexByteCodeWriter {
@@ -24,10 +28,21 @@
   final DexApplication application;
   final InternalOptions options;
 
-  DexByteCodeWriter(DexApplication application,
-      InternalOptions options) {
+  final Set<ClassReference> classReferences;
+  final Set<FieldReference> fieldReferences;
+  final Set<MethodReference> methodReferences;
+
+  DexByteCodeWriter(
+      DexApplication application,
+      InternalOptions options,
+      Set<ClassReference> classReferences,
+      Set<FieldReference> fieldReferences,
+      Set<MethodReference> methodReferences) {
     this.application = application;
     this.options = options;
+    this.classReferences = classReferences;
+    this.fieldReferences = fieldReferences;
+    this.methodReferences = methodReferences;
   }
 
   private static void ensureParentExists(Path path) throws IOException {
@@ -65,7 +80,7 @@
       throws IOException {
     Iterable<DexProgramClass> classes = application.classesWithDeterministicOrder();
     for (DexProgramClass clazz : classes) {
-      if (anyMethodMatches(clazz)) {
+      if (shouldWriteClass(clazz)) {
         PrintStream ps = outputStreamProvider.get(clazz);
         try {
           writeClass(clazz, ps);
@@ -76,11 +91,6 @@
     }
   }
 
-  private boolean anyMethodMatches(DexClass clazz) {
-    return !options.hasMethodsFilter()
-        || clazz.getMethodCollection().hasMethods(options::methodMatchesFilter);
-  }
-
   private void writeClass(DexProgramClass clazz, PrintStream ps) {
     writeClassHeader(clazz, ps);
     writeFieldsHeader(clazz, ps);
@@ -115,4 +125,58 @@
   }
 
   abstract void writeClassFooter(DexProgramClass clazz, PrintStream ps);
+
+  boolean shouldWriteClass(DexClass clazz) {
+    if (classReferences == null && fieldReferences == null && methodReferences == null) {
+      return true;
+    }
+    if (classReferences != null && classReferences.contains(clazz.getClassReference())) {
+      return true;
+    }
+    if (fieldReferences != null
+        && clazz
+            .fields(field -> fieldReferences.contains(field.getReference().asFieldReference()))
+            .iterator()
+            .hasNext()) {
+      return true;
+    }
+    if (methodReferences != null
+        && clazz
+            .methods(method -> methodReferences.contains(method.getReference().asMethodReference()))
+            .iterator()
+            .hasNext()) {
+      return true;
+    }
+    return false;
+  }
+
+  boolean shouldWriteField(DexEncodedField field) {
+    if (classReferences == null && fieldReferences == null && methodReferences == null) {
+      return true;
+    }
+    if (classReferences != null
+        && classReferences.contains(field.getHolderType().asClassReference())) {
+      return true;
+    }
+    if (fieldReferences != null
+        && fieldReferences.contains(field.getReference().asFieldReference())) {
+      return true;
+    }
+    return false;
+  }
+
+  boolean shouldWriteMethod(DexEncodedMethod method) {
+    if (classReferences == null && fieldReferences == null && methodReferences == null) {
+      return true;
+    }
+    if (classReferences != null
+        && classReferences.contains(method.getHolderType().asClassReference())) {
+      return true;
+    }
+    if (methodReferences != null
+        && methodReferences.contains(method.getReference().asMethodReference())) {
+      return true;
+    }
+    return false;
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/graph/SmaliWriter.java b/src/main/java/com/android/tools/r8/graph/SmaliWriter.java
index 2229f85..0c464cb 100644
--- a/src/main/java/com/android/tools/r8/graph/SmaliWriter.java
+++ b/src/main/java/com/android/tools/r8/graph/SmaliWriter.java
@@ -6,6 +6,9 @@
 
 import com.android.tools.r8.dex.ApplicationReader;
 import com.android.tools.r8.errors.CompilationError;
+import com.android.tools.r8.references.ClassReference;
+import com.android.tools.r8.references.FieldReference;
+import com.android.tools.r8.references.MethodReference;
 import com.android.tools.r8.utils.AndroidApp;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.RetracerForCodePrinting;
@@ -14,11 +17,17 @@
 import java.io.IOException;
 import java.io.PrintStream;
 import java.nio.charset.StandardCharsets;
+import java.util.Set;
 
 public class SmaliWriter extends DexByteCodeWriter {
 
-  public SmaliWriter(DexApplication application, InternalOptions options) {
-    super(application, options);
+  public SmaliWriter(
+      DexApplication application,
+      InternalOptions options,
+      Set<ClassReference> classReferences,
+      Set<FieldReference> fieldReferences,
+      Set<MethodReference> methodReferences) {
+    super(application, options, classReferences, fieldReferences, methodReferences);
   }
 
   /** Return smali source for the application code. */
@@ -27,7 +36,7 @@
     try (PrintStream ps = new PrintStream(os)) {
       DexApplication dexApplication =
           new ApplicationReader(application, options, Timing.empty()).read();
-      SmaliWriter writer = new SmaliWriter(dexApplication, options);
+      SmaliWriter writer = new SmaliWriter(dexApplication, options, null, null, null);
       writer.write(ps);
     } catch (IOException e) {
       throw new CompilationError("Failed to generate smali sting", e);
diff --git a/src/main/java/com/android/tools/r8/utils/FieldReferenceUtils.java b/src/main/java/com/android/tools/r8/utils/FieldReferenceUtils.java
index d7b0439..105cbac 100644
--- a/src/main/java/com/android/tools/r8/utils/FieldReferenceUtils.java
+++ b/src/main/java/com/android/tools/r8/utils/FieldReferenceUtils.java
@@ -11,6 +11,7 @@
 import com.android.tools.r8.references.FieldReference;
 import com.android.tools.r8.references.MethodReference;
 import com.android.tools.r8.references.Reference;
+import com.android.tools.r8.references.TypeReference;
 import java.util.Comparator;
 
 public class FieldReferenceUtils {
@@ -57,6 +58,33 @@
     return COMPARATOR;
   }
 
+  public static FieldReference parseSmaliString(String classAndFieldDescriptor) {
+    int arrowStartIndex = classAndFieldDescriptor.indexOf("->");
+    if (arrowStartIndex >= 0) {
+      return parseSmaliString(classAndFieldDescriptor, arrowStartIndex);
+    }
+    return null;
+  }
+
+  public static FieldReference parseSmaliString(
+      String classAndFieldDescriptor, int arrowStartIndex) {
+    String classDescriptor = classAndFieldDescriptor.substring(0, arrowStartIndex);
+    ClassReference fieldHolder = ClassReferenceUtils.parseClassDescriptor(classDescriptor);
+    if (fieldHolder == null) {
+      return null;
+    }
+    int fieldNameStartIndex = arrowStartIndex + 2;
+    String fieldNameAndType = classAndFieldDescriptor.substring(fieldNameStartIndex);
+    int fieldNameEndIndex = fieldNameAndType.indexOf(':');
+    if (fieldNameEndIndex <= 0) {
+      return null;
+    }
+    String fieldName = fieldNameAndType.substring(0, fieldNameEndIndex);
+    String fieldTypeDescriptor = fieldNameAndType.substring(fieldNameEndIndex + 1);
+    TypeReference fieldType = Reference.returnTypeFromDescriptor(fieldTypeDescriptor);
+    return Reference.field(fieldHolder, fieldName, fieldType);
+  }
+
   public static String toSourceString(FieldReference fieldReference) {
     return fieldReference.getFieldType().getTypeName()
         + " "