Share bridges in synthetic super class

Bug: b/309575527
Change-Id: Ibce87fc14932e2fa8b5351388c0c82ced9ca9d71
diff --git a/src/main/java/com/android/tools/r8/R8.java b/src/main/java/com/android/tools/r8/R8.java
index 7d04b5d..9371a3f 100644
--- a/src/main/java/com/android/tools/r8/R8.java
+++ b/src/main/java/com/android/tools/r8/R8.java
@@ -69,6 +69,7 @@
 import com.android.tools.r8.naming.RecordInvokeDynamicInvokeCustomRewriter;
 import com.android.tools.r8.naming.RecordRewritingNamingLens;
 import com.android.tools.r8.naming.signature.GenericSignatureRewriter;
+import com.android.tools.r8.optimize.BridgeHoistingToSharedSyntheticSuperClass;
 import com.android.tools.r8.optimize.MemberRebindingAnalysis;
 import com.android.tools.r8.optimize.MemberRebindingIdentityLens;
 import com.android.tools.r8.optimize.MemberRebindingIdentityLensFactory;
@@ -495,12 +496,13 @@
           .setMustRetargetInvokesToTargetMethod()
           .run(executorService, timing);
 
-      RuntimeTypeCheckInfo runtimeTypeCheckInfo =
-          classMergingEnqueuerExtensionBuilder.build(appView.graphLens());
-      classMergingEnqueuerExtensionBuilder = null;
+      BridgeHoistingToSharedSyntheticSuperClass.run(appViewWithLiveness, executorService, timing);
 
       assert ArtProfileCompletenessChecker.verify(appView);
 
