Parallelize run_on_app.py

Bug: b/297302759
Change-Id: I454286c5d673c7cec41f14f0acad89f1842122cd
diff --git a/tools/internal_test.py b/tools/internal_test.py
index e8a4f5a..b32c378 100755
--- a/tools/internal_test.py
+++ b/tools/internal_test.py
@@ -117,7 +117,7 @@
     ['tools/test.py', '--only_internal', '--slow_tests',
      '--java_max_memory_size=8G'],
     # Ensure that all internal apps compile.
-    ['tools/run_on_app.py', '--run-all', '--out=out'],
+    ['tools/run_on_app.py', '--run-all', '--out=out', '--workers', '4'],
 ]
 
 # Command timeout, in seconds.
diff --git a/tools/run_on_app.py b/tools/run_on_app.py
index b4d0418..2340b46 100755
--- a/tools/run_on_app.py
+++ b/tools/run_on_app.py
@@ -18,6 +18,8 @@
 import gmscore_data
 import nest_data
 from sanitize_libraries import SanitizeLibraries, SanitizeLibrariesInPgconf
+import thread_utils
+from thread_utils import print_thread
 import toolhelper
 import update_prebuilds_in_android
 import utils
@@ -48,6 +50,11 @@
                     help='Compiler build to use',
                     choices=COMPILER_BUILDS,
                     default='lib')
+  result.add_option('--no-fail-fast',
+                    help='Whether run_on_app.py should report all failures '
+                         'and not just the first one',
+                    default=False,
+                    action='store_true')
   result.add_option('--hash',
                     help='The version of D8/R8 to use')
   result.add_option('--app',
@@ -181,6 +188,10 @@
                     help='Disable compiler logging',
                     default=False,
                     action='store_true')
+  result.add_option('--workers',
+                    help='Number of workers to use',
+                    default=1,
+                    type=int)
   (options, args) = result.parse_args(argv)
   assert not options.hash or options.no_build, (
       'Argument --no-build is required when using --hash')
@@ -227,25 +238,53 @@
             yield app, version, type, use_r8lib
 
 def run_all(options, args):
+  # Build first so that each job won't.
+  if should_build(options):
+    gradle.RunGradle(['r8lib'])
+    options.no_build = True
+  assert not should_build(options)
+
   # Args will be destroyed
   assert len(args) == 0
+  jobs = []
   for name, version, type, use_r8lib in get_permutations():
     compiler = 'r8' if type == 'deploy' else 'd8'
     compiler_build = 'lib' if use_r8lib else 'full'
-    print('Executing %s/%s with %s %s %s' % (compiler, compiler_build, name,
-      version, type))
-
     fixed_options = copy.copy(options)
     fixed_options.app = name
     fixed_options.version = version
     fixed_options.compiler = compiler
     fixed_options.compiler_build = compiler_build
     fixed_options.type = type
-    exit_code = run_with_options(fixed_options, [])
-    if exit_code != 0:
-      print('Failed %s %s %s with %s/%s' % (name, version, type, compiler,
-        compiler_build))
-      exit(exit_code)
+    jobs.append(
+        create_job(
+            compiler, compiler_build, name, fixed_options, type, version))
+  exit_code = thread_utils.run_in_parallel(
+      jobs,
+      number_of_workers=options.workers,
+      stop_on_first_failure=not options.no_fail_fast)
+  exit(exit_code)
+
+def create_job(compiler, compiler_build, name, options, type, version):
+  return lambda worker_id: run_job(
+      compiler, compiler_build, name, options, type, version, worker_id)
+
+def run_job(
+    compiler, compiler_build, name, options, type, version, worker_id):
+  print_thread(
+      'Executing %s/%s with %s %s %s'
+          % (compiler, compiler_build, name, version, type),
+      worker_id)
+  if worker_id is not None:
+    options.out = os.path.join(options.out, str(worker_id))
+    os.makedirs(options.out, exist_ok=True)
+  exit_code = run_with_options(options, [], worker_id=worker_id)
+  if exit_code:
+    print_thread(
+        'Failed %s %s %s with %s/%s'
+            % (name, version, type, compiler, compiler_build),
+        worker_id)
+  return exit_code
 
 def find_min_xmx(options, args):
   # Args will be destroyed
