Extend service loader optimization to const class after inlining

Change-Id: I7822245fba574f235f91acb3367d9a3d626d7945
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index 7bab38d..9f18e74 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -56,7 +56,6 @@
 import com.android.tools.r8.ir.optimize.Inliner.ConstraintWithTarget;
 import com.android.tools.r8.ir.optimize.ReflectionOptimizer;
 import com.android.tools.r8.ir.optimize.RemoveVerificationErrorForUnknownReturnedValues;
-import com.android.tools.r8.ir.optimize.ServiceLoaderRewriter;
 import com.android.tools.r8.ir.optimize.api.InstanceInitializerOutliner;
 import com.android.tools.r8.ir.optimize.classinliner.ClassInliner;
 import com.android.tools.r8.ir.optimize.enums.EnumUnboxer;
@@ -605,9 +604,6 @@
     CheckNotNullConverter.runIfNecessary(appView, code);
     previous = printMethod(code, "IR after disable assertions (SSA)", previous);
 
-    new ServiceLoaderRewriter(appView).run(code, methodProcessor, methodProcessingContext, timing);
-    previous = printMethod(code, "IR after service rewriting (SSA)", previous);
-
     if (identifierNameStringMarker != null) {
       timing.begin("Decouple identifier-name strings");
       identifierNameStringMarker.decoupleIdentifierNameStringsInMethod(code);
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/CodeRewriterPassCollection.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/CodeRewriterPassCollection.java
index dbbf1f6..b96c810 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/passes/CodeRewriterPassCollection.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/CodeRewriterPassCollection.java
@@ -11,6 +11,7 @@
 import com.android.tools.r8.ir.conversion.MethodProcessor;
 import com.android.tools.r8.ir.conversion.passes.result.CodeRewriterResult;
 import com.android.tools.r8.ir.optimize.RedundantFieldLoadAndStoreElimination;
+import com.android.tools.r8.ir.optimize.ServiceLoaderRewriter;
 import com.android.tools.r8.ir.optimize.enums.EnumValueOptimizer;
 import com.android.tools.r8.ir.optimize.string.StringBuilderAppendOptimizer;
 import com.android.tools.r8.utils.Timing;
@@ -49,6 +50,7 @@
       passes.add(new RedundantFieldLoadAndStoreElimination(appView));
     }
     passes.add(new BinopRewriter(appView));
+    passes.add(new ServiceLoaderRewriter(appView));
     return new CodeRewriterPassCollection(passes);
   }
 
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/ServiceLoaderRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/ServiceLoaderRewriter.java
index 4d791b0..86706f2 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/ServiceLoaderRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/ServiceLoaderRewriter.java
@@ -130,7 +130,7 @@
         report(
             code.context(),
             constClass.getType(),
-            "Inlining is only support for `java.util.ServiceLoader.load(java.lang.Class,"
+            "Inlining is only supported for `java.util.ServiceLoader.load(java.lang.Class,"
                 + " java.lang.ClassLoader)`");
         continue;
       }
@@ -177,7 +177,8 @@
 
       // Check that ClassLoader used is the ClassLoader defined for the service configuration
       // that we are instantiating or NULL.
-      if (serviceLoaderLoad.getLastArgument().isPhi()) {
+      Value classLoaderValue = serviceLoaderLoad.getLastArgument().getAliasedValue();
+      if (classLoaderValue.isPhi()) {
         report(
             code.context(),
             constClass.getType(),
@@ -186,10 +187,9 @@
                 + ".class.getClassLoader()");
         continue;
       }
-      InvokeVirtual classLoaderInvoke =
-          serviceLoaderLoad.getLastArgument().getDefinition().asInvokeVirtual();
+      InvokeVirtual classLoaderInvoke = classLoaderValue.getDefinition().asInvokeVirtual();
       boolean isGetClassLoaderOnConstClassOrNull =
-          serviceLoaderLoad.getLastArgument().getType().isNullType()
+          classLoaderValue.getType().isNullType()
               || (classLoaderInvoke != null
                   && classLoaderInvoke.arguments().size() == 1
                   && classLoaderInvoke.getReceiver().getAliasedValue().isConstClass()
@@ -343,12 +343,8 @@
             new BooleanBox(!classLoaderInvoke.outValue().hasPhiUsers());
         classLoaderInvoke
             .outValue()
-            .uniqueUsers()
-            .forEach(
-                user -> {
-                  assert !user.isAssume();
-                  allClassLoaderUsersAreServiceLoaders.and(user == serviceLoaderLoad);
-                });
+            .aliasedUsers()
+            .forEach(user -> allClassLoaderUsersAreServiceLoaders.and(user == serviceLoaderLoad));
         if (allClassLoaderUsersAreServiceLoaders.get()) {
           clearGetClassLoader(classLoaderInvoke);
           iterator.nextUntil(i -> i == serviceLoaderLoad);
diff --git a/src/test/java/com/android/tools/r8/optimize/serviceloader/ServiceLoaderConstClassFromCalleeTest.java b/src/test/java/com/android/tools/r8/optimize/serviceloader/ServiceLoaderConstClassFromCalleeTest.java
new file mode 100644
index 0000000..aa411ab
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/optimize/serviceloader/ServiceLoaderConstClassFromCalleeTest.java
@@ -0,0 +1,84 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.optimize.serviceloader;
+
+import static com.android.tools.r8.optimize.serviceloader.ServiceLoaderRewritingTest.getServiceLoaderLoads;
+import static junit.framework.TestCase.assertEquals;
+
+import com.android.tools.r8.DataEntryResource;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.origin.Origin;
+import com.android.tools.r8.utils.StringUtils;
+import java.util.ServiceLoader;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class ServiceLoaderConstClassFromCalleeTest extends TestBase {
+
+  @Parameter(0)
+  public TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimesAndApiLevels().build();
+  }
+
+  @Test
+  public void test() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepMainRule(Main.class)
+        .setMinApi(parameters)
+        .addDataEntryResources(
+            DataEntryResource.fromBytes(
+                StringUtils.lines(ServiceImpl.class.getTypeName(), ServiceImpl2.class.getTypeName())
+                    .getBytes(),
+                "META-INF/services/" + Service.class.getTypeName(),
+                Origin.unknown()))
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("Hello, world!")
+        // Check that the call to ServiceLoader.load is removed.
+        .inspect(inspector -> assertEquals(0, getServiceLoaderLoads(inspector, Main.class)));
+  }
+
+  public static class Main {
+
+    public static void main(String[] args) {
+      for (Service service : ServiceLoader.load(getServiceClass(), null)) {
+        service.print();
+      }
+    }
+
+    private static Class<Service> getServiceClass() {
+      return Service.class;
+    }
+  }
+
+  public interface Service {
+
+    void print();
+  }
+
+  public static class ServiceImpl implements Service {
+
+    @Override
+    public void print() {
+      System.out.print("Hello");
+    }
+  }
+
+  public static class ServiceImpl2 implements Service {
+
+    @Override
+    public void print() {
+      System.out.println(", world!");
+    }
+  }
+}