Dedup code objects on api 31 and above

Art will no longer quicken and we can safely share code objects.

Bug: 174120485
Bug: 278742499
Change-Id: Ibde9c342b7f93cfe5c09423242fcd21386533e0b
diff --git a/src/main/java/com/android/tools/r8/dex/FileWriter.java b/src/main/java/com/android/tools/r8/dex/FileWriter.java
index 8c6f33a..686d327 100644
--- a/src/main/java/com/android/tools/r8/dex/FileWriter.java
+++ b/src/main/java/com/android/tools/r8/dex/FileWriter.java
@@ -40,6 +40,7 @@
 import com.android.tools.r8.graph.DexTypeList;
 import com.android.tools.r8.graph.DexValue;
 import com.android.tools.r8.graph.DexWritableCode;
+import com.android.tools.r8.graph.DexWritableCode.DexWritableCacheKey;
 import com.android.tools.r8.graph.IndexedDexItem;
 import com.android.tools.r8.graph.ObjectToOffsetMapping;
 import com.android.tools.r8.graph.ParameterAnnotationsList;
@@ -56,6 +57,7 @@
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.IterableUtils;
 import com.android.tools.r8.utils.LebUtils;
+import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Sets;
 import it.unimi.dsi.fastutil.objects.Object2IntLinkedOpenHashMap;
 import it.unimi.dsi.fastutil.objects.Object2IntMap;
@@ -67,6 +69,7 @@
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Comparator;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.IdentityHashMap;
 import java.util.List;
@@ -223,7 +226,8 @@
     Collection<ProgramMethod> codes = mixedSectionLayoutStrategy.getCodeLayout();
 
     // Output the debug_info_items first, as they have no dependencies.
-    dest.moveTo(layout.getCodesOffset() + sizeOfCodeItems(codes));
+    SizeAndCount sizeAndCountOfCodeItems = sizeAndCountOfCodeItems(codes);
+    dest.moveTo(layout.getCodesOffset() + sizeAndCountOfCodeItems.size);
     if (mixedSectionOffsets.getDebugInfos().isEmpty()) {
       layout.setDebugInfosOffset(0);
     } else {
@@ -245,7 +249,25 @@
     // Now output the code.
     dest.moveTo(layout.getCodesOffset());
     assert dest.isAligned(4);
-    writeItems(codes, layout::alreadySetOffset, this::writeCodeItem, 4);
+    Map<DexWritableCacheKey, Integer> offsetCache = new HashMap<>();
+    for (ProgramMethod method : codes) {
+      DexWritableCode dexWritableCode = method.getDefinition().getCode().asDexWritableCode();
+      if (!options.canUseCanonicalizedCodeObjects()) {
+        writeCodeItem(method, dexWritableCode);
+      } else {
+        DexWritableCacheKey cacheLookupKey =
+            dexWritableCode.getCacheLookupKey(method, appView.dexItemFactory());
+        Integer offsetOrNull = offsetCache.get(cacheLookupKey);
+        if (offsetOrNull != null) {
+          mixedSectionOffsets.setOffsetFor(method.getDefinition(), offsetOrNull);
+        } else {
+          offsetCache.put(cacheLookupKey, writeCodeItem(method, dexWritableCode));
+        }
+      }
+    }
+    assert sizeAndCountOfCodeItems.getCount()
+        == ImmutableSet.copyOf(mixedSectionOffsets.codes.values()).size();
+    layout.setCodeCount(sizeAndCountOfCodeItems.getCount());
     assert layout.getDebugInfosOffset() == 0 || dest.position() == layout.getDebugInfosOffset();
 
     // Now the type lists and rest.
@@ -434,13 +456,32 @@
     }
   }
 
