Synthesize keep rules for recompilation

Change-Id: Idb94b5efc22ebc0965bcd7dbc097f29f8425dae1
diff --git a/src/main/java/com/android/tools/r8/R8Command.java b/src/main/java/com/android/tools/r8/R8Command.java
index e4d1496..3d9336e 100644
--- a/src/main/java/com/android/tools/r8/R8Command.java
+++ b/src/main/java/com/android/tools/r8/R8Command.java
@@ -375,6 +375,9 @@
       }
       ProguardConfiguration.Builder configurationBuilder = parser.getConfigurationBuilder();
       configurationBuilder.setForceProguardCompatibility(forceProguardCompatibility);
+      if (InternalOptions.shouldEnableKeepRuleSynthesisForRecompilation()) {
+        configurationBuilder.enableKeepRuleSynthesisForRecompilation();
+      }
 
       if (proguardConfigurationConsumer != null) {
         proguardConfigurationConsumer.accept(configurationBuilder);
diff --git a/src/main/java/com/android/tools/r8/shaking/ProguardConfiguration.java b/src/main/java/com/android/tools/r8/shaking/ProguardConfiguration.java
index 0f511d4..eb9916e 100644
--- a/src/main/java/com/android/tools/r8/shaking/ProguardConfiguration.java
+++ b/src/main/java/com/android/tools/r8/shaking/ProguardConfiguration.java
@@ -9,6 +9,7 @@
 import com.android.tools.r8.position.Position;
 import com.android.tools.r8.utils.InternalOptions.PackageObfuscationMode;
 import com.android.tools.r8.utils.Reporter;
+import com.android.tools.r8.utils.StringUtils;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Sets;
 import java.nio.file.Path;
@@ -63,6 +64,7 @@
         ProguardPathFilter.builder().disable();
     private boolean forceProguardCompatibility = false;
     private boolean overloadAggressively;
+    private boolean keepRuleSynthesisForRecompilation = false;
 
     private Builder(DexItemFactory dexItemFactory, Reporter reporter) {
       this.dexItemFactory = dexItemFactory;
@@ -273,8 +275,29 @@
       this.overloadAggressively = overloadAggressively;
     }
 
-    public ProguardConfiguration buildRaw() {
+    public void enableKeepRuleSynthesisForRecompilation() {
+      this.keepRuleSynthesisForRecompilation = true;
+    }
 
+    /**
+     * This synthesizes a set of keep rules that are necessary in order to be able to successfully
+     * recompile the generated dex files with the same keep rules.
+     */
+    public void synthesizeKeepRulesForRecompilation() {
+      List<ProguardConfigurationRule> synthesizedKeepRules = new ArrayList<>();
+      for (ProguardConfigurationRule rule : rules) {
+        ProguardConfigurationUtils.synthesizeKeepRulesForRecompilation(rule, synthesizedKeepRules);
+      }
+      if (rules.addAll(synthesizedKeepRules)) {
+        parsedConfiguration.add(
+            StringUtils.lines(
+                synthesizedKeepRules.stream()
+                    .map(ProguardClassSpecification::toString)
+                    .toArray(String[]::new)));
+      }
+    }
+
+    public ProguardConfiguration buildRaw() {
       ProguardConfiguration configuration = new ProguardConfiguration(
           String.join(System.lineSeparator(), parsedConfiguration),
           dexItemFactory,
@@ -334,6 +357,10 @@
         }));
       }
 
+      if (keepRuleSynthesisForRecompilation) {
+        synthesizeKeepRulesForRecompilation();
+      }
+
       return buildRaw();
     }
   }
diff --git a/src/main/java/com/android/tools/r8/shaking/ProguardConfigurationUtils.java b/src/main/java/com/android/tools/r8/shaking/ProguardConfigurationUtils.java
index f9846b8..7dbae8a 100644
--- a/src/main/java/com/android/tools/r8/shaking/ProguardConfigurationUtils.java
+++ b/src/main/java/com/android/tools/r8/shaking/ProguardConfigurationUtils.java
@@ -26,6 +26,14 @@
         }
       };
 
+  private static Origin synthesizedRecompilationOrigin =
+      new Origin(Origin.root()) {
+        @Override
+        public String part() {
+          return "<SYNTHESIZED_RECOMPILATION_RULE>";
+        }
+      };
+
   public static ProguardKeepRule buildDefaultInitializerKeepRule(DexClass clazz) {
     ProguardKeepRule.Builder builder = ProguardKeepRule.builder();
     builder.setOrigin(proguardCompatOrigin);
@@ -165,4 +173,21 @@
     }
     return false;
   }
