Protect api surface in horizontal class merging

Change-Id: Ie42c489c8bcc4da37c11740288a21b824f2f345e
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/PolicyScheduler.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/PolicyScheduler.java
index 6377f75..0f1355f 100644
--- a/src/main/java/com/android/tools/r8/horizontalclassmerging/PolicyScheduler.java
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/PolicyScheduler.java
@@ -47,6 +47,7 @@
 import com.android.tools.r8.horizontalclassmerging.policies.PreserveInterfaceMethodDispatch;
 import com.android.tools.r8.horizontalclassmerging.policies.PreserveMethodCharacteristics;
 import com.android.tools.r8.horizontalclassmerging.policies.PreventClassMethodAndDefaultMethodCollisions;
+import com.android.tools.r8.horizontalclassmerging.policies.ProtectApiSurface;
 import com.android.tools.r8.horizontalclassmerging.policies.RespectPackageBoundaries;
 import com.android.tools.r8.horizontalclassmerging.policies.SameFeatureSplit;
 import com.android.tools.r8.horizontalclassmerging.policies.SameFilePolicy;
@@ -97,7 +98,7 @@
       RuntimeTypeCheckInfo runtimeTypeCheckInfo) {
     List<Policy> policies =
         ImmutableList.<Policy>builder()
-            .addAll(getSingleClassPolicies(appView, runtimeTypeCheckInfo))
+            .addAll(getSingleClassPolicies(appView, immediateSubtypingInfo, runtimeTypeCheckInfo))
             .addAll(getMultiClassPolicies(appView, immediateSubtypingInfo, runtimeTypeCheckInfo))
             .build();
     policies = appView.options().testing.horizontalClassMergingPolicyRewriter.apply(policies);
@@ -107,10 +108,11 @@
 
   private static List<SingleClassPolicy> getSingleClassPolicies(
       AppView<? extends AppInfoWithClassHierarchy> appView,
+      ImmediateProgramSubtypingInfo immediateSubtypingInfo,
       RuntimeTypeCheckInfo runtimeTypeCheckInfo) {
     ImmutableList.Builder<SingleClassPolicy> builder = ImmutableList.builder();
 
-    addRequiredSingleClassPolicies(appView, builder);
+    addRequiredSingleClassPolicies(appView, immediateSubtypingInfo, builder);
 
     if (appView.options().horizontalClassMergerOptions().isRestrictedToSynthetics()) {
       assert verifySingleClassPoliciesIrrelevantForMergingSynthetics(appView, builder);
@@ -134,12 +136,16 @@
 
   private static void addRequiredSingleClassPolicies(
       AppView<? extends AppInfoWithClassHierarchy> appView,
+      ImmediateProgramSubtypingInfo immediateSubtypingInfo,
       ImmutableList.Builder<SingleClassPolicy> builder) {
     builder.add(
         new CheckSyntheticClasses(appView),
         new NoCheckDiscard(appView),
         new NoKeepRules(appView),
         new NoClassInitializerWithObservableSideEffects());
+    if (appView.options().shouldProtectApiSurface()) {
+      builder.add(new ProtectApiSurface(appView, immediateSubtypingInfo));
+    }
     if (appView.hasLiveness() && appView.options().isGeneratingClassFiles()) {
       builder.add(new NoMethodHandleFromLambda(appView.withLiveness()));
     }
diff --git a/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/ProtectApiSurface.java b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/ProtectApiSurface.java
new file mode 100644
index 0000000..b0c1348
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/horizontalclassmerging/policies/ProtectApiSurface.java
@@ -0,0 +1,116 @@
+// Copyright (c) 2026, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.horizontalclassmerging.policies;
+
+import com.android.tools.r8.graph.AppInfoWithClassHierarchy;
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.ImmediateProgramSubtypingInfo;
+import com.android.tools.r8.horizontalclassmerging.SingleClassPolicy;
+import com.android.tools.r8.optimize.argumentpropagation.utils.DepthFirstTopDownClassHierarchyTraversal;
+import com.android.tools.r8.shaking.KeepClassInfo;
+import com.android.tools.r8.utils.InternalOptions;
+import com.google.common.collect.Sets;
+import java.util.IdentityHashMap;
+import java.util.Map;
+import java.util.Set;
+
+public class ProtectApiSurface extends SingleClassPolicy {
+
+  private final Set<DexProgramClass> apiClasses = Sets.newIdentityHashSet();
+
+  public ProtectApiSurface(
+      AppView<? extends AppInfoWithClassHierarchy> appView,
+      ImmediateProgramSubtypingInfo immediateSubtypingInfo) {
+    new Traversal(appView, immediateSubtypingInfo).run(appView.appInfo().classes());
+  }
+
+  @Override
+  public boolean canMerge(DexProgramClass program) {
+    return !apiClasses.contains(program);
+  }
+
+  @Override
+  public void clear() {
+    apiClasses.clear();
+  }
+
+  @Override
+  public boolean shouldSkipPolicy() {
+    return apiClasses.isEmpty();
+  }
+
+  @Override
+  public String getName() {
+    return "ProtectApiSurface";
+  }
+
+  private class Traversal extends DepthFirstTopDownClassHierarchyTraversal {
+
+    private final InternalOptions options;
+    private final Map<DexProgramClass, BottomUpTraversalState> states = new IdentityHashMap<>();
+
+    Traversal(
+        AppView<? extends AppInfoWithClassHierarchy> appView,
+        ImmediateProgramSubtypingInfo immediateSubtypingInfo) {
+      super(appView, immediateSubtypingInfo);
+      this.options = appView.options();
+    }
+
+    @Override
+    public void visit(DexProgramClass clazz) {
+      // Intentionally empty.
+    }
+
+    @Override
+    public void prune(DexProgramClass clazz) {
+      boolean isKeptOrHasKeptSubclass = unsetBottomUpTraversalState(clazz);
+      if (isKeptOrHasKeptSubclass) {
+        apiClasses.add(clazz);
+        immediateSubtypingInfo.forEachImmediateProgramSuperClass(
+            clazz,
+            superClass ->
+                getOrCreateBottomUpTraversalState(superClass).setIsKeptOrHasKeptSubclass());
+      }
+    }
+
+    private BottomUpTraversalState getOrCreateBottomUpTraversalState(DexProgramClass clazz) {
+      return states.computeIfAbsent(
+          clazz,
+          c -> {
+            BottomUpTraversalState newState = new BottomUpTraversalState(isKept(clazz));
+            states.put(c, newState);
+            return newState;
+          });
+    }
+
+    private boolean isKept(DexProgramClass clazz) {
+      KeepClassInfo keepInfo = appView.getKeepInfo(clazz);
+      return !keepInfo.isMinificationAllowed(options) && !keepInfo.isShrinkingAllowed(options);
+    }
+
+    // Returns whether the current class is kept or has a kept subclass.
+    private boolean unsetBottomUpTraversalState(DexProgramClass clazz) {
+      BottomUpTraversalState state = states.remove(clazz);
+      if (state != null) {
+        return state.isKeptOrHasKeptSubclass;
+      }
+      return isKept(clazz);
+    }
+  }
+
+  static class BottomUpTraversalState {
+
+    boolean isKeptOrHasKeptSubclass;
+
+    private BottomUpTraversalState(boolean isKept) {
+      this.isKeptOrHasKeptSubclass = isKept;
+    }
+
+    void setIsKeptOrHasKeptSubclass() {
+      isKeptOrHasKeptSubclass = true;
+    }
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/bridgeremoval/BridgeWithInvokeSuperOnInterfaceTest.java b/src/test/java/com/android/tools/r8/bridgeremoval/BridgeWithInvokeSuperOnInterfaceTest.java
index e6d8721..9ff0edc 100644
--- a/src/test/java/com/android/tools/r8/bridgeremoval/BridgeWithInvokeSuperOnInterfaceTest.java
+++ b/src/test/java/com/android/tools/r8/bridgeremoval/BridgeWithInvokeSuperOnInterfaceTest.java
@@ -4,7 +4,7 @@
 
 package com.android.tools.r8.bridgeremoval;
 
-import static com.android.tools.r8.utils.codeinspector.Matchers.isAbsent;
+import static com.android.tools.r8.utils.codeinspector.Matchers.isAbsentIf;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertTrue;
 
@@ -60,7 +60,7 @@
               // Check that we are removing the bridge if we support default methods.
               if (parameters.canUseDefaultAndStaticInterfaceMethods()) {
                 ClassSubject J = inspector.clazz(J.class);
-                assertThat(J, isAbsent());
+                assertThat(J, isAbsentIf(parameters.isDexRuntime()));
                 assertTrue(
                     inspector.allClasses().stream()
                         .allMatch(
diff --git a/src/test/java/com/android/tools/r8/classmerging/horizontal/ProtectApiSurfaceHorizontalClassMergingTest.java b/src/test/java/com/android/tools/r8/classmerging/horizontal/ProtectApiSurfaceHorizontalClassMergingTest.java
new file mode 100644
index 0000000..13b58cf
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/classmerging/horizontal/ProtectApiSurfaceHorizontalClassMergingTest.java
@@ -0,0 +1,88 @@
+// Copyright (c) 2026, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.classmerging.horizontal;
+
+import com.android.tools.r8.NeverClassInline;
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.NoVerticalClassMerging;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.utils.internal.BooleanUtils;
+import java.util.List;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class ProtectApiSurfaceHorizontalClassMergingTest extends TestBase {
+
+  @Parameter(0)
+  public TestParameters parameters;
+
+  @Parameter(1)
+  public boolean protectApiSurface;
+
+  @Parameters(name = "{0}, protect: {1}")
+  public static List<Object[]> data() {
+    return buildParameters(
+        getTestParameters().withAllRuntimesAndApiLevels().build(), BooleanUtils.values());
+  }
+
+  @Test
+  public void test() throws Exception {
+    testForR8(parameters)
+        .addInnerClasses(getClass())
+        .addKeepMainRule(Main.class)
+        .addKeepClassAndMembersRules(B.class)
+        .addHorizontallyMergedClassesInspector(
+            inspector -> {
+              if (!protectApiSurface && parameters.isDexRuntime()) {
+                inspector.assertIsCompleteMergeGroup(A.class, C.class);
+              }
+              inspector.assertNoOtherClassesMerged();
+            })
+        .apply(b -> b.getBuilder().setProtectApiSurface(protectApiSurface))
+        .enableInliningAnnotations()
+        .enableNeverClassInliningAnnotations()
+        .enableNoVerticalClassMergingAnnotations()
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("A", "A", "B", "C");
+  }
+
+  @NoVerticalClassMerging
+  public static class A {
+
+    public A() {
+      System.out.println("A");
+    }
+  }
+
+  // Kept.
+  public static class B extends A {
+
+    public B() {
+      System.out.println("B");
+    }
+  }
+
+  @NeverClassInline
+  public static class C {
+
+    @NeverInline
+    public void m() {
+      System.out.println("C");
+    }
+  }
+
+  public static class Main {
+    public static void main(String[] args) {
+      new A();
+      new B();
+      new C().m();
+    }
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/classmerging/horizontal/interfaces/ClassHierarchyCycleCrossGroupMergingTest.java b/src/test/java/com/android/tools/r8/classmerging/horizontal/interfaces/ClassHierarchyCycleCrossGroupMergingTest.java
index a2bce6c..007f387 100644
--- a/src/test/java/com/android/tools/r8/classmerging/horizontal/interfaces/ClassHierarchyCycleCrossGroupMergingTest.java
+++ b/src/test/java/com/android/tools/r8/classmerging/horizontal/interfaces/ClassHierarchyCycleCrossGroupMergingTest.java
@@ -36,7 +36,11 @@
         .addKeepMainRule(Main.class)
         .addHorizontallyMergedClassesInspector(
             inspector ->
-                inspector.assertIsCompleteMergeGroup(I.class, J.class).assertNoOtherClassesMerged())
+                inspector
+                    .applyIf(
+                        parameters.isDexRuntime(),
+                        i -> i.assertIsCompleteMergeGroup(I.class, J.class))
+                    .assertNoOtherClassesMerged())
         .enableNoHorizontalClassMergingAnnotations()
         .enableNoUnusedInterfaceRemovalAnnotations()
         .enableNoVerticalClassMergingAnnotations()
diff --git a/src/test/java8/shaking/com/android/tools/r8/shaking/KeepAnnotatedMemberTest.java b/src/test/java8/shaking/com/android/tools/r8/shaking/KeepAnnotatedMemberTest.java
index 4f80d47..e928a76 100644
--- a/src/test/java8/shaking/com/android/tools/r8/shaking/KeepAnnotatedMemberTest.java
+++ b/src/test/java8/shaking/com/android/tools/r8/shaking/KeepAnnotatedMemberTest.java
@@ -254,7 +254,7 @@
             .apply(this::suppressZipFileAssignmentsToJavaLangAutoCloseable)
             .compile()
             .graphInspector();
-    assertRetainedClassesEqual(referenceInspector, ifThenKeepClassesWithMembersInspector);
+    assertRetainedClassesEqual(referenceInspector, ifThenKeepClassesWithMembersInspector, true);
 
     GraphInspector ifHasMemberThenKeepClassInspector =
         testForR8(Backend.CF)