Refactor JDK lookup in third_party

Change-Id: Iafec2bccb8892f5cd7877227a9f357b45084fe55
diff --git a/tools/create_r8lib.py b/tools/create_r8lib.py
index 1c349e1..37bea93 100755
--- a/tools/create_r8lib.py
+++ b/tools/create_r8lib.py
@@ -134,7 +134,7 @@
     cmd.extend(['--pg-conf-output', args.output + '.config'])
     cmd.extend(['--pg-map-output', args.output + '.map'])
     cmd.extend(['--partition-map-output', args.output + '_map.zip'])
-    cmd.extend(['--lib', jdk.GetJdkHome()])
+    cmd.extend(['--lib', jdk.GetDefaultJdkHome()])
     if args.pg_conf:
         for pgconf in args.pg_conf:
             cmd.extend(['--pg-conf', pgconf])
diff --git a/tools/gradle.py b/tools/gradle.py
index 688b01f..159cd3f 100755
--- a/tools/gradle.py
+++ b/tools/gradle.py
@@ -55,9 +55,9 @@
 
 
 def GetJavaEnv(env):
-    java_env = dict(env if env else os.environ, JAVA_HOME=jdk.GetJdkHome())
+    java_env = dict(env if env else os.environ, JAVA_HOME=jdk.GetDefaultJdkHome())
     java_env['PATH'] = java_env['PATH'] + os.pathsep + os.path.join(
-        jdk.GetJdkHome(), 'bin')
+        jdk.GetDefaultJdkHome(), 'bin')
     java_env['GRADLE_OPTS'] = '-Xmx1g'
     return java_env
 
diff --git a/tools/jdk.py b/tools/jdk.py
index 3c2d3ec..6b92f2c 100755
--- a/tools/jdk.py
+++ b/tools/jdk.py
@@ -8,40 +8,54 @@
 
 import defines
 
-JDK_DIR = os.path.join(defines.THIRD_PARTY, 'openjdk')
+JDK_DIRS = os.path.join(defines.THIRD_PARTY, 'openjdk')
 
 ALL_JDKS = ['openjdk-9.0.4', 'jdk-11', 'jdk-17', 'jdk-21', 'jdk-23']
 
 
-def GetJdkHome():
+def GetDefaultJdkHome():
     return GetJdk11Home()
 
 
-def GetJdkRoot(name):
-    root = os.path.join(JDK_DIR, name)
-    os_root = GetOSPath(root)
-    return os_root if os_root else os.environ['JAVA_HOME']
+def GetJdkHome(name):
+    if name == 'jdk8':
+        return GetJdk8Home()
+    third_party_jdk_root = os.path.join(JDK_DIRS, name)
+    if not os.path.exists(third_party_jdk_root):
+        raise Exception('No JDKs found in ' + third_party_jdk_root)
+    os_root = GetOSJavaHome(third_party_jdk_root)
+    if not os.path.exists(os_root):
+        raise Exception('No platform JDK found in ' + os_root)
+    return os_root
 
 
-def GetJdk11Root():
-    return GetJdkRoot('jdk-11')
-
-
-def GetOSPath(root):
+def GetOSJavaHome(root):
     if defines.IsLinux():
         return os.path.join(root, 'linux')
     elif defines.IsOsX():
-        return os.path.join(root, 'osx')
+        return os.path.join(root, 'osx', 'Contents', 'Home')
     elif defines.IsWindows():
         return os.path.join(root, 'windows')
     else:
-        return None
+        raise Exception(
+            'Unsupported platform'
+            ' (not detected as either of Linux, macOS or Windows)')
 
 
 def GetAllJdkDirs():
     dirs = []
     for jdk in ALL_JDKS:
-        root = GetOSPath(os.path.join(JDK_DIR, jdk))
+        root = os.path.join(JDK_DIRS, jdk)
+        if defines.IsLinux():
+            root = os.path.join(root, 'linux')
+        elif defines.IsOsX():
+            root = os.path.join(root, 'osx')
+        elif defines.IsWindows():
+            root = os.path.join(root, 'windows')
+        else:
+            raise Exception(
+                'Unsupported platform'
+                ' (not detected as either of Linux, macOS or Windows)')
         # Some jdks are not available on windows, don't try to get these.
         if os.path.exists(root + '.tar.gz.sha1'):
             dirs.append(root)
@@ -49,28 +63,15 @@
 
 
 def GetJdk11Home():
-    root = GetJdk11Root()
-    # osx has the home inside Contents/Home in the bundle
-    if defines.IsOsX():
-        return os.path.join(root, 'Contents', 'Home')
-    else:
-        return root
+    return GetJdkHome('jdk-11')
 
 
 def GetJdk9Home():
-    root = os.path.join(JDK_DIR, 'openjdk-9.0.4')
-    if defines.IsLinux():
-        return os.path.join(root, 'linux')
-    elif defines.IsOsX():
-        return os.path.join(root, 'osx')
-    elif defines.IsWindows():
-        return os.path.join(root, 'windows')
-    else:
-        return os.environ['JAVA_HOME']
+    return GetJdkHome('openjdk-9.0.4')
 
 
 def GetJdk8Home():
-    root = os.path.join(JDK_DIR, 'jdk8')
+    root = os.path.join(JDK_DIRS, 'jdk8')
     if defines.IsLinux():
         return os.path.join(root, 'linux-x86')
     elif defines.IsOsX():
@@ -80,19 +81,19 @@
 
 
 def GetJavaExecutable(jdkHome=None):
-    jdkHome = jdkHome if jdkHome else GetJdkHome()
+    jdkHome = jdkHome if jdkHome else GetDefaultJdkHome()
     executable = 'java.exe' if defines.IsWindows() else 'java'
     return os.path.join(jdkHome, 'bin', executable) if jdkHome else executable
 
 
 def GetJavacExecutable(jdkHome=None):
-    jdkHome = jdkHome if jdkHome else GetJdkHome()
+    jdkHome = jdkHome if jdkHome else GetDefaultJdkHome()
     executable = 'javac.exe' if defines.IsWindows() else 'javac'
     return os.path.join(jdkHome, 'bin', executable) if jdkHome else executable
 
 
 def Main():
-    print(GetJdkHome())
+    print(GetDefaultJdkHome())
 
 
 if __name__ == '__main__':