+
+  public static void synthesizeKeepRulesForRecompilation(
+      ProguardConfigurationRule rule, List<ProguardConfigurationRule> synthesizedKeepRules) {
+    if (rule.hasInheritanceClassName()) {
+      ProguardTypeMatcher inheritanceClassName = rule.getInheritanceClassName();
+      synthesizedKeepRules.add(
+          ProguardKeepRule.builder()
+              .setOrigin(synthesizedRecompilationOrigin)
+              .setType(ProguardKeepRuleType.KEEP)
+              .setClassType(
+                  rule.getInheritanceIsExtends()
+                      ? ProguardClassType.CLASS
+                      : ProguardClassType.INTERFACE)
+              .setClassNames(ProguardClassNameList.singletonList(inheritanceClassName))
+              .build());
+    }
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/shaking/ProguardKeepRuleBase.java b/src/main/java/com/android/tools/r8/shaking/ProguardKeepRuleBase.java
index f67634e..7455469 100644
--- a/src/main/java/com/android/tools/r8/shaking/ProguardKeepRuleBase.java
+++ b/src/main/java/com/android/tools/r8/shaking/ProguardKeepRuleBase.java
@@ -6,6 +6,7 @@
 import com.android.tools.r8.origin.Origin;
 import com.android.tools.r8.position.Position;
 import java.util.List;
+import java.util.function.Consumer;
 
 public class ProguardKeepRuleBase extends ProguardConfigurationRule {
 
@@ -20,13 +21,19 @@
       super();
     }
 
-    public void setType(ProguardKeepRuleType type) {
+    public B setType(ProguardKeepRuleType type) {
       this.type = type;
+      return self();
     }
 
     public ProguardKeepRuleModifiers.Builder getModifiersBuilder() {
       return modifiersBuilder;
     }
+
+    public B updateModifiers(Consumer<ProguardKeepRuleModifiers.Builder> consumer) {
+      consumer.accept(getModifiersBuilder());
+      return self();
+    }
   }
 
   private final ProguardKeepRuleType type;
diff --git a/src/main/java/com/android/tools/r8/shaking/ProguardKeepRuleModifiers.java b/src/main/java/com/android/tools/r8/shaking/ProguardKeepRuleModifiers.java
index 4b9c63b..eb2acee 100644
--- a/src/main/java/com/android/tools/r8/shaking/ProguardKeepRuleModifiers.java
+++ b/src/main/java/com/android/tools/r8/shaking/ProguardKeepRuleModifiers.java
@@ -21,8 +21,9 @@
       this.allowsOptimization = allowsOptimization;
     }
 
