Extend generate_startup_descriptors.py to app bundles

Change-Id: Id2e97a31481e40f22bc9fcdea0f116e6e1b0eef4
diff --git a/tools/startup/adb_utils.py b/tools/startup/adb_utils.py
index a5a4e89..108896e 100755
--- a/tools/startup/adb_utils.py
+++ b/tools/startup/adb_utils.py
@@ -56,6 +56,15 @@
   cmd = create_adb_cmd('shell am broadcast -a %s %s' % (action, component), device_id)
   return subprocess.check_output(cmd).decode('utf-8').strip().splitlines()
 
+def build_apks_from_bundle(bundle, output):
+  print('Building %s' % bundle)
+  cmd = [
+      'java', '-jar', utils.BUNDLETOOL_JAR,
+      'build-apks',
+      '--bundle=%s' % bundle,
+      '--output=%s' % output]
+  subprocess.check_call(cmd, stdout=DEVNULL, stderr=DEVNULL)
+
 def capture_screen(target, device_id=None):
   print('Taking screenshot to %s' % target)
   tmp = '/sdcard/screencap.png'
@@ -235,6 +244,23 @@
   stdout = subprocess.check_output(cmd).decode('utf-8')
   assert 'Success' in stdout
 
+def install_apks(apks, device_id=None):
+  print('Installing %s' % apks)
+  cmd = [
+      'java', '-jar', utils.BUNDLETOOL_JAR,
+      'install-apks',
+      '--apks=%s' % apks]
+  if device_id is not None:
+    cmd.append('--device-id=%s' % device_id)
+  subprocess.check_call(cmd, stdout=DEVNULL, stderr=DEVNULL)
+
+def install_bundle(bundle, device_id=None):
+  print('Installing %s' % bundle)
+  with utils.TempDir() as temp:
+    apks = os.path.join(temp, 'Bundle.apks')
+    build_apks_from_bundle(bundle, apks)
+    install_apks(apks, device_id)
+
 def install_profile(app_id, device_id=None):
   # This assumes that the profileinstaller library has been added to the app,
   # https://developer.android.com/jetpack/androidx/releases/profileinstaller.
diff --git a/tools/startup/generate_startup_descriptors.py b/tools/startup/generate_startup_descriptors.py
index 822fc1c..5116938 100755
--- a/tools/startup/generate_startup_descriptors.py
+++ b/tools/startup/generate_startup_descriptors.py
@@ -302,10 +302,14 @@
   result = argparse.ArgumentParser(
       description='Generate a perfetto trace file.')
   result.add_argument('--apk',
-                      help='Path to the APK')
+                      help='Path to the .apk')
+  result.add_argument('--apks',
+                      help='Path to the .apks')
   result.add_argument('--app-id',
                       help='The application ID of interest',
                       required=True)
+  result.add_argument('--bundle',
+                      help='Path to the .aab')
   result.add_argument('--device-id',
                       help='Device id (e.g., emulator-5554).',
                       action='append')
@@ -374,6 +378,10 @@
       options.devices.append(Device(device_id, device_pin))
   del options.device_id
 
+  paths = [
+      path for path in [options.apk, options.apks, options.bundle]
+      if path is not None]
+  assert len(paths) <= 1, 'Expected at most one .apk, .apks, or .aab file.'
   assert options.main_activity is not None or options.use_existing_profile, \
       'Argument --main-activity is required except when running with ' \
       '--use-existing-profile.'
@@ -385,6 +393,12 @@
   if options.apk:
     adb_utils.uninstall(options.app_id, device.device_id)
     adb_utils.install(options.apk, device.device_id)
+  elif options.apks:
+    adb_utils.uninstall(options.app_id, device.device_id)
+    adb_utils.install_apks(options.apks, device.device_id)
+  elif options.bundle:
+    adb_utils.uninstall(options.app_id, device.device_id)
+    adb_utils.install_bundle(options.bundle, device.device_id)
   if options.until_stable:
     iteration = 0
     stable_iterations = 0
diff --git a/tools/utils.py b/tools/utils.py
index ecd172c..9bd518f 100644
--- a/tools/utils.py
+++ b/tools/utils.py
@@ -23,6 +23,8 @@
 TOOLS_DIR = defines.TOOLS_DIR
 REPO_ROOT = defines.REPO_ROOT
 THIRD_PARTY = defines.THIRD_PARTY
+BUNDLETOOL_JAR_DIR = os.path.join(THIRD_PARTY, 'bundletool/bundletool-1.11.0')
+BUNDLETOOL_JAR = os.path.join(BUNDLETOOL_JAR_DIR, 'bundletool-all-1.11.0.jar')
 ANDROID_SDK = os.path.join(THIRD_PARTY, 'android_sdk')
 MEMORY_USE_TMP_FILE = 'memory_use.tmp'
 DEX_SEGMENTS_RESULT_PATTERN = re.compile('- ([^:]+): ([0-9]+)')