Insert class cast exceptions for null-holders of AtomicFieldUpdaters.

Bug: (b/453628974)
Change-Id: I6e3dc434482e4dc3ecda1901dc7af1b7f38c5688
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/passes/AtomicFieldUpdaterOptimizer.java b/src/main/java/com/android/tools/r8/ir/conversion/passes/AtomicFieldUpdaterOptimizer.java
index 1740878..13226e7 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/passes/AtomicFieldUpdaterOptimizer.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/passes/AtomicFieldUpdaterOptimizer.java
@@ -3,14 +3,17 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.ir.conversion.passes;
 
+import com.android.tools.r8.contexts.CompilationContext.MethodProcessingContext;
 import com.android.tools.r8.graph.AppInfoWithClassHierarchy;
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.ir.analysis.type.TypeElement;
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.code.IRCodeInstructionListIterator;
 import com.android.tools.r8.ir.code.Instruction;
+import com.android.tools.r8.ir.code.InvokeStatic;
 import com.android.tools.r8.ir.code.InvokeVirtual;
 import com.android.tools.r8.ir.code.Position;
 import com.android.tools.r8.ir.code.StaticGet;
@@ -18,6 +21,7 @@
 import com.android.tools.r8.ir.conversion.MethodProcessor;
 import com.android.tools.r8.ir.conversion.passes.result.CodeRewriterResult;
 import com.android.tools.r8.ir.optimize.AtomicFieldUpdaterInstrumentor.AtomicFieldUpdaterInstrumentorInfo;
+import com.android.tools.r8.ir.optimize.UtilityMethodsForCodeOptimizations;
 import com.android.tools.r8.utils.AndroidApiLevel;
 import com.google.common.collect.ImmutableList;
 import java.util.ArrayList;
@@ -81,7 +85,10 @@
   }
 
   @Override