@@ -492,7 +531,8 @@
       os.path.join(android_java8_libs_output, 'classes.dex'),
       os.path.join(outdir, dex_file_name))
 
-def run_with_options(options, args, extra_args=None, stdout=None, quiet=False):
+def run_with_options(
+    options, args, extra_args=None, stdout=None, quiet=False, worker_id=None):
   if extra_args is None:
     extra_args = []
   app_provided_pg_conf = False;
@@ -550,9 +590,9 @@
 
   if options.compiler == 'r8':
     if 'pgconf' in values and not options.k:
+      sanitized_lib_path = os.path.join(
+          os.path.abspath(outdir), 'sanitized_lib.jar')
       if has_injars_and_libraryjars(values['pgconf']):
-        sanitized_lib_path = os.path.join(
-            os.path.abspath(outdir), 'sanitized_lib.jar')
         sanitized_pgconf_path = os.path.join(
             os.path.abspath(outdir), 'sanitized.config')
         SanitizeLibrariesInPgconf(
@@ -566,8 +606,6 @@
         for pgconf in values['pgconf']:
           args.extend(['--pg-conf', pgconf])
         if 'sanitize_libraries' in values and values['sanitize_libraries']:
-          sanitized_lib_path = os.path.join(
-              os.path.abspath(outdir), 'sanitized_lib.jar')
           SanitizeLibraries(
             sanitized_lib_path, values['libraries'], values['inputs'])
           libraries = [sanitized_lib_path]
@@ -693,7 +731,8 @@
             cmd_prefix=[
                 'taskset', '-c', options.cpu_list] if options.cpu_list else [],
             jar=jar,
-            main=main)
+            main=main,
+            worker_id=worker_id)
       if exit_code != 0:
         with open(stderr_path) as stderr:
           stderr_text = stderr.read()
