SafeCheckCast instruction for compiler synthesized casts
Fixes: 190368585
Change-Id: I63f45367d749f0911c6e75f7e1d048ecf83ea596
diff --git a/src/main/java/com/android/tools/r8/cf/code/CfCheckCast.java b/src/main/java/com/android/tools/r8/cf/code/CfCheckCast.java
index 5ed4009..cfad75e 100644
--- a/src/main/java/com/android/tools/r8/cf/code/CfCheckCast.java
+++ b/src/main/java/com/android/tools/r8/cf/code/CfCheckCast.java
@@ -101,6 +101,10 @@
// Pop the top value and push it back on with the checked type.
state.pop();
Slot object = state.push(type);
+ addCheckCast(builder, object);
+ }
+
+ void addCheckCast(IRBuilder builder, Slot object) {
builder.addCheckCast(object.register, type);
}
diff --git a/src/main/java/com/android/tools/r8/cf/code/CfSafeCheckCast.java b/src/main/java/com/android/tools/r8/cf/code/CfSafeCheckCast.java
new file mode 100644
index 0000000..e1473d3
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/cf/code/CfSafeCheckCast.java
@@ -0,0 +1,35 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.cf.code;
+
+import com.android.tools.r8.graph.DexClassAndMethod;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.UseRegistry;
+import com.android.tools.r8.ir.conversion.CfState.Slot;
+import com.android.tools.r8.ir.conversion.IRBuilder;
+import java.util.ListIterator;
+
+public class CfSafeCheckCast extends CfCheckCast {
+
+ public CfSafeCheckCast(DexType type) {
+ super(type);
+ }
+
+ @Override
+ void addCheckCast(IRBuilder builder, Slot object) {
+ builder.addSafeCheckCast(object.register, getType());
+ }
+
+ @Override
+ void internalRegisterUse(
+ UseRegistry registry, DexClassAndMethod context, ListIterator<CfInstruction> iterator) {
+ registry.registerSafeCheckCast(getType());
+ }
+
+ @Override
+ public CfInstruction withType(DexType newType) {
+ return new CfSafeCheckCast(newType);
+ }
+}
diff --git a/src/main/java/com/android/tools/r8/code/SafeCheckCast.java b/src/main/java/com/android/tools/r8/code/SafeCheckCast.java
new file mode 100644
index 0000000..e4af2a0
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/code/SafeCheckCast.java
@@ -0,0 +1,31 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.code;
+
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.OffsetToObjectMapping;
+import com.android.tools.r8.graph.UseRegistry;
+import com.android.tools.r8.ir.conversion.IRBuilder;
+
+public class SafeCheckCast extends CheckCast {
+
+ SafeCheckCast(int high, BytecodeStream stream, OffsetToObjectMapping mapping) {
+ super(high, stream, mapping);
+ }
+
+ public SafeCheckCast(int valueRegister, DexType type) {
+ super(valueRegister, type);
+ }
+
+ @Override
+ public void buildIR(IRBuilder builder) {
+ builder.addSafeCheckCast(AA, getType());
+ }
+
+ @Override
+ public void registerUse(UseRegistry registry) {
+ registry.registerSafeCheckCast(getType());
+ }
+}
diff --git a/src/main/java/com/android/tools/r8/graph/UseRegistry.java b/src/main/java/com/android/tools/r8/graph/UseRegistry.java
index 565511a..708a02f 100644
--- a/src/main/java/com/android/tools/r8/graph/UseRegistry.java
+++ b/src/main/java/com/android/tools/r8/graph/UseRegistry.java
@@ -82,6 +82,10 @@
registerTypeReference(type);
}
+ public void registerSafeCheckCast(DexType type) {
+ registerCheckCast(type);
+ }
+
public void registerExceptionGuard(DexType guard) {
registerTypeReference(guard);
}
diff --git a/src/main/java/com/android/tools/r8/graph/analysis/EnqueuerCheckCastAnalysis.java b/src/main/java/com/android/tools/r8/graph/analysis/EnqueuerCheckCastAnalysis.java
index 503e74e..ad0dcbf 100644
--- a/src/main/java/com/android/tools/r8/graph/analysis/EnqueuerCheckCastAnalysis.java
+++ b/src/main/java/com/android/tools/r8/graph/analysis/EnqueuerCheckCastAnalysis.java
@@ -8,5 +8,8 @@
import com.android.tools.r8.graph.ProgramMethod;
public interface EnqueuerCheckCastAnalysis {
+
void traceCheckCast(DexType type, ProgramMethod context);
+
+ void traceSafeCheckCast(DexType type, ProgramMethod context);
}
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteBuilderShrinker.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteBuilderShrinker.java
index 06e1e90..3aa2dac 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteBuilderShrinker.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/GeneratedMessageLiteBuilderShrinker.java
@@ -28,6 +28,7 @@
import com.android.tools.r8.ir.code.InvokeDirect;
import com.android.tools.r8.ir.code.InvokeVirtual;
import com.android.tools.r8.ir.code.NewInstance;
+import com.android.tools.r8.ir.code.SafeCheckCast;
import com.android.tools.r8.ir.code.StaticGet;
import com.android.tools.r8.ir.code.Value;
import com.android.tools.r8.ir.conversion.CallGraph.Node;
@@ -373,7 +374,7 @@
DexType rawReceiverType = receiverType.getClassType();
if (appInfo.isStrictSubtypeOf(rawReceiverType, references.generatedMessageLiteType)) {
Value dest = code.createValue(receiverType.asMaybeNull(), checkCast.getLocalInfo());
- CheckCast replacement = new CheckCast(dest, checkCast.object(), rawReceiverType);
+ SafeCheckCast replacement = new SafeCheckCast(dest, checkCast.object(), rawReceiverType);
instructionIterator.replaceCurrentInstruction(replacement, affectedValues);
}
}
diff --git a/src/main/java/com/android/tools/r8/ir/code/BasicBlockInstructionListIterator.java b/src/main/java/com/android/tools/r8/ir/code/BasicBlockInstructionListIterator.java
index 6917a40..2448b0c 100644
--- a/src/main/java/com/android/tools/r8/ir/code/BasicBlockInstructionListIterator.java
+++ b/src/main/java/com/android/tools/r8/ir/code/BasicBlockInstructionListIterator.java
@@ -788,8 +788,8 @@
Value receiver = invoke.inValues().get(0);
TypeElement castTypeLattice =
TypeElement.fromDexType(downcast, receiver.getType().nullability(), appView);
- CheckCast castInstruction =
- new CheckCast(code.createValue(castTypeLattice), receiver, downcast);
+ SafeCheckCast castInstruction =
+ new SafeCheckCast(code.createValue(castTypeLattice), receiver, downcast);
castInstruction.setPosition(invoke.getPosition());
// Splice in the check cast operation.
diff --git a/src/main/java/com/android/tools/r8/ir/code/CheckCast.java b/src/main/java/com/android/tools/r8/ir/code/CheckCast.java
index 3e0f741..b02559a 100644
--- a/src/main/java/com/android/tools/r8/ir/code/CheckCast.java
+++ b/src/main/java/com/android/tools/r8/ir/code/CheckCast.java
@@ -67,14 +67,13 @@
// we have to insert a move before the check cast instruction.
int inRegister = builder.allocatedRegister(inValues.get(0), getNumber());
if (outValue == null) {
- builder.add(this, new com.android.tools.r8.code.CheckCast(inRegister, type));
+ builder.add(this, createCheckCast(inRegister));
} else {
int outRegister = builder.allocatedRegister(outValue, getNumber());
if (inRegister == outRegister) {
- builder.add(this, new com.android.tools.r8.code.CheckCast(outRegister, type));
+ builder.add(this, createCheckCast(outRegister));
} else {
- com.android.tools.r8.code.CheckCast cast =
- new com.android.tools.r8.code.CheckCast(outRegister, type);
+ com.android.tools.r8.code.CheckCast cast = createCheckCast(outRegister);
if (outRegister <= Constants.U4BIT_MAX && inRegister <= Constants.U4BIT_MAX) {
builder.add(this, new MoveObject(outRegister, inRegister), cast);
} else {
@@ -84,6 +83,10 @@
}
}
+ com.android.tools.r8.code.CheckCast createCheckCast(int register) {
+ return new com.android.tools.r8.code.CheckCast(register, getType());
+ }
+
@Override
public boolean identicalNonValueNonPositionParts(Instruction other) {
return other.isCheckCast() && other.asCheckCast().type == type;
@@ -237,8 +240,8 @@
public static class Builder extends BuilderBase<Builder, CheckCast> {
- private DexType castType;
- private Value object;
+ protected DexType castType;
+ protected Value object;
public Builder setCastType(DexType castType) {
this.castType = castType;
diff --git a/src/main/java/com/android/tools/r8/ir/code/SafeCheckCast.java b/src/main/java/com/android/tools/r8/ir/code/SafeCheckCast.java
new file mode 100644
index 0000000..8b36c90
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/code/SafeCheckCast.java
@@ -0,0 +1,45 @@
+// Copyright (c) 2021, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.code;
+
+import com.android.tools.r8.cf.code.CfSafeCheckCast;
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.graph.ProgramMethod;
+import com.android.tools.r8.ir.conversion.CfBuilder;
+
+public class SafeCheckCast extends CheckCast {
+
+ public SafeCheckCast(Value dest, Value value, DexType type) {
+ super(dest, value, type);
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ @Override
+ public void buildCf(CfBuilder builder) {
+ builder.add(new CfSafeCheckCast(getType()));
+ }
+
+ @Override
+ com.android.tools.r8.code.CheckCast createCheckCast(int register) {
+ return new com.android.tools.r8.code.SafeCheckCast(register, getType());
+ }
+
+ @Override
+ public boolean instructionInstanceCanThrow(AppView<?> appView, ProgramMethod context) {
+ return false;
+ }
+
+ public static class Builder extends CheckCast.Builder {
+
+ @Override
+ public CheckCast build() {
+ return amend(new SafeCheckCast(outValue, object, castType));
+ }
+ }
+}
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRBuilder.java b/src/main/java/com/android/tools/r8/ir/conversion/IRBuilder.java
index 4ae55f6..ba1a15c 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRBuilder.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRBuilder.java
@@ -102,6 +102,7 @@
import com.android.tools.r8.ir.code.Position;
import com.android.tools.r8.ir.code.Rem;
import com.android.tools.r8.ir.code.Return;
+import com.android.tools.r8.ir.code.SafeCheckCast;
import com.android.tools.r8.ir.code.Shl;
import com.android.tools.r8.ir.code.Shr;
import com.android.tools.r8.ir.code.StaticGet;
@@ -1181,11 +1182,20 @@
}
public void addCheckCast(int value, DexType type) {
+ internalAddCheckCast(value, type, false);
+ }
+
+ public void addSafeCheckCast(int value, DexType type) {
+ internalAddCheckCast(value, type, true);
+ }
+
+ private void internalAddCheckCast(int value, DexType type, boolean isSafe) {
Value in = readRegister(value, ValueTypeConstraint.OBJECT);
TypeElement castTypeLattice =
TypeElement.fromDexType(type, in.getType().nullability(), appView);
Value out = writeRegister(value, castTypeLattice, ThrowingInfo.CAN_THROW);
- CheckCast instruction = new CheckCast(out, in, type);
+ CheckCast instruction =
+ isSafe ? new SafeCheckCast(out, in, type) : new CheckCast(out, in, type);
assert instruction.instructionTypeCanThrow();
add(instruction);
}
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
index b02d5e7..f44ec67 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
@@ -86,6 +86,7 @@
import com.android.tools.r8.ir.code.Phi;
import com.android.tools.r8.ir.code.Position;
import com.android.tools.r8.ir.code.Return;
+import com.android.tools.r8.ir.code.SafeCheckCast;
import com.android.tools.r8.ir.code.StaticGet;
import com.android.tools.r8.ir.code.StaticPut;
import com.android.tools.r8.ir.code.TypeAndLocalInfoSupplier;
@@ -415,7 +416,7 @@
Value castOutValue = code.createValue(castType);
newOutValue.replaceUsers(castOutValue);
CheckCast checkCast =
- CheckCast.builder()
+ SafeCheckCast.builder()
.setCastType(lookup.getCastType())
.setObject(newOutValue)
.setOutValue(castOutValue)
@@ -479,7 +480,7 @@
Value castOutValue = code.createValue(castType);
newOutValue.replaceUsers(castOutValue);
CheckCast checkCast =
- CheckCast.builder()
+ SafeCheckCast.builder()
.setCastType(lookup.getCastType())
.setObject(newOutValue)
.setOutValue(castOutValue)
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/Devirtualizer.java b/src/main/java/com/android/tools/r8/ir/optimize/Devirtualizer.java
index 7dae8cd..55780b6 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/Devirtualizer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/Devirtualizer.java
@@ -25,6 +25,7 @@
import com.android.tools.r8.ir.code.InvokeInterface;
import com.android.tools.r8.ir.code.InvokeSuper;
import com.android.tools.r8.ir.code.InvokeVirtual;
+import com.android.tools.r8.ir.code.SafeCheckCast;
import com.android.tools.r8.ir.code.Value;
import com.android.tools.r8.shaking.AppInfoWithLiveness;
import com.android.tools.r8.utils.InternalOptions;
@@ -63,7 +64,7 @@
Map<InvokeInterface, InvokeVirtual> devirtualizedCall = new IdentityHashMap<>();
DominatorTree dominatorTree = new DominatorTree(code);
Map<Value, Map<DexType, Value>> castedReceiverCache = new IdentityHashMap<>();
- Set<CheckCast> newCheckCastInstructions = Sets.newIdentityHashSet();
+ Set<SafeCheckCast> newCheckCastInstructions = Sets.newIdentityHashSet();
ListIterator<BasicBlock> blocks = code.listIterator();
while (blocks.hasNext()) {
@@ -261,7 +262,8 @@
castedReceiverCache.putIfAbsent(receiver, new IdentityHashMap<>());
castedReceiverCache.get(receiver).put(holderClass.getType(), newReceiver);
}
- CheckCast checkCast = new CheckCast(newReceiver, receiver, holderClass.getType());
+ SafeCheckCast checkCast =
+ new SafeCheckCast(newReceiver, receiver, holderClass.getType());
checkCast.setPosition(invoke.getPosition());
newCheckCastInstructions.add(checkCast);
diff --git a/src/main/java/com/android/tools/r8/shaking/DefaultEnqueuerUseRegistry.java b/src/main/java/com/android/tools/r8/shaking/DefaultEnqueuerUseRegistry.java
index b560847..da86fe5 100644
--- a/src/main/java/com/android/tools/r8/shaking/DefaultEnqueuerUseRegistry.java
+++ b/src/main/java/com/android/tools/r8/shaking/DefaultEnqueuerUseRegistry.java
@@ -155,6 +155,11 @@
}
@Override
+ public void registerSafeCheckCast(DexType type) {
+ enqueuer.traceSafeCheckCast(type, context);
+ }
+
+ @Override
public void registerTypeReference(DexType type) {
enqueuer.traceTypeReference(type, context);
}
diff --git a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
index 16d8707..edb24b0 100644
--- a/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
+++ b/src/main/java/com/android/tools/r8/shaking/Enqueuer.java
@@ -1032,6 +1032,11 @@
traceConstClassOrCheckCast(type, currentMethod);
}
+ void traceSafeCheckCast(DexType type, ProgramMethod currentMethod) {
+ checkCastAnalyses.forEach(analysis -> analysis.traceSafeCheckCast(type, currentMethod));
+ traceCompilerSynthesizedConstClassOrCheckCast(type, currentMethod);
+ }
+
void traceConstClass(
DexType type,
ProgramMethod currentMethod,
@@ -1092,8 +1097,21 @@
}
private void traceConstClassOrCheckCast(DexType type, ProgramMethod currentMethod) {
+ internalTraceConstClassOrCheckCast(type, currentMethod, false);
+ }
+
+ // TODO(b/190487539): Currently only used by traceSafeCheckCast(), but should also be used to
+ // ensure we don't trigger compat behavior for const-class instructions synthesized for
+ // synchronized methods.
+ private void traceCompilerSynthesizedConstClassOrCheckCast(
+ DexType type, ProgramMethod currentMethod) {
+ internalTraceConstClassOrCheckCast(type, currentMethod, true);
+ }
+
+ private void internalTraceConstClassOrCheckCast(
+ DexType type, ProgramMethod currentMethod, boolean isCompilerSynthesized) {
traceTypeReference(type, currentMethod);
- if (!forceProguardCompatibility) {
+ if (!forceProguardCompatibility || isCompilerSynthesized) {
return;
}
DexType baseType = type.toBaseType(appView.dexItemFactory());
diff --git a/src/main/java/com/android/tools/r8/shaking/RuntimeTypeCheckInfo.java b/src/main/java/com/android/tools/r8/shaking/RuntimeTypeCheckInfo.java
index fc2756a..67f1c75 100644
--- a/src/main/java/com/android/tools/r8/shaking/RuntimeTypeCheckInfo.java
+++ b/src/main/java/com/android/tools/r8/shaking/RuntimeTypeCheckInfo.java
@@ -61,6 +61,11 @@
}
@Override
+ public void traceSafeCheckCast(DexType type, ProgramMethod context) {
+ // Intentionally empty.
+ }
+
+ @Override
public void traceInstanceOf(DexType type, ProgramMethod context) {
add(type, instanceOfTypes);
}
diff --git a/src/test/java/com/android/tools/r8/classmerging/horizontal/MergingWithSafeCheckCastTest.java b/src/test/java/com/android/tools/r8/classmerging/horizontal/MergingWithSafeCheckCastTest.java
new file mode 100644
index 0000000..02f0f5b
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/classmerging/horizontal/MergingWithSafeCheckCastTest.java
@@ -0,0 +1,125 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.classmerging.horizontal;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import com.android.tools.r8.NeverClassInline;
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.NeverPropagateValue;
+import com.android.tools.r8.NoHorizontalClassMerging;
+import com.android.tools.r8.NoUnusedInterfaceRemoval;
+import com.android.tools.r8.NoVerticalClassMerging;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import org.junit.Test;
+
+public class MergingWithSafeCheckCastTest extends HorizontalClassMergingTestBase {
+
+ public MergingWithSafeCheckCastTest(TestParameters parameters) {
+ super(parameters);
+ }
+
+ @Test
+ public void testR8() throws Exception {
+ testForR8(parameters.getBackend())
+ .addInnerClasses(getClass())
+ .addKeepMainRule(Main.class)
+ .enableInliningAnnotations()
+ .enableMemberValuePropagationAnnotations()
+ .enableNeverClassInliningAnnotations()
+ .enableNoHorizontalClassMergingAnnotations()
+ .enableNoUnusedInterfaceRemovalAnnotations()
+ .enableNoVerticalClassMergingAnnotations()
+ .setMinApi(parameters.getApiLevel())
+ .addHorizontallyMergedClassesInspector(
+ inspector ->
+ inspector
+ .assertIsCompleteMergeGroup(A.class, B.class)
+ .assertIsCompleteMergeGroup(I.class, J.class)
+ .assertNoOtherClassesMerged())
+ .compile()
+ .inspect(
+ inspector -> {
+ ClassSubject aClassSubject = inspector.clazz(A.class);
+ assertThat(aClassSubject, isPresent());
+
+ // Check that the field f has been changed to have type java.lang.Object.
+ assertEquals(1, aClassSubject.allFields().size());
+ assertEquals(
+ Object.class.getTypeName(),
+ aClassSubject.allFields().get(0).getField().getType().getTypeName());
+
+ // Check that casts have been inserted into main().
+ MethodSubject mainMethodSubject = inspector.clazz(Main.class).mainMethod();
+ assertTrue(
+ mainMethodSubject.streamInstructions().anyMatch(InstructionSubject::isCheckCast));
+ })
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutputLines("I", "J");
+ }
+
+ public static class Main {
+ public static void main(String[] args) {
+ new A(new IImpl()).f.printI();
+ new B(new JImpl()).f.printJ();
+ }
+ }
+
+ @NeverClassInline
+ public static class A {
+
+ @NeverPropagateValue I f;
+
+ A(I f) {
+ this.f = f;
+ }
+ }
+
+ @NeverClassInline
+ public static class B {
+
+ @NeverPropagateValue J f;
+
+ B(J f) {
+ this.f = f;
+ }
+ }
+
+ @NoUnusedInterfaceRemoval
+ @NoVerticalClassMerging
+ interface I {
+ void printI();
+ }
+
+ @NoUnusedInterfaceRemoval
+ @NoVerticalClassMerging
+ interface J {
+ void printJ();
+ }
+
+ @NeverClassInline
+ @NoHorizontalClassMerging
+ static class IImpl implements I {
+ @NeverInline
+ public void printI() {
+ System.out.println("I");
+ }
+ }
+
+ @NeverClassInline
+ @NoHorizontalClassMerging
+ static class JImpl implements J {
+ @NeverInline
+ public void printJ() {
+ System.out.println("J");
+ }
+ }
+}