Reapply "Fix TypeSwitch for enums"
This reverts commit 700298b73999005a05be4c3190a4d8aaa6690b9e.
Bug: b/336510513
Change-Id: I00b04b80e36750be6fc6404ae62d80ba05452d6a
diff --git a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
index e3eaa24..77e3369 100644
--- a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
+++ b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
@@ -261,6 +261,8 @@
public final DexString classLoaderDescriptor = createString("Ljava/lang/ClassLoader;");
public final DexString autoCloseableDescriptor = createString("Ljava/lang/AutoCloseable;");
public final DexString classArrayDescriptor = createString("[Ljava/lang/Class;");
+ public final DexString classDescDescriptor = createString("Ljava/lang/constant/ClassDesc;");
+ public final DexString enumDescDescriptor = createString("Ljava/lang/Enum$EnumDesc;");
public final DexString constructorDescriptor = createString("Ljava/lang/reflect/Constructor;");
public final DexString fieldDescriptor = createString("Ljava/lang/reflect/Field;");
public final DexString methodDescriptor = createString("Ljava/lang/reflect/Method;");
@@ -457,6 +459,9 @@
public final DexType stringBuilderType = createStaticallyKnownType(stringBuilderDescriptor);
public final DexType stringBufferType = createStaticallyKnownType(stringBufferDescriptor);
+ public final DexType classDescType = createStaticallyKnownType(classDescDescriptor);
+ public final DexType enumDescType = createStaticallyKnownType(enumDescDescriptor);
+
public final DexType javaLangAnnotationRetentionPolicyType =
createStaticallyKnownType("Ljava/lang/annotation/RetentionPolicy;");
public final DexType javaLangReflectArrayType =
@@ -817,6 +822,17 @@
public final DexType callSiteType = createStaticallyKnownType("Ljava/lang/invoke/CallSite;");
public final DexType lookupType =
createStaticallyKnownType("Ljava/lang/invoke/MethodHandles$Lookup;");
+ public final DexMethod constantDynamicBootstrapMethod =
+ createMethod(
+ constantBootstrapsType,
+ createProto(
+ objectType,
+ methodHandlesLookupType,
+ stringType,
+ classType,
+ methodHandleType,
+ objectArrayType),
+ invokeMethodName);
public final DexType objectMethodsType =
createStaticallyKnownType("Ljava/lang/runtime/ObjectMethods;");
public final DexType typeDescriptorType =
diff --git a/src/main/java/com/android/tools/r8/graph/DexValue.java b/src/main/java/com/android/tools/r8/graph/DexValue.java
index f4f7a60..fe532fc 100644
--- a/src/main/java/com/android/tools/r8/graph/DexValue.java
+++ b/src/main/java/com/android/tools/r8/graph/DexValue.java
@@ -2118,6 +2118,10 @@
return DexValueKind.CONST_DYNAMIC;
}
+ public ConstantDynamicReference getValue() {
+ return value;
+ }
+
private CompilationError throwCannotConvertToDex() {
throw new CompilationError("DexValueConstDynamic should be desugared");
}
diff --git a/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaring.java b/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaring.java
index ce86ab6..e6e99e7 100644
--- a/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaring.java
+++ b/src/main/java/com/android/tools/r8/ir/desugar/typeswitch/TypeSwitchDesugaring.java
@@ -10,20 +10,27 @@
import com.android.tools.r8.cf.code.CfConstString;
import com.android.tools.r8.cf.code.CfInstruction;
import com.android.tools.r8.cf.code.CfInvoke;
+import com.android.tools.r8.cf.code.CfNew;
import com.android.tools.r8.cf.code.CfNewArray;
import com.android.tools.r8.cf.code.CfStackInstruction;
import com.android.tools.r8.cf.code.CfStackInstruction.Opcode;
+import com.android.tools.r8.cf.code.CfStaticFieldRead;
import com.android.tools.r8.contexts.CompilationContext.MethodProcessingContext;
-import com.android.tools.r8.errors.Unreachable;
+import com.android.tools.r8.errors.CompilationError;
import com.android.tools.r8.graph.AppView;
import com.android.tools.r8.graph.CfCode;
import com.android.tools.r8.graph.DexCallSite;
+import com.android.tools.r8.graph.DexClass;
+import com.android.tools.r8.graph.DexEncodedField;
+import com.android.tools.r8.graph.DexField;
import com.android.tools.r8.graph.DexItemFactory;
import com.android.tools.r8.graph.DexMethod;
+import com.android.tools.r8.graph.DexMethodHandle;
import com.android.tools.r8.graph.DexProto;
import com.android.tools.r8.graph.DexString;
import com.android.tools.r8.graph.DexType;
import com.android.tools.r8.graph.DexValue;
+import com.android.tools.r8.graph.DexValue.DexValueConstDynamic;
import com.android.tools.r8.graph.MethodAccessFlags;
import com.android.tools.r8.graph.ProgramMethod;
import com.android.tools.r8.ir.code.MemberType;
@@ -31,6 +38,9 @@
import com.android.tools.r8.ir.desugar.CfInstructionDesugaring;
import com.android.tools.r8.ir.desugar.CfInstructionDesugaringEventConsumer;
import com.android.tools.r8.ir.desugar.DesugarDescription;
+import com.android.tools.r8.ir.desugar.constantdynamic.ConstantDynamicReference;
+import com.android.tools.r8.utils.DescriptorUtils;
+import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
import org.objectweb.asm.Opcodes;
@@ -39,17 +49,19 @@
private final AppView<?> appView;
- private final DexString typeSwitch;
private final DexMethod typeSwitchMethod;
private final DexProto typeSwitchProto;
private final DexProto typeSwitchHelperProto;
+ private final DexMethod enumDescMethod;
+ private final DexMethod classDescMethod;
+ private final DexType matchException;
+ private final DexMethod matchExceptionInit;
+ private final DexItemFactory factory;
public TypeSwitchDesugaring(AppView<?> appView) {
this.appView = appView;
- DexItemFactory factory = appView.dexItemFactory();
- typeSwitchProto = factory.createProto(factory.intType, factory.objectType, factory.intType);
+ this.factory = appView.dexItemFactory();
DexType switchBootstrap = factory.createType("Ljava/lang/runtime/SwitchBootstraps;");
- typeSwitch = factory.createString("typeSwitch");
typeSwitchMethod =
factory.createMethod(
switchBootstrap,
@@ -59,22 +71,85 @@
factory.stringType,
factory.methodTypeType,
factory.objectArrayType),
- typeSwitch);
+ factory.createString("typeSwitch"));
+ typeSwitchProto = factory.createProto(factory.intType, factory.objectType, factory.intType);
typeSwitchHelperProto =
factory.createProto(
factory.intType, factory.objectType, factory.intType, factory.objectArrayType);
+ enumDescMethod =
+ factory.createMethod(
+ factory.enumDescType,
+ factory.createProto(factory.enumDescType, factory.classDescType, factory.stringType),
+ "of");
+ classDescMethod =
+ factory.createMethod(
+ factory.classDescType,
+ factory.createProto(factory.classDescType, factory.stringType),
+ "of");
+ matchException = factory.createType("Ljava/lang/MatchException;");
+ matchExceptionInit =
+ factory.createInstanceInitializer(
+ matchException, factory.stringType, factory.throwableType);
+ }
+
+ private boolean methodHandleIsInvokeStaticTo(DexValue dexValue, DexMethod method) {
+ if (!dexValue.isDexValueMethodHandle()) {
+ return false;
+ }
+ return methodHandleIsInvokeStaticTo(dexValue.asDexValueMethodHandle().getValue(), method);
+ }
+
+ private boolean methodHandleIsInvokeStaticTo(DexMethodHandle methodHandle, DexMethod method) {
+ return methodHandle.type.isInvokeStatic() && methodHandle.asMethod().isIdenticalTo(method);
}
@Override
public DesugarDescription compute(CfInstruction instruction, ProgramMethod context) {
if (!instruction.isInvokeDynamic()) {
+ // We need to replace the new MatchException with RuntimeException.
+ if (instruction.isNew() && instruction.asNew().getType().isIdenticalTo(matchException)) {
+ return DesugarDescription.builder()
+ .setDesugarRewrite(
+ (position,
+ freshLocalProvider,
+ localStackAllocator,
+ desugaringInfo,
+ eventConsumer,
+ theContext,
+ methodProcessingContext,
+ desugaringCollection,
+ dexItemFactory) -> ImmutableList.of(new CfNew(factory.runtimeExceptionType)))
+ .build();
+ }
+ if (instruction.isInvokeSpecial()
+ && instruction.asInvoke().getMethod().isIdenticalTo(matchExceptionInit)) {
+ return DesugarDescription.builder()
+ .setDesugarRewrite(
+ (position,
+ freshLocalProvider,
+ localStackAllocator,
+ desugaringInfo,
+ eventConsumer,
+ theContext,
+ methodProcessingContext,
+ desugaringCollection,
+ dexItemFactory) ->
+ ImmutableList.of(
+ new CfInvoke(
+ Opcodes.INVOKESPECIAL,
+ factory.createInstanceInitializer(
+ factory.runtimeExceptionType,
+ factory.stringType,
+ factory.throwableType),
+ false)))
+ .build();
+ }
return DesugarDescription.nothing();
}
DexCallSite callSite = instruction.asInvokeDynamic().getCallSite();
- if (!(callSite.methodName.isIdenticalTo(typeSwitch)
+ if (!(callSite.methodName.isIdenticalTo(typeSwitchMethod.getName())
&& callSite.methodProto.isIdenticalTo(typeSwitchProto)
- && callSite.bootstrapMethod.member.isDexMethod()
- && callSite.bootstrapMethod.member.asDexMethod().isIdenticalTo(typeSwitchMethod))) {
+ && methodHandleIsInvokeStaticTo(callSite.bootstrapMethod, typeSwitchMethod))) {
return DesugarDescription.nothing();
}
// Call the desugared method.
@@ -91,7 +166,7 @@
dexItemFactory) -> {
// We add on stack (2) array, (3) dupped array, (4) index, (5) value.
localStackAllocator.allocateLocalStack(4);
- List<CfInstruction> cfInstructions = generateLoadArguments(callSite);
+ List<CfInstruction> cfInstructions = generateLoadArguments(callSite, context);
generateInvokeToDesugaredMethod(
methodProcessingContext, cfInstructions, theContext, eventConsumer);
return cfInstructions;
@@ -120,14 +195,10 @@
methodSig -> {
CfCode code =
TypeSwitchMethods.TypeSwitchMethods_typeSwitch(
- appView.dexItemFactory(), methodSig);
+ factory, methodSig);
if (appView.options().hasMappingFileSupport()) {
return code.getCodeAsInlining(
- methodSig,
- true,
- context.getReference(),
- false,
- appView.dexItemFactory());
+ methodSig, true, context.getReference(), false, factory);
}
return code;
}));
@@ -135,13 +206,13 @@
cfInstructions.add(new CfInvoke(Opcodes.INVOKESTATIC, method.getReference(), false));
}
- private List<CfInstruction> generateLoadArguments(DexCallSite callSite) {
+ private List<CfInstruction> generateLoadArguments(DexCallSite callSite, ProgramMethod context) {
// We need to call the method with the bootstrap args as parameters.
// We need to convert the bootstrap args into a list of cf instructions.
// The object and the int are already pushed on stack, we simply need to push the extra array.
List<CfInstruction> cfInstructions = new ArrayList<>();
cfInstructions.add(new CfConstNumber(callSite.bootstrapArgs.size(), ValueType.INT));
- cfInstructions.add(new CfNewArray(appView.dexItemFactory().objectArrayType));
+ cfInstructions.add(new CfNewArray(factory.objectArrayType));
for (int i = 0; i < callSite.bootstrapArgs.size(); i++) {
DexValue bootstrapArg = callSite.bootstrapArgs.get(i);
cfInstructions.add(new CfStackInstruction(Opcode.Dup));
@@ -152,16 +223,72 @@
cfInstructions.add(
new CfConstNumber(bootstrapArg.asDexValueInt().getValue(), ValueType.INT));
cfInstructions.add(
- new CfInvoke(
- Opcodes.INVOKESTATIC, appView.dexItemFactory().integerMembers.valueOf, false));
+ new CfInvoke(Opcodes.INVOKESTATIC, factory.integerMembers.valueOf, false));
} else if (bootstrapArg.isDexValueString()) {
cfInstructions.add(new CfConstString(bootstrapArg.asDexValueString().getValue()));
} else {
assert bootstrapArg.isDexValueConstDynamic();
- throw new Unreachable("TODO(b/336510513): Enum descriptor should be implemented");
+ DexField enumField = extractEnumField(bootstrapArg.asDexValueConstDynamic(), context);
+ cfInstructions.add(new CfStaticFieldRead(enumField));
}
cfInstructions.add(new CfArrayStore(MemberType.OBJECT));
}
return cfInstructions;
}
+
+ private CompilationError throwEnumFieldConstantDynamic(String msg, ProgramMethod context) {
+ throw new CompilationError(
+ "Unexpected ConstantDynamic in TypeSwitch: " + msg, context.getOrigin());
+ }
+
+ private DexField extractEnumField(
+ DexValueConstDynamic dexValueConstDynamic, ProgramMethod context) {
+ ConstantDynamicReference enumCstDynamic = dexValueConstDynamic.getValue();
+ DexMethod bootstrapMethod = factory.constantDynamicBootstrapMethod;
+ if (!(enumCstDynamic.getType().isIdenticalTo(factory.enumDescType)
+ && enumCstDynamic.getName().isIdenticalTo(bootstrapMethod.getName())
+ && enumCstDynamic.getBootstrapMethod().asMethod().isIdenticalTo(bootstrapMethod)
+ && enumCstDynamic.getBootstrapMethodArguments().size() == 3
+ && methodHandleIsInvokeStaticTo(
+ enumCstDynamic.getBootstrapMethodArguments().get(0), enumDescMethod))) {
+ throw throwEnumFieldConstantDynamic("Invalid EnumDesc", context);
+ }
+ DexValue dexValueFieldName = enumCstDynamic.getBootstrapMethodArguments().get(2);
+ if (!dexValueFieldName.isDexValueString()) {
+ throw throwEnumFieldConstantDynamic("Field name " + dexValueFieldName, context);
+ }
+ DexString fieldName = dexValueFieldName.asDexValueString().getValue();
+
+ DexValue dexValueClassCstDynamic = enumCstDynamic.getBootstrapMethodArguments().get(1);
+ if (!dexValueClassCstDynamic.isDexValueConstDynamic()) {
+ throw throwEnumFieldConstantDynamic("Enum class " + dexValueClassCstDynamic, context);
+ }
+ ConstantDynamicReference classCstDynamic =
+ dexValueClassCstDynamic.asDexValueConstDynamic().getValue();
+ if (!(classCstDynamic.getType().isIdenticalTo(factory.classDescType)
+ && classCstDynamic.getName().isIdenticalTo(bootstrapMethod.getName())
+ && classCstDynamic.getBootstrapMethod().asMethod().isIdenticalTo(bootstrapMethod)
+ && classCstDynamic.getBootstrapMethodArguments().size() == 2
+ && methodHandleIsInvokeStaticTo(
+ classCstDynamic.getBootstrapMethodArguments().get(0), classDescMethod))) {
+ throw throwEnumFieldConstantDynamic("Class descriptor " + classCstDynamic, context);
+ }
+ DexValue dexValueClassName = classCstDynamic.getBootstrapMethodArguments().get(1);
+ if (!dexValueClassName.isDexValueString()) {
+ throw throwEnumFieldConstantDynamic("Class name " + dexValueClassName, context);
+ }
+ DexString className = dexValueClassName.asDexValueString().getValue();
+ DexType enumType =
+ factory.createType(DescriptorUtils.javaTypeToDescriptor(className.toString()));
+ DexClass enumClass = appView.definitionFor(enumType);
+ if (enumClass == null) {
+ throw throwEnumFieldConstantDynamic("Missing enum class " + enumType, context);
+ }
+ DexEncodedField dexEncodedField = enumClass.lookupUniqueStaticFieldWithName(fieldName);
+ if (dexEncodedField == null) {
+ throw throwEnumFieldConstantDynamic(
+ "Missing enum field " + fieldName + " in " + enumType, context);
+ }
+ return dexEncodedField.getReference();
+ }
}
diff --git a/src/test/examplesJava21/switchpatternmatching/EnumSwitchTest.java b/src/test/examplesJava21/switchpatternmatching/EnumSwitchTest.java
index 90001a5..e57e31c 100644
--- a/src/test/examplesJava21/switchpatternmatching/EnumSwitchTest.java
+++ b/src/test/examplesJava21/switchpatternmatching/EnumSwitchTest.java
@@ -4,11 +4,10 @@
package switchpatternmatching;
import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertThrows;
import static org.junit.Assume.assumeTrue;
-import com.android.tools.r8.CompilationFailedException;
import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestBuilder;
import com.android.tools.r8.TestParameters;
import com.android.tools.r8.TestParametersCollection;
import com.android.tools.r8.TestRuntime.CfVm;
@@ -16,11 +15,13 @@
import com.android.tools.r8.utils.StringUtils;
import com.android.tools.r8.utils.codeinspector.CodeInspector;
import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import org.junit.Assume;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameter;
import org.junit.runners.Parameterized.Parameters;
+import switchpatternmatching.StringSwitchTest.Main;
@RunWith(Parameterized.class)
public class EnumSwitchTest extends TestBase {
@@ -32,7 +33,8 @@
return getTestParameters().withAllRuntimesAndApiLevels().build();
}
- public static String EXPECTED_OUTPUT = StringUtils.lines("null", "E1", "E2", "E3", "E4", "a C");
+ public static String EXPECTED_OUTPUT =
+ StringUtils.lines("null", "E1", "E2", "E3", "E4", "a C", "class %s");
@Test
public void testJvm() throws Exception {
@@ -68,36 +70,72 @@
parameters.assumeJvmTestParameters();
testForJvm(parameters)
- .addInnerClassesAndStrippedOuter(getClass())
+ .apply(this::addModifiedProgramClasses)
.run(parameters.getRuntime(), Main.class)
.applyIf(
parameters.getCfRuntime().isNewerThanOrEqual(CfVm.JDK21),
- r -> r.assertSuccessWithOutput(EXPECTED_OUTPUT),
+ r ->
+ r.assertSuccessWithOutput(
+ String.format(EXPECTED_OUTPUT, "java.lang.MatchException")),
r -> r.assertFailureWithErrorThatThrows(UnsupportedClassVersionError.class));
}
+ private <T extends TestBuilder<?, T>> void addModifiedProgramClasses(
+ TestBuilder<?, T> testBuilder) throws Exception {
+ testBuilder
+ .addStrippedOuter(getClass())
+ .addProgramClasses(FakeI.class, E.class, C.class)
+ .addProgramClassFileData(
+ transformer(I.class)
+ .setPermittedSubclasses(I.class, E.class, C.class, D.class)
+ .transform())
+ .addProgramClassFileData(transformer(D.class).setImplements(I.class).transform())
+ .addProgramClassFileData(
+ transformer(Main.class)
+ .transformTypeInsnInMethod(
+ "getD",
+ (opcode, type, visitor) ->
+ visitor.visitTypeInsn(opcode, "switchpatternmatching/EnumSwitchTest$D"))
+ .transformMethodInsnInMethod(
+ "getD",
+ (opcode, owner, name, descriptor, isInterface, visitor) -> {
+ assert name.equals("<init>");
+ visitor.visitMethodInsn(
+ opcode,
+ "switchpatternmatching/EnumSwitchTest$D",
+ name,
+ descriptor,
+ isInterface);
+ })
+ .transform());
+ }
+
@Test
public void testD8() throws Exception {
parameters.assumeDexRuntime();
- assertThrows(
- CompilationFailedException.class,
- () -> testForD8().addInnerClasses(getClass()).setMinApi(parameters).compile());
+ testForD8()
+ .apply(this::addModifiedProgramClasses)
+ .setMinApi(parameters)
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutput(String.format(EXPECTED_OUTPUT, "java.lang.RuntimeException"));
}
@Test
public void testR8() throws Exception {
- assertThrows(
- CompilationFailedException.class,
- () ->
- testForR8(parameters.getBackend())
- .addInnerClasses(getClass())
- .setMinApi(parameters)
- .addKeepMainRule(Main.class)
- .compile());
+ Assume.assumeTrue("For Cf we should compile with Jdk 21 library", parameters.isDexRuntime());
+ testForR8(parameters.getBackend())
+ .apply(this::addModifiedProgramClasses)
+ .setMinApi(parameters)
+ .addKeepMainRule(Main.class)
+ .run(parameters.getRuntime(), Main.class)
+ .assertSuccessWithOutput(String.format(EXPECTED_OUTPUT, "java.lang.RuntimeException"));
}
+ // D is added to the list of permitted subclasses to reproduce the MatchException.
sealed interface I permits E, C {}
+ interface FakeI {}
+
public enum E implements I {
E1,
E2,
@@ -105,6 +143,9 @@
E4
}
+ // Replaced with I.
+ static final class D implements FakeI {}
+
static final class C implements I {}
static class Main {
@@ -140,6 +181,16 @@
enumSwitch(E.E3);
enumSwitch(E.E4);
enumSwitch(new C());
+ try {
+ enumSwitch(getD());
+ } catch (Throwable t) {
+ System.out.println(t.getClass());
+ }
+ }
+
+ public static I getD() {
+ // Replaced by new D();
+ return new C();
}
}
}
diff --git a/src/test/testbase/java/com/android/tools/r8/TestBuilder.java b/src/test/testbase/java/com/android/tools/r8/TestBuilder.java
index 02d0e06..32a40c0 100644
--- a/src/test/testbase/java/com/android/tools/r8/TestBuilder.java
+++ b/src/test/testbase/java/com/android/tools/r8/TestBuilder.java
@@ -147,16 +147,19 @@
return addProgramFiles(getFilesForInnerClasses(classes));
}
- public T addInnerClassesAndStrippedOuter(Class<?> clazz) throws IOException {
+ public T addStrippedOuter(Class<?> clazz) throws IOException {
return addProgramClassFileData(
- TestBase.transformer(clazz)
- .removeFields(FieldPredicate.all())
- .removeMethods(MethodPredicate.all())
- .removeAllAnnotations()
- .setSuper(descriptor(Object.class))
- .setImplements()
- .transform())
- .addInnerClasses(ImmutableList.of(clazz));
+ TestBase.transformer(clazz)
+ .removeFields(FieldPredicate.all())
+ .removeMethods(MethodPredicate.all())
+ .removeAllAnnotations()
+ .setSuper(descriptor(Object.class))
+ .setImplements()
+ .transform());
+ }
+
+ public T addInnerClassesAndStrippedOuter(Class<?> clazz) throws IOException {
+ return addStrippedOuter(clazz).addInnerClasses(ImmutableList.of(clazz));
}
public abstract T addLibraryFiles(Collection<Path> files);