diff --git a/tools/thread_utils.py b/tools/thread_utils.py
new file mode 100755
index 0000000..28fd348
--- /dev/null
+++ b/tools/thread_utils.py
@@ -0,0 +1,116 @@
+#!/usr/bin/env python3
+# Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+# for details. All rights reserved. Use of this source code is governed by a
+# BSD-style license that can be found in the LICENSE file.
+
+import sys
+import threading
+from threading import Thread
+import traceback
+
+# A thread that is given a list of jobs. The thread will repeatedly take a job
+# from the list of jobs (which is shared with other threads) and execute it
+# until there is no more jobs.
+#
+# If stop_on_first_failure is True, then the thread will exit upon the first
+# failing job. The thread will then clear the jobs to ensure that all other
+# workers will also terminate after completing there current job.
+#
+# Each job is a lambda that takes the worker_id as an argument. To guarantee
+# termination each job must itself terminate (i.e., each job is responsible for
+# setting an appropriate timeout).
+class WorkerThread(Thread):
+
+  # The initialization of a WorkerThread is never run concurrently with the
+  # initialization of other WorkerThreads.
+  def __init__(self, jobs, jobs_lock, stop_on_first_failure, worker_id):
+    Thread.__init__(self)
+    self.jobs = jobs
+    self.jobs_lock = jobs_lock
+    self.number_of_jobs = len(jobs)
+    self.stop_on_first_failure = stop_on_first_failure
+    self.success = True
+    self.worker_id = worker_id
+
+  def run(self):
+    print_thread("Starting worker", self.worker_id)
+    while True:
+      (job, job_id) = self.take_job(self.jobs, self.jobs_lock)
+      if job is None:
+        break
+      try:
+        print_thread(
+            "Starting job %s/%s" % (job_id, self.number_of_jobs),
+            self.worker_id)
+        exit_code = job(self.worker_id)
+        print_thread(
+            "Job %s finished with exit code %s"
+                % (job_id, exit_code),
+            self.worker_id)
+        if exit_code:
+          self.success = False
+          if self.stop_on_first_failure:
+            self.clear_jobs(jobs, jobs_lock)
+            break
+      except:
+        print_thread("Job %s crashed" % job_id, self.worker_id)
+        print_thread(traceback.format_exc(), self.worker_id)
+        self.success = False
+        if self.stop_on_first_failure:
+          self.clear_jobs(jobs, jobs_lock)
+          break
+    print_thread("Exiting", self.worker_id)
+
+  def take_job(self, jobs, jobs_lock):
+    jobs_lock.acquire()
+    job_id = self.number_of_jobs - len(jobs) + 1
+    job = jobs.pop(0) if jobs else None
+    jobs_lock.release()
+    return (job, job_id)
+
+  def clear_jobs(self, jobs, jobs_lock):
+    jobs_lock.acquire()
+    jobs.clear()
+    jobs_lock.release()
+
+def run_in_parallel(jobs, number_of_workers, stop_on_first_failure):
+  assert number_of_workers > 0
+  if number_of_workers > len(jobs):
+    number_of_workers = len(jobs)
+  if number_of_workers == 1:
+    return run_in_sequence(jobs, stop_on_first_failure)
+  jobs_lock = threading.Lock()
+  threads = []
+  for worker_id in range(1, number_of_workers + 1):
+    threads.append(
+        WorkerThread(jobs, jobs_lock, stop_on_first_failure, worker_id))
+  for thread in threads:
+    thread.start()
+  for thread in threads:
+    thread.join()
+  for thread in threads:
+    if not thread.success:
+      return 1
+  return 0
+
+def run_in_sequence(jobs, stop_on_first_failure):
+  combined_exit_code = 0
+  worker_id = None
+  for job in jobs:
+    try:
+      exit_code = job(worker_id)
+      if exit_code:
+        combined_exit_code = exit_code
+        if stop_on_first_failure:
+          break
+    except:
+      print(traceback.format_exc())
+      combined_exit_code = 1
+      if stop_on_first_failure:
+        break
+  return combined_exit_code
+
+def print_thread(msg, worker_id):
+  if worker_id is None:
+    print(msg)
+  print('WORKER %s: %s' % (worker_id, msg))
\ No newline at end of file
diff --git a/tools/toolhelper.py b/tools/toolhelper.py
index 881f6dc..c8d8a08 100644
--- a/tools/toolhelper.py
+++ b/tools/toolhelper.py
@@ -15,7 +15,8 @@
 def run(tool, args, build=None, debug=True,
         profile=False, track_memory_file=None, extra_args=None,
         stderr=None, stdout=None, return_stdout=False, timeout=0, quiet=False,
-        cmd_prefix=None, jar=None, main=None, time_consumer=None):
+        cmd_prefix=None, jar=None, main=None, time_consumer=None,
+        worker_id=None):
   cmd = []
   if cmd_prefix:
     cmd.extend(cmd_prefix)
@@ -52,7 +53,7 @@
   if lib:
     cmd.extend(["--lib", lib])
   cmd.extend(args)
-  utils.PrintCmd(cmd, quiet=quiet)
+  utils.PrintCmd(cmd, quiet=quiet, worker_id=worker_id)
   start = time.time()
   if timeout > 0:
     kill = lambda process: process.kill()
diff --git a/tools/utils.py b/tools/utils.py
index b691b16..9e9e831 100644
--- a/tools/utils.py
+++ b/tools/utils.py
@@ -17,6 +17,7 @@
 
 import defines
 import jdk
+from thread_utils import print_thread
 
 ANDROID_JAR_DIR = 'third_party/android_jar/lib-v{api}'
 ANDROID_JAR = os.path.join(ANDROID_JAR_DIR, 'android.jar')
@@ -203,16 +204,16 @@
   CEND = '\033[0m'
   print(CRED + message + CEND)
 
-def PrintCmd(cmd, env=None, quiet=False):
+def PrintCmd(cmd, env=None, quiet=False, worker_id=None):
   if quiet:
     return
   if type(cmd) is list:
     cmd = ' '.join(cmd)
   if env:
     env = ' '.join(['{}=\"{}\"'.format(x, y) for x, y in env.iteritems()])
-    print('Running: {} {}'.format(env, cmd))
+    print_thread('Running: {} {}'.format(env, cmd), worker_id)
   else:
-    print('Running: {}'.format(cmd))
+    print_thread('Running: {}'.format(cmd), worker_id)
   # I know this will hit os on windows eventually if we don't do this.
   sys.stdout.flush()