Optimize ServiceLoader.load() for impls within feature splits

Just check that the service class and the impl classes are in the same
feature.

Bug: b/291923475
Change-Id: I0850b0578a461c8b5ab21ee63f7bcbba64e15dde
diff --git a/src/main/java/com/android/tools/r8/graph/AppServices.java b/src/main/java/com/android/tools/r8/graph/AppServices.java
index 5e03a38..ea50b5a 100644
--- a/src/main/java/com/android/tools/r8/graph/AppServices.java
+++ b/src/main/java/com/android/tools/r8/graph/AppServices.java
@@ -12,10 +12,8 @@
 import com.android.tools.r8.ProgramResourceProvider;
 import com.android.tools.r8.ResourceException;
 import com.android.tools.r8.errors.CompilationError;
-import com.android.tools.r8.features.ClassToFeatureSplitMap;
 import com.android.tools.r8.graph.lens.GraphLens;
 import com.android.tools.r8.origin.Origin;
-import com.android.tools.r8.shaking.AppInfoWithLiveness;
 import com.android.tools.r8.utils.DescriptorUtils;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.StringDiagnostic;
@@ -91,43 +89,8 @@
     return builder.build();
   }
 
-  public boolean hasServiceImplementationsInFeature(
-      AppView<? extends AppInfoWithLiveness> appView, DexType serviceType) {
-    ClassToFeatureSplitMap classToFeatureSplitMap = appView.appInfo().getClassToFeatureSplitMap();
-    if (classToFeatureSplitMap.isEmpty()) {
-      return false;
-    }
-    Map<FeatureSplit, List<DexType>> featureImplementations = services.get(serviceType);
-    if (featureImplementations == null || featureImplementations.isEmpty()) {
-      assert false
-          : "Unexpected attempt to get service implementations for non-service type `"
-              + serviceType.toSourceString()
-              + "`";
-      return true;
-    }
-    if (featureImplementations.keySet().stream().anyMatch(feature -> !feature.isBase())) {
-      return true;
-    }
-    // All service implementations are in one of the base splits.
-    assert featureImplementations.size() <= 2;
-    // Check if service is defined feature
-    DexProgramClass serviceClass = appView.definitionForProgramType(serviceType);
-    if (serviceClass != null && classToFeatureSplitMap.isInFeature(serviceClass, appView)) {
-      return true;
-    }
-    for (Entry<FeatureSplit, List<DexType>> entry : featureImplementations.entrySet()) {
-      FeatureSplit feature = entry.getKey();
-      assert feature.isBase();
-      List<DexType> implementationTypes = entry.getValue();
-      for (DexType implementationType : implementationTypes) {
-        DexProgramClass implementationClass = appView.definitionForProgramType(implementationType);
-        if (implementationClass != null
-            && classToFeatureSplitMap.isInFeature(implementationClass, appView)) {
-          return true;
-        }
-      }
-    }
-    return false;
+  public Map<FeatureSplit, List<DexType>> serviceImplementationsByFeatureFor(DexType serviceType) {
+    return services.get(serviceType);
   }
 
   public AppServices rewrittenWithLens(GraphLens graphLens, Timing timing) {
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/ServiceLoaderRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/ServiceLoaderRewriter.java
index 86706f2..2ca09f2 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/ServiceLoaderRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/ServiceLoaderRewriter.java
@@ -4,14 +4,17 @@
 
 package com.android.tools.r8.ir.optimize;
 
+import com.android.tools.r8.FeatureSplit;
 import com.android.tools.r8.androidapi.AndroidApiLevelCompute;
 import com.android.tools.r8.contexts.CompilationContext.MethodProcessingContext;
+import com.android.tools.r8.features.ClassToFeatureSplitMap;
 import com.android.tools.r8.graph.AppServices;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexClass;
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexItemFactory.ServiceLoaderMethods;
 import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexProgramClass;
 import com.android.tools.r8.graph.DexProto;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.graph.MethodAccessFlags;
@@ -159,22 +162,6 @@
         continue;
       }
 
-      // Check that the service is configured in the META-INF/services.
-      AppServices appServices = appView.appServices();
-      if (!appServices.allServiceTypes().contains(constClass.getValue())) {
-        // Error already reported in the Enqueuer.
-        continue;
-      }
-
-      // Check that we are not service loading anything from a feature into base.
-      if (appServices.hasServiceImplementationsInFeature(appView(), constClass.getValue())) {
-        report(
-            code.context(),
-            constClass.getType(),
-            "The service loader type has implementations in a feature split");
-        continue;
-      }
-
       // Check that ClassLoader used is the ClassLoader defined for the service configuration
       // that we are instantiating or NULL.
       Value classLoaderValue = serviceLoaderLoad.getLastArgument().getAliasedValue();
@@ -187,20 +174,20 @@
                 + ".class.getClassLoader()");
         continue;
       }