+      RuntimeTypeCheckInfo runtimeTypeCheckInfo =
+          classMergingEnqueuerExtensionBuilder.build(appView.graphLens());
+      classMergingEnqueuerExtensionBuilder = null;
       if (!appView.hasCfByteCodePassThroughMethods()
           && options.getProguardConfiguration().isOptimizing()) {
         if (options.enableVerticalClassMerging) {
diff --git a/src/main/java/com/android/tools/r8/graph/DexClass.java b/src/main/java/com/android/tools/r8/graph/DexClass.java
index 616c9ec..2a80c0b 100644
--- a/src/main/java/com/android/tools/r8/graph/DexClass.java
+++ b/src/main/java/com/android/tools/r8/graph/DexClass.java
@@ -734,6 +734,10 @@
     return superType;
   }
 
+  public void setSuperType(DexType superType) {
+    this.superType = superType;
+  }
+
   public boolean hasClassInitializer() {
     return getClassInitializer() != null;
   }
diff --git a/src/main/java/com/android/tools/r8/optimize/BridgeHoistingToSharedSyntheticSuperClass.java b/src/main/java/com/android/tools/r8/optimize/BridgeHoistingToSharedSyntheticSuperClass.java
new file mode 100644
index 0000000..c8d99c7
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/optimize/BridgeHoistingToSharedSyntheticSuperClass.java
@@ -0,0 +1,335 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.optimize;
+
+import static com.android.tools.r8.ir.optimize.info.OptimizationFeedback.getSimpleFeedback;
+import static com.android.tools.r8.utils.MapUtils.ignoreKey;
+
+import com.android.tools.r8.contexts.CompilationContext.MainThreadContext;
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexClass;
+import com.android.tools.r8.graph.DexEncodedMethod;
+import com.android.tools.r8.graph.DexMethodSignature;
+import com.android.tools.r8.graph.DexProgramClass;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.MethodAccessFlags;
+import com.android.tools.r8.ir.code.IRCode;
+import com.android.tools.r8.ir.conversion.MethodConversionOptions;
+import com.android.tools.r8.ir.optimize.info.bridge.BridgeAnalyzer;
+import com.android.tools.r8.ir.optimize.info.bridge.BridgeInfo;
+import com.android.tools.r8.ir.optimize.info.bridge.VirtualBridgeInfo;
+import com.android.tools.r8.optimize.bridgehoisting.BridgeHoisting;
+import com.android.tools.r8.profile.rewriting.ConcreteProfileCollectionAdditions;
+import com.android.tools.r8.profile.rewriting.ProfileCollectionAdditions;
+import com.android.tools.r8.shaking.AppInfoWithLiveness;
+import com.android.tools.r8.utils.InternalOptions;
+import com.android.tools.r8.utils.InternalOptions.TestingOptions;
+import com.android.tools.r8.utils.ListUtils;
+import com.android.tools.r8.utils.OptionalBool;
+import com.android.tools.r8.utils.SetUtils;
+import com.android.tools.r8.utils.Timing;
+import com.android.tools.r8.utils.collections.DexMethodSignatureMap;
+import com.google.common.collect.Iterables;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.function.BiConsumer;
+
+public class BridgeHoistingToSharedSyntheticSuperClass {
+
+  private final AppView<AppInfoWithLiveness> appView;
+
+  BridgeHoistingToSharedSyntheticSuperClass(AppView<AppInfoWithLiveness> appView) {
+    this.appView = appView;
+  }
+
+  public static void run(
+      AppView<AppInfoWithLiveness> appView, ExecutorService executorService, Timing timing)
+      throws ExecutionException {
+    InternalOptions options = appView.options();
+    if (!options.isOptimizing() || !options.isShrinking()) {
+      return;
+    }
+    if (!appView.options().canHaveNonReboundConstructorInvoke()) {
+      // TODO(b/309575527): Extend to all runtimes.
+      return;
+    }
+    TestingOptions testingOptions = options.getTestingOptions();
+    if (!testingOptions.enableBridgeHoistingToSharedSyntheticSuperclass) {
+      return;
+    }
+    timing.time(
+        "BridgeHoistingToSharedSyntheticSuperClass",
+        () -> new BridgeHoistingToSharedSyntheticSuperClass(appView).run(executorService, timing));
+  }
+
+  private void run(ExecutorService executorService, Timing timing) throws ExecutionException {
+    Collection<Group> groups = createInitialGroups(appView);
+    groups = refineGroups(groups);
+    if (!groups.isEmpty()) {
+      rewriteApplication(groups);
+      commitPendingSyntheticClasses();
+      updateArtProfiles(groups);
+      new BridgeHoisting(appView).run(executorService, timing);
+    }
+  }
+
+  /** Returns the set of (non-singleton) groups that have the same superclass. */
+  private Collection<Group> createInitialGroups(AppView<AppInfoWithLiveness> appView) {
+    Map<DexClass, Group> groups = new LinkedHashMap<>();
+    for (DexProgramClass clazz : appView.appInfo().classesWithDeterministicOrder()) {
+      if (!clazz.hasSuperType()) {
+        continue;
+      }
+      DexClass superclass = appView.definitionFor(clazz.getSuperType());
+      if (superclass != null) {
+        groups.computeIfAbsent(superclass, ignoreKey(Group::new)).addClass(clazz);
+      }
+    }
+    groups.values().removeIf(Group::isSingleton);
+    return groups.values();
+  }
+
+  private Collection<Group> refineGroups(Collection<Group> groups) {
+    Collection<Group> newGroups = new ArrayList<>();
+    for (Group group : groups) {
+      Iterables.addAll(newGroups, refineGroup(group));
+    }
+    return newGroups;
+  }
+
+  /**
+   * Splits the group into a collection of smaller groups that should receive a shared superclass.
+   *
+   * <p>For each class, this creates a specification of the bridges (a mapping from bridge method
+   * signatures to their bridge implementation). Two classes are selected for getting a shared
+   * synthetic super class if the bridge specification of one is a subset of the other (i.e., a
+   * subset of the bridges can be shared and there are no bridges with the same signature that have
+   * different behavior).
+   */
+  private Iterable<Group> refineGroup(Group group) {
+    List<Group> newGroups = new ArrayList<>();
+    for (DexProgramClass clazz : group) {
+      BridgeSpecification bridgeSpecification = getBridgeSpecification(clazz);
+      if (bridgeSpecification.isEmpty()) {
+        continue;
+      }
+      Group targetGroup = getGroupForClass(newGroups, clazz, bridgeSpecification);
+      if (targetGroup == null) {
+        newGroups.add(new Group(clazz, bridgeSpecification));
+      }
+    }
+    // Only introduce a shared super class for non-singleton groups that do not already have a
+    // shared superclass in the first place.
+    return Iterables.filter(
+        newGroups, newGroup -> !newGroup.isSingleton() && newGroup.size() < group.size());
+  }
+
+  // TODO(b/309575527): Avoid building IR for all methods.
+  private BridgeSpecification getBridgeSpecification(DexProgramClass clazz) {
+    BridgeSpecification bridgeSpecification = new BridgeSpecification();
+    clazz.forEachProgramVirtualMethodMatching(
+        DexEncodedMethod::hasCode,
+        method -> {
+          IRCode code = method.buildIR(appView, MethodConversionOptions.nonConverting());
+          BridgeInfo bridgeInfo = BridgeAnalyzer.analyzeMethod(method.getDefinition(), code);
+          if (bridgeInfo != null) {
+            getSimpleFeedback().setBridgeInfo(method, bridgeInfo);
+            if (bridgeInfo.isVirtualBridgeInfo()) {
+              bridgeSpecification.addBridge(
+                  method.getMethodSignature(), bridgeInfo.asVirtualBridgeInfo());
+            }
+          }
+        });
+    return bridgeSpecification;
+  }
+
+  private Group getGroupForClass(
+      Collection<Group> groups, DexProgramClass clazz, BridgeSpecification bridgeSpecification) {
+    for (Group group : groups) {
+      if (bridgeSpecification.lessThanOrEquals(group.getBridgeSpecification())) {
+        group.addClass(clazz);
+        return group;
+      } else if (group.getBridgeSpecification().lessThanOrEquals(bridgeSpecification)) {
+        group.addClass(clazz);
+        group.setBridgeSpecification(bridgeSpecification);
+        return group;
+      }
+    }
+    return null;
+  }
+
+  private void rewriteApplication(Collection<Group> groups) {
+    MainThreadContext mainThreadContext =
+        appView.createProcessorContext().createMainThreadContext();
+    for (Group group : groups) {
+      DexProgramClass representative = ListUtils.first(group.getClasses());
+      Set<DexType> interfaces = SetUtils.newIdentityHashSet(representative.getInterfaces());
+      for (DexProgramClass clazz : Iterables.skip(group.getClasses(), 1)) {
+        interfaces.removeIf(type -> !clazz.getInterfaces().contains(type));
+      }
+      DexProgramClass syntheticSuperclass =
+          appView
+              .getSyntheticItems()
+              .createClass(
+                  kinds -> kinds.SHARED_SUPER_CLASS,
+                  mainThreadContext.createUniqueContext(representative),
+                  appView,
+                  classBuilder -> {
+                    classBuilder
+                        .setAbstract()
+                        .setSuperType(representative.getSuperType())
+                        .setInterfaces(ListUtils.sort(interfaces, Comparator.naturalOrder()));
+                    group
+                        .getBridgeSpecification()
+                        .forEach(
+                            (bridge, target) ->
+                                classBuilder.addMethod(
+                                    methodBuilder ->
+                                        methodBuilder
+                                            .setAccessFlags(
+                                                MethodAccessFlags.builder()
+                                                    .setAbstract()
+                                                    .setPublic()
+                                                    .build())
+                                            // TODO(b/309575527): Set correct api level.
+                                            .setApiLevelForDefinition(appView.computedMinApiLevel())
+                                            // TODO(b/309575527): Set correct library override info.
+                                            .setIsLibraryMethodOverride(OptionalBool.FALSE)
+                                            .setName(target.getName())
+                                            .setProto(target.getProto())));
+                  });
+      for (DexProgramClass clazz : group) {
+        clazz.setSuperType(syntheticSuperclass.getType());
+      }
+    }
+  }
+
+  private void commitPendingSyntheticClasses() {
+    assert appView.getSyntheticItems().hasPendingSyntheticClasses();
+    appView.setAppInfo(
+        appView.appInfo().rebuildWithLiveness(appView.getSyntheticItems().commit(appView.app())));
+  }
+
+  private void updateArtProfiles(Collection<Group> groups) {
+    ConcreteProfileCollectionAdditions profileCollectionAdditions =
+        ProfileCollectionAdditions.create(appView).asConcrete();
+    if (profileCollectionAdditions == null) {
+      return;
+    }
+    for (Group group : groups) {
+      for (DexProgramClass clazz : group) {
+        profileCollectionAdditions.applyIfContextIsInProfile(
+            clazz, additionsBuilder -> additionsBuilder.addClassRule(clazz.getSuperType()));
+        group
+            .getBridgeSpecification()
+            .forEach(
+                (bridge, target) -> {
+                  DexEncodedMethod targetMethod = clazz.getMethodCollection().getMethod(target);
+                  if (targetMethod != null) {
+                    profileCollectionAdditions.applyIfContextIsInProfile(
+                        targetMethod.getReference(),
+                        additionsBuilder ->
+                            additionsBuilder.addMethodRule(
+                                target.withHolder(clazz.getSuperType(), appView.dexItemFactory())));
+                  }
+                });
+      }
+    }
+    profileCollectionAdditions.commit(appView);
+  }
+
+  private static class Group implements Iterable<DexProgramClass> {
+
+    private final List<DexProgramClass> classes;
+    private BridgeSpecification bridgeSpecification;
+
+    public Group() {
+      this.classes = new ArrayList<>();
+      this.bridgeSpecification = null;
+    }
+
+    public Group(DexProgramClass clazz, BridgeSpecification bridgeSpecification) {
+      this.classes = ListUtils.newArrayList(clazz);
+      this.bridgeSpecification = bridgeSpecification;
+    }
+
+    void addClass(DexProgramClass clazz) {
+      classes.add(clazz);
+    }
+
+    BridgeSpecification getBridgeSpecification() {
+      return bridgeSpecification;
+    }
+
+    List<DexProgramClass> getClasses() {
+      return classes;
+    }
+
+    void setBridgeSpecification(BridgeSpecification bridgeSpecification) {
+      this.bridgeSpecification = bridgeSpecification;
+    }
+
+    boolean isSingleton() {
+      return size() == 1;
+    }
+
+    @Override
+    public Iterator<DexProgramClass> iterator() {
+      return classes.iterator();
+    }
+
+    public int size() {
+      return classes.size();
+    }
+  }
+
+  private static class BridgeSpecification {
+
+    private final DexMethodSignatureMap<DexMethodSignature> bridges =
+        DexMethodSignatureMap.create();
+
+    void addBridge(DexMethodSignature method, VirtualBridgeInfo bridgeInfo) {
+      bridges.put(method, bridgeInfo.getInvokedMethod().getSignature());
+    }
+
+    boolean containsBridgeWithTarget(DexMethodSignature method, DexMethodSignature target) {
+      return target.equals(bridges.get(method));
+    }
+
+    void forEach(BiConsumer<? super DexMethodSignature, ? super DexMethodSignature> consumer) {
+      bridges.forEach(consumer);
+    }
+
+    boolean isEmpty() {
+      return bridges.isEmpty();
+    }
+
+    boolean lessThanOrEquals(BridgeSpecification bridgeSpecification) {
+      if (size() > bridgeSpecification.size()) {
+        return false;
+      }
+      for (Entry<DexMethodSignature, DexMethodSignature> entry : bridges.entrySet()) {
+        DexMethodSignature method = entry.getKey();
+        DexMethodSignature target = entry.getValue();
+        if (!bridgeSpecification.containsBridgeWithTarget(method, target)) {
+          return false;
+        }
+      }
+      return true;
+    }
+
+    int size() {
+      return bridges.size();
+    }
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/synthesis/SyntheticNaming.java b/src/main/java/com/android/tools/r8/synthesis/SyntheticNaming.java
index 4707606..a1795a7 100644
--- a/src/main/java/com/android/tools/r8/synthesis/SyntheticNaming.java
+++ b/src/main/java/com/android/tools/r8/synthesis/SyntheticNaming.java
@@ -61,6 +61,10 @@
   public final SyntheticKind LAMBDA = generator.forInstanceClass("Lambda");
   public final SyntheticKind THREAD_LOCAL = generator.forInstanceClass("ThreadLocal");
 
+  // Merging not permitted since this could defeat the purpose of the synthetic class.
+  public final SyntheticKind SHARED_SUPER_CLASS =
+      generator.forNonSharableInstanceClass("SharedSuper");
+
   // TODO(b/214901256): Sharing of synthetic classes may lead to duplicate method errors.
   public final SyntheticKind NON_FIXED_INIT_TYPE_ARGUMENT =
       generator.forNonSharableInstanceClass("$IA");
diff --git a/src/main/java/com/android/tools/r8/utils/InternalOptions.java b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
index a372d38..e90c9a1 100644
--- a/src/main/java/com/android/tools/r8/utils/InternalOptions.java
+++ b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
@@ -2365,6 +2365,7 @@
     public boolean allowUnusedDontWarnRules = true;
     public boolean alwaysUseExistingAccessInfoCollectionsInMemberRebinding = true;
     public boolean alwaysUsePessimisticRegisterAllocation = false;
+    public boolean enableBridgeHoistingToSharedSyntheticSuperclass = false;
     public boolean enableCheckCastAndInstanceOfRemoval = true;
     public boolean enableDeadSwitchCaseElimination = true;
     public boolean enableInvokeSuperToInvokeVirtualRewriting = true;
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/boxedprimitives/BoxedPrimitiveFromGenericUnboxingTest.java b/src/test/java/com/android/tools/r8/ir/optimize/boxedprimitives/BoxedPrimitiveFromGenericUnboxingTest.java
new file mode 100644
index 0000000..5a23263
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/boxedprimitives/BoxedPrimitiveFromGenericUnboxingTest.java
@@ -0,0 +1,159 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.ir.optimize.boxedprimitives;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isAbsentIf;
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertEquals;
+
+import com.android.tools.r8.NoHorizontalClassMerging;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.utils.BooleanUtils;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import java.util.List;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class BoxedPrimitiveFromGenericUnboxingTest extends TestBase {
+
+  @Parameter(0)
+  public boolean enableBridgeHoistingToSharedSyntheticSuperclass;
+
+  @Parameter(1)
+  public TestParameters parameters;
+
+  @Parameters(name = "{1}, opt: {0}")
+  public static List<Object[]> data() {
+    return buildParameters(
+        BooleanUtils.values(), getTestParameters().withAllRuntimesAndApiLevels().build());
+  }
+
+  @Test
+  public void test() throws Exception {
+    boolean optimize =
+        enableBridgeHoistingToSharedSyntheticSuperclass
+            && parameters.canHaveNonReboundConstructorInvoke();
+    testForR8(parameters.getBackend())
+        .addInnerClasses(getClass())
+        .addKeepMainRule(Main.class)
+        .addOptionsModification(
+            options ->
+                options.testing.enableBridgeHoistingToSharedSyntheticSuperclass =
+                    enableBridgeHoistingToSharedSyntheticSuperclass)
+        .enableNoHorizontalClassMergingAnnotations()
+        .setMinApi(parameters)
+        .compile()
+        .inspect(
+            inspector -> {
+              // Function should be removed as a result of bridge hoisting + inlining when adding a
+              // shared superclass to Increment and Decrement, and another shared superclass to
+              // StdoutPrinter and StderrPrinter.
+              ClassSubject functionClassSubject = inspector.clazz(Function.class);
+              assertThat(functionClassSubject, isAbsentIf(optimize));
+
+              // Check that the cast to java.lang.Integer in Increment.apply has been removed as a
+              // result of devirtualization.
+              ClassSubject incrementClassSubject = inspector.clazz(Increment.class);
+              assertThat(incrementClassSubject, isPresent());
+
+              MethodSubject incrementApplyMethodSubject =
+                  incrementClassSubject.uniqueMethodWithOriginalName("apply");
+              assertThat(incrementApplyMethodSubject, isPresent());
+              assertEquals(
+                  optimize,
+                  incrementApplyMethodSubject
+                      .streamInstructions()
+                      .noneMatch(
+                          instruction -> instruction.isCheckCast(Integer.class.getTypeName())));
+
+              // Check that the cast to java.lang.String in StdoutPrinter.apply has been removed as
+              // result of devirtualization (in fact the `Void apply(String)` method has been
+              // optimized to `void apply()` as a result of constant propagation).
+              ClassSubject stdoutPrinterClassSubject = inspector.clazz(StdoutPrinter.class);
+              assertThat(stdoutPrinterClassSubject, isPresent());
+
+              MethodSubject stdoutPrinterApplyMethodSubject =
+                  stdoutPrinterClassSubject.uniqueMethodWithOriginalName("apply");
+              assertThat(stdoutPrinterApplyMethodSubject, isPresent());
+              assertEquals(
+                  optimize,
+                  stdoutPrinterApplyMethodSubject.getProgramMethod().getReturnType().isVoidType());
+              assertEquals(
+                  optimize ? 0 : 1, stdoutPrinterApplyMethodSubject.getParameters().size());
+              assertEquals(
+                  optimize,
+                  stdoutPrinterApplyMethodSubject
+                      .streamInstructions()
+                      .noneMatch(
+                          instruction -> instruction.isCheckCast(String.class.getTypeName())));
+            })
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("42", "42", "42");
+  }
+
+  static class Main {
+
+    public static void main(String[] args) {
+      Function<Integer, Integer> inc =
+          System.currentTimeMillis() > 0 ? new Increment() : new Decrement();
+      Function<Integer, Integer> dec =
+          System.currentTimeMillis() > 0 ? new Decrement() : new Increment();
+      Function<String, Void> printer =
+          System.currentTimeMillis() > 0 ? new StdoutPrinter() : new StderrPrinter();
+      System.out.println(inc.apply(41));
+      System.out.println(dec.apply(43));
+      printer.apply("42");
+    }
+  }
+
+  interface Function<S, T> {
+
+    T apply(S s);
+  }
+
+  @NoHorizontalClassMerging
+  static class Increment implements Function<Integer, Integer> {
+
+    @Override
+    public Integer apply(Integer i) {
+      return i + 1;
+    }
+  }
+
+  @NoHorizontalClassMerging
+  static class Decrement implements Function<Integer, Integer> {
+
+    @Override
+    public Integer apply(Integer i) {
+      return i - 1;
+    }
+  }
+
+  @NoHorizontalClassMerging
+  static class StdoutPrinter implements Function<String, Void> {
+
+    @Override
+    public Void apply(String obj) {
+      System.out.println(obj);
+      return null;
+    }
+  }
+
+  @NoHorizontalClassMerging
+  static class StderrPrinter implements Function<String, Void> {
+
+    @Override
+    public Void apply(String obj) {
+      System.err.println(obj);
+      return null;
+    }
+  }
+}