-  protected CodeRewriterResult rewriteCode(IRCode code) {
+  protected CodeRewriterResult rewriteCode(
+      IRCode code,
+      MethodProcessor methodProcessor,
+      MethodProcessingContext methodProcessingContext) {
     AtomicFieldUpdaterInstrumentorInfo info = appView.getAtomicFieldUpdaterInstrumentorInfo();
     var atomicUpdaterFields = info.getInstrumentations().get(code.context().getHolderType());
     assert atomicUpdaterFields != null;
@@ -116,9 +123,11 @@
         if (visitCompareAndSet(
             code,
             it,
+            methodProcessor,
+            methodProcessingContext,
             invoke,
             atomicUpdaterFields,
-            info.getUnsafeInstanceField(),
+            info,
             next.outValue())) {
           changed = true;
         }
@@ -127,9 +136,11 @@
         if (visitGet(
             code,
             it,
+            methodProcessor,
+            methodProcessingContext,
             invoke,
             atomicUpdaterFields,
-            info.getUnsafeInstanceField(),
+            info,
             next.outValue())) {
           changed = true;
         }
@@ -138,9 +149,11 @@
         if (visitSet(
             code,
             it,
+            methodProcessor,
+            methodProcessingContext,
             invoke,
             atomicUpdaterFields,
-            info.getUnsafeInstanceField(),
+            info,
             next.outValue())) {
           changed = true;
         }
@@ -149,10 +162,12 @@
         if (visitGetAndSet(
             code,
             it,
+            methodProcessor,
+            methodProcessingContext,
             next.getPosition(),
             invoke,
             atomicUpdaterFields,
-            info.getUnsafeInstanceField(),
+            info,
             next.outValue())) {
           changed = true;
         }
@@ -167,9 +182,11 @@
   private boolean visitCompareAndSet(
       IRCode code,
       IRCodeInstructionListIterator it,
+      MethodProcessor methodProcessor,
+      MethodProcessingContext methodProcessingContext,
       InvokeVirtual invoke,
       Map<DexField, AtomicFieldUpdaterInfo> atomicUpdaterFields,
-      DexField unsafeInstanceField,
+      AtomicFieldUpdaterInstrumentorInfo info,
       Value outValue) {
     // Resolve updater.
     var updaterValue = invoke.getReceiver();
@@ -183,7 +200,9 @@
     // Resolve holder.
     var holderValue = invoke.getFirstNonReceiverArgument();
     var expectedHolder = resolvedUpdater.updaterFieldInfo.holder;
-    if (!isHolderValid(invoke.getPosition(), holderValue, expectedHolder, "compareAndSet")) {
+    var resolvedHolder =
+        resolveHolder(invoke.getPosition(), holderValue, expectedHolder, "compareAndSet");
+    if (resolvedHolder == null) {
       return false;
     }
 
@@ -200,12 +219,14 @@
     rewriteCompareAndSet(
         code,
         it,
+        methodProcessor,
+        methodProcessingContext,
         invoke.getPosition(),
-        resolvedUpdater.isNullable,
-        unsafeInstanceField,
+        resolvedUpdater,
+        resolvedHolder,
+        info,
         updaterValue,
         holderValue,
-        resolvedUpdater.updaterFieldInfo.offsetField,
         expectValue,
         updateValue,
         outValue);
@@ -215,9 +236,11 @@
   private boolean visitGet(
       IRCode code,
       IRCodeInstructionListIterator it,
+      MethodProcessor methodProcessor,
+      MethodProcessingContext methodProcessingContext,
       InvokeVirtual invoke,
       Map<DexField, AtomicFieldUpdaterInfo> atomicUpdaterFields,
-      DexField unsafeInstanceField,
+      AtomicFieldUpdaterInstrumentorInfo info,
       Value outValue) {
     // Resolve updater.
     var updaterValue = invoke.getReceiver();
@@ -230,7 +253,8 @@
     // Resolve holder.
     var holderValue = invoke.getFirstNonReceiverArgument();
     var expectedHolder = resolvedUpdater.updaterFieldInfo.holder;
-    if (!isHolderValid(invoke.getPosition(), holderValue, expectedHolder, "get")) {
+    var resolvedHolder = resolveHolder(invoke.getPosition(), holderValue, expectedHolder, "get");
+    if (resolvedHolder == null) {
       return false;
     }
 
@@ -238,12 +262,14 @@
     rewriteGet(
         code,
         it,
+        methodProcessor,
+        methodProcessingContext,
         invoke.getPosition(),
-        resolvedUpdater.isNullable,
-        unsafeInstanceField,
+        resolvedUpdater,
+        resolvedHolder,
+        info,
         updaterValue,
         holderValue,
-        resolvedUpdater.updaterFieldInfo.offsetField,
         outValue);
     return true;
   }
@@ -251,9 +277,11 @@
   private boolean visitSet(
       IRCode code,
       IRCodeInstructionListIterator it,
+      MethodProcessor methodProcessor,
+      MethodProcessingContext methodProcessingContext,
       InvokeVirtual invoke,
       Map<DexField, AtomicFieldUpdaterInfo> atomicUpdaterFields,
-      DexField unsafeInstanceField,
+      AtomicFieldUpdaterInstrumentorInfo info,
       Value outValue) {
     // Resolve updater.
     var updaterValue = invoke.getReceiver();
@@ -266,7 +294,8 @@
     // Resolve holder.
     var holderValue = invoke.getFirstNonReceiverArgument();
     var expectedHolder = resolvedUpdater.updaterFieldInfo.holder;
-    if (!isHolderValid(invoke.getPosition(), holderValue, expectedHolder, "set")) {
+    var resolvedHolder = resolveHolder(invoke.getPosition(), holderValue, expectedHolder, "set");
+    if (resolvedHolder == null) {
       return false;
     }
 
@@ -280,12 +309,14 @@
     rewriteSet(
         code,
         it,
+        methodProcessor,
+        methodProcessingContext,
         invoke.getPosition(),
-        resolvedUpdater.isNullable,
-        unsafeInstanceField,
+        resolvedUpdater,
+        resolvedHolder,
+        info,
         updaterValue,
         holderValue,
-        resolvedUpdater.updaterFieldInfo.offsetField,
         newValueValue,
         outValue);
     return true;
@@ -294,10 +325,12 @@
   private boolean visitGetAndSet(
       IRCode code,
       IRCodeInstructionListIterator it,
+      MethodProcessor methodProcessor,
+      MethodProcessingContext methodProcessingContext,
       Position position,
       InvokeVirtual invoke,
       Map<DexField, AtomicFieldUpdaterInfo> atomicUpdaterFields,
-      DexField unsafeInstanceField,
+      AtomicFieldUpdaterInstrumentorInfo info,
       Value outValue) {
     if (appView.options().isGeneratingDex()
         && appView.options().getMinApiLevel().isLessThan(AndroidApiLevel.N)) {
@@ -316,7 +349,9 @@
     // Resolve holder.
     var holderValue = invoke.getFirstNonReceiverArgument();
     var expectedHolder = resolvedUpdater.updaterFieldInfo.holder;
-    if (!isHolderValid(position, holderValue, expectedHolder, "getAndSet")) {
+    var resolvedHolder =
+        resolveHolder(invoke.getPosition(), holderValue, expectedHolder, "getAndSet");
+    if (resolvedHolder == null) {
       return false;
     }
 
@@ -330,12 +365,14 @@
     rewriteGetAndSet(
         code,
         it,
+        methodProcessor,
+        methodProcessingContext,
         position,
-        resolvedUpdater.isNullable,
-        unsafeInstanceField,
+        resolvedUpdater,
+        resolvedHolder,
+        info,
         updaterValue,
         holderValue,
-        resolvedUpdater.updaterFieldInfo.offsetField,
         newValueValue,
         outValue);
     return true;
@@ -384,21 +421,23 @@
     }
   }
 
-  private boolean isHolderValid(
+  private ResolvedHolder resolveHolder(
       Position position, Value holderValue, DexType expectedHolder, String methodNameForLogging) {
-    if (holderValue
-        .getType()
-        .lessThanOrEqual(expectedHolder.toNonNullTypeElement(appView), appView)) {
-      return true;
+    TypeElement holderType = holderValue.getType();
+    if (!holderType.lessThanOrEqual(expectedHolder.toTypeElement(appView), appView)) {
+      reportFailure(position, "_." + methodNameForLogging + "(HERE, ..) is a wrong type");
+      return null;
     }
-    if (appView.testing().enableAtomicFieldUpdaterLogs) {
-      if (holderValue.getType().lessThanOrEqual(expectedHolder.toTypeElement(appView), appView)) {
-        reportFailure(position, "_." + methodNameForLogging + "(HERE, ..) is nullable");
-      } else {
-        reportFailure(position, "_." + methodNameForLogging + "(HERE, ..) is of unexpected type");
-      }
+    var isNullable = holderType.isNullable();
+    return new ResolvedHolder(isNullable);
+  }
+
+  private static class ResolvedHolder {
+    public final boolean isNullable;
+
+    private ResolvedHolder(boolean isNullable) {
+      this.isNullable = isNullable;
     }
-    return false;
   }
 
   private boolean isNewValueValid(
@@ -423,25 +462,34 @@
   private void rewriteCompareAndSet(
       IRCode code,
       IRCodeInstructionListIterator it,
+      MethodProcessor methodProcessor,
+      MethodProcessingContext methodProcessingContext,
       Position position,
-      boolean updaterMightBeNull,
-      DexField unsafeInstanceField,
+      ResolvedUpdater resolvedUpdater,
+      ResolvedHolder resolvedHolder,
+      AtomicFieldUpdaterInstrumentorInfo info,
       Value updaterValue,
       Value holderValue,
-      DexField offsetField,
       Value expectValue,
       Value updateValue,
       Value outValue) {
-    var instructions = new ArrayList<Instruction>(3);
+    var instructions = new ArrayList<Instruction>(4);
 
-    if (updaterMightBeNull) {
+    if (resolvedUpdater.isNullable) {
       instructions.add(createNullCheck(code, position, updaterValue));
     }
 
-    Instruction unsafeInstance = createUnsafeGet(code, position, unsafeInstanceField);
+    if (resolvedHolder.isNullable) {
+      instructions.add(
+          createNullCheckWithClassCastException(
+              methodProcessor, methodProcessingContext, position, holderValue));
+    }
+
+    Instruction unsafeInstance = createUnsafeGet(code, position, info.getUnsafeInstanceField());
     instructions.add(unsafeInstance);
 
-    Instruction offset = createOffsetGet(code, position, offsetField);
+    Instruction offset =
+        createOffsetGet(code, position, resolvedUpdater.updaterFieldInfo.offsetField);
     instructions.add(offset);
 
     // Add instructions BEFORE the compareAndSet instruction.
@@ -481,26 +529,35 @@
   private void rewriteGet(
       IRCode code,
       IRCodeInstructionListIterator it,
+      MethodProcessor methodProcessor,
+      MethodProcessingContext methodProcessingContext,
       Position position,
-      boolean updaterMightBeNull,
-      DexField unsafeInstanceField,
+      ResolvedUpdater resolvedUpdater,
+      ResolvedHolder resolvedHolder,
+      AtomicFieldUpdaterInstrumentorInfo info,
       Value updaterValue,
       Value holderValue,
-      DexField offsetField,
       Value outValue) {
-    var instructions = new ArrayList<Instruction>(3);
+    var instructions = new ArrayList<Instruction>(4);
 
     // Null-check for updater.
-    if (updaterMightBeNull) {
+    if (resolvedUpdater.isNullable) {
       instructions.add(createNullCheck(code, position, updaterValue));
     }
 
+    if (resolvedHolder.isNullable) {
+      instructions.add(
+          createNullCheckWithClassCastException(
+              methodProcessor, methodProcessingContext, position, holderValue));
+    }
+
     // Get unsafe instance.
-    Instruction unsafeInstance = createUnsafeGet(code, position, unsafeInstanceField);
+    Instruction unsafeInstance = createUnsafeGet(code, position, info.getUnsafeInstanceField());
     instructions.add(unsafeInstance);
 
     // Get offset field.
-    Instruction offset = createOffsetGet(code, position, offsetField);
+    Instruction offset =
+        createOffsetGet(code, position, resolvedUpdater.updaterFieldInfo.offsetField);
     instructions.add(offset);
 
     // Add instructions BEFORE the get instruction.
@@ -532,27 +589,36 @@
   private void rewriteSet(
       IRCode code,
       IRCodeInstructionListIterator it,
+      MethodProcessor methodProcessor,
+      MethodProcessingContext methodProcessingContext,
       Position position,
-      boolean updaterMightBeNull,
-      DexField unsafeInstanceField,
+      ResolvedUpdater resolvedUpdater,
+      ResolvedHolder resolvedHolder,
+      AtomicFieldUpdaterInstrumentorInfo info,
       Value updaterValue,
       Value holderValue,
-      DexField offsetField,
       Value newValueValue,
       Value outValue) {
-    var instructions = new ArrayList<Instruction>(3);
+    var instructions = new ArrayList<Instruction>(4);
 
     // Null-check for updater.
-    if (updaterMightBeNull) {
+    if (resolvedUpdater.isNullable) {
       instructions.add(createNullCheck(code, position, updaterValue));
     }
 
+    if (resolvedHolder.isNullable) {
+      instructions.add(
+          createNullCheckWithClassCastException(
+              methodProcessor, methodProcessingContext, position, holderValue));
+    }
+
     // Get unsafe instance.
-    Instruction unsafeInstance = createUnsafeGet(code, position, unsafeInstanceField);
+    Instruction unsafeInstance = createUnsafeGet(code, position, info.getUnsafeInstanceField());
     instructions.add(unsafeInstance);
 
     // Get offset field.
-    Instruction offset = createOffsetGet(code, position, offsetField);
+    Instruction offset =
+        createOffsetGet(code, position, resolvedUpdater.updaterFieldInfo.offsetField);
     instructions.add(offset);
 
     // Add instructions BEFORE the get instruction.
@@ -588,27 +654,36 @@
   private void rewriteGetAndSet(
       IRCode code,
       IRCodeInstructionListIterator it,
+      MethodProcessor methodProcessor,
+      MethodProcessingContext methodProcessingContext,
       Position position,
-      boolean updaterMightBeNull,
-      DexField unsafeInstanceField,
+      ResolvedUpdater resolvedUpdater,
+      ResolvedHolder resolvedHolder,
+      AtomicFieldUpdaterInstrumentorInfo info,
       Value updaterValue,
       Value holderValue,
-      DexField offsetField,
       Value newValueValue,
       Value outValue) {
-    var instructions = new ArrayList<Instruction>(3);
+    var instructions = new ArrayList<Instruction>(4);
 
     // Null-check for updater.
-    if (updaterMightBeNull) {
+    if (resolvedUpdater.isNullable) {
       instructions.add(createNullCheck(code, position, updaterValue));
     }
 
+    if (resolvedHolder.isNullable) {
+      instructions.add(
+          createNullCheckWithClassCastException(
+              methodProcessor, methodProcessingContext, position, holderValue));
+    }
+
     // Get unsafe instance.
-    Instruction unsafeInstance = createUnsafeGet(code, position, unsafeInstanceField);
+    Instruction unsafeInstance = createUnsafeGet(code, position, info.getUnsafeInstanceField());
     instructions.add(unsafeInstance);
 
     // Get offset field.
-    Instruction offset = createOffsetGet(code, position, offsetField);
+    Instruction offset =
+        createOffsetGet(code, position, resolvedUpdater.updaterFieldInfo.offsetField);
     instructions.add(offset);
 
     // Add instructions BEFORE the get instruction.
@@ -644,16 +719,31 @@
     return offset;
   }
 
-  private InvokeVirtual createNullCheck(IRCode code, Position position, Value updaterValue) {
+  private InvokeVirtual createNullCheck(IRCode code, Position position, Value value) {
     var nullCheck =
         new InvokeVirtual(
             dexItemFactory.objectMembers.getClass,
             code.createValue(dexItemFactory.classType.toTypeElement(appView)),
-            ImmutableList.of(updaterValue));
+            ImmutableList.of(value));
     nullCheck.setPosition(position);
     return nullCheck;
   }
 
+  private InvokeStatic createNullCheckWithClassCastException(
+      MethodProcessor methodProcessor,
+      MethodProcessingContext methodProcessingContext,
+      Position position,
+      Value value) {
+    var optimizations =
+        UtilityMethodsForCodeOptimizations.synthesizeThrowClassCastExceptionIfNullMethod(
+            appView, methodProcessor.getEventConsumer(), methodProcessingContext);
+    optimizations.optimize(methodProcessor);
+    InvokeStatic invokeStatic =
+        new InvokeStatic(optimizations.getMethod().getReference(), null, ImmutableList.of(value));
+    invokeStatic.setPosition(position);
+    return invokeStatic;
+  }
+
   private Instruction createUnsafeGet(
       IRCode code, Position position, DexField unsafeInstanceField) {
     assert unsafeInstanceField.type.isIdenticalTo(dexItemFactory.unsafeType);
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/UtilityMethodsForCodeOptimizations.java b/src/main/java/com/android/tools/r8/ir/optimize/UtilityMethodsForCodeOptimizations.java
index 9d6c650..ff8f5b5 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/UtilityMethodsForCodeOptimizations.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/UtilityMethodsForCodeOptimizations.java
@@ -87,6 +87,34 @@
     return new UtilityMethodForCodeOptimizations(syntheticMethod);
   }
 
+  public static UtilityMethodForCodeOptimizations synthesizeThrowClassCastExceptionIfNullMethod(
+      AppView<?> appView,
+      UtilityMethodsForCodeOptimizationsEventConsumer eventConsumer,
+      MethodProcessingContext methodProcessingContext) {
+    DexItemFactory dexItemFactory = appView.dexItemFactory();
+    DexProto proto = dexItemFactory.createProto(dexItemFactory.voidType, dexItemFactory.objectType);
+    SyntheticItems syntheticItems = appView.getSyntheticItems();
+    UniqueContext positionContext = methodProcessingContext.createUniqueContext();
+    ProgramMethod syntheticMethod =
+        syntheticItems.createMethod(
+            kinds -> kinds.THROW_CCE_IF_NULL,
+            positionContext,
+            appView,
+            builder ->
+                builder
+                    .setAccessFlags(MethodAccessFlags.createPublicStaticSynthetic())
+                    .setClassFileVersion(CfVersion.V1_8)
+                    .setApiLevelForDefinition(appView.computedMinApiLevel())
+                    .setApiLevelForCode(appView.computedMinApiLevel())
+                    .setCode(
+                        method ->
+                            getThrowClassCastExceptionIfNullCodeTemplate(method, dexItemFactory))
+                    .setProto(proto));
+    eventConsumer.acceptUtilityThrowClassCastExceptionIfNotNullMethod(
+        syntheticMethod, methodProcessingContext.getMethodContext());
+    return new UtilityMethodForCodeOptimizations(syntheticMethod);
+  }
+
   private static CfCode getThrowClassCastExceptionIfNotNullCodeTemplate(
       DexMethod method, DexItemFactory dexItemFactory) {
     return CfUtilityMethodsForCodeOptimizations
@@ -94,6 +122,13 @@
             dexItemFactory, method);
   }
 
+  private static CfCode getThrowClassCastExceptionIfNullCodeTemplate(
+      DexMethod method, DexItemFactory dexItemFactory) {
+    return CfUtilityMethodsForCodeOptimizations
+        .CfUtilityMethodsForCodeOptimizationsTemplates_throwClassCastExceptionIfNull(
+            dexItemFactory, method);
+  }
+
   public static UtilityMethodForCodeOptimizations synthesizeNonNullMethod(
       AppView<?> appView,
       UtilityMethodsForCodeOptimizationsEventConsumer eventConsumer,
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/templates/CfUtilityMethodsForCodeOptimizations.java b/src/main/java/com/android/tools/r8/ir/optimize/templates/CfUtilityMethodsForCodeOptimizations.java
index 831f03a..c144687 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/templates/CfUtilityMethodsForCodeOptimizations.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/templates/CfUtilityMethodsForCodeOptimizations.java
@@ -136,6 +136,42 @@
         ImmutableList.of());
   }
 
+  public static CfCode CfUtilityMethodsForCodeOptimizationsTemplates_throwClassCastExceptionIfNull(
+      DexItemFactory factory, DexMethod method) {
+    CfLabel label0 = new CfLabel();
+    CfLabel label1 = new CfLabel();
+    CfLabel label2 = new CfLabel();
+    CfLabel label3 = new CfLabel();
+    return new CfCode(
+        method.holder,
+        2,
+        1,
+        ImmutableList.of(
+            label0,
+            new CfLoad(ValueType.OBJECT, 0),
+            new CfIf(IfType.NE, ValueType.OBJECT, label2),
+            label1,
+            new CfNew(factory.createType("Ljava/lang/ClassCastException;")),
+            new CfStackInstruction(CfStackInstruction.Opcode.Dup),
+            new CfInvoke(
+                183,
+                factory.createMethod(
+                    factory.createType("Ljava/lang/ClassCastException;"),
+                    factory.createProto(factory.voidType),
+                    factory.createString("<init>")),
+                false),
+            new CfThrow(),
+            label2,
+            new CfFrame(
+                new Int2ObjectAVLTreeMap<>(
+                    new int[] {0},
+                    new FrameType[] {FrameType.initializedNonNullReference(factory.objectType)})),
+            new CfReturnVoid(),
+            label3),
+        ImmutableList.of(),
+        ImmutableList.of());
+  }
+
   public static CfCode CfUtilityMethodsForCodeOptimizationsTemplates_throwIllegalAccessError(
       DexItemFactory factory, DexMethod method) {
     CfLabel label0 = new CfLabel();
diff --git a/src/main/java/com/android/tools/r8/synthesis/SyntheticNaming.java b/src/main/java/com/android/tools/r8/synthesis/SyntheticNaming.java
index 449df92..648c0e6 100644
--- a/src/main/java/com/android/tools/r8/synthesis/SyntheticNaming.java
+++ b/src/main/java/com/android/tools/r8/synthesis/SyntheticNaming.java
@@ -95,6 +95,8 @@
       generator.forSingleMethodWithGlobalMerging("ThrowCCEIfNotEquals");
   public final SyntheticKind THROW_CCE_IF_NOT_NULL =
       generator.forSingleMethodWithGlobalMerging("ThrowCCEIfNotNull");
+  public final SyntheticKind THROW_CCE_IF_NULL =
+      generator.forSingleMethodWithGlobalMerging("ThrowCCEIfNull");
   public final SyntheticKind NON_NULL = generator.forSingleMethodWithGlobalMerging("NonNull");
   public final SyntheticKind THROW_AME = generator.forSingleMethodWithGlobalMerging("ThrowAME");
   public final SyntheticKind THROW_IAE = generator.forSingleMethodWithGlobalMerging("ThrowIAE");
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/templates/CfUtilityMethodsForCodeOptimizationsTemplates.java b/src/test/java/com/android/tools/r8/ir/optimize/templates/CfUtilityMethodsForCodeOptimizationsTemplates.java
index de5de0e..3b67f93 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/templates/CfUtilityMethodsForCodeOptimizationsTemplates.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/templates/CfUtilityMethodsForCodeOptimizationsTemplates.java
@@ -22,6 +22,12 @@
     }
   }
 
+  public static void throwClassCastExceptionIfNull(Object o) {
+    if (o == null) {
+      throw new ClassCastException();
+    }
+  }
+
   public static AbstractMethodError throwAbstractMethodError() {
     throw new AbstractMethodError();
   }
diff --git a/src/test/java/com/android/tools/r8/optimize/AtomicFieldUpdaterNullHolderTest.java b/src/test/java/com/android/tools/r8/optimize/AtomicFieldUpdaterNullHolderTest.java
new file mode 100644
index 0000000..ed75031
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/optimize/AtomicFieldUpdaterNullHolderTest.java
@@ -0,0 +1,136 @@
+// Copyright (c) 2026, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.optimize;
+
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+import com.android.tools.r8.Diagnostic;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestShrinkerBuilder;
+import com.android.tools.r8.ToolHelper.DexVm.Version;
+import com.android.tools.r8.utils.BooleanUtils;
+import com.android.tools.r8.utils.StringUtils;
+import com.android.tools.r8.utils.codeinspector.CodeMatchers;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import com.google.common.collect.ImmutableList;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class AtomicFieldUpdaterNullHolderTest extends TestBase {
+
+  @Parameter(0)
+  public TestParameters parameters;
+
+  @Parameter(1)
+  public boolean dontObfuscate;
+
+  @Parameters(name = "{0}, dontObfuscate:{1}")
+  public static List<Object[]> data() {
+    // TODO(b/453628974): test all dex and api levels.
+    return buildParameters(
+        TestParameters.builder()
+            .withDexRuntimesStartingFromIncluding(
+                Version.V4_4_4) // Unsafe synthetic doesn't work for 4.0.4.
+            .withAllApiLevels()
+            .build(),
+        BooleanUtils.values());
+  }
+
+  @Test
+  public void testR8() throws Exception {
+    Class<TestClass> testClass = TestClass.class;
+    testForR8(parameters)
+        .addOptionsModification(
+            options -> {
+              assertFalse(options.enableAtomicFieldUpdaterOptimization);
+              options.enableAtomicFieldUpdaterOptimization = true;
+              assertFalse(options.testing.enableAtomicFieldUpdaterLogs);
+              options.testing.enableAtomicFieldUpdaterLogs = true;
+            })
+        .addProgramClasses(testClass)
+        .allowDiagnosticInfoMessages()
+        .addKeepMainRule(testClass)
+        .applyIf(dontObfuscate, TestShrinkerBuilder::addDontObfuscate)
+        .compile()
+        .inspectDiagnosticMessages(
+            diagnostics -> {
+              assertEquals(3, diagnostics.getInfos().size());
+              Diagnostic diagnostic = diagnostics.getInfos().get(0);
+              List<String> diagnosticLines =
+                  StringUtils.splitLines(diagnostic.getDiagnosticMessage());
+              for (String message : diagnosticLines) {
+                assertTrue(
+                    "Does not contain 'Can instrument': " + message,
+                    message.contains("Can instrument"));
+              }
+              assertEquals(1, diagnosticLines.size());
+              diagnostic = diagnostics.getInfos().get(1);
+              diagnosticLines = StringUtils.splitLines(diagnostic.getDiagnosticMessage());
+              for (String message : diagnosticLines) {
+                assertTrue(
+                    "Does not contain 'Can optimize': " + message,
+                    message.contains("Can optimize"));
+              }
+              assertEquals(1, diagnosticLines.size());
+              diagnostic = diagnostics.getInfos().get(2);
+              diagnosticLines = StringUtils.splitLines(diagnostic.getDiagnosticMessage());
+              for (String message : diagnosticLines) {
+                assertTrue(
+                    "Does not contain 'Can remove': " + message, message.contains("Can remove"));
+              }
+              assertEquals(1, diagnosticLines.size());
+            })
+        .inspect(
+            inspector -> {
+              MethodSubject method = inspector.clazz(testClass).mainMethod();
+              assertThat(
+                  method,
+                  CodeMatchers.invokesMethod(
+                      "java.lang.Object",
+                      "sun.misc.Unsafe",
+                      "getObjectVolatile",
+                      ImmutableList.of("java.lang.Object", "long")));
+            })
+        .run(parameters.getRuntime(), testClass)
+        .assertFailureWithErrorThatThrows(ClassCastException.class);
+  }
+
+  // Corresponding to simple kotlin usage of `atomic("Hello")` via atomicfu.
+  public static class TestClass {
+
+    private volatile Object myString;
+
+    private static final AtomicReferenceFieldUpdater<TestClass, Object> myString$FU;
+
+    static {
+      myString$FU =
+          AtomicReferenceFieldUpdater.newUpdater(TestClass.class, Object.class, "myString");
+    }
+
+    public TestClass() {
+      super();
+      myString = "Hello";
+    }
+
+    public static void main(String[] args) {
+      TestClass holder;
+      if (System.out != null) {
+        holder = null;
+      } else {
+        holder = new TestClass();
+      }
+      System.out.println(myString$FU.get(holder));
+    }
+  }
+}