Add python formatting to presubmit.

Change-Id: Ib2fb18aa48375007d2cc2e52f5c6030ab3d7dc15
diff --git a/PRESUBMIT.py b/PRESUBMIT.py
index 2ea419e..8f44aaf 100644
--- a/PRESUBMIT.py
+++ b/PRESUBMIT.py
@@ -4,7 +4,7 @@
 
 from os import path
 import datetime
-from subprocess import check_output, Popen, PIPE, STDOUT
+from subprocess import check_output, check_call, CalledProcessError, Popen, PIPE, STDOUT, DEVNULL
 import inspect
 import os
 import sys
@@ -22,7 +22,7 @@
 KOTLIN_FMT_SHA1 = path.join('third_party', 'google', 'google-kotlin-format',
                             '0.54.tar.gz.sha1')
 KOTLIN_FMT_TGZ = path.join('third_party', 'google', 'google-kotlin-format',
-                           '0.54.tar.gz.sha1')
+                           '0.54.tar.gz')
 KOTLIN_FMT_IGNORE = {
     'src/test/java/com/android/tools/r8/kotlin/metadata/inline_class_fun_descriptor_classes_app/main.kt'
 }
@@ -38,6 +38,14 @@
 FMT_TGZ = path.join('third_party', 'google', 'google-java-format',
                     '1.24.0.tar.gz')
 
+PYTHON_FMT = path.join('third_party', 'google', 'yapf', '20231013')
+PYTHON_FMT_EXEC = path.join('third_party', 'google', 'yapf', '20231013', 'yapf')
+PYTHON_FMT_SHA1 = path.join('third_party', 'google', 'yapf',
+                            '20231013.tar.gz.sha1')
+PYTHON_FMT_TGZ = path.join('third_party', 'google', 'yapf', '20231013.tar.gz')
+
+YAPF_PYTHON_PATH = [PYTHON_FMT, os.path.join(PYTHON_FMT, 'third_party')]
+
 
 def CheckDoNotMerge(input_api, output_api):
     for l in input_api.change.FullDescriptionText().splitlines():
@@ -47,32 +55,50 @@
     return []
 
 
+def is_java_extension(file_path):
+    return file_path.endswith('.java')
+
+
+def is_kotlin_extension(file_path):
+    return file_path.endswith('.kt') or file_path.endswith('.kts')
+
+
+def is_python_extension(file_path):
+    return file_path.endswith('.py')
+
+
 def CheckFormatting(input_api, output_api, branch):
     seen_kotlin_error = False
     seen_java_error = False
+    seen_python_error = False
     pending_kotlin_files = []
     EnsureDepFromGoogleCloudStorage(KOTLIN_FMT_JAR, KOTLIN_FMT_TGZ,
                                     KOTLIN_FMT_SHA1, 'google-kotlin-format')
     EnsureDepFromGoogleCloudStorage(FMT_CMD, FMT_TGZ, FMT_SHA1,
                                     'google-java-format')
+    EnsureDepFromGoogleCloudStorage(PYTHON_FMT_EXEC, PYTHON_FMT_TGZ,
+                                    PYTHON_FMT_SHA1, 'yapf')
     results = []
+    python_runtime = PythonRuntime()
     for f in input_api.AffectedFiles():
-        path = f.LocalPath()
-        if not path.endswith('.java') and not path.endswith(
-                '.kt') and not path.endswith('.kts'):
-            continue
-        if path.endswith('.kt') or path.endswith('.kts'):
-            if path in KOTLIN_FMT_IGNORE:
+        file_path = f.LocalPath()
+        if is_kotlin_extension(file_path):
+            if file_path in KOTLIN_FMT_IGNORE:
                 continue
-            pending_kotlin_files.append(path)
+            pending_kotlin_files.append(file_path)
             if len(pending_kotlin_files) == KOTLIN_FMT_BATCH_SIZE:
                 seen_kotlin_error = (CheckKotlinFormatting(
                     pending_kotlin_files, output_api, results) or
                                      seen_kotlin_error)
                 pending_kotlin_files = []
+        elif is_java_extension(file_path):
+            seen_java_error = (CheckJavaFormatting(
+                file_path, branch, output_api, results) or seen_java_error)
+        elif is_python_extension(file_path):
+            seen_python_error = (python_runtime.check_formatting(
+                file_path, output_api, results) or seen_python_error)
         else:
