Disable code deduplication for code objects containing invoke-super

Bug: b/445349082
Change-Id: I21f3ba04da7bb0febddb5d9b0c5b1f3e64e9faf6
diff --git a/src/main/java/com/android/tools/r8/dex/DefaultMixedSectionLayoutStrategy.java b/src/main/java/com/android/tools/r8/dex/DefaultMixedSectionLayoutStrategy.java
index d9f6def..9f0fa1f 100644
--- a/src/main/java/com/android/tools/r8/dex/DefaultMixedSectionLayoutStrategy.java
+++ b/src/main/java/com/android/tools/r8/dex/DefaultMixedSectionLayoutStrategy.java
@@ -78,13 +78,12 @@
       this.appView = appView;
     }
 
-    public void addCode(DexWritableCode dexWritableCode, ProgramMethod method) {
+    public void addCode(DexWritableCode code, ProgramMethod method) {
       assert appView.options().canUseCanonicalizedCodeObjects();
       if (counts == null) {
         counts = new HashMap<>();
       }
-      DexWritableCacheKey cacheKey =
-          dexWritableCode.getCacheLookupKey(method, appView.dexItemFactory());
+      DexWritableCacheKey cacheKey = code.getCacheLookupKey(method, appView.dexItemFactory());
       if (!counts.containsKey(cacheKey)) {
         counts.put(cacheKey, 1);
       } else {
@@ -98,9 +97,11 @@
             || method.getDefinition().getDexWritableCodeOrNull() == null;
         return 1;
       }
-      DexWritableCode dexWritableCodeOrNull = method.getDefinition().getDexWritableCodeOrNull();
-      DexWritableCacheKey cacheLookupKey =
-          dexWritableCodeOrNull.getCacheLookupKey(method, appView.dexItemFactory());
+      DexWritableCode code = method.getDefinition().getCode().asDexWritableCode();
+      if (!code.canBeCanonicalized(appView.options())) {
+        return 1;
+      }
+      DexWritableCacheKey cacheLookupKey = code.getCacheLookupKey(method, appView.dexItemFactory());
       assert counts.containsKey(cacheLookupKey);
       return counts.get(cacheLookupKey);
     }
@@ -117,7 +118,7 @@
             DexWritableCode code = method.getDefinition().getDexWritableCodeOrNull();
             assert code != null || method.getDefinition().shouldNotHaveCode();
             if (code != null) {
-              if (appView.options().canUseCanonicalizedCodeObjects()) {
+              if (code.canBeCanonicalized(appView.options())) {
                 codeCounts.addCode(code, method);
               }
               codesSorted.add(method);
diff --git a/src/main/java/com/android/tools/r8/dex/FileWriter.java b/src/main/java/com/android/tools/r8/dex/FileWriter.java
index 7cdd5cc..5826972 100644
--- a/src/main/java/com/android/tools/r8/dex/FileWriter.java
+++ b/src/main/java/com/android/tools/r8/dex/FileWriter.java
@@ -251,9 +251,7 @@
       Map<DexWritableCacheKey, Integer> offsetCache = new HashMap<>();
       for (ProgramMethod method : codes) {
         DexWritableCode dexWritableCode = method.getDefinition().getCode().asDexWritableCode();
-        if (!options.canUseCanonicalizedCodeObjects()) {
-          writeCodeItem(method, dexWritableCode);
-        } else {
+        if (dexWritableCode.canBeCanonicalized(options)) {
           DexWritableCacheKey cacheLookupKey =
               dexWritableCode.getCacheLookupKey(method, appView.dexItemFactory());
           Integer offsetOrNull = offsetCache.get(cacheLookupKey);
@@ -262,6 +260,8 @@
           } else {
             offsetCache.put(cacheLookupKey, writeCodeItem(method, dexWritableCode));
           }
+        } else {
+          writeCodeItem(method, dexWritableCode);
         }
       }
       assert sizeAndCountOfCodeItems.getCount()
@@ -479,7 +479,7 @@
     Set<DexWritableCacheKey> cache = new HashSet<>();
     for (ProgramMethod method : methods) {
       DexWritableCode code = method.getDefinition().getCode().asDexWritableCode();
-      if (!options.canUseCanonicalizedCodeObjects()
+      if (!code.canBeCanonicalized(options)
           || cache.add(code.getCacheLookupKey(method, appView.dexItemFactory()))) {
         sizeAndCount.count++;
         sizeAndCount.size = alignSize(4, sizeAndCount.size) + sizeOfCodeItem(code);
diff --git a/src/main/java/com/android/tools/r8/dex/code/DexInstruction.java b/src/main/java/com/android/tools/r8/dex/code/DexInstruction.java
index 70079d8..faa31b9 100644
--- a/src/main/java/com/android/tools/r8/dex/code/DexInstruction.java
+++ b/src/main/java/com/android/tools/r8/dex/code/DexInstruction.java
@@ -229,6 +229,14 @@
     return null;
   }
 
+  public boolean isInvokeSuper() {
+    return false;
+  }
+
+  public boolean isInvokeSuperRange() {
+    return false;
+  }
+
   public boolean isInvokeVirtual() {
     return false;
   }
diff --git a/src/main/java/com/android/tools/r8/dex/code/DexInvokeSuper.java b/src/main/java/com/android/tools/r8/dex/code/DexInvokeSuper.java
index 46af3323229..1586793 100644
--- a/src/main/java/com/android/tools/r8/dex/code/DexInvokeSuper.java
+++ b/src/main/java/com/android/tools/r8/dex/code/DexInvokeSuper.java
@@ -49,6 +49,11 @@
   }
 
   @Override
+  public boolean isInvokeSuper() {
+    return true;
+  }
+
+  @Override
   public void registerUse(UseRegistry<?> registry) {
     registry.registerInvokeSuper(getMethod());
   }
diff --git a/src/main/java/com/android/tools/r8/dex/code/DexInvokeSuperRange.java b/src/main/java/com/android/tools/r8/dex/code/DexInvokeSuperRange.java
index 015f6a5..2a8e5a0 100644
--- a/src/main/java/com/android/tools/r8/dex/code/DexInvokeSuperRange.java
+++ b/src/main/java/com/android/tools/r8/dex/code/DexInvokeSuperRange.java
@@ -49,6 +49,11 @@
   }
 
   @Override
+  public boolean isInvokeSuperRange() {
+    return true;
+  }
+
+  @Override
   public void registerUse(UseRegistry<?> registry) {
     registry.registerInvokeSuper(getMethod());
   }
diff --git a/src/main/java/com/android/tools/r8/graph/DexCode.java b/src/main/java/com/android/tools/r8/graph/DexCode.java
index e64405a..1150791 100644
--- a/src/main/java/com/android/tools/r8/graph/DexCode.java
+++ b/src/main/java/com/android/tools/r8/graph/DexCode.java
@@ -40,6 +40,7 @@
 import com.android.tools.r8.lightir.ByteUtils;
 import com.android.tools.r8.utils.ArrayUtils;
 import com.android.tools.r8.utils.DexDebugUtils.PositionInfo;
+import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.RetracerForCodePrinting;
 import com.android.tools.r8.utils.StringUtils;
 import com.android.tools.r8.utils.structural.Equatable;
@@ -159,6 +160,14 @@
     int unused = hashCode(); // Cache the hash code eagerly.
   }
 
+  @Override
+  public boolean canBeCanonicalized(InternalOptions options) {
+    // Do not canonicalize code objects with invoke-super instructions due to ART's thread
+    // interpreter cache. See also b/445349082.
+    return options.canUseCanonicalizedCodeObjects()
+        && ArrayUtils.none(instructions, i -> i.isInvokeSuper() || i.isInvokeSuperRange());
+  }
+
   public DexCode withCodeLens(GraphLens codeLens) {
     return new DexCode(this) {
 
diff --git a/src/main/java/com/android/tools/r8/graph/DexWritableCode.java b/src/main/java/com/android/tools/r8/graph/DexWritableCode.java
index e5037e5..7f018ce 100644
--- a/src/main/java/com/android/tools/r8/graph/DexWritableCode.java
+++ b/src/main/java/com/android/tools/r8/graph/DexWritableCode.java
@@ -11,6 +11,7 @@
 import com.android.tools.r8.graph.DexCode.TryHandler;
 import com.android.tools.r8.graph.lens.GraphLens;
 import com.android.tools.r8.ir.conversion.LensCodeRewriterUtils;
+import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.structural.CompareToVisitor;
 import com.android.tools.r8.utils.structural.HashingVisitor;
 import java.nio.ShortBuffer;
@@ -24,6 +25,10 @@
     THROW_EXCEPTION
   }
 
+  default boolean canBeCanonicalized(InternalOptions options) {
+    return options.canUseCanonicalizedCodeObjects();
+  }
+
   boolean isThrowExceptionCode();
 
   ThrowExceptionCode asThrowExceptionCode();
diff --git a/src/test/java/com/android/tools/r8/dex/DexCodeDeduppingTest.java b/src/test/java/com/android/tools/r8/dex/DexCodeDeduppingTest.java
index 55f23f9..90322ea 100644
--- a/src/test/java/com/android/tools/r8/dex/DexCodeDeduppingTest.java
+++ b/src/test/java/com/android/tools/r8/dex/DexCodeDeduppingTest.java
@@ -29,11 +29,11 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
 import org.junit.runners.Parameterized.Parameters;
 
 @RunWith(Parameterized.class)
 public class DexCodeDeduppingTest extends TestBase {
-  private final TestParameters parameters;
   private static final List<String> EXPECTED = ImmutableList.of("foo", "bar", "foo", "bar");
 
   private static final int ONE_CLASS_COUNT = 4;
@@ -42,15 +42,14 @@
   private static final int TWO_CLASS_COUNT = 6;
   private static final int TWO_CLASS_DEDUPLICATED_COUNT = 3;
 
+  @Parameter(0)
+  public TestParameters parameters;
+
   @Parameters(name = "{0}")
   public static TestParametersCollection data() {
     return getTestParameters().withDexRuntimes().withAllApiLevels().build();
   }
 
-  public DexCodeDeduppingTest(TestParameters parameters) {
-    this.parameters = parameters;
-  }
-
   @Test
   public void testR8SingleClass() throws Exception {
     R8TestCompileResult compile =
diff --git a/src/test/java/com/android/tools/r8/dex/DexCodeInvokeSuperDeduppingTest.java b/src/test/java/com/android/tools/r8/dex/DexCodeInvokeSuperDeduppingTest.java
new file mode 100644
index 0000000..b5dd40c
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/dex/DexCodeInvokeSuperDeduppingTest.java
@@ -0,0 +1,149 @@
+// Copyright (c) 2025, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.dex;
+
+import static org.junit.Assert.assertEquals;
+
+import com.android.tools.r8.CompilationFailedException;
+import com.android.tools.r8.D8TestBuilder;
+import com.android.tools.r8.D8TestCompileResult;
+import com.android.tools.r8.DexSegments;
+import com.android.tools.r8.DexSegments.SegmentInfo;
+import com.android.tools.r8.ResourceException;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.utils.AndroidApiLevel;
+import com.android.tools.r8.utils.BooleanUtils;
+import com.google.common.collect.ImmutableList;
+import java.io.IOException;
+import java.nio.file.Path;
+import java.util.List;
+import java.util.Map;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+import org.objectweb.asm.Opcodes;
+
+@RunWith(Parameterized.class)
+public class DexCodeInvokeSuperDeduppingTest extends TestBase {
+
+  @Parameter(0)
+  public TestParameters parameters;
+
+  @Parameter(1)
+  public boolean enableMappingOutput;
+
+  @Parameters(name = "{0}, map output: {1}")
+  public static List<Object[]> data() {
+    return buildParameters(
+        getTestParameters().withDexRuntimesAndAllApiLevels().build(), BooleanUtils.values());
+  }
+
+  @Test
+  public void testD8MappingOutput() throws Exception {
+    testForD8(parameters.getBackend())
+        .addProgramClasses(Base.class, FooBase.class, BarBase.class, Main.class)
+        .addProgramClassFileData(getProgramClassFileData())
+        .setMinApi(parameters)
+        .release()
+        .applyIf(enableMappingOutput, D8TestBuilder::internalEnableMappingOutput)
+        .compile()
+        .apply(this::inspectCodeSegment)
+        .run(parameters.getRuntime(), Main.class)
+        .assertSuccessWithOutputLines("FooBase", "BarBase");
+  }
+
+  private List<byte[]> getProgramClassFileData() throws IOException {
+    return ImmutableList.of(
+        transformer(Foo.class)
+            .transformMethodInsnInMethod(
+                "foo",
+                (opcode, owner, name, descriptor, isInterface, visitor) -> {
+                  if (opcode == Opcodes.INVOKESPECIAL) {
+                    assertEquals(
+                        "com/android/tools/r8/dex/DexCodeInvokeSuperDeduppingTest$FooBase", owner);
+                    owner = binaryName(Base.class);
+                  }
+                  visitor.visitMethodInsn(opcode, owner, name, descriptor, isInterface);
+                })
+            .transform(),
+        transformer(Bar.class)
+            .transformMethodInsnInMethod(
+                "bar",
+                (opcode, owner, name, descriptor, isInterface, visitor) -> {
+                  if (opcode == Opcodes.INVOKESPECIAL) {
+                    assertEquals(
+                        "com/android/tools/r8/dex/DexCodeInvokeSuperDeduppingTest$BarBase", owner);
+                    owner = binaryName(Base.class);
+                  }
+                  visitor.visitMethodInsn(opcode, owner, name, descriptor, isInterface);
+                })
+            .transform());
+  }
+
+  private void inspectCodeSegment(D8TestCompileResult compileResult) throws Exception {
+    SegmentInfo codeSegmentInfo = getCodeSegmentInfo(compileResult.writeToZip());
+    int expectedCodeItems = 11;
+    if (parameters.getApiLevel().isGreaterThanOrEqualTo(AndroidApiLevel.S) && enableMappingOutput) {
+      // Main.<init> and Base.<init> are canonicalized, and so is FooBase.<init> and BarBase.<init>.
+      assertEquals(expectedCodeItems - 2, codeSegmentInfo.getItemCount());
+    } else {
+      assertEquals(expectedCodeItems, codeSegmentInfo.getItemCount());
+    }
+  }
+
+  public SegmentInfo getCodeSegmentInfo(Path path)
+      throws CompilationFailedException, ResourceException, IOException {
+    DexSegments.Command command = DexSegments.Command.builder().addProgramFiles(path).build();
+    Map<Integer, SegmentInfo> segmentInfoMap = DexSegments.runForTesting(command);
+    return segmentInfoMap.get(Constants.TYPE_CODE_ITEM);
+  }
+
+  static class Main {
+
+    public static void main(String[] args) {
+      new Foo().foo();
+      new Bar().bar();
+    }
+  }
+
+  abstract static class Base {
+
+    abstract void m();
+  }
+
+  static class FooBase extends Base {
+
+    @Override
+    void m() {
+      System.out.println("FooBase");
+    }
+  }
+
+  static class Foo extends FooBase {
+
+    void foo() {
+      // Symbolic holder transformed to Base class.
+      super.m();
+    }
+  }
+
+  static class BarBase extends Base {
+
+    @Override
+    void m() {
+      System.out.println("BarBase");
+    }
+  }
+
+  static class Bar extends BarBase {
+
+    void bar() {
+      // Symbolic holder transformed to Base class.
+      super.m();
+    }
+  }
+}