Performing if-zero-test optimization after class inlining

This optimization is particularly interesting for kotlin code
since after inlining methods null-checking their parameters
most (if not all) of the checks are expected to go away since
Kotlin guarantees passing non-null values to such methods.

Bug: 111262189
Bug: 80134059
Change-Id: Ic26d8ec6d7ebb2573a3bca3a2f68cd73a4947ee2
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
index c6ba97d..63843ee 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/IRConverter.java
@@ -739,7 +739,7 @@
       assert options.enableInlining && inliner != null;
       TypeEnvironment effectivelyFinalTypeEnvironment = typeEnvironment;
       classInliner.processMethodCode(
-          appInfo.withLiveness(), method, code, isProcessedConcurrently,
+          appInfo.withLiveness(), codeRewriter, method, code, isProcessedConcurrently,
           methodsToInline -> inliner.performForcedInlining(method, code, methodsToInline),
           Suppliers.memoize(() -> inliner.createDefaultOracle(
               method, code, effectivelyFinalTypeEnvironment,
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
index 121f47d..94e8147 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/CodeRewriter.java
@@ -2362,7 +2362,7 @@
             && (theIf.getType() == Type.EQ || theIf.getType() == Type.NE)) {
           if (inValues.get(0).isNeverNull()) {
             simplifyIfWithKnownCondition(code, block, theIf, 1);
-          } else {
+          } else if (typeEnvironment != null) {
             // TODO(b/72693244): annotate type lattice to value
             TypeLatticeElement l = typeEnvironment.getLatticeElement(inValues.get(0));
             if (!l.isPrimitive() && !l.isNullable()) {
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/classinliner/ClassInliner.java b/src/main/java/com/android/tools/r8/ir/optimize/classinliner/ClassInliner.java
index 219b5b1..5bf2c74 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/classinliner/ClassInliner.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/classinliner/ClassInliner.java
@@ -12,6 +12,7 @@
 import com.android.tools.r8.ir.code.IRCode;
 import com.android.tools.r8.ir.code.Instruction;
 import com.android.tools.r8.ir.code.InvokeMethod;
+import com.android.tools.r8.ir.optimize.CodeRewriter;
 import com.android.tools.r8.ir.optimize.Inliner.InliningInfo;
 import com.android.tools.r8.ir.optimize.InliningOracle;
 import com.android.tools.r8.shaking.Enqueuer.AppInfoWithLiveness;
@@ -117,6 +118,7 @@
   //
   public final void processMethodCode(
       AppInfoWithLiveness appInfo,
+      CodeRewriter codeRewriter,
       DexEncodedMethod method,
       IRCode code,
       Predicate<DexEncodedMethod> isProcessedConcurrently,
@@ -164,11 +166,14 @@
         }
 
         // Inline the class instance.
-        processor.processInlining(code, inliner);
+        boolean anyInlinedMethods = processor.processInlining(code, inliner);
 
         // Restore normality.
         code.removeAllTrivialPhis();
         assert code.isConsistentSSA();
+        if (anyInlinedMethods) {
+          codeRewriter.simplifyIf(code, null);
+        }
         rootsIterator.remove();
         repeat = true;
       }
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/classinliner/InlineCandidateProcessor.java b/src/main/java/com/android/tools/r8/ir/optimize/classinliner/InlineCandidateProcessor.java
index b3bde90..57ce42f 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/classinliner/InlineCandidateProcessor.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/classinliner/InlineCandidateProcessor.java
@@ -271,14 +271,16 @@
   //  * remove field writes
   //  * remove root instruction
   //
-  void processInlining(IRCode code, InlinerAction inliner) {
+  // Returns `true` if at least one method was inlined.
+  boolean processInlining(IRCode code, InlinerAction inliner) {
     replaceUsagesAsUnusedArgument(code);
-    forceInlineExtraMethodInvocations(inliner);
-    forceInlineDirectMethodInvocations(inliner);
+    boolean anyInlinedMethods = forceInlineExtraMethodInvocations(inliner);
+    anyInlinedMethods |= forceInlineDirectMethodInvocations(inliner);
     removeMiscUsages(code);
     removeFieldReads(code);
     removeFieldWrites();
     removeInstruction(root);
+    return anyInlinedMethods;
   }
 
   private void replaceUsagesAsUnusedArgument(IRCode code) {
@@ -298,9 +300,9 @@
     unusedArguments.clear();
   }
 
-  private void forceInlineExtraMethodInvocations(InlinerAction inliner) {
+  private boolean forceInlineExtraMethodInvocations(InlinerAction inliner) {
     if (extraMethodCalls.isEmpty()) {
-      return;
+      return false;
     }
 
     // Inline extra methods.
@@ -320,12 +322,15 @@
     }
     assert extraMethodCalls.isEmpty();
     assert unusedArguments.isEmpty();
+    return true;
   }
 
-  private void forceInlineDirectMethodInvocations(InlinerAction inliner) {
-    if (!methodCallsOnInstance.isEmpty()) {
-      inliner.inline(methodCallsOnInstance);
+  private boolean forceInlineDirectMethodInvocations(InlinerAction inliner) {
+    if (methodCallsOnInstance.isEmpty()) {
+      return false;
     }
+    inliner.inline(methodCallsOnInstance);
+    return true;
   }
 
   // Remove miscellaneous users before handling field reads.
diff --git a/src/test/java/com/android/tools/r8/kotlin/KotlinClassInlinerTest.java b/src/test/java/com/android/tools/r8/kotlin/KotlinClassInlinerTest.java
index 81e12a1..dc2774a 100644
--- a/src/test/java/com/android/tools/r8/kotlin/KotlinClassInlinerTest.java
+++ b/src/test/java/com/android/tools/r8/kotlin/KotlinClassInlinerTest.java
@@ -10,6 +10,7 @@
 import static org.junit.Assert.assertTrue;
 
 import com.android.tools.r8.ToolHelper.KotlinTargetVersion;
+import com.android.tools.r8.code.InvokeStatic;
 import com.android.tools.r8.code.NewInstance;
 import com.android.tools.r8.code.SgetObject;
 import com.android.tools.r8.graph.DexClass;
@@ -19,8 +20,10 @@
 import com.android.tools.r8.utils.DexInspector;
 import com.android.tools.r8.utils.DexInspector.ClassSubject;
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
 import com.google.common.collect.Sets;
 import java.util.Collection;
+import java.util.List;
 import java.util.Set;
 import java.util.function.Predicate;
 import java.util.stream.Collectors;
@@ -80,25 +83,25 @@
 
       assertEquals(
           Sets.newHashSet(),
-          collectAccessedLambdaTypes(lambdaCheck, clazz, "testStateless"));
+          collectAccessedTypes(lambdaCheck, clazz, "testStateless"));
 
       assertEquals(
           Sets.newHashSet(),
-          collectAccessedLambdaTypes(lambdaCheck, clazz, "testStateful"));
+          collectAccessedTypes(lambdaCheck, clazz, "testStateful"));
 
       assertFalse(
           inspector.clazz("class_inliner_lambda_j_style.MainKt$testStateful$1").isPresent());
 
       assertEquals(
           Sets.newHashSet(),
-          collectAccessedLambdaTypes(lambdaCheck, clazz, "testStateful2"));
+          collectAccessedTypes(lambdaCheck, clazz, "testStateful2"));
 
       assertFalse(
           inspector.clazz("class_inliner_lambda_j_style.MainKt$testStateful2$1").isPresent());
 
       assertEquals(
           Sets.newHashSet(),
-          collectAccessedLambdaTypes(lambdaCheck, clazz, "testStateful3"));
+          collectAccessedTypes(lambdaCheck, clazz, "testStateful3"));
 
       assertFalse(
           inspector.clazz("class_inliner_lambda_j_style.MainKt$testStateful3$1").isPresent());
@@ -138,7 +141,7 @@
 
       assertEquals(
           Sets.newHashSet(),
-          collectAccessedLambdaTypes(lambdaCheck, clazz,
+          collectAccessedTypes(lambdaCheck, clazz,
               "testKotlinSequencesStateless", "kotlin.sequences.Sequence"));
 
       assertFalse(inspector.clazz(
@@ -146,7 +149,7 @@
 
       assertEquals(
           Sets.newHashSet(),
-          collectAccessedLambdaTypes(lambdaCheck, clazz,
+          collectAccessedTypes(lambdaCheck, clazz,
               "testKotlinSequencesStateful", "int", "int", "kotlin.sequences.Sequence"));
 
       assertFalse(inspector.clazz(
@@ -154,7 +157,7 @@
 
       assertEquals(
           Sets.newHashSet(),
-          collectAccessedLambdaTypes(lambdaCheck, clazz, "testBigExtraMethod"));
+          collectAccessedTypes(lambdaCheck, clazz, "testBigExtraMethod"));
 
       assertFalse(inspector.clazz(
           "class_inliner_lambda_k_style.MainKt$testBigExtraMethod$1").isPresent());
@@ -165,7 +168,7 @@
 
       assertEquals(
           Sets.newHashSet(),
-          collectAccessedLambdaTypes(lambdaCheck, clazz, "testBigExtraMethodReturningLambda"));
+          collectAccessedTypes(lambdaCheck, clazz, "testBigExtraMethodReturningLambda"));
 
       assertFalse(inspector.clazz(
           "class_inliner_lambda_k_style.MainKt$testBigExtraMethodReturningLambda$1")
@@ -179,8 +182,25 @@
     });
   }
 
-  private Set<String> collectAccessedLambdaTypes(
-      Predicate<DexType> isLambdaType, ClassSubject clazz, String methodName, String... params) {
+  @Test
+  public void testDataClass() throws Exception {
+    final String mainClassName = "class_inliner_data_class.MainKt";
+    runTest("class_inliner_data_class", mainClassName, true, (app) -> {
+      DexInspector inspector = new DexInspector(app);
+      ClassSubject clazz = inspector.clazz(mainClassName);
+      assertTrue(collectAccessedTypes(
+          type -> !type.toSourceString().startsWith("java."),
+          clazz, "main", String[].class.getCanonicalName()).isEmpty());
+      assertEquals(
+          Lists.newArrayList(
+              "void kotlin.jvm.internal.Intrinsics.throwParameterIsNullException(java.lang.String)"
+          ),
+          collectStaticCalls(clazz, "main", String[].class.getCanonicalName()));
+    });
+  }
+
+  private Set<String> collectAccessedTypes(Predicate<DexType> isTypeOfInterest,
+      ClassSubject clazz, String methodName, String... params) {
     assertNotNull(clazz);
     MethodSignature signature = new MethodSignature(methodName, "void", params);
     DexCode code = clazz.method(signature).getMethod().getCode().asDexCode();
@@ -190,7 +210,7 @@
         filterInstructionKind(code, SgetObject.class)
             .map(insn -> insn.getField().getHolder())
     )
-        .filter(isLambdaType)
+        .filter(isTypeOfInterest)
         .map(DexType::toSourceString)
         .collect(Collectors.toSet());
   }
@@ -205,4 +225,14 @@
           options.enableLambdaMerging = false;
         }, inspector);
   }
+
+  private List<String> collectStaticCalls(ClassSubject clazz, String methodName, String... params) {
+    assertNotNull(clazz);
+    MethodSignature signature = new MethodSignature(methodName, "void", params);
+    DexCode code = clazz.method(signature).getMethod().getCode().asDexCode();
+    return filterInstructionKind(code, InvokeStatic.class)
+        .map(insn -> insn.getMethod().toSourceString())
+        .sorted()
+        .collect(Collectors.toList());
+  }
 }
diff --git a/src/test/kotlinR8TestResources/class_inliner_data_class/main.kt b/src/test/kotlinR8TestResources/class_inliner_data_class/main.kt
new file mode 100644
index 0000000..c8ef79a
--- /dev/null
+++ b/src/test/kotlinR8TestResources/class_inliner_data_class/main.kt
@@ -0,0 +1,21 @@
+// Copyright (c) 2018, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package class_inliner_data_class
+
+fun main(args: Array<String>) {
+    val alpha = Alpha("", "m", "")
+    alpha.right = "l"
+    alpha.left = "r"
+    alpha.rotate()
+    println("result: ${alpha.toString()}")
+}
+
+data class Alpha(var left: String, val middle: String, var right: String) {
+    fun rotate() {
+        val t = left
+        left = right
+        right = t
+    }
+}