+      boolean isNullClassLoader = classLoaderValue.getType().isNullType();
       InvokeVirtual classLoaderInvoke = classLoaderValue.getDefinition().asInvokeVirtual();
-      boolean isGetClassLoaderOnConstClassOrNull =
-          classLoaderValue.getType().isNullType()
-              || (classLoaderInvoke != null
-                  && classLoaderInvoke.arguments().size() == 1
-                  && classLoaderInvoke.getReceiver().getAliasedValue().isConstClass()
-                  && classLoaderInvoke
-                      .getReceiver()
-                      .getAliasedValue()
-                      .getDefinition()
-                      .asConstClass()
-                      .getValue()
-                      .isIdenticalTo(constClass.getValue()));
-      if (!isGetClassLoaderOnConstClassOrNull) {
+      boolean isGetClassLoaderOnConstClass =
+          classLoaderInvoke != null
+              && classLoaderInvoke.arguments().size() == 1
+              && classLoaderInvoke.getReceiver().getAliasedValue().isConstClass()
+              && classLoaderInvoke
+                  .getReceiver()
+                  .getAliasedValue()
+                  .getDefinition()
+                  .asConstClass()
+                  .getValue()
+                  .isIdenticalTo(constClass.getType());
+      if (!isNullClassLoader && !isGetClassLoaderOnConstClass) {
         report(
             code.context(),
             constClass.getType(),
@@ -210,6 +197,19 @@
         continue;
       }
 
+      // Check that the service is configured in the META-INF/services.
+      AppServices appServices = appView.appServices();
+      if (!appServices.allServiceTypes().contains(constClass.getValue())) {
+        report(code.context(), constClass.getType(), "No META-INF/services file found.");
+        continue;
+      }
+
+      // Check that we are not service loading anything from a feature into base.
+      if (hasServiceImplementationInDifferentFeature(
+          code, constClass.getType(), isNullClassLoader)) {
+        continue;
+      }
+
       List<DexType> dexTypes = appServices.serviceImplementationsFor(constClass.getValue());
       List<DexClass> classes = new ArrayList<>(dexTypes.size());
       boolean seenNull = false;
@@ -268,6 +268,64 @@
     }
   }
 