-    public void setAllowsObfuscation(boolean allowsObfuscation) {
+    public Builder setAllowsObfuscation(boolean allowsObfuscation) {
       this.allowsObfuscation = allowsObfuscation;
+      return this;
     }
 
     public void setIncludeDescriptorClasses(boolean includeDescriptorClasses) {
diff --git a/src/main/java/com/android/tools/r8/utils/InternalOptions.java b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
index 4e66abf..3015132 100644
--- a/src/main/java/com/android/tools/r8/utils/InternalOptions.java
+++ b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
@@ -277,6 +277,11 @@
 
   public LineNumberOptimization lineNumberOptimization = LineNumberOptimization.ON;
 
+  public static boolean shouldEnableKeepRuleSynthesisForRecompilation() {
+    return Version.isDev()
+        && System.getProperty("com.android.tools.r8.keepRuleSynthesisForRecompilation") != null;
+  }
+
   private static Set<String> getExtensiveLoggingFilter() {
     String property = System.getProperty("com.android.tools.r8.extensiveLoggingFilter");
     if (property != null) {
diff --git a/tools/as_utils.py b/tools/as_utils.py
index 65ab470..93c910e 100644
--- a/tools/as_utils.py
+++ b/tools/as_utils.py
@@ -172,14 +172,16 @@
 # <td></td>
 # </tr>
 class ProfileReportParser(HTMLParser):
-  entered_table_row = False
-  entered_task_name_cell = False
-  entered_duration_cell = False
+  def __init__(self):
+    HTMLParser.__init__(self)
+    self.entered_table_row = False
+    self.entered_task_name_cell = False
+    self.entered_duration_cell = False
 
-  current_task_name = None
-  current_duration = None
+    self.current_task_name = None
+    self.current_duration = None
 
-  result = {}
+    self.result = {}
 
   def handle_starttag(self, tag, attrs):
     entered_table_row_before = self.entered_table_row
@@ -208,5 +210,13 @@
       if IsGradleTaskName(stripped):
         self.current_task_name = stripped
     elif self.entered_duration_cell and stripped.endswith('s'):
-      self.current_duration = float(stripped[:-1])
+      duration = stripped[:-1]
+      if 'm' in duration:
+        tmp = duration.split('m')
+        minutes = int(tmp[0])
+        seconds = float(tmp[1])
+      else:
+        minutes = 0
+        seconds = float(duration)
+      self.current_duration = 60 * minutes + seconds
     self.entered_table_row = False
diff --git a/tools/run_on_as_app.py b/tools/run_on_as_app.py
index a31456e..6c8de4b 100755
--- a/tools/run_on_as_app.py
+++ b/tools/run_on_as_app.py
@@ -231,10 +231,13 @@
         continue
 
       apk_dest = None
+
       result = {}
       try:
+        out_dir = os.path.join(checkout_dir, 'out', shrinker)
         (apk_dest, profile_dest_dir, proguard_config_file) = \
-            BuildAppWithShrinker(app, config, shrinker, checkout_dir, options)
+            BuildAppWithShrinker(app, config, shrinker, checkout_dir, out_dir,
+                options)
         dex_size = ComputeSizeOfDexFilesInApk(apk_dest)
         result['apk_dest'] = apk_dest,
         result['build_status'] = 'success'
@@ -258,13 +261,42 @@
 
         if 'r8' in shrinker and options.r8_compilation_steps > 1:
           recompilation_results = []
+
+          # Build app with gradle using -D...keepRuleSynthesisForRecompilation=
+          # true.
+          out_dir = os.path.join(checkout_dir, 'out', shrinker + '-1')
+          extra_env_vars = {
+            'JAVA_OPTS': ' '.join([
+              '-ea:com.android.tools.r8...',
+              '-Dcom.android.tools.r8.keepRuleSynthesisForRecompilation=true'
+            ])
+          }
+          (apk_dest, profile_dest_dir, ext_proguard_config_file) = \
+              BuildAppWithShrinker(app, config, shrinker, checkout_dir, out_dir,
+                  options, extra_env_vars)
+          dex_size = ComputeSizeOfDexFilesInApk(apk_dest)
+          recompilation_result = {
+            'apk_dest': apk_dest,
+            'build_status': 'success',
+            'dex_size': ComputeSizeOfDexFilesInApk(apk_dest),
+            'monkey_status': 'skipped'
+          }
+          recompilation_results.append(recompilation_result)
+
+          # Sanity check that keep rules have changed.
+          with open(ext_proguard_config_file) as new:
+            with open(proguard_config_file) as old:
+              assert(sum(1 for line in new) < sum(1 for line in old))
+
+          # Now rebuild generated apk.
           previous_apk = apk_dest
           for i in range(1, options.r8_compilation_steps):
             try:
               recompiled_apk_dest = os.path.join(
                   checkout_dir, 'out', shrinker, 'app-release-{}.apk'.format(i))
               RebuildAppWithShrinker(
-                  previous_apk, recompiled_apk_dest, proguard_config_file, shrinker)
+                  previous_apk, recompiled_apk_dest, ext_proguard_config_file,
+                  shrinker)
               recompilation_result = {
                 'apk_dest': recompiled_apk_dest,
                 'build_status': 'success',
@@ -285,7 +317,8 @@
 
   return result_per_shrinker
 
-def BuildAppWithShrinker(app, config, shrinker, checkout_dir, options):
+def BuildAppWithShrinker(
+    app, config, shrinker, checkout_dir, out_dir, options, env_vars=None):
   print()
   print('Building {} with {}'.format(app, shrinker))
 
@@ -299,19 +332,21 @@
   archives_base_name = config.get('archives_base_name', app_module)
   flavor = config.get('flavor')
 
-  out = os.path.join(checkout_dir, 'out', shrinker)
-  if not os.path.exists(out):
-    os.makedirs(out)
+  if not os.path.exists(out_dir):
+    os.makedirs(out_dir)
 
   # Set -printconfiguration in Proguard rules.
   proguard_config_dest = os.path.abspath(
-      os.path.join(out, 'proguard-rules.pro'))
+      os.path.join(out_dir, 'proguard-rules.pro'))
   as_utils.SetPrintConfigurationDirective(
       app, config, checkout_dir, proguard_config_dest)
 
   env = os.environ.copy()
   env['ANDROID_HOME'] = android_home
-  env['JAVA_OPTS'] = '-ea'
+  env['JAVA_OPTS'] = '-ea:com.android.tools.r8...'
+  if env_vars:
+    env.update(env_vars)
+
   releaseTarget = config.get('releaseTarget')
   if not releaseTarget:
     releaseTarget = app_module + ':' + 'assemble' + (
@@ -369,16 +404,16 @@
           keystore_password)
 
   if os.path.isfile(signed_apk):
-    apk_dest = os.path.join(out, signed_apk_name)
+    apk_dest = os.path.join(out_dir, signed_apk_name)
     as_utils.MoveFile(signed_apk, apk_dest)
   else:
-    apk_dest = os.path.join(out, unsigned_apk_name)
+    apk_dest = os.path.join(out_dir, unsigned_apk_name)
     as_utils.MoveFile(unsigned_apk, apk_dest)
 
   assert IsBuiltWithR8(apk_dest) == ('r8' in shrinker), (
       'Unexpected marker in generated APK for {}'.format(shrinker))
 
-  profile_dest_dir = os.path.join(out, 'profile')
+  profile_dest_dir = os.path.join(out_dir, 'profile')
   as_utils.MoveProfileReportTo(profile_dest_dir, stdout)
 
   return (apk_dest, profile_dest_dir, proguard_config_dest)
@@ -391,10 +426,11 @@
   api = 28 # TODO(christofferqa): Should be the one from build.gradle
   android_jar = os.path.join(utils.REPO_ROOT, utils.ANDROID_JAR.format(api=api))
   r8_jar = utils.R8LIB_JAR if IsMinifiedR8(shrinker) else utils.R8_JAR
-  zip_dest = apk_dest[:-3] + '.zip'
+  zip_dest = apk_dest[:-4] + '.zip'
 
-  cmd = ['java', '-ea', '-jar', r8_jar, '--release', '--pg-conf',
-      proguard_config_file, '--lib', android_jar, '--output', zip_dest, apk]
+  cmd = ['java', '-ea:com.android.tools.r8...', '-cp', r8_jar,
+      'com.android.tools.r8.R8', '--release', '--pg-conf', proguard_config_file,
+      '--lib', android_jar, '--output', zip_dest, apk]
   utils.PrintCmd(cmd)
 
   subprocess.check_output(cmd)
@@ -423,78 +459,90 @@
 
   try:
     stdout = subprocess.check_output(cmd)
+    succeeded = (
+        'Events injected: {}'.format(number_of_events_to_generate) in stdout)
   except subprocess.CalledProcessError as e:
-    return False
+    succeeded = False
 
-  return 'Events injected: {}'.format(number_of_events_to_generate) in stdout
+  UninstallApkOnEmulator(app, config)
 
-def LogResults(result_per_shrinker_per_app, options):
+  return succeeded
+
+def LogResultsForApps(result_per_shrinker_per_app, options):
   for app, result_per_shrinker in result_per_shrinker_per_app.iteritems():
-    print(app + ':')
+    LogResultsForApp(app, result_per_shrinker, options)
 
-    if result_per_shrinker.get('status') != 'success':
-      error_message = result_per_shrinker.get('error_message')
-      print('  skipped ({})'.format(error_message))
+def LogResultsForApp(app, result_per_shrinker, options):
+  print(app + ':')
+
+  if result_per_shrinker.get('status') != 'success':
+    error_message = result_per_shrinker.get('error_message')
+    print('  skipped ({})'.format(error_message))
+    return
+
+  proguard_result = result_per_shrinker.get('proguard', {})
+  proguard_dex_size = float(proguard_result.get('dex_size', -1))
+  proguard_duration = sum(proguard_result.get('profile', {}).values())
+
+  for shrinker in SHRINKERS:
+    if shrinker not in result_per_shrinker:
       continue
-
-    proguard_result = result_per_shrinker.get('proguard', {})
-    proguard_dex_size = float(proguard_result.get('dex_size', -1))
-    proguard_duration = sum(proguard_result.get('profile', {}).values())
-
-    for shrinker in SHRINKERS:
-      if shrinker not in result_per_shrinker:
-        continue
-      result = result_per_shrinker.get(shrinker)
-      build_status = result.get('build_status')
-      if build_status != 'success':
-        warn('  {}: {}'.format(shrinker, build_status))
+    result = result_per_shrinker.get(shrinker)
+    build_status = result.get('build_status')
+    if build_status != 'success':
+      warn('  {}: {}'.format(shrinker, build_status))
+    else:
+      print('  {}:'.format(shrinker))
+      dex_size = result.get('dex_size')
+      msg = '    dex size: {}'.format(dex_size)
+      if dex_size != proguard_dex_size and proguard_dex_size >= 0:
+        msg = '{} ({}, {})'.format(
+            msg, dex_size - proguard_dex_size,
+            PercentageDiffAsString(proguard_dex_size, dex_size))
+        success(msg) if dex_size < proguard_dex_size else warn(msg)
       else:
-        print('  {}:'.format(shrinker))
-        dex_size = result.get('dex_size')
-        msg = '    dex size: {}'.format(dex_size)
-        if dex_size != proguard_dex_size and proguard_dex_size >= 0:
-          msg = '{} ({}, {})'.format(
-              msg, dex_size - proguard_dex_size,
-              PercentageDiffAsString(proguard_dex_size, dex_size))
-          success(msg) if dex_size < proguard_dex_size else warn(msg)
-        else:
-          print(msg)
+        print(msg)
 
-        profile = result.get('profile')
-        duration = sum(profile.values())
-        msg = '    performance: {}s'.format(duration)
-        if duration != proguard_duration and proguard_duration > 0:
-          msg = '{} ({}s, {})'.format(
-              msg, duration - proguard_duration,
-              PercentageDiffAsString(proguard_duration, duration))
-          success(msg) if duration < proguard_duration else warn(msg)
-        else:
-          print(msg)
-        if len(profile) >= 2:
-          for task_name, task_duration in profile.iteritems():
-            print('      {}: {}s'.format(task_name, task_duration))
+      profile = result.get('profile')
+      duration = sum(profile.values())
+      msg = '    performance: {}s'.format(duration)
+      if duration != proguard_duration and proguard_duration > 0:
+        msg = '{} ({}s, {})'.format(
+            msg, duration - proguard_duration,
+            PercentageDiffAsString(proguard_duration, duration))
+        success(msg) if duration < proguard_duration else warn(msg)
+      else:
+        print(msg)
+      if len(profile) >= 2:
+        for task_name, task_duration in profile.iteritems():
+          print('      {}: {}s'.format(task_name, task_duration))
 
-        if options.monkey:
-          monkey_status = result.get('monkey_status')
-          if monkey_status != 'success':
-            warn('    monkey: {}'.format(monkey_status))
-          else:
-            success('    monkey: {}'.format(monkey_status))
-        recompilation_results = result.get('recompilation_results', [])
-        i = 1
-        for recompilation_result in recompilation_results:
-          build_status = recompilation_result.get('build_status')
-          if build_status != 'success':
-            print('    recompilation #{}: {}'.format(i, build_status))
-          else:
-            dex_size = recompilation_result.get('dex_size')
-            print('    recompilation #{}'.format(i))
-            print('      dex size: {}'.format(dex_size))
-            if options.monkey:
-              monkey_status = recompilation_result.get('monkey_status')
-              msg = '      monkey: {}'.format(monkey_status)
-              success(msg) if monkey_status == 'success' else warn(msg)
-          i += 1
+      if options.monkey:
+        monkey_status = result.get('monkey_status')
+        if monkey_status != 'success':
+          warn('    monkey: {}'.format(monkey_status))
+        else:
+          success('    monkey: {}'.format(monkey_status))
+      recompilation_results = result.get('recompilation_results', [])
+      i = 0
+      for recompilation_result in recompilation_results:
+        build_status = recompilation_result.get('build_status')
+        if build_status != 'success':
+          print('    recompilation #{}: {}'.format(i, build_status))
+        else:
+          dex_size = recompilation_result.get('dex_size')
+          print('    recompilation #{}'.format(i))
+          print('      dex size: {}'.format(dex_size))
+          if options.monkey:
+            monkey_status = recompilation_result.get('monkey_status')
+            msg = '      monkey: {}'.format(monkey_status)
+            if monkey_status == 'success':
+              success(msg)
+            elif monkey_status == 'skipped':
+              print(msg)
+            else:
+              warn(msg)
+        i += 1
 
 def ParseOptions(argv):
   result = optparse.OptionParser()
@@ -572,7 +620,7 @@
         result_per_shrinker_per_app[app] = GetResultsForApp(
             app, config, options)
 
-  LogResults(result_per_shrinker_per_app, options)
+  LogResultsForApps(result_per_shrinker_per_app, options)
 
 def success(message):
   CGREEN = '\033[32m'