Extend member value propagation to const-class

Bug: 148136558
Change-Id: Ie3e98a85a5f97d7445f6d3ba931f68550d784a14
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/value/AbstractValue.java b/src/main/java/com/android/tools/r8/ir/analysis/value/AbstractValue.java
index c70ff33..007fd9f 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/value/AbstractValue.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/value/AbstractValue.java
@@ -24,6 +24,14 @@
     return null;
   }
 
+  public boolean isSingleConstClassValue() {
+    return false;
+  }
+
+  public SingleConstClassValue asSingleConstClassValue() {
+    return null;
+  }
+
   public boolean isSingleEnumValue() {
     return false;
   }
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/value/AbstractValueFactory.java b/src/main/java/com/android/tools/r8/ir/analysis/value/AbstractValueFactory.java
index 887a171..80ca063 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/value/AbstractValueFactory.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/value/AbstractValueFactory.java
@@ -6,10 +6,13 @@
 
 import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexString;
+import com.android.tools.r8.graph.DexType;
 import java.util.concurrent.ConcurrentHashMap;
 
 public class AbstractValueFactory {
 
+  private ConcurrentHashMap<DexType, SingleConstClassValue> singleConstClassValues =
+      new ConcurrentHashMap<>();
   private ConcurrentHashMap<DexField, SingleEnumValue> singleEnumValues = new ConcurrentHashMap<>();
   private ConcurrentHashMap<DexField, SingleFieldValue> singleFieldValues =
       new ConcurrentHashMap<>();
@@ -17,6 +20,10 @@
   private ConcurrentHashMap<DexString, SingleStringValue> singleStringValues =
       new ConcurrentHashMap<>();
 
+  public SingleConstClassValue createSingleConstClassValue(DexType type) {
+    return singleConstClassValues.computeIfAbsent(type, SingleConstClassValue::new);
+  }
+
   public SingleEnumValue createSingleEnumValue(DexField field) {
     return singleEnumValues.computeIfAbsent(field, SingleEnumValue::new);
   }
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/value/SingleConstClassValue.java b/src/main/java/com/android/tools/r8/ir/analysis/value/SingleConstClassValue.java
new file mode 100644
index 0000000..a6bf499
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/analysis/value/SingleConstClassValue.java
@@ -0,0 +1,91 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.analysis.value;
+
+import static com.android.tools.r8.ir.analysis.type.Nullability.definitelyNotNull;
+import static com.android.tools.r8.ir.analysis.type.TypeLatticeElement.classClassType;
+import static com.android.tools.r8.optimize.MemberRebindingAnalysis.isClassTypeVisibleFromContext;
+
+import com.android.tools.r8.graph.AppInfoWithSubtyping;
+import com.android.tools.r8.graph.AppView;
+import com.android.tools.r8.graph.DebugLocalInfo;
+import com.android.tools.r8.graph.DexClass;
+import com.android.tools.r8.graph.DexType;
+import com.android.tools.r8.ir.analysis.type.TypeLatticeElement;
+import com.android.tools.r8.ir.code.ConstClass;
+import com.android.tools.r8.ir.code.IRCode;
+import com.android.tools.r8.ir.code.Instruction;
+import com.android.tools.r8.ir.code.TypeAndLocalInfoSupplier;
+import com.android.tools.r8.ir.code.Value;
+
+public class SingleConstClassValue extends SingleValue {
+
+  private final DexType type;
+
+  /** Intentionally package private, use {@link AbstractValueFactory} instead. */
+  SingleConstClassValue(DexType type) {
+    this.type = type;
+  }
+
+  @Override
+  public boolean isSingleConstClassValue() {
+    return true;
+  }
+
+  @Override
+  public SingleConstClassValue asSingleConstClassValue() {
+    return this;
+  }
+
+  public DexType getType() {
+    return type;
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    return this == o;
+  }
+
+  @Override
+  public int hashCode() {
+    return type.hashCode();
+  }
+
+  @Override
+  public String toString() {
+    return "SingleConstClassValue(" + type.toSourceString() + ")";
+  }
+
+  @Override
+  public Instruction createMaterializingInstruction(
+      AppView<? extends AppInfoWithSubtyping> appView, IRCode code, TypeAndLocalInfoSupplier info) {
+    TypeLatticeElement typeLattice = info.getTypeLattice();
+    DebugLocalInfo debugLocalInfo = info.getLocalInfo();
+    assert typeLattice.isClassType();
+    assert appView
+        .isSubtype(
+            appView.dexItemFactory().classType,
+            typeLattice.asClassTypeLatticeElement().getClassType())
+        .isTrue();
+    Value returnedValue =
+        code.createValue(classClassType(appView, definitelyNotNull()), debugLocalInfo);
+    ConstClass instruction = new ConstClass(returnedValue, type);
+    assert !instruction.instructionMayHaveSideEffects(appView, code.method.method.holder);
+    return instruction;
+  }
+
+  @Override
+  public boolean isMaterializableInContext(AppView<?> appView, DexType context) {
+    DexType baseType = type.toBaseType(appView.dexItemFactory());
+    if (baseType.isClassType()) {
+      DexClass clazz = appView.definitionFor(type);
+      return clazz != null
+          && clazz.isResolvable(appView)
+          && isClassTypeVisibleFromContext(appView, context, clazz);
+    }
+    assert baseType.isPrimitiveType();
+    return true;
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/ir/code/ConstClass.java b/src/main/java/com/android/tools/r8/ir/code/ConstClass.java
index f58d657..4a58946 100644
--- a/src/main/java/com/android/tools/r8/ir/code/ConstClass.java
+++ b/src/main/java/com/android/tools/r8/ir/code/ConstClass.java
@@ -16,6 +16,8 @@
 import com.android.tools.r8.ir.analysis.AbstractError;
 import com.android.tools.r8.ir.analysis.type.Nullability;
 import com.android.tools.r8.ir.analysis.type.TypeLatticeElement;
+import com.android.tools.r8.ir.analysis.value.AbstractValue;
+import com.android.tools.r8.ir.analysis.value.UnknownValue;
 import com.android.tools.r8.ir.conversion.CfBuilder;
 import com.android.tools.r8.ir.conversion.DexBuilder;
 import com.android.tools.r8.ir.optimize.Inliner.ConstraintWithTarget;
@@ -187,4 +189,12 @@
   public void buildCf(CfBuilder builder) {
     builder.add(new CfConstClass(clazz));
   }
+
+  @Override
+  public AbstractValue getAbstractValue(AppView<?> appView, DexType context) {
+    if (!instructionMayHaveSideEffects(appView, context)) {
+      return appView.abstractValueFactory().createSingleConstClassValue(context);
+    }
+    return UnknownValue.getInstance();
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
index 09c260c..2f5df0d 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
@@ -2475,7 +2475,7 @@
           } else {
             DexType context = code.method.method.holder;
             AbstractValue abstractValue = lhs.getAbstractValue(appView, context);
-            if (abstractValue.isSingleEnumValue()) {
+            if (abstractValue.isSingleConstClassValue() || abstractValue.isSingleFieldValue()) {
               AbstractValue otherAbstractValue = rhs.getAbstractValue(appView, context);
               if (abstractValue == otherAbstractValue) {
                 simplifyIfWithKnownCondition(code, block, theIf, 0);
diff --git a/src/test/java/com/android/tools/r8/R8RunExamplesAndroidOTest.java b/src/test/java/com/android/tools/r8/R8RunExamplesAndroidOTest.java
index 39ac985..db80105 100644
--- a/src/test/java/com/android/tools/r8/R8RunExamplesAndroidOTest.java
+++ b/src/test/java/com/android/tools/r8/R8RunExamplesAndroidOTest.java
@@ -102,7 +102,7 @@
         .withOptionConsumer(opts -> opts.enableClassInlining = false)
         .withBuilderTransformation(
             b -> b.addProguardConfiguration(PROGUARD_OPTIONS, Origin.unknown()))
-        .withDexCheck(inspector -> checkLambdaCount(inspector, 119, "lambdadesugaring"))
+        .withDexCheck(inspector -> checkLambdaCount(inspector, 118, "lambdadesugaring"))
         .run();
 
     test("lambdadesugaring", "lambdadesugaring", "LambdaDesugaring")
@@ -110,7 +110,7 @@
         .withOptionConsumer(opts -> opts.enableClassInlining = true)
         .withBuilderTransformation(
             b -> b.addProguardConfiguration(PROGUARD_OPTIONS, Origin.unknown()))
-        .withDexCheck(inspector -> checkLambdaCount(inspector, 8, "lambdadesugaring"))
+        .withDexCheck(inspector -> checkLambdaCount(inspector, 7, "lambdadesugaring"))
         .run();
   }
 
@@ -142,7 +142,7 @@
         .withOptionConsumer(opts -> opts.enableClassInlining = false)
         .withBuilderTransformation(
             b -> b.addProguardConfiguration(PROGUARD_OPTIONS, Origin.unknown()))
-        .withDexCheck(inspector -> checkLambdaCount(inspector, 119, "lambdadesugaring"))
+        .withDexCheck(inspector -> checkLambdaCount(inspector, 118, "lambdadesugaring"))
         .run();
 
     test("lambdadesugaring", "lambdadesugaring", "LambdaDesugaring")
@@ -150,7 +150,7 @@
         .withOptionConsumer(opts -> opts.enableClassInlining = true)
         .withBuilderTransformation(
             b -> b.addProguardConfiguration(PROGUARD_OPTIONS, Origin.unknown()))
-        .withDexCheck(inspector -> checkLambdaCount(inspector, 8, "lambdadesugaring"))
+        .withDexCheck(inspector -> checkLambdaCount(inspector, 7, "lambdadesugaring"))
         .run();
   }
 
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/membervaluepropagation/ConstClassMemberValuePropagationTest.java b/src/test/java/com/android/tools/r8/ir/optimize/membervaluepropagation/ConstClassMemberValuePropagationTest.java
new file mode 100644
index 0000000..af4af19
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/membervaluepropagation/ConstClassMemberValuePropagationTest.java
@@ -0,0 +1,98 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.optimize.membervaluepropagation;
+
+import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.CoreMatchers.not;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.assertTrue;
+
+import com.android.tools.r8.NeverInline;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import com.android.tools.r8.utils.codeinspector.ClassSubject;
+import com.android.tools.r8.utils.codeinspector.CodeInspector;
+import com.android.tools.r8.utils.codeinspector.InstructionSubject;
+import com.android.tools.r8.utils.codeinspector.MethodSubject;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class ConstClassMemberValuePropagationTest extends TestBase {
+
+  private final TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return getTestParameters().withAllRuntimesAndApiLevels().build();
+  }
+
+  public ConstClassMemberValuePropagationTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  @Test
+  public void test() throws Exception {
+    testForR8(parameters.getBackend())
+        .addInnerClasses(ConstClassMemberValuePropagationTest.class)
+        .addKeepMainRule(TestClass.class)
+        .enableInliningAnnotations()
+        .setMinApi(parameters.getApiLevel())
+        .compile()
+        .inspect(this::inspect)
+        .run(parameters.getRuntime(), TestClass.class)
+        .assertSuccessWithOutputLines("Hello world!");
+  }
+
+  private void inspect(CodeInspector inspector) {
+    ClassSubject testClassSubject = inspector.clazz(TestClass.class);
+    assertThat(testClassSubject, isPresent());
+    assertThat(
+        testClassSubject.uniqueMethodWithName("deadDueToFieldValuePropagation"), not(isPresent()));
+    assertThat(
+        testClassSubject.uniqueMethodWithName("deadDueToReturnValuePropagation"), not(isPresent()));
+
+    // Verify that there are no more conditional instructions.
+    MethodSubject mainMethodSubject = testClassSubject.mainMethod();
+    assertThat(mainMethodSubject, isPresent());
+    assertTrue(mainMethodSubject.streamInstructions().noneMatch(InstructionSubject::isIf));
+  }
+
+  static class TestClass {
+
+    static Class<?> INSTANCE = TestClass.class;
+
+    public static void main(String[] args) {
+      if (INSTANCE == TestClass.class) {
+        System.out.print("Hello");
+      } else {
+        deadDueToFieldValuePropagation();
+      }
+      if (get() == TestClass.class) {
+        System.out.println(" world!");
+      } else {
+        deadDueToReturnValuePropagation();
+      }
+    }
+
+    @NeverInline
+    static Class<?> get() {
+      return TestClass.class;
+    }
+
+    @NeverInline
+    static void deadDueToFieldValuePropagation() {
+      throw new RuntimeException();
+    }
+
+    @NeverInline
+    static void deadDueToReturnValuePropagation() {
+      throw new RuntimeException();
+    }
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/reflection/GetClassTest.java b/src/test/java/com/android/tools/r8/ir/optimize/reflection/GetClassTest.java
index cf4a324..c2fa7fb 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/reflection/GetClassTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/reflection/GetClassTest.java
@@ -4,6 +4,7 @@
 package com.android.tools.r8.ir.optimize.reflection;
 
 import static com.android.tools.r8.utils.codeinspector.Matchers.isPresent;
+import static org.hamcrest.CoreMatchers.not;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assume.assumeTrue;
@@ -155,7 +156,7 @@
     assertThat(mainMethod, isPresent());
     int expectedCount = isR8 ? (isRelease ? 0 : 5) : 6;
     assertEquals(expectedCount, countGetClass(mainMethod));
-    expectedCount = isR8 ? (isRelease ? (parameters.isCfRuntime() ? 7 : 5) : 1) : 0;
+    expectedCount = isR8 ? (isRelease ? (parameters.isCfRuntime() ? 8 : 6) : 1) : 0;
     assertEquals(expectedCount, countConstClass(mainMethod));
 
     boolean expectToBeOptimized = isR8 && isRelease;
@@ -167,10 +168,14 @@
     assertEquals(0, countConstClass(getMainClass));
 
     MethodSubject call = mainClass.method("java.lang.Class", "call", ImmutableList.of());
-    assertThat(call, isPresent());
-    // Because of local, only R8 release mode can rewrite getClass() to const-class.
-    assertEquals(expectToBeOptimized ? 0 : 1, countGetClass(call));
-    assertEquals(expectToBeOptimized ? 1 : 0, countConstClass(call));
+    if (isR8 && isRelease) {
+      assertThat(call, not(isPresent()));
+    } else {
+      assertThat(call, isPresent());
+      // Because of local, only R8 release mode can rewrite getClass() to const-class.
+      assertEquals(expectToBeOptimized ? 0 : 1, countGetClass(call));
+      assertEquals(expectToBeOptimized ? 1 : 0, countConstClass(call));
+    }
   }
 
   @Test