Parallelize run_on_app_dump.py

Bug: b/297302759
Change-Id: I4d14ab3b4001ad37823446069e7c8eb335b635ad
diff --git a/tools/compiledump.py b/tools/compiledump.py
index 39315c9..1587e7d 100755
--- a/tools/compiledump.py
+++ b/tools/compiledump.py
@@ -243,13 +243,14 @@
   if os.path.isdir(dump):
     return Dump(dump)
   dump_file = zipfile.ZipFile(os.path.abspath(dump), 'r')
-  with utils.ChangedWorkingDirectory(temp, quiet=True):
-    if override or not os.path.isfile('r8-version'):
-      dump_file.extractall()
-      if not os.path.isfile('r8-version'):
-        error("Did not extract into %s. Either the zip file is invalid or the "
-              "dump is missing files" % temp)
-    return Dump(temp)
+  r8_version_file = os.path.join(temp, 'r8-version')
+
+  if override or not os.path.isfile(r8_version_file):
+    dump_file.extractall(temp)
+    if not os.path.isfile(r8_version_file):
+      error("Did not extract into %s. Either the zip file is invalid or the "
+            "dump is missing files" % temp)
+  return Dump(temp)
 
 def determine_build_properties(args, dump):
   build_properties = {}
@@ -463,7 +464,7 @@
 def is_hash(version):
   return len(version) == 40
 
-def run1(out, args, otherargs, jdkhome=None):
+def run1(out, args, otherargs, jdkhome=None, worker_id=None):
   jvmargs = []
   compilerargs = []
   for arg in otherargs:
@@ -584,7 +585,7 @@
     if args.threads:
       cmd.extend(['--threads', args.threads])
     cmd.extend(compilerargs)
-    utils.PrintCmd(cmd)
+    utils.PrintCmd(cmd, worker_id=worker_id)
     try:
       print(subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode('utf-8'))
       return 0
diff --git a/tools/run_on_app_dump.py b/tools/run_on_app_dump.py
index a814e95..44239b8 100755
--- a/tools/run_on_app_dump.py
+++ b/tools/run_on_app_dump.py
@@ -18,6 +18,8 @@
 import compiledump
 import gradle
 import jdk
+import thread_utils
+from thread_utils import print_thread
 import update_prebuilds_in_android
 import utils
 
@@ -509,11 +511,15 @@
 def get_r8_jar(options, temp_dir, shrinker):
   if (options.version == 'source'):
     return None
-  return os.path.join(
-      temp_dir, 'r8lib.jar' if is_minified_r8(shrinker) else 'r8.jar')
+  jar = os.path.abspath(
+      os.path.join(
+          temp_dir,
+          '..',
+          'r8lib.jar' if is_minified_r8(shrinker) else 'r8.jar'))
+  return jar
 
 
-def get_results_for_app(app, options, temp_dir):
+def get_results_for_app(app, options, temp_dir, worker_id):
   app_folder = app.folder if app.folder else app.name + "_" + app.revision
   # Golem extraction will extract to the basename under the benchmarks dir.
   app_location = os.path.basename(app_folder) if options.golem else app_folder
@@ -532,23 +538,23 @@
   result = {}
   result['status'] = 'success'
   result_per_shrinker = build_app_with_shrinkers(
-    app, options, temp_dir, app_dir)
+    app, options, temp_dir, app_dir, worker_id=worker_id)
   for shrinker, shrinker_result in result_per_shrinker.items():
     result[shrinker] = shrinker_result
   return result
 
 
-def build_app_with_shrinkers(app, options, temp_dir, app_dir):
+def build_app_with_shrinkers(app, options, temp_dir, app_dir, worker_id):
   result_per_shrinker = {}
   for shrinker in options.shrinker:
     results = []
     build_app_and_run_with_shrinker(
-      app, options, temp_dir, app_dir, shrinker, results)
+      app, options, temp_dir, app_dir, shrinker, results, worker_id=worker_id)
     result_per_shrinker[shrinker] = results
   if len(options.apps) > 1:
-    print('')
-    log_results_for_app(app, result_per_shrinker, options)
-    print('')
+    print_thread('', worker_id)
+    log_results_for_app(app, result_per_shrinker, options, worker_id=worker_id)
+    print_thread('', worker_id)
 
   return result_per_shrinker
 