+  private boolean hasServiceImplementationInDifferentFeature(
+      IRCode code, DexType serviceType, boolean baseFeatureOnly) {
+    AppView<AppInfoWithLiveness> appViewWithClasses = appView();
+    ClassToFeatureSplitMap classToFeatureSplitMap =
+        appViewWithClasses.appInfo().getClassToFeatureSplitMap();
+    if (classToFeatureSplitMap.isEmpty()) {
+      return false;
+    }
+    Map<FeatureSplit, List<DexType>> featureImplementations =
+        appView.appServices().serviceImplementationsByFeatureFor(serviceType);
+    if (featureImplementations == null || featureImplementations.isEmpty()) {
+      return false;
+    }
+    DexProgramClass serviceClass = appView.definitionForProgramType(serviceType);
+    if (serviceClass == null) {
+      return false;
+    }
+    FeatureSplit serviceFeature =
+        classToFeatureSplitMap.getFeatureSplit(serviceClass, appViewWithClasses);
+    if (baseFeatureOnly && !serviceFeature.isBase()) {
+      report(
+          code.context(),
+          serviceType,
+          "ClassLoader arg was null and service interface is in non-base feature");
+      return true;
+    }
+    for (var entry : featureImplementations.entrySet()) {
+      FeatureSplit metaInfFeature = entry.getKey();
+      if (!metaInfFeature.isBase()) {
+        if (baseFeatureOnly) {
+          report(
+              code.context(),
+              serviceType,
+              "ClassLoader arg was null and META-INF/ service entry found in non-base feature");
+          return true;
+        }
+        if (metaInfFeature != serviceFeature) {
+          report(
+              code.context(),
+              serviceType,
+              "META-INF/ service found in different feature from service interface");
+          return true;
+        }
+      }
+      for (DexType impl : entry.getValue()) {
+        FeatureSplit implFeature = classToFeatureSplitMap.getFeatureSplit(impl, appViewWithClasses);
+        if (implFeature != serviceFeature) {
+          report(
+              code.context(),
+              serviceType,
+              "Implementation found in different feature from service interface: " + impl);
+          return true;
+        }
+      }
+    }
+    return false;
+  }
+
   private DexEncodedMethod createSynthesizedMethod(
       DexType serviceType,
       List<DexClass> classes,
diff --git a/src/test/java/com/android/tools/r8/optimize/serviceloader/ServiceLoaderRewritingTest.java b/src/test/java/com/android/tools/r8/optimize/serviceloader/ServiceLoaderRewritingTest.java
index c9fc05e..09025e9 100644
--- a/src/test/java/com/android/tools/r8/optimize/serviceloader/ServiceLoaderRewritingTest.java
+++ b/src/test/java/com/android/tools/r8/optimize/serviceloader/ServiceLoaderRewritingTest.java
@@ -8,12 +8,10 @@
 import static org.junit.Assume.assumeTrue;
 
 import com.android.tools.r8.CompilationFailedException;
-import com.android.tools.r8.DiagnosticsMatcher;
 import com.android.tools.r8.NeverInline;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.TestParametersCollection;
 import com.android.tools.r8.ToolHelper.DexVm.Version;
-import com.android.tools.r8.ir.optimize.ServiceLoaderRewriterDiagnostic;
 import com.android.tools.r8.utils.StringUtils;
 import java.io.IOException;
 import java.util.List;
@@ -155,9 +153,8 @@
       throws IOException, CompilationFailedException, ExecutionException {
     serviceLoaderTest(null)
         .addKeepMainRule(MainRunner.class)
-        .compile()
-        .assertAllInfosMatch(
-            DiagnosticsMatcher.diagnosticType(ServiceLoaderRewriterDiagnostic.class))
+        .allowDiagnosticInfoMessages()
+        .compileWithExpectedDiagnostics(REWRITER_DIAGNOSTICS)
         .run(parameters.getRuntime(), MainRunner.class)
         .assertFailureWithErrorThatThrows(NoSuchElementException.class)
         .inspectFailure(inspector -> assertEquals(1, getServiceLoaderLoads(inspector)));
@@ -213,10 +210,7 @@
     serviceLoaderTest(Service.class, ServiceImpl.class)
         .addKeepMainRule(OtherRunner.class)
         .allowDiagnosticInfoMessages()
-        .compile()
-        .assertAllInfosMatch(
-            DiagnosticsMatcher.diagnosticType(ServiceLoaderRewriterDiagnostic.class))
-        .assertAtLeastOneInfoMessage()
+        .compileWithExpectedDiagnostics(REWRITER_DIAGNOSTICS)
         .run(parameters.getRuntime(), OtherRunner.class)
         .assertSuccessWithOutput(EXPECTED_OUTPUT)
         .inspect(
@@ -234,10 +228,7 @@
         .enableInliningAnnotations()
         .addDontObfuscate()
         .allowDiagnosticInfoMessages()
-        .compile()
-        .assertAllInfosMatch(
-            DiagnosticsMatcher.diagnosticType(ServiceLoaderRewriterDiagnostic.class))
-        .assertAtLeastOneInfoMessage()
+        .compileWithExpectedDiagnostics(REWRITER_DIAGNOSTICS)
         .run(parameters.getRuntime(), EscapingRunner.class)
         .assertSuccessWithOutput(EXPECTED_OUTPUT)
         .inspect(
@@ -254,10 +245,7 @@
         .addKeepMainRule(LoadWhereClassLoaderIsPhi.class)
         .enableInliningAnnotations()
         .allowDiagnosticInfoMessages()
-        .compile()
-        .assertAllInfosMatch(
-            DiagnosticsMatcher.diagnosticType(ServiceLoaderRewriterDiagnostic.class))
-        .assertAtLeastOneInfoMessage()
+        .compileWithExpectedDiagnostics(REWRITER_DIAGNOSTICS)
         .run(parameters.getRuntime(), LoadWhereClassLoaderIsPhi.class)
         .assertSuccessWithOutputLines("Hello World!")
         .inspect(
@@ -279,10 +267,7 @@
         .addKeepMainRule(MainRunner.class)
         .addKeepClassRules(Service.class)
         .allowDiagnosticInfoMessages()
-        .compile()
-        .assertAllInfosMatch(
-            DiagnosticsMatcher.diagnosticType(ServiceLoaderRewriterDiagnostic.class))
-        .assertAtLeastOneInfoMessage()
+        .compileWithExpectedDiagnostics(REWRITER_DIAGNOSTICS)
         .run(parameters.getRuntime(), MainRunner.class)
         .assertSuccessWithOutput(EXPECTED_OUTPUT)
         .inspect(
diff --git a/src/test/java/com/android/tools/r8/optimize/serviceloader/ServiceLoaderTestBase.java b/src/test/java/com/android/tools/r8/optimize/serviceloader/ServiceLoaderTestBase.java
index cf1fc6f..edb1c4e 100644
--- a/src/test/java/com/android/tools/r8/optimize/serviceloader/ServiceLoaderTestBase.java
+++ b/src/test/java/com/android/tools/r8/optimize/serviceloader/ServiceLoaderTestBase.java
@@ -9,11 +9,15 @@
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 
+import com.android.tools.r8.CompilationFailedException;
 import com.android.tools.r8.DataEntryResource;
+import com.android.tools.r8.DiagnosticsMatcher;
 import com.android.tools.r8.R8FullTestBuilder;
 import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestCompilerBuilder.DiagnosticsConsumer;
 import com.android.tools.r8.TestParameters;
 import com.android.tools.r8.graph.AppServices;
+import com.android.tools.r8.ir.optimize.ServiceLoaderRewriterDiagnostic;
 import com.android.tools.r8.origin.Origin;
 import com.android.tools.r8.utils.DataResourceConsumerForTesting;
 import com.android.tools.r8.utils.codeinspector.ClassSubject;
@@ -25,12 +29,17 @@
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
-import java.util.ServiceLoader;
 import java.util.stream.Collectors;
 
 public class ServiceLoaderTestBase extends TestBase {
   protected final TestParameters parameters;
   protected DataResourceConsumerForTesting dataResourceConsumer;
+  protected static final DiagnosticsConsumer<CompilationFailedException> REWRITER_DIAGNOSTICS =
+      diagnostics ->
+          diagnostics
+              .assertOnlyInfos()
+              .assertAllInfosMatch(
+                  DiagnosticsMatcher.diagnosticType(ServiceLoaderRewriterDiagnostic.class));
 
   public ServiceLoaderTestBase(TestParameters parameters) {
     this.parameters = parameters;
diff --git a/src/test/java/com/android/tools/r8/optimize/serviceloader/ServiceWithFeatureNullClassLoaderTest.java b/src/test/java/com/android/tools/r8/optimize/serviceloader/ServiceWithFeatureNullClassLoaderTest.java
new file mode 100644
index 0000000..76ea395
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/optimize/serviceloader/ServiceWithFeatureNullClassLoaderTest.java
@@ -0,0 +1,80 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.optimize.serviceloader;
+
+import static org.junit.Assert.assertEquals;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.R8TestCompileResult;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import java.util.ServiceLoader;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
+public class ServiceWithFeatureNullClassLoaderTest extends ServiceLoaderTestBase {
+
+  public interface Service {
+
+    void print();
+  }
+
+  public static class ServiceImpl implements Service {
+
+    @Override
+    public void print() {
+      System.out.println("Hello World!");
+    }
+  }
+
+  public static class MainRunner {
+
+    public static void main(String[] args) {
+      run1();
+    }
+
+    @NeverInline
+    public static void run1() {
+      for (Service x : ServiceLoader.load(Service.class, null)) {
+        x.print();
+      }
+    }
+
+    @NeverInline
+    public static void checkNotNull(ClassLoader classLoader) {
+      if (classLoader == null) {
+        throw new NullPointerException("ClassLoader should not be null");
+      }
+    }
+  }
+
+  @Parameterized.Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withDexRuntimes().withAllApiLevels().build();
+  }
+
+  public ServiceWithFeatureNullClassLoaderTest(TestParameters parameters) {
+    super(parameters);
+  }
+
+  @Test
+  public void testNoRewritings() throws Exception {
+    R8TestCompileResult result =
+        serviceLoaderTestNoClasses(Service.class, ServiceImpl.class)
+            .addFeatureSplit(MainRunner.class, Service.class, ServiceImpl.class)
+            .enableInliningAnnotations()
+            .addKeepMainRule(MainRunner.class)
+            .allowDiagnosticInfoMessages()
+            .compileWithExpectedDiagnostics(REWRITER_DIAGNOSTICS);
+
+    CodeInspector inspector = result.featureInspector();
+    assertEquals(getServiceLoaderLoads(inspector), 1);
+    // Check that we have not removed the service configuration from META-INF/services.
+    verifyServiceMetaInf(inspector, Service.class, ServiceImpl.class);
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/optimize/serviceloader/ServiceWithFeatureTest.java b/src/test/java/com/android/tools/r8/optimize/serviceloader/ServiceWithFeatureTest.java
new file mode 100644
index 0000000..34237e1
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/optimize/serviceloader/ServiceWithFeatureTest.java
@@ -0,0 +1,77 @@
+// Copyright (c) 2024, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.optimize.serviceloader;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import java.util.ServiceLoader;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
+public class ServiceWithFeatureTest extends ServiceLoaderTestBase {
+
+  public interface Service {
+
+    void print();
+  }
+
+  public static class ServiceImpl implements Service {
+
+    @Override
+    public void print() {
+      System.out.println("Hello World!");
+    }
+  }
+
+  public static class MainRunner {
+
+    public static void main(String[] args) {
+      run1();
+    }
+
+    @NeverInline
+    public static void run1() {
+      ClassLoader classLoader = Service.class.getClassLoader();
+      checkNotNull(classLoader);
+      for (Service x : ServiceLoader.load(Service.class, classLoader)) {
+        x.print();
+      }
+    }
+
+    @NeverInline
+    public static void checkNotNull(ClassLoader classLoader) {
+      if (classLoader == null) {
+        throw new NullPointerException("ClassLoader should not be null");
+      }
+    }
+  }
+
+  @Parameterized.Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withDexRuntimes().withAllApiLevels().build();
+  }
+
+  public ServiceWithFeatureTest(TestParameters parameters) {
+    super(parameters);
+  }
+
+  @Test
+  public void testRewritings() throws Exception {
+    serviceLoaderTestNoClasses(Service.class, ServiceImpl.class)
+        .addFeatureSplit(MainRunner.class, Service.class, ServiceImpl.class)
+        .enableInliningAnnotations()
+        .addKeepMainRule(MainRunner.class)
+        .compile()
+        .inspect(
+            inspector -> {},
+            inspector -> {
+              verifyNoServiceLoaderLoads(inspector.clazz(MainRunner.class));
+              verifyServiceMetaInf(inspector, Service.class, ServiceImpl.class);
+            });
+  }
+}
diff --git a/src/test/testbase/java/com/android/tools/r8/R8TestCompileResult.java b/src/test/testbase/java/com/android/tools/r8/R8TestCompileResult.java
index f9db531..942cc92 100644
--- a/src/test/testbase/java/com/android/tools/r8/R8TestCompileResult.java
+++ b/src/test/testbase/java/com/android/tools/r8/R8TestCompileResult.java
@@ -140,6 +140,11 @@
         AndroidApp.builder().addProgramFile(feature).setProguardMapOutputData(proguardMap).build());
   }
 
+  public CodeInspector featureInspector() throws IOException {
+    assert features.size() == 1;
+    return featureInspector(features.get(0));
+  }
+
   @SafeVarargs
   public final <E extends Throwable> R8TestCompileResult inspect(
       ThrowingConsumer<CodeInspector, E>... consumers) throws IOException, E {