Insert class cast exceptions for null-holders of AtomicFieldUpdaters.
Bug: (b/453628974)
Change-Id: I6e3dc434482e4dc3ecda1901dc7af1b7f38c5688
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/AtomicFieldUpdaterOptimizer.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/AtomicFieldUpdaterOptimizer.java
index 1740878..13226e7 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/passes/AtomicFieldUpdaterOptimizer.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/AtomicFieldUpdaterOptimizer.java
@@ -3,14 +3,17 @@
// BSD-style license that can be found in the LICENSE file.
package com.android.tools.r8.ir.conversion.passes;
+import com.android.tools.r8.contexts.CompilationContext.MethodProcessingContext;
import com.android.tools.r8.graph.AppInfoWithClassHierarchy;
import com.android.tools.r8.graph.AppView;
import com.android.tools.r8.graph.DexField;
import com.android.tools.r8.graph.DexMethod;
import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.ir.analysis.type.TypeElement;
import com.android.tools.r8.ir.code.IRCode;
import com.android.tools.r8.ir.code.IRCodeInstructionListIterator;
import com.android.tools.r8.ir.code.Instruction;
+import com.android.tools.r8.ir.code.InvokeStatic;
import com.android.tools.r8.ir.code.InvokeVirtual;
import com.android.tools.r8.ir.code.Position;
import com.android.tools.r8.ir.code.StaticGet;
@@ -18,6 +21,7 @@
import com.android.tools.r8.ir.conversion.MethodProcessor;
import com.android.tools.r8.ir.conversion.passes.result.CodeRewriterResult;
import com.android.tools.r8.ir.optimize.AtomicFieldUpdaterInstrumentor.AtomicFieldUpdaterInstrumentorInfo;
+import com.android.tools.r8.ir.optimize.UtilityMethodsForCodeOptimizations;
import com.android.tools.r8.utils.AndroidApiLevel;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
@@ -81,7 +85,10 @@
}
@Override
- protected CodeRewriterResult rewriteCode(IRCode code) {
+ protected CodeRewriterResult rewriteCode(
+ IRCode code,
+ MethodProcessor methodProcessor,
+ MethodProcessingContext methodProcessingContext) {
AtomicFieldUpdaterInstrumentorInfo info = appView.getAtomicFieldUpdaterInstrumentorInfo();
var atomicUpdaterFields = info.getInstrumentations().get(code.context().getHolderType());
assert atomicUpdaterFields != null;
@@ -116,9 +123,11 @@
if (visitCompareAndSet(
code,
it,
+ methodProcessor,
+ methodProcessingContext,
invoke,
atomicUpdaterFields,
- info.getUnsafeInstanceField(),
+ info,
next.outValue())) {
changed = true;
}
@@ -127,9 +136,11 @@
if (visitGet(
code,
it,
+ methodProcessor,
+ methodProcessingContext,
invoke,
atomicUpdaterFields,
- info.getUnsafeInstanceField(),
+ info,
next.outValue())) {
changed = true;
}
@@ -138,9 +149,11 @@
if (visitSet(
code,
it,
+ methodProcessor,
+ methodProcessingContext,
invoke,
atomicUpdaterFields,
- info.getUnsafeInstanceField(),
+ info,
next.outValue())) {
changed = true;
}
@@ -149,10 +162,12 @@
if (visitGetAndSet(
code,
it,
+ methodProcessor,
+ methodProcessingContext,
next.getPosition(),
invoke,
atomicUpdaterFields,
- info.getUnsafeInstanceField(),
+ info,
next.outValue())) {
changed = true;
}
@@ -167,9 +182,11 @@
private boolean visitCompareAndSet(
IRCode code,
IRCodeInstructionListIterator it,
+ MethodProcessor methodProcessor,
+ MethodProcessingContext methodProcessingContext,
InvokeVirtual invoke,
Map<DexField, AtomicFieldUpdaterInfo> atomicUpdaterFields,
- DexField unsafeInstanceField,
+ AtomicFieldUpdaterInstrumentorInfo info,
Value outValue) {
// Resolve updater.
var updaterValue = invoke.getReceiver();
@@ -183,7 +200,9 @@
// Resolve holder.
var holderValue = invoke.getFirstNonReceiverArgument();
var expectedHolder = resolvedUpdater.updaterFieldInfo.holder;
- if (!isHolderValid(invoke.getPosition(), holderValue, expectedHolder, "compareAndSet")) {
+ var resolvedHolder =
+ resolveHolder(invoke.getPosition(), holderValue, expectedHolder, "compareAndSet");
+ if (resolvedHolder == null) {
return false;
}
@@ -200,12 +219,14 @@
rewriteCompareAndSet(
code,
it,
+ methodProcessor,
+ methodProcessingContext,
invoke.getPosition(),
- resolvedUpdater.isNullable,
- unsafeInstanceField,
+ resolvedUpdater,
+ resolvedHolder,
+ info,
updaterValue,
holderValue,
- resolvedUpdater.updaterFieldInfo.offsetField,
expectValue,
updateValue,
outValue);
@@ -215,9 +236,11 @@
private boolean visitGet(
IRCode code,
IRCodeInstructionListIterator it,
+ MethodProcessor methodProcessor,
+ MethodProcessingContext methodProcessingContext,
InvokeVirtual invoke,
Map<DexField, AtomicFieldUpdaterInfo> atomicUpdaterFields,
- DexField unsafeInstanceField,
+ AtomicFieldUpdaterInstrumentorInfo info,
Value outValue) {
// Resolve updater.
var updaterValue = invoke.getReceiver();
@@ -230,7 +253,8 @@
// Resolve holder.
var holderValue = invoke.getFirstNonReceiverArgument();
var expectedHolder = resolvedUpdater.updaterFieldInfo.holder;
- if (!isHolderValid(invoke.getPosition(), holderValue, expectedHolder, "get")) {
+ var resolvedHolder = resolveHolder(invoke.getPosition(), holderValue, expectedHolder, "get");
+ if (resolvedHolder == null) {
return false;
}
@@ -238,12 +262,14 @@
rewriteGet(
code,
it,
+ methodProcessor,
+ methodProcessingContext,
invoke.getPosition(),
- resolvedUpdater.isNullable,
- unsafeInstanceField,
+ resolvedUpdater,
+ resolvedHolder,
+ info,
updaterValue,
holderValue,
- resolvedUpdater.updaterFieldInfo.offsetField,
outValue);
return true;
}
@@ -251,9 +277,11 @@
private boolean visitSet(
IRCode code,
IRCodeInstructionListIterator it,
+ MethodProcessor methodProcessor,
+ MethodProcessingContext methodProcessingContext,
InvokeVirtual invoke,
Map<DexField, AtomicFieldUpdaterInfo> atomicUpdaterFields,
- DexField unsafeInstanceField,
+ AtomicFieldUpdaterInstrumentorInfo info,
Value outValue) {
// Resolve updater.
var updaterValue = invoke.getReceiver();
@@ -266,7 +294,8 @@
// Resolve holder.
var holderValue = invoke.getFirstNonReceiverArgument();
var expectedHolder = resolvedUpdater.updaterFieldInfo.holder;
- if (!isHolderValid(invoke.getPosition(), holderValue, expectedHolder, "set")) {
+ var resolvedHolder = resolveHolder(invoke.getPosition(), holderValue, expectedHolder, "set");
+ if (resolvedHolder == null) {
return false;
}
@@ -280,12 +309,14 @@
rewriteSet(
code,
it,
+ methodProcessor,
+ methodProcessingContext,
invoke.getPosition(),
- resolvedUpdater.isNullable,
- unsafeInstanceField,
+ resolvedUpdater,
+ resolvedHolder,
+ info,
updaterValue,
holderValue,
- resolvedUpdater.updaterFieldInfo.offsetField,
newValueValue,
outValue);
return true;
@@ -294,10 +325,12 @@
private boolean visitGetAndSet(
IRCode code,
IRCodeInstructionListIterator it,
+ MethodProcessor methodProcessor,
+ MethodProcessingContext methodProcessingContext,
Position position,
InvokeVirtual invoke,
Map<DexField, AtomicFieldUpdaterInfo> atomicUpdaterFields,
- DexField unsafeInstanceField,
+ AtomicFieldUpdaterInstrumentorInfo info,
Value outValue) {
if (appView.options().isGeneratingDex()
&& appView.options().getMinApiLevel().isLessThan(AndroidApiLevel.N)) {
@@ -316,7 +349,9 @@
// Resolve holder.
var holderValue = invoke.getFirstNonReceiverArgument();
var expectedHolder = resolvedUpdater.updaterFieldInfo.holder;
- if (!isHolderValid(position, holderValue, expectedHolder, "getAndSet")) {
+ var resolvedHolder =
+ resolveHolder(invoke.getPosition(), holderValue, expectedHolder, "getAndSet");
+ if (resolvedHolder == null) {
return false;
}
@@ -330,12 +365,14 @@
rewriteGetAndSet(
code,
it,
+ methodProcessor,
+ methodProcessingContext,
position,
- resolvedUpdater.isNullable,
- unsafeInstanceField,
+ resolvedUpdater,
+ resolvedHolder,
+ info,
updaterValue,
holderValue,
- resolvedUpdater.updaterFieldInfo.offsetField,
newValueValue,
outValue);
return true;
@@ -384,21 +421,23 @@
}
}
- private boolean isHolderValid(
+ private ResolvedHolder resolveHolder(
Position position, Value holderValue, DexType expectedHolder, String methodNameForLogging) {
- if (holderValue
- .getType()
- .lessThanOrEqual(expectedHolder.toNonNullTypeElement(appView), appView)) {
- return true;
+ TypeElement holderType = holderValue.getType();
+ if (!holderType.lessThanOrEqual(expectedHolder.toTypeElement(appView), appView)) {
+ reportFailure(position, "_." + methodNameForLogging + "(HERE, ..) is a wrong type");
+ return null;
}
- if (appView.testing().enableAtomicFieldUpdaterLogs) {
- if (holderValue.getType().lessThanOrEqual(expectedHolder.toTypeElement(appView), appView)) {
- reportFailure(position, "_." + methodNameForLogging + "(HERE, ..) is nullable");
- } else {
- reportFailure(position, "_." + methodNameForLogging + "(HERE, ..) is of unexpected type");
- }
+ var isNullable = holderType.isNullable();
+ return new ResolvedHolder(isNullable);
+ }
+
+ private static class ResolvedHolder {
+ public final boolean isNullable;
+
+ private ResolvedHolder(boolean isNullable) {
+ this.isNullable = isNullable;
}
- return false;
}
private boolean isNewValueValid(
@@ -423,25 +462,34 @@
private void rewriteCompareAndSet(
IRCode code,
IRCodeInstructionListIterator it,
+ MethodProcessor methodProcessor,
+ MethodProcessingContext methodProcessingContext,
Position position,
- boolean updaterMightBeNull,
- DexField unsafeInstanceField,
+ ResolvedUpdater resolvedUpdater,
+ ResolvedHolder resolvedHolder,
+ AtomicFieldUpdaterInstrumentorInfo info,
Value updaterValue,
Value holderValue,
- DexField offsetField,
Value expectValue,
Value updateValue,
Value outValue) {
- var instructions = new ArrayList<Instruction>(3);
+ var instructions = new ArrayList<Instruction>(4);
- if (updaterMightBeNull) {
+ if (resolvedUpdater.isNullable) {
instructions.add(createNullCheck(code, position, updaterValue));
}
- Instruction unsafeInstance = createUnsafeGet(code, position, unsafeInstanceField);
+ if (resolvedHolder.isNullable) {
+ instructions.add(
+ createNullCheckWithClassCastException(
+ methodProcessor, methodProcessingContext, position, holderValue));
+ }
+
+ Instruction unsafeInstance = createUnsafeGet(code, position, info.getUnsafeInstanceField());
instructions.add(unsafeInstance);
- Instruction offset = createOffsetGet(code, position, offsetField);
+ Instruction offset =
+ createOffsetGet(code, position, resolvedUpdater.updaterFieldInfo.offsetField);
instructions.add(offset);
// Add instructions BEFORE the compareAndSet instruction.
@@ -481,26 +529,35 @@
private void rewriteGet(
IRCode code,
IRCodeInstructionListIterator it,
+ MethodProcessor methodProcessor,
+ MethodProcessingContext methodProcessingContext,
Position position,
- boolean updaterMightBeNull,
- DexField unsafeInstanceField,
+ ResolvedUpdater resolvedUpdater,
+ ResolvedHolder resolvedHolder,
+ AtomicFieldUpdaterInstrumentorInfo info,
Value updaterValue,
Value holderValue,
- DexField offsetField,
Value outValue) {
- var instructions = new ArrayList<Instruction>(3);
+ var instructions = new ArrayList<Instruction>(4);
// Null-check for updater.
- if (updaterMightBeNull) {
+ if (resolvedUpdater.isNullable) {
instructions.add(createNullCheck(code, position, updaterValue));
}
+ if (resolvedHolder.isNullable) {
+ instructions.add(
+ createNullCheckWithClassCastException(
+ methodProcessor, methodProcessingContext, position, holderValue));
+ }
+
// Get unsafe instance.
- Instruction unsafeInstance = createUnsafeGet(code, position, unsafeInstanceField);
+ Instruction unsafeInstance = createUnsafeGet(code, position, info.getUnsafeInstanceField());
instructions.add(unsafeInstance);
// Get offset field.
- Instruction offset = createOffsetGet(code, position, offsetField);
+ Instruction offset =
+ createOffsetGet(code, position, resolvedUpdater.updaterFieldInfo.offsetField);
instructions.add(offset);
// Add instructions BEFORE the get instruction.
@@ -532,27 +589,36 @@
private void rewriteSet(
IRCode code,
IRCodeInstructionListIterator it,
+ MethodProcessor methodProcessor,
+ MethodProcessingContext methodProcessingContext,
Position position,
- boolean updaterMightBeNull,
- DexField unsafeInstanceField,
+ ResolvedUpdater resolvedUpdater,
+ ResolvedHolder resolvedHolder,
+ AtomicFieldUpdaterInstrumentorInfo info,
Value updaterValue,
Value holderValue,
- DexField offsetField,
Value newValueValue,
Value outValue) {
- var instructions = new ArrayList<Instruction>(3);
+ var instructions = new ArrayList<Instruction>(4);
// Null-check for updater.
- if (updaterMightBeNull) {
+ if (resolvedUpdater.isNullable) {
instructions.add(createNullCheck(code, position, updaterValue));
}
+ if (resolvedHolder.isNullable) {
+ instructions.add(
+ createNullCheckWithClassCastException(
+ methodProcessor, methodProcessingContext, position, holderValue));
+ }
+
// Get unsafe instance.
- Instruction unsafeInstance = createUnsafeGet(code, position, unsafeInstanceField);
+ Instruction unsafeInstance = createUnsafeGet(code, position, info.getUnsafeInstanceField());
instructions.add(unsafeInstance);
// Get offset field.
- Instruction offset = createOffsetGet(code, position, offsetField);
+ Instruction offset =
+ createOffsetGet(code, position, resolvedUpdater.updaterFieldInfo.offsetField);
instructions.add(offset);
// Add instructions BEFORE the get instruction.
@@ -588,27 +654,36 @@
private void rewriteGetAndSet(
IRCode code,
IRCodeInstructionListIterator it,
+ MethodProcessor methodProcessor,
+ MethodProcessingContext methodProcessingContext,
Position position,
- boolean updaterMightBeNull,
- DexField unsafeInstanceField,
+ ResolvedUpdater resolvedUpdater,
+ ResolvedHolder resolvedHolder,
+ AtomicFieldUpdaterInstrumentorInfo info,
Value updaterValue,
Value holderValue,
- DexField offsetField,
Value newValueValue,
Value outValue) {
- var instructions = new ArrayList<Instruction>(3);
+ var instructions = new ArrayList<Instruction>(4);
// Null-check for updater.
- if (updaterMightBeNull) {
+ if (resolvedUpdater.isNullable) {
instructions.add(createNullCheck(code, position, updaterValue));
}
+ if (resolvedHolder.isNullable) {
+ instructions.add(
+ createNullCheckWithClassCastException(
+ methodProcessor, methodProcessingContext, position, holderValue));
+ }
+
// Get unsafe instance.
- Instruction unsafeInstance = createUnsafeGet(code, position, unsafeInstanceField);
+ Instruction unsafeInstance = createUnsafeGet(code, position, info.getUnsafeInstanceField());
instructions.add(unsafeInstance);
// Get offset field.
- Instruction offset = createOffsetGet(code, position, offsetField);
+ Instruction offset =
+ createOffsetGet(code, position, resolvedUpdater.updaterFieldInfo.offsetField);
instructions.add(offset);
// Add instructions BEFORE the get instruction.
@@ -644,16 +719,31 @@
return offset;
}
- private InvokeVirtual createNullCheck(IRCode code, Position position, Value updaterValue) {
+ private InvokeVirtual createNullCheck(IRCode code, Position position, Value value) {
var nullCheck =
new InvokeVirtual(
dexItemFactory.objectMembers.getClass,
code.createValue(dexItemFactory.classType.toTypeElement(appView)),
- ImmutableList.of(updaterValue));
+ ImmutableList.of(value));
nullCheck.setPosition(position);
return nullCheck;
}
+ private InvokeStatic createNullCheckWithClassCastException(
+ MethodProcessor methodProcessor,
+ MethodProcessingContext methodProcessingContext,
+ Position position,
+ Value value) {
+ var optimizations =
+ UtilityMethodsForCodeOptimizations.synthesizeThrowClassCastExceptionIfNullMethod(
+ appView, methodProcessor.getEventConsumer(), methodProcessingContext);
+ optimizations.optimize(methodProcessor);
+ InvokeStatic invokeStatic =
+ new InvokeStatic(optimizations.getMethod().getReference(), null, ImmutableList.of(value));
+ invokeStatic.setPosition(position);
+ return invokeStatic;
+ }
+
private Instruction createUnsafeGet(
IRCode code, Position position, DexField unsafeInstanceField) {
assert unsafeInstanceField.type.isIdenticalTo(dexItemFactory.unsafeType);
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/UtilityMethodsForCodeOptimizations.java b/src/main/java/com/android/tools/r8/ir/optimize/UtilityMethodsForCodeOptimizations.java
index 9d6c650..ff8f5b5 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/UtilityMethodsForCodeOptimizations.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/UtilityMethodsForCodeOptimizations.java
@@ -87,6 +87,34 @@
return new UtilityMethodForCodeOptimizations(syntheticMethod);
}
+ public static UtilityMethodForCodeOptimizations synthesizeThrowClassCastExceptionIfNullMethod(
+ AppView<?> appView,
+ UtilityMethodsForCodeOptimizationsEventConsumer eventConsumer,
+ MethodProcessingContext methodProcessingContext) {
+ DexItemFactory dexItemFactory = appView.dexItemFactory();
+ DexProto proto = dexItemFactory.createProto(dexItemFactory.voidType, dexItemFactory.objectType);
+ SyntheticItems syntheticItems = appView.getSyntheticItems();
+ UniqueContext positionContext = methodProcessingContext.createUniqueContext();
+ ProgramMethod syntheticMethod =
+ syntheticItems.createMethod(
+ kinds -> kinds.THROW_CCE_IF_NULL,
+ positionContext,
+ appView,
+ builder ->
+ builder
+ .setAccessFlags(MethodAccessFlags.createPublicStaticSynthetic())
+ .setClassFileVersion(CfVersion.V1_8)
+ .setApiLevelForDefinition(appView.computedMinApiLevel())
+ .setApiLevelForCode(appView.computedMinApiLevel())
+ .setCode(
+ method ->
+ getThrowClassCastExceptionIfNullCodeTemplate(method, dexItemFactory))
+ .setProto(proto));
+ eventConsumer.acceptUtilityThrowClassCastExceptionIfNotNullMethod(
+ syntheticMethod, methodProcessingContext.getMethodContext());
+ return new UtilityMethodForCodeOptimizations(syntheticMethod);
+ }
+
private static CfCode getThrowClassCastExceptionIfNotNullCodeTemplate(
DexMethod method, DexItemFactory dexItemFactory) {
return CfUtilityMethodsForCodeOptimizations
@@ -94,6 +122,13 @@
dexItemFactory, method);
}
+ private static CfCode getThrowClassCastExceptionIfNullCodeTemplate(
+ DexMethod method, DexItemFactory dexItemFactory) {
+ return CfUtilityMethodsForCodeOptimizations
+ .CfUtilityMethodsForCodeOptimizationsTemplates_throwClassCastExceptionIfNull(
+ dexItemFactory, method);
+ }
+
public static UtilityMethodForCodeOptimizations synthesizeNonNullMethod(
AppView<?> appView,
UtilityMethodsForCodeOptimizationsEventConsumer eventConsumer,
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/templates/CfUtilityMethodsForCodeOptimizations.java b/src/main/java/com/android/tools/r8/ir/optimize/templates/CfUtilityMethodsForCodeOptimizations.java
index 831f03a..c144687 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/templates/CfUtilityMethodsForCodeOptimizations.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/templates/CfUtilityMethodsForCodeOptimizations.java
@@ -136,6 +136,42 @@
ImmutableList.of());
}
+ public static CfCode CfUtilityMethodsForCodeOptimizationsTemplates_throwClassCastExceptionIfNull(
+ DexItemFactory factory, DexMethod method) {
+ CfLabel label0 = new CfLabel();
+ CfLabel label1 = new CfLabel();
+ CfLabel label2 = new CfLabel();
+ CfLabel label3 = new CfLabel();
+ return new CfCode(
+ method.holder,
+ 2,
+ 1,
+ ImmutableList.of(
+ label0,
+ new CfLoad(ValueType.OBJECT, 0),
+ new CfIf(IfType.NE, ValueType.OBJECT, label2),
+ label1,
+ new CfNew(factory.createType("Ljava/lang/ClassCastException;")),
+ new CfStackInstruction(CfStackInstruction.Opcode.Dup),
+ new CfInvoke(
+ 183,
+ factory.createMethod(
+ factory.createType("Ljava/lang/ClassCastException;"),
+ factory.createProto(factory.voidType),
+ factory.createString("<init>")),
+ false),
+ new CfThrow(),
+ label2,
+ new CfFrame(
+ new Int2ObjectAVLTreeMap<>(
+ new int[] {0},
+ new FrameType[] {FrameType.initializedNonNullReference(factory.objectType)})),
+ new CfReturnVoid(),
+ label3),
+ ImmutableList.of(),
+ ImmutableList.of());
+ }
+
public static CfCode CfUtilityMethodsForCodeOptimizationsTemplates_throwIllegalAccessError(
DexItemFactory factory, DexMethod method) {
CfLabel label0 = new CfLabel();
diff --git a/src/main/java/com/android/tools/r8/synthesis/SyntheticNaming.java b/src/main/java/com/android/tools/r8/synthesis/SyntheticNaming.java
index 449df92..648c0e6 100644
--- a/src/main/java/com/android/tools/r8/synthesis/SyntheticNaming.java
+++ b/src/main/java/com/android/tools/r8/synthesis/SyntheticNaming.java
@@ -95,6 +95,8 @@
generator.forSingleMethodWithGlobalMerging("ThrowCCEIfNotEquals");
public final SyntheticKind THROW_CCE_IF_NOT_NULL =
generator.forSingleMethodWithGlobalMerging("ThrowCCEIfNotNull");
+ public final SyntheticKind THROW_CCE_IF_NULL =
+ generator.forSingleMethodWithGlobalMerging("ThrowCCEIfNull");
public final SyntheticKind NON_NULL = generator.forSingleMethodWithGlobalMerging("NonNull");
public final SyntheticKind THROW_AME = generator.forSingleMethodWithGlobalMerging("ThrowAME");
public final SyntheticKind THROW_IAE = generator.forSingleMethodWithGlobalMerging("ThrowIAE");
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/templates/CfUtilityMethodsForCodeOptimizationsTemplates.java b/src/test/java/com/android/tools/r8/ir/optimize/templates/CfUtilityMethodsForCodeOptimizationsTemplates.java
index de5de0e..3b67f93 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/templates/CfUtilityMethodsForCodeOptimizationsTemplates.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/templates/CfUtilityMethodsForCodeOptimizationsTemplates.java
@@ -22,6 +22,12 @@
}
}
+ public static void throwClassCastExceptionIfNull(Object o) {
+ if (o == null) {
+ throw new ClassCastException();
+ }
+ }
+
public static AbstractMethodError throwAbstractMethodError() {
throw new AbstractMethodError();
}
diff --git a/src/test/java/com/android/tools/r8/optimize/AtomicFieldUpdaterNullHolderTest.java b/src/test/java/com/android/tools/r8/optimize/AtomicFieldUpdaterNullHolderTest.java
new file mode 100644
index 0000000..ed75031
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/optimize/AtomicFieldUpdaterNullHolderTest.java
@@ -0,0 +1,136 @@
+// Copyright (c) 2026, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.optimize;
+
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+import com.android.tools.r8.Diagnostic;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestShrinkerBuilder;
+import com.android.tools.r8.ToolHelper.DexVm.Version;
+import com.android.tools.r8.utils.BooleanUtils;
+import com.android.tools.r8.utils.StringUtils;
+import com.android.tools.r8.utils.codeinspector.CodeMatchers;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import com.google.common.collect.ImmutableList;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class AtomicFieldUpdaterNullHolderTest extends TestBase {
+
+ @Parameter(0)
+ public TestParameters parameters;
+
+ @Parameter(1)
+ public boolean dontObfuscate;
+
+ @Parameters(name = "{0}, dontObfuscate:{1}")
+ public static List<Object[]> data() {
+ // TODO(b/453628974): test all dex and api levels.
+ return buildParameters(
+ TestParameters.builder()
+ .withDexRuntimesStartingFromIncluding(
+ Version.V4_4_4) // Unsafe synthetic doesn't work for 4.0.4.
+ .withAllApiLevels()
+ .build(),
+ BooleanUtils.values());
+ }
+
+ @Test
+ public void testR8() throws Exception {
+ Class<TestClass> testClass = TestClass.class;
+ testForR8(parameters)
+ .addOptionsModification(
+ options -> {
+ assertFalse(options.enableAtomicFieldUpdaterOptimization);
+ options.enableAtomicFieldUpdaterOptimization = true;
+ assertFalse(options.testing.enableAtomicFieldUpdaterLogs);
+ options.testing.enableAtomicFieldUpdaterLogs = true;
+ })
+ .addProgramClasses(testClass)
+ .allowDiagnosticInfoMessages()
+ .addKeepMainRule(testClass)
+ .applyIf(dontObfuscate, TestShrinkerBuilder::addDontObfuscate)
+ .compile()
+ .inspectDiagnosticMessages(
+ diagnostics -> {
+ assertEquals(3, diagnostics.getInfos().size());
+ Diagnostic diagnostic = diagnostics.getInfos().get(0);
+ List<String> diagnosticLines =
+ StringUtils.splitLines(diagnostic.getDiagnosticMessage());
+ for (String message : diagnosticLines) {
+ assertTrue(
+ "Does not contain 'Can instrument': " + message,
+ message.contains("Can instrument"));
+ }
+ assertEquals(1, diagnosticLines.size());
+ diagnostic = diagnostics.getInfos().get(1);
+ diagnosticLines = StringUtils.splitLines(diagnostic.getDiagnosticMessage());
+ for (String message : diagnosticLines) {
+ assertTrue(
+ "Does not contain 'Can optimize': " + message,
+ message.contains("Can optimize"));
+ }
+ assertEquals(1, diagnosticLines.size());
+ diagnostic = diagnostics.getInfos().get(2);
+ diagnosticLines = StringUtils.splitLines(diagnostic.getDiagnosticMessage());
+ for (String message : diagnosticLines) {
+ assertTrue(
+ "Does not contain 'Can remove': " + message, message.contains("Can remove"));
+ }
+ assertEquals(1, diagnosticLines.size());
+ })
+ .inspect(
+ inspector -> {
+ MethodSubject method = inspector.clazz(testClass).mainMethod();
+ assertThat(
+ method,
+ CodeMatchers.invokesMethod(
+ "java.lang.Object",
+ "sun.misc.Unsafe",
+ "getObjectVolatile",
+ ImmutableList.of("java.lang.Object", "long")));
+ })
+ .run(parameters.getRuntime(), testClass)
+ .assertFailureWithErrorThatThrows(ClassCastException.class);
+ }
+
+ // Corresponding to simple kotlin usage of `atomic("Hello")` via atomicfu.
+ public static class TestClass {
+
+ private volatile Object myString;
+
+ private static final AtomicReferenceFieldUpdater<TestClass, Object> myString$FU;
+
+ static {
+ myString$FU =
+ AtomicReferenceFieldUpdater.newUpdater(TestClass.class, Object.class, "myString");
+ }
+
+ public TestClass() {
+ super();
+ myString = "Hello";
+ }
+
+ public static void main(String[] args) {
+ TestClass holder;
+ if (System.out != null) {
+ holder = null;
+ } else {
+ holder = new TestClass();
+ }
+ System.out.println(myString$FU.get(holder));
+ }
+ }
+}