@@ -558,21 +564,27 @@
 
 
 def build_app_and_run_with_shrinker(app, options, temp_dir, app_dir, shrinker,
-                                    results):
-  print('[{}] Building {} with {}'.format(
-    datetime.now().strftime("%H:%M:%S"),
-    app.name,
-    shrinker))
-  print('To compile locally: '
-        'tools/run_on_app_dump.py --shrinker {} --r8-compilation-steps {} '
-        '--app {} --minify {} --optimize {} --shrink {}'.format(
-    shrinker,
-    options.r8_compilation_steps,
-    app.name,
-    options.minify,
-    options.optimize,
-    options.shrink))
-  print('HINT: use --shrinker r8-nolib --no-build if you have a local R8.jar')
+                                    results, worker_id):
+  print_thread(
+      '[{}] Building {} with {}'.format(
+          datetime.now().strftime("%H:%M:%S"),
+          app.name,
+          shrinker),
+      worker_id)
+  print_thread(
+      'To compile locally: '
+          'tools/run_on_app_dump.py --shrinker {} --r8-compilation-steps {} '
+          '--app {} --minify {} --optimize {} --shrink {}'.format(
+              shrinker,
+              options.r8_compilation_steps,
+              app.name,
+              options.minify,
+              options.optimize,
+              options.shrink),
+      worker_id)
+  print_thread(
+      'HINT: use --shrinker r8-nolib --no-build if you have a local R8.jar',
+      worker_id)
   recomp_jar = None
   status = 'success'
   if options.r8_compilation_steps < 1:
@@ -581,14 +593,16 @@
   for compilation_step in range(0, compilation_steps):
     if status != 'success':
       break
-    print('Compiling {} of {}'.format(compilation_step + 1, compilation_steps))
+    print_thread(
+        'Compiling {} of {}'.format(compilation_step + 1, compilation_steps),
+        worker_id)
     result = {}
     try:
       start = time.time()
       (app_jar, mapping, new_recomp_jar) = \
         build_app_with_shrinker(
           app, options, temp_dir, app_dir, shrinker, compilation_step,
-          compilation_steps, recomp_jar)
+          compilation_steps, recomp_jar, worker_id=worker_id)
       end = time.time()
       dex_size = compute_size_of_dex_files_in_package(app_jar)
       result['build_status'] = 'success'
@@ -605,7 +619,7 @@
     except Exception as e:
       warn('Failed to build {} with {}'.format(app.name, shrinker))
       if e:
-        print('Error: ' + str(e))
+        print_thread('Error: ' + str(e), worker_id)
       result['build_status'] = 'failed'
       status = 'failed'
 
@@ -666,7 +680,7 @@
 
 def build_app_with_shrinker(app, options, temp_dir, app_dir, shrinker,
                             compilation_step_index, compilation_steps,
-                            prev_recomp_jar):
+                            prev_recomp_jar, worker_id):
   def config_files_consumer(files):
     for file in files:
       compiledump.clean_config(file, options)
@@ -675,7 +689,7 @@
     'dump': dump_for_app(app_dir, app),
     'r8_jar': get_r8_jar(options, temp_dir, shrinker),
     'r8_flags': options.r8_flags,
-    'ea': False if options.disable_assertions else True,
+    'ea': not options.disable_assertions,
     'version': options.version,
     'compiler': 'r8full' if is_full_r8(shrinker) else 'r8',
     'debug_agent': options.debug_agent,
@@ -696,7 +710,8 @@
   recomp_jar = None
   jdkhome = get_jdk_home(options, app)
   with utils.TempDir() as compile_temp_dir:
-    compile_result = compiledump.run1(compile_temp_dir, args, [], jdkhome)
+    compile_result = compiledump.run1(
+        compile_temp_dir, args, [], jdkhome, worker_id=worker_id)
     out_jar = os.path.join(compile_temp_dir, "out.jar")
     out_mapping = os.path.join(compile_temp_dir, "out.jar.map")
 