-            seen_java_error = (CheckJavaFormatting(path, branch, output_api,
-                                                   results) or seen_java_error)
+            continue
     # Check remaining Kotlin files if any.
     if len(pending_kotlin_files) > 0:
         seen_kotlin_error = (CheckKotlinFormatting(
@@ -82,9 +108,12 @@
         results.append(output_api.PresubmitError(
             KotlinFormatPresubmitMessage()))
     if seen_java_error:
-        results.append(output_api.PresubmitError(JavaFormatPresubMessage()))
+        results.append(output_api.PresubmitError(JavaFormatPresubmitMessage()))
+    if seen_python_error:
+        results.append(output_api.PresubmitError(
+            PythonFormatPresubmitMessage()))
 
-    # Comment this out to easily presumbit changes
+    # Comment this out to easily fail presubmit changes
     # results.append(output_api.PresubmitError("TESTING"))
     return results
 
@@ -144,7 +173,7 @@
     return len(stdout) > 0
 
 
-def JavaFormatPresubMessage():
+def JavaFormatPresubmitMessage():
     return """Please fix the Java formatting by running:
 
   git diff -U0 $(git cl upstream) | %s -p1 -i
@@ -166,6 +195,83 @@
     )
 
 
+def get_env_with_python_path():
+    new_env = os.environ.copy()
+    new_env['PYTHONPATH'] = ':'.join(YAPF_PYTHON_PATH)
+    return new_env
+
+
+class PythonRuntime:
+
+    def __init__(self):
+        self.interpreter = None
+        self.has_failed = False
+
+    def initialize_runtime(self):
+        # Ensure a python interpreter with platformdirs.
+        # This search allows manual setup of .venv.
+        python_env = get_env_with_python_path()
+        for candidate in [sys.executable, 'python3']:
+            try:
+                check_call([candidate, '-c', 'import platformdirs'],
+                           stdout=DEVNULL,
+                           stderr=DEVNULL,
+                           env=python_env)
+                self.interpreter = candidate
+                return None
+            except (CalledProcessError, FileNotFoundError):
+                continue
+
+        self.has_failed = True
+        return (
+            "Error: Could not find a Python interpreter with `platformdirs` installed.\n"
+            "Please ensure it is installed in your environment:\n"
+            "  $ python3 -m venv .venv\n"
+            "  $ source .venv/bin/activate\n"
+            "  $ pip3 install platformdirs")
+
+    def check_formatting(self, file_path, output_api, results):
+        # Avoid repeating initialization errors.
+        if self.has_failed:
+            return False
+        # Initialize interpreter if not done already.
+        elif self.interpreter is None:
+            init_error = self.initialize_runtime()
+            if init_error:
+                results.append(output_api.PresubmitError(init_error))
+                return True
+        format_cmd = [
+            self.interpreter, PYTHON_FMT_EXEC, '--diff', '--style', 'google'
+        ]
+        format_cmd.extend([file_path])
+
+        python_env = get_env_with_python_path()
+        format_output = "ill-formatted"
+        try:
+            format_output = check_output(format_cmd,
+                                         env=python_env).decode('utf-8')
+        except CalledProcessError as e:
+            # --diff returns non-zero if there is a diff
+            results.append(output_api.PresubmitError(e.output))
+            return True
+        return False
+
+
+def PythonFormatPresubmitMessage():
+    return """Please fix the Python formatting by running:
+
+  tools/fmt-diff.py --no-java --no-kotlin --python
+
+or fix formatting, commit and upload:
+
+  tools/fmt-diff.py --no-java --no-kotlin --python && git commit -a --amend --no-edit && git cl upload
+
+or bypass the checks with:
+
+  git cl upload --bypass-hooks
+    """
+
+
 def CheckDeterministicDebuggingChanged(input_api, output_api, branch):
     for f in input_api.AffectedFiles():
         path = f.LocalPath()
@@ -181,7 +287,7 @@
 
 def IsTestFile(file):
     localPath = file.LocalPath()
-    return localPath.endswith('.java') and '/test/' in localPath
+    return is_java_extension(localPath) and '/test/' in localPath
 
 
 def CheckForAddedDisassemble(input_api, output_api):
@@ -209,7 +315,7 @@
 def CheckForAddedPartialDebug(input_api, output_api):
     results = []
     for (file, line_nr, line) in input_api.RightHandSideLines():
-        if not file.LocalPath().endswith('.java'):
+        if not is_java_extension(file.LocalPath()):
             continue
         if '.enablePrintPartialCompilationPartitioning(' in line:
             results.append(
@@ -224,7 +330,7 @@
     return results
 
 
-def CheckForCopyRight(input_api, output_api, branch):
+def CheckForCopyright(input_api, output_api, branch):
     results = []
     for f in input_api.AffectedSourceFiles(None):
         # Check if it is a new file.
@@ -233,16 +339,16 @@
         contents = f.NewContents()
         if (not contents) or (len(contents) == 0):
             continue
-        if not CopyRightInContents(f, contents):
+        if not CopyrightInContents(f, contents):
             results.append(
                 output_api.PresubmitError('Could not find correctly formatted '
                                           'copyright in file: %s' % f))
     return results
 
 
-def CopyRightInContents(f, contents):
+def CopyrightInContents(f, contents):
     expected = '//'
-    if f.LocalPath().endswith('.py') or f.LocalPath().endswith('.sh'):
+    if is_python_extension(f.LocalPath()) or f.LocalPath().endswith('.sh'):
         expected = '#'
     expected = expected + ' Copyright (c) ' + str(datetime.datetime.now().year)
     for content_line in contents:
@@ -263,7 +369,7 @@
     results.extend(CheckForAddedDisassemble(input_api, output_api))
     results.extend(CheckForAddedAllowXxxxxxMessages(input_api, output_api))
     results.extend(CheckForAddedPartialDebug(input_api, output_api))
-    results.extend(CheckForCopyRight(input_api, output_api, branch))
+    results.extend(CheckForCopyright(input_api, output_api, branch))
     return results