-  private int sizeOfCodeItems(Iterable<ProgramMethod> methods) {
-    int size = 0;
-    for (ProgramMethod method : methods) {
-      size = alignSize(4, size);
-      size += sizeOfCodeItem(method.getDefinition().getCode().asDexWritableCode());
+  static class SizeAndCount {
+
+    private int size = 0;
+    private int count = 0;
+
+    public int getCount() {
+      return count;
     }
-    return size;
+
+    public int getSize() {
+      return size;
+    }
+  }
+
+  private SizeAndCount sizeAndCountOfCodeItems(Iterable<ProgramMethod> methods) {
+    SizeAndCount sizeAndCount = new SizeAndCount();
+    Set<DexWritableCacheKey> cache = new HashSet<>();
+    for (ProgramMethod method : methods) {
+      DexWritableCode code = method.getDefinition().getCode().asDexWritableCode();
+      if (!options.canUseCanonicalizedCodeObjects()
+          || cache.add(code.getCacheLookupKey(method, appView.dexItemFactory()))) {
+        sizeAndCount.count++;
+        sizeAndCount.size = alignSize(4, sizeAndCount.size) + sizeOfCodeItem(code);
+      }
+    }
+    return sizeAndCount;
   }
 
   private int sizeOfCodeItem(DexWritableCode code) {
@@ -525,12 +566,9 @@
     dest.putBytes(new DebugBytecodeWriter(debugInfo, mapping, graphLens).generate());
   }
 
-  private void writeCodeItem(ProgramMethod method) {
-    writeCodeItem(method, method.getDefinition().getCode().asDexWritableCode());
-  }
-
-  private void writeCodeItem(ProgramMethod method, DexWritableCode code) {
-    mixedSectionOffsets.setOffsetFor(method.getDefinition(), code, dest.align(4));
+  private int writeCodeItem(ProgramMethod method, DexWritableCode code) {
+    int codeOffset = dest.align(4);
+    mixedSectionOffsets.setOffsetFor(method.getDefinition(), codeOffset);
     // Fixed size header information.
     dest.putShort((short) code.getRegisterSize(method));
     dest.putShort((short) code.getIncomingRegisterSize(method));
@@ -580,6 +618,7 @@
       // And move to the end.
       dest.moveTo(endOfCodeOffset);
     }
+    return codeOffset;
   }
 
   private void writeTypeList(DexTypeList list) {
@@ -943,6 +982,7 @@
     private int encodedArraysOffset = NOT_SET;
     private int mapOffset = NOT_SET;
     private int endOfFile = NOT_SET;
+    private int codeCount = NOT_SET;
 
     private Layout(
         int headerOffset,
@@ -1024,6 +1064,15 @@
       this.codesOffset = codesOffset;
     }
 
+    public void setCodeCount(int codeCount) {
+      assert this.codeCount == NOT_SET;
+      this.codeCount = codeCount;
+    }
+
+    public int getCodeCount() {
+      return codeCount;
+    }
+
     public int getDebugInfosOffset() {
       assert isValidOffset(debugInfosOffset, false);
       return debugInfosOffset;
@@ -1185,11 +1234,7 @@
               Constants.TYPE_METHOD_HANDLE_ITEM,
               methodHandleIdsOffset,
               fileWriter.mapping.getMethodHandles().size()));
-      mapItems.add(
-          new MapItem(
-              Constants.TYPE_CODE_ITEM,
-              getCodesOffset(),
-              fileWriter.mixedSectionOffsets.getCodes().size()));
+      mapItems.add(new MapItem(Constants.TYPE_CODE_ITEM, getCodesOffset(), codeCount));
       mapItems.add(
           new MapItem(
               Constants.TYPE_DEBUG_INFO_ITEM,
@@ -1601,7 +1646,7 @@
       setOffsetFor(debugInfo, offset, debugInfos);
     }
 
-    void setOffsetFor(DexEncodedMethod method, DexWritableCode code, int offset) {
+    void setOffsetFor(DexEncodedMethod method, int offset) {
       setOffsetFor(method, offset, codes);
     }
 
diff --git a/src/main/java/com/android/tools/r8/graph/DefaultInstanceInitializerCode.java b/src/main/java/com/android/tools/r8/graph/DefaultInstanceInitializerCode.java
index 06fd92e..66df280 100644
--- a/src/main/java/com/android/tools/r8/graph/DefaultInstanceInitializerCode.java
+++ b/src/main/java/com/android/tools/r8/graph/DefaultInstanceInitializerCode.java
@@ -397,6 +397,15 @@
     return toString();
   }
 
+  @Override
+  public DexWritableCacheKey getCacheLookupKey(ProgramMethod method, DexItemFactory factory) {
+    return new AmendedDexWritableCodeKey<DexMethod>(
+        this,
+        getParentConstructor(method, factory),
+        getIncomingRegisterSize(method),
+        getRegisterSize(method));
+  }
+
   static class DefaultInstanceInitializerSourceCode extends SyntheticStraightLineSourceCode {
 
     DefaultInstanceInitializerSourceCode(DexMethod method) {
diff --git a/src/main/java/com/android/tools/r8/graph/DexCode.java b/src/main/java/com/android/tools/r8/graph/DexCode.java
index ae96c5e..039a0f0 100644
--- a/src/main/java/com/android/tools/r8/graph/DexCode.java
+++ b/src/main/java/com/android/tools/r8/graph/DexCode.java
@@ -22,6 +22,7 @@
 import com.android.tools.r8.graph.DexDebugEvent.SetPositionFrame;
 import com.android.tools.r8.graph.DexDebugEvent.StartLocal;
 import com.android.tools.r8.graph.DexDebugInfo.EventBasedDebugInfo;
+import com.android.tools.r8.graph.DexWritableCode.DexWritableCacheKey;
 import com.android.tools.r8.graph.bytecodemetadata.BytecodeInstructionMetadata;
 import com.android.tools.r8.graph.bytecodemetadata.BytecodeMetadata;
 import com.android.tools.r8.graph.lens.GraphLens;
@@ -62,7 +63,8 @@
 import java.util.function.Consumer;
 
 // DexCode corresponds to code item in dalvik/dex-format.html
-public class DexCode extends Code implements DexWritableCode, StructuralItem<DexCode> {
+public class DexCode extends Code
+    implements DexWritableCode, StructuralItem<DexCode>, DexWritableCacheKey {
 
   public static final String FAKE_THIS_PREFIX = "_";
   public static final String FAKE_THIS_SUFFIX = "this";
@@ -283,6 +285,7 @@
     if (debugInfoForWriting != null) {
       debugInfoForWriting = null;
     }
+    flushCachedValues();
   }
 
   public DexDebugInfo debugInfoWithFakeThisParameter(DexItemFactory factory) {
@@ -858,6 +861,11 @@
     }
   }
 
+  @Override
+  public DexWritableCacheKey getCacheLookupKey(ProgramMethod method, DexItemFactory factory) {
+    return this;
+  }
+
   public static class Try extends DexItem implements StructuralItem<Try> {
 
     public static final Try[] EMPTY_ARRAY = new Try[0];
diff --git a/src/main/java/com/android/tools/r8/graph/DexWritableCode.java b/src/main/java/com/android/tools/r8/graph/DexWritableCode.java
index 1e8856c..99f8679 100644
--- a/src/main/java/com/android/tools/r8/graph/DexWritableCode.java
+++ b/src/main/java/com/android/tools/r8/graph/DexWritableCode.java
@@ -92,6 +92,8 @@
     return null;
   }
 
+  DexWritableCacheKey getCacheLookupKey(ProgramMethod method, DexItemFactory factory);
+
   /** Rewrites the code to have JumboString bytecode if required by mapping. */
   DexWritableCode rewriteCodeWithJumboStrings(
       ProgramMethod method, ObjectToOffsetMapping mapping, DexItemFactory factory, boolean force);
@@ -105,4 +107,65 @@
       GraphLens codeLens,
       LensCodeRewriterUtils lensCodeRewriter,
       ObjectToOffsetMapping mapping);
+
+  interface DexWritableCacheKey {}
+
+  class DexWritableCodeKey implements DexWritableCacheKey {
+
+    private final DexWritableCode code;
+    private final int incomingRegisterSize;
+    private final int registerSize;
+
+    public DexWritableCodeKey(DexWritableCode code, int incomingRegisterSize, int registerSize) {
+      this.code = code;
+      this.incomingRegisterSize = incomingRegisterSize;
+      this.registerSize = registerSize;
+    }
+
+    @Override
+    public int hashCode() {
+      return code.hashCode() + incomingRegisterSize * 13 + registerSize * 17;
+    }
+
+    @Override
+    public boolean equals(Object other) {
+      if (this == other) {
+        return true;
+      }
+      if (!(other instanceof DexWritableCodeKey)) {
+        return false;
+      }
+      DexWritableCodeKey that = (DexWritableCodeKey) other;
+      return code.equals(that.code)
+          && incomingRegisterSize == that.incomingRegisterSize
+          && registerSize == that.registerSize;
+    }
+  }
+
+  class AmendedDexWritableCodeKey<S> extends DexWritableCodeKey {
+    private final S extra;
+
+    public AmendedDexWritableCodeKey(
+        DexWritableCode code, S extra, int incomingRegisterSize, int registerSize) {
+      super(code, incomingRegisterSize, registerSize);
+      this.extra = extra;
+    }
+
+    @Override
+    public int hashCode() {
+      return super.hashCode() + extra.hashCode() * 7;
+    }
+
+    @Override
+    public boolean equals(Object other) {
+      if (this == other) {
+        return true;
+      }
+      if (!(other instanceof AmendedDexWritableCodeKey)) {
+        return false;
+      }
+      AmendedDexWritableCodeKey that = (AmendedDexWritableCodeKey) other;
+      return super.equals(other) && extra.equals(that.extra);
+    }
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/graph/ThrowExceptionCode.java b/src/main/java/com/android/tools/r8/graph/ThrowExceptionCode.java
index 383504d..19d78c1 100644
--- a/src/main/java/com/android/tools/r8/graph/ThrowExceptionCode.java
+++ b/src/main/java/com/android/tools/r8/graph/ThrowExceptionCode.java
@@ -235,6 +235,12 @@
   }
 
   @Override
+  public DexWritableCacheKey getCacheLookupKey(ProgramMethod method, DexItemFactory factory) {
+    return new AmendedDexWritableCodeKey<DexType>(
+        this, exceptionType, getIncomingRegisterSize(method), getRegisterSize(method));
+  }
+
+  @Override
   public String toString() {
     return "ThrowExceptionCode";
   }
diff --git a/src/main/java/com/android/tools/r8/graph/ThrowNullCode.java b/src/main/java/com/android/tools/r8/graph/ThrowNullCode.java
index 82f320c..707e45a 100644
--- a/src/main/java/com/android/tools/r8/graph/ThrowNullCode.java
+++ b/src/main/java/com/android/tools/r8/graph/ThrowNullCode.java
@@ -274,6 +274,12 @@
     return "ThrowNullCode";
   }
 
+  @Override
+  public DexWritableCacheKey getCacheLookupKey(ProgramMethod method, DexItemFactory factory) {
+    return new AmendedDexWritableCodeKey<DexWritableCode>(
+        this, this, getIncomingRegisterSize(method), getRegisterSize(method));
+  }
+
   static class ThrowNullSourceCode extends SyntheticStraightLineSourceCode {
 
     ThrowNullSourceCode(ProgramMethod method) {
diff --git a/src/main/java/com/android/tools/r8/utils/InternalOptions.java b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
index 615c439..41ada41 100644
--- a/src/main/java/com/android/tools/r8/utils/InternalOptions.java
+++ b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
@@ -2541,6 +2541,10 @@
     return hasFeaturePresentFrom(AndroidApiLevel.K);
   }
 
+  public boolean canUseCanonicalizedCodeObjects() {
+    return hasFeaturePresentFrom(AndroidApiLevel.S);
+  }
+
   public CfVersion classFileVersionAfterDesugaring(CfVersion version) {
     assert isGeneratingClassFiles();
     if (!isDesugaring()) {
diff --git a/src/test/java/com/android/tools/r8/dex/DexCodeDeduppingTest.java b/src/test/java/com/android/tools/r8/dex/DexCodeDeduppingTest.java
new file mode 100644
index 0000000..4b0b57a
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/dex/DexCodeDeduppingTest.java
@@ -0,0 +1,207 @@
+// Copyright (c) 2023, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.dex;
+
+import static org.junit.Assert.assertEquals;
+
+import com.android.tools.r8.CompilationFailedException;
+import com.android.tools.r8.D8TestCompileResult;
+import com.android.tools.r8.DexSegments;
+import com.android.tools.r8.DexSegments.Command;
+import com.android.tools.r8.DexSegments.SegmentInfo;
+import com.android.tools.r8.R8TestCompileResult;
+import com.android.tools.r8.ResourceException;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.AndroidApiLevel;
+import com.google.common.collect.ImmutableList;
+import java.io.IOException;
+import java.nio.file.Path;
+import java.util.List;
+import java.util.Map;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class DexCodeDeduppingTest extends TestBase {
+  private final TestParameters parameters;
+  private static final List<String> EXPECTED = ImmutableList.of("foo", "bar", "foo", "bar");
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withDexRuntimes().withAllApiLevels().build();
+  }
+
+  public DexCodeDeduppingTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  @Test
+  public void testR8SingleClass() throws Exception {
+    R8TestCompileResult compile =
+        testForR8(parameters.getBackend())
+            .addProgramClasses(Foo.class)
+            .setMinApi(parameters)
+            .addKeepAllClassesRule()
+            .compile();
+    compile.run(parameters.getRuntime(), Foo.class).assertSuccessWithOutputLines(EXPECTED);
+    assertFooSizes(compile.writeToZip());
+  }
+
+  @Test
+  public void testR8WithLinesSingleClass() throws Exception {
+    R8TestCompileResult compile =
+        testForR8(parameters.getBackend())
+            .addProgramClasses(Foo.class)
+            .setMinApi(parameters)
+            .addKeepAllClassesRule()
+            .addKeepAttributeLineNumberTable()
+            .compile();
+    compile.run(parameters.getRuntime(), Foo.class).assertSuccessWithOutputLines(EXPECTED);
+    assertFooSizes(compile.writeToZip());
+  }
+
+  @Test
+  public void testD8SingleClassMappingOutput() throws Exception {
+    D8TestCompileResult compile =
+        testForD8(parameters.getBackend())
+            .addProgramClasses(Foo.class)
+            .setMinApi(parameters)
+            .release()
+            .internalEnableMappingOutput()
+            .compile();
+    compile.run(parameters.getRuntime(), Foo.class).assertSuccessWithOutputLines(EXPECTED);
+    assertFooSizes(compile.writeToZip());
+  }
+
+  @Test
+  public void testD8SingleClassNoMappingOutput() throws Exception {
+    D8TestCompileResult compile =
+        testForD8(parameters.getBackend())
+            .addProgramClasses(Foo.class)
+            .setMinApi(parameters)
+            .release()
+            .compile();
+    compile.run(parameters.getRuntime(), Foo.class).assertSuccessWithOutputLines(EXPECTED);
+    // When d8 has no map output we can't share debug info and hence can't share code.
+    assertSizes(compile.writeToZip(), 4, 4);
+  }
+
+  @Test
+  public void testR8TwoClasses() throws Exception {
+    R8TestCompileResult compile =
+        testForR8(parameters.getBackend())
+            .addProgramClasses(Foo.class, Bar.class)
+            .setMinApi(parameters)
+            .addKeepAllClassesRule()
+            .compile();
+    compile.run(parameters.getRuntime(), Foo.class).assertSuccessWithOutputLines(EXPECTED);
+    assertFooAndBarSizes(compile.writeToZip());
+  }
+
+  @Test
+  public void testR8WithLinesTwoClasses() throws Exception {
+    R8TestCompileResult compile =
+        testForR8(parameters.getBackend())
+            .addProgramClasses(Foo.class, Bar.class)
+            .addKeepAttributeLineNumberTable()
+            .setMinApi(parameters)
+            .addKeepAllClassesRule()
+            .compile();
+    compile.run(parameters.getRuntime(), Foo.class).assertSuccessWithOutputLines(EXPECTED);
+    assertFooAndBarSizes(compile.writeToZip());
+  }
+
+  @Test
+  public void testD8TwoClassesMappingOutput() throws Exception {
+    D8TestCompileResult compile =
+        testForD8(parameters.getBackend())
+            .addProgramClasses(Foo.class, Bar.class)
+            .setMinApi(parameters)
+            .release()
+            .internalEnableMappingOutput()
+            .compile();
+    compile.run(parameters.getRuntime(), Foo.class).assertSuccessWithOutputLines(EXPECTED);
+    assertFooAndBarSizes(compile.writeToZip());
+  }
+
+  @Test
+  public void testD8TwoClassesNoMappingOutput() throws Exception {
+    D8TestCompileResult compile =
+        testForD8(parameters.getBackend())
+            .addProgramClasses(Foo.class, Bar.class)
+            .setMinApi(parameters)
+            .release()
+            .compile();
+    compile.run(parameters.getRuntime(), Foo.class).assertSuccessWithOutputLines(EXPECTED);
+    // When d8 has no map output we can't share debug info and hence can't share code.
+    assertSizes(compile.writeToZip(), 6, 6);
+  }
+
+  private void assertFooSizes(Path output) throws Exception {
+    assertSizes(output, 3, 4);
+  }
+
+  private void assertFooAndBarSizes(Path output) throws Exception {
+    assertSizes(output, 3, 6);
+  }
+
+  private void assertSizes(Path output, int deduppedSize, int originalSize)
+      throws CompilationFailedException, ResourceException, IOException {
+    if (parameters.isDexRuntime()) {
+      SegmentInfo codeSegmentInfo = getCodeSegmentInfo(output);
+      if (parameters.getApiLevel().isGreaterThanOrEqualTo(AndroidApiLevel.S)) {
+        assertEquals(codeSegmentInfo.getItemCount(), deduppedSize);
+      } else {
+        assertEquals(codeSegmentInfo.getItemCount(), originalSize);
+      }
+    }
+  }
+
+  public SegmentInfo getCodeSegmentInfo(Path path)
+      throws CompilationFailedException, ResourceException, IOException {
+    Command.Builder builder = Command.builder().addProgramFiles(path);
+    Map<Integer, SegmentInfo> segmentInfoMap = DexSegments.run(builder.build());
+    return segmentInfoMap.get(Constants.TYPE_CODE_ITEM);
+  }
+
+  public static class Foo {
+    public static void main(String[] args) {
+      foo();
+      bar();
+    }
+
+    public static void foo() {
+      if (System.currentTimeMillis() == 0) {
+        System.out.println("That was early");
+      } else {
+        System.out.println("foo");
+      }
+      System.out.println("bar");
+    }
+
+    public static void bar() {
+      if (System.currentTimeMillis() == 0) {
+        System.out.println("That was early");
+      } else {
+        System.out.println("foo");
+      }
+      System.out.println("bar");
+    }
+  }
+
+  public static class Bar {
+    public static void foo() {
+      if (System.currentTimeMillis() == 0) {
+        System.out.println("That was early");
+      } else {
+        System.out.println("foo");
+      }
+      System.out.println("bar");
+    }
+  }
+}