@@ -740,7 +755,7 @@
   args = AttrDict({
     'dump': dump_test_for_app(app_dir, app),
     'r8_jar': get_r8_jar(options, temp_dir, shrinker),
-    'ea': False if options.disable_assertions else True,
+    'ea': not options.disable_assertions,
     'version': options.version,
     'compiler': 'r8full' if is_full_r8(shrinker) else 'r8',
     'debug_agent': options.debug_agent,
@@ -774,24 +789,29 @@
   return app_errors
 
 
-def log_results_for_app(app, result_per_shrinker, options):
+def log_results_for_app(app, result_per_shrinker, options, worker_id=None):
   if options.print_dexsegments:
-    log_segments_for_app(app, result_per_shrinker, options)
+    log_segments_for_app(app, result_per_shrinker, options, worker_id=worker_id)
     return False
   else:
-    return log_comparison_results_for_app(app, result_per_shrinker, options)
+    return log_comparison_results_for_app(app, result_per_shrinker, options, worker_id=worker_id)
 
 
-def log_segments_for_app(app, result_per_shrinker, options):
+def log_segments_for_app(app, result_per_shrinker, options, worker_id):
   for shrinker in SHRINKERS:
     if shrinker not in result_per_shrinker:
       continue
     for result in result_per_shrinker.get(shrinker):
       benchmark_name = '{}-{}'.format(options.print_dexsegments, app.name)
-      utils.print_dexsegments(benchmark_name, [result.get('output_jar')])
+      utils.print_dexsegments(
+        benchmark_name, [result.get('output_jar')], worker_id=worker_id)
       duration = result.get('duration')
-      print('%s-Total(RunTimeRaw): %s ms' % (benchmark_name, duration))
-      print('%s-Total(CodeSize): %s' % (benchmark_name, result.get('dex_size')))
+      print_thread(
+        '%s-Total(RunTimeRaw): %s ms' % (benchmark_name, duration),
+        worker_id)
+      print_thread(
+        '%s-Total(CodeSize): %s' % (benchmark_name, result.get('dex_size')),
+        worker_id)
 
 
 def percentage_diff_as_string(before, after):
@@ -801,12 +821,12 @@
     return '+' + str(round((after - before) / before * 100)) + '%'
 
 
-def log_comparison_results_for_app(app, result_per_shrinker, options):
-  print(app.name + ':')
+def log_comparison_results_for_app(app, result_per_shrinker, options, worker_id):
+  print_thread(app.name + ':', worker_id)
   app_error = False
   if result_per_shrinker.get('status', 'success') != 'success':
     error_message = result_per_shrinker.get('error_message')
-    print('  skipped ({})'.format(error_message))
+    print_thread('  skipped ({})'.format(error_message), worker_id)
     return
 
   proguard_result = result_per_shrinker.get('pg', {})
@@ -824,22 +844,26 @@
         continue
 
       if options.golem:
-        print('%s(RunTimeRaw): %s ms' % (app.name, result.get('duration')))
-        print('%s(CodeSize): %s' % (app.name, result.get('dex_size')))
+        print_thread(
+          '%s(RunTimeRaw): %s ms' % (app.name, result.get('duration')),
+          worker_id)
+        print_thread(
+          '%s(CodeSize): %s' % (app.name, result.get('dex_size')), worker_id)
         continue
 
-      print('  {}-#{}:'.format(shrinker, compilation_index))
+      print_thread('  {}-#{}:'.format(shrinker, compilation_index), worker_id)
       dex_size = result.get('dex_size')
       msg = '    dex size: {}'.format(dex_size)
       if options.print_runtimeraw:
-        print('    run time raw: {} ms'.format(result.get('duration')))
+        print_thread(
+            '    run time raw: {} ms'.format(result.get('duration')), worker_id)
       if dex_size != proguard_dex_size and proguard_dex_size >= 0:
         msg = '{} ({}, {})'.format(
           msg, dex_size - proguard_dex_size,
           percentage_diff_as_string(proguard_dex_size, dex_size))
         success(msg) if dex_size < proguard_dex_size else warn(msg)
       else:
-        print(msg)
+        print_thread(msg, worker_id)
 
       if options.monkey:
         monkey_status = result.get('monkey_status')
@@ -1197,12 +1221,18 @@
         assert os.path.isfile(utils.R8LIB_JAR), 'Cannot build without r8lib.jar'
         shutil.copyfile(utils.R8LIB_JAR, os.path.join(temp_dir, 'r8lib.jar'))
 
+    jobs = []
     result_per_shrinker_per_app = []
     for app in options.apps:
       if app.skip:
         continue
-      result_per_shrinker_per_app.append(
-        (app, get_results_for_app(app, options, temp_dir)))
+      result = {}
+      result_per_shrinker_per_app.append((app, result))
+      jobs.append(create_job(app, options, result, temp_dir))
+    thread_utils.run_in_parallel(
+        jobs,
+        number_of_workers=options.workers,
+        stop_on_first_failure=False)
     errors = log_results_for_apps(result_per_shrinker_per_app, options)
     if errors > 0:
       dest = 'gs://r8-test-results/r8-libs/' + str(int(time.time()))
@@ -1210,6 +1240,15 @@
       print('R8lib saved to %s' % dest)
     return errors
 
+def create_job(app, options, result, temp_dir):
+  return lambda worker_id: run_job(
+      app, options, result, temp_dir, worker_id)
+
+def run_job(app, options, result, temp_dir, worker_id):
+  job_temp_dir = os.path.join(temp_dir, str(worker_id or 0))
+  os.makedirs(job_temp_dir, exist_ok=True)
+  result.update(get_results_for_app(app, options, job_temp_dir, worker_id))
+  return 0
 
 def success(message):
   CGREEN = '\033[32m'
diff --git a/tools/utils.py b/tools/utils.py
index 9e9e831..c737776 100644
--- a/tools/utils.py
+++ b/tools/utils.py
@@ -592,10 +592,11 @@
       print('{}-{}(CodeSize): {}'
             .format(prefix, segment_name, size))
 
-def print_dexsegments(prefix, dex_files):
+def print_dexsegments(prefix, dex_files, worker_id=None):
   for segment_name, size in getDexSegmentSizes(dex_files).items():
-    print('{}-{}(CodeSize): {}'
-        .format(prefix, segment_name, size))
+    print_thread(
+      '{}-{}(CodeSize): {}'.format(prefix, segment_name, size),
+      worker_id)
 
 # Ensure that we are not benchmarking with a google jvm.
 def check_java_version():