Extend staticizer to support for singleton getter.

Bug: 111832046
Change-Id: Id3bd068ceba1aa4ab08844d54cc4193ba80b55fb
diff --git a/src/main/java/com/android/tools/r8/ir/optimize/staticizer/ClassStaticizer.java b/src/main/java/com/android/tools/r8/ir/optimize/staticizer/ClassStaticizer.java
index c8e9133..8682fc1 100644
--- a/src/main/java/com/android/tools/r8/ir/optimize/staticizer/ClassStaticizer.java
+++ b/src/main/java/com/android/tools/r8/ir/optimize/staticizer/ClassStaticizer.java
@@ -21,6 +21,7 @@
 import com.android.tools.r8.ir.code.Instruction;
 import com.android.tools.r8.ir.code.InvokeDirect;
 import com.android.tools.r8.ir.code.InvokeMethodWithReceiver;
+import com.android.tools.r8.ir.code.InvokeStatic;
 import com.android.tools.r8.ir.code.NewInstance;
 import com.android.tools.r8.ir.code.StaticGet;
 import com.android.tools.r8.ir.code.StaticPut;
@@ -31,6 +32,7 @@
 import com.android.tools.r8.utils.ListUtils;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Sets;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
@@ -63,6 +65,7 @@
     final AtomicInteger instancesCreated = new AtomicInteger();
     final Set<DexEncodedMethod> referencedFrom = Sets.newConcurrentHashSet();
     final AtomicReference<DexEncodedMethod> constructor = new AtomicReference<>();
+    final AtomicReference<DexEncodedMethod> getter = new AtomicReference<>();
 
     CandidateInfo(DexProgramClass candidate, DexEncodedField singletonField) {
       assert candidate != null;
@@ -131,11 +134,9 @@
                 notEligible.add(field.field.type);
               }
 
-              // Let's also assume no methods should take or return a
-              // value of this type.
+              // Don't allow methods that take a value of this type.
               for (DexEncodedMethod method : cls.methods()) {
                 DexProto proto = method.method.proto;
-                notEligible.add(proto.returnType);
                 notEligible.addAll(Arrays.asList(proto.parameters.values));
               }
 
@@ -213,20 +214,46 @@
 
     CandidateInfo receiverClassCandidateInfo = candidates.get(method.method.holder);
     Value receiverValue = code.getThis(); // NOTE: is null for static methods.
-    if (receiverClassCandidateInfo != null && receiverValue != null) {
-      // We are inside an instance method of candidate class (not an instance initializer
-      // which we will check later), check if all the references to 'this' are valid
-      // (the call will invalidate the candidate if some of them are not valid).
-      analyzeAllValueUsers(
-          receiverClassCandidateInfo, receiverValue, factory.isConstructor(method.method));
+    if (receiverClassCandidateInfo != null) {
+      if (receiverValue != null) {
+        // We are inside an instance method of candidate class (not an instance initializer
+        // which we will check later), check if all the references to 'this' are valid
+        // (the call will invalidate the candidate if some of them are not valid).
+        analyzeAllValueUsers(
+            receiverClassCandidateInfo, receiverValue, factory.isConstructor(method.method));
 
-      // If the candidate is still valid, ignore all instructions
-      // we treat as valid usages on receiver.
-      if (candidates.get(method.method.holder) != null) {
-        alreadyProcessed.addAll(receiverValue.uniqueUsers());
+        // If the candidate is still valid, ignore all instructions
+        // we treat as valid usages on receiver.
+        if (candidates.get(method.method.holder) != null) {
+          alreadyProcessed.addAll(receiverValue.uniqueUsers());
+        }
+      } else {
+        // We are inside a static method of candidate class.
+        // Check if this is a valid getter of the singleton field.
+        if (method.method.proto.returnType == method.method.holder) {
+          List<Instruction> examined = isValidGetter(receiverClassCandidateInfo, code);
+          if (examined != null) {
+            DexEncodedMethod getter = receiverClassCandidateInfo.getter.get();
+            if (getter == null) {
+              receiverClassCandidateInfo.getter.set(method);
+              // Except for static-get and return, iterate other remaining instructions if any.
+              alreadyProcessed.addAll(examined);
+            } else {
+              assert getter != method;
+              // Not sure how to deal with many getters.
+              receiverClassCandidateInfo.invalidate();
+            }
+          } else {
+            // Invalidate the candidate if it has a static method whose return type is a candidate
+            // type but doesn't return the singleton field (in a trivial way).
+            receiverClassCandidateInfo.invalidate();
+          }
+        }
       }
     }
 
+    // TODO(b/143375203): if fully implemented, the following iterator could be:
+    //   InstructionListIterator iterator = code.instructionListIterator();
     ListIterator<Instruction> iterator =
         Lists.newArrayList(code.instructionIterator()).listIterator();
     while (iterator.hasNext()) {
@@ -273,7 +300,21 @@
         CandidateInfo info = processStaticFieldRead(instruction.asStaticGet());
         if (info != null) {
           info.referencedFrom.add(method);
-          // If the candidate still valid, ignore all usages in further analysis.
+          // If the candidate is still valid, ignore all usages in further analysis.
+          Value value = instruction.outValue();
+          if (value != null) {
+            alreadyProcessed.addAll(value.aliasedUsers());
+          }
+        }
+        continue;
+      }
+
+      if (instruction.isInvokeStatic()) {
+        // Check if it is a static singleton getter.
+        CandidateInfo info = processInvokeStatic(instruction.asInvokeStatic());
+        if (info != null) {
+          info.referencedFrom.add(method);
+          // If the candidate is still valid, ignore all usages in further analysis.
           Value value = instruction.outValue();
           if (value != null) {
             alreadyProcessed.addAll(value.aliasedUsers());
@@ -450,6 +491,41 @@
     return fieldAccessed == info.singletonField;
   }
 
+  // Only allow a very trivial pattern: load the singleton field and return it, which looks like:
+  //
+  //   v <- static-get singleton-field
+  //   <assume instructions on v> // (optional)
+  //   return v // or aliased value
+  //
+  // Returns a list of instructions that are examined (as long as the method is a trivial getter).
+  private List<Instruction> isValidGetter(CandidateInfo info, IRCode code) {
+    List<Instruction> instructions = new ArrayList<>();
+    StaticGet staticGet = null;
+    for (Instruction instr : code.instructions()) {
+      if (instr.isStaticGet()) {
+        staticGet = instr.asStaticGet();
+        DexEncodedField fieldAccessed =
+            appView.appInfo().lookupStaticTarget(staticGet.getField().holder, staticGet.getField());
+        if (fieldAccessed != info.singletonField) {
+          return null;
+        }
+        instructions.add(instr);
+        continue;
+      }
+      if (instr.isAssume() || instr.isReturn()) {
+        Value v = instr.inValues().get(0).getAliasedValue();
+        if (v.isPhi() || v.definition != staticGet) {
+          return null;
+        }
+        instructions.add(instr);
+        continue;
+      }
+      // All other instructions are not allowed.
+      return null;
+    }
+    return instructions;
+  }
+
   // Static field get: can be a valid singleton field for a
   // candidate in which case we should check if all the usages of the
   // value read are eligible.
@@ -471,6 +547,23 @@
     return candidateInfo;
   }
 
+  // Static getter: if this invokes a registered getter, treat it as static field get.
+  // That is, we should check if all the usages of the out value are eligible.
+  private CandidateInfo processInvokeStatic(InvokeStatic invoke) {
+    DexType candidateType = invoke.getInvokedMethod().proto.returnType;
+    CandidateInfo candidateInfo = candidates.get(candidateType);
+    if (candidateInfo == null) {
+      return null;
+    }
+
+    if (invoke.hasOutValue()
+        && candidateInfo.getter.get() != null
+        && candidateInfo.getter.get().method == invoke.getInvokedMethod()) {
+      candidateInfo = analyzeAllValueUsers(candidateInfo, invoke.outValue(), false);
+    }
+    return candidateInfo;
+  }
+
   private CandidateInfo analyzeAllValueUsers(
       CandidateInfo candidateInfo, Value value, boolean ignoreSuperClassInitInvoke) {
     assert value != null && value == value.getAliasedValue();
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/staticizer/ClassStaticizerTest.java b/src/test/java/com/android/tools/r8/ir/optimize/staticizer/ClassStaticizerTest.java
index 99d27fa..d0a4064 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/staticizer/ClassStaticizerTest.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/staticizer/ClassStaticizerTest.java
@@ -42,9 +42,11 @@
 import com.android.tools.r8.ir.optimize.staticizer.movetohost.MoveToHostTestClass;
 import com.android.tools.r8.ir.optimize.staticizer.trivial.Simple;
 import com.android.tools.r8.ir.optimize.staticizer.trivial.SimpleWithGetter;
+import com.android.tools.r8.ir.optimize.staticizer.trivial.SimpleWithLazyInit;
 import com.android.tools.r8.ir.optimize.staticizer.trivial.SimpleWithParams;
 import com.android.tools.r8.ir.optimize.staticizer.trivial.SimpleWithPhi;
 import com.android.tools.r8.ir.optimize.staticizer.trivial.SimpleWithSideEffects;
+import com.android.tools.r8.ir.optimize.staticizer.trivial.SimpleWithThrowingGetter;
 import com.android.tools.r8.ir.optimize.staticizer.trivial.TrivialTestClass;
 import com.android.tools.r8.naming.MemberNaming.MethodSignature;
 import com.android.tools.r8.utils.InternalOptions;
@@ -79,11 +81,13 @@
         NeverInline.class,
         TrivialTestClass.class,
         Simple.class,
-        SimpleWithSideEffects.class,
-        SimpleWithParams.class,
         SimpleWithGetter.class,
+        SimpleWithLazyInit.class,
+        SimpleWithParams.class,
         SimpleWithPhi.class,
-        SimpleWithPhi.Companion.class
+        SimpleWithPhi.Companion.class,
+        SimpleWithSideEffects.class,
+        SimpleWithThrowingGetter.class
     };
     String javaOutput = runOnJava(main);
     TestRunResult result =
@@ -92,7 +96,7 @@
             .enableInliningAnnotations()
             .addKeepMainRule(main)
             .noMinification()
-            .addKeepRules("-keepattributes InnerClasses,EnclosingMethod")
+            .addKeepAttributes("InnerClasses", "EnclosingMethod")
             .addOptionsModification(this::configure)
             .allowAccessModification()
             .setMinApi(parameters.getApiLevel())
@@ -104,9 +108,9 @@
 
     assertEquals(
         Lists.newArrayList(
-            "STATIC: String trivial.Simple.bar(String)",
-            "STATIC: String trivial.Simple.foo()",
-            "STATIC: String trivial.TrivialTestClass.next()"),
+            "STATIC: String Simple.bar(String)",
+            "STATIC: String Simple.foo()",
+            "STATIC: String TrivialTestClass.next()"),
         references(clazz, "testSimple", "void"));
 
     ClassSubject simple = inspector.clazz(Simple.class);
@@ -115,10 +119,10 @@
 
     assertEquals(
         Lists.newArrayList(
-            "STATIC: String trivial.SimpleWithPhi.bar(String)",
-            "STATIC: String trivial.SimpleWithPhi.foo()",
-            "STATIC: String trivial.SimpleWithPhi.foo()",
-            "STATIC: String trivial.TrivialTestClass.next()"),
+            "STATIC: String SimpleWithPhi.bar(String)",
+            "STATIC: String SimpleWithPhi.foo()",
+            "STATIC: String SimpleWithPhi.foo()",
+            "STATIC: String TrivialTestClass.next()"),
         references(clazz, "testSimpleWithPhi", "void", "int"));
 
     ClassSubject simpleWithPhi = inspector.clazz(SimpleWithPhi.class);
@@ -127,9 +131,9 @@
 
     assertEquals(
         Lists.newArrayList(
-            "STATIC: String trivial.SimpleWithParams.bar(String)",
-            "STATIC: String trivial.SimpleWithParams.foo()",
-            "STATIC: String trivial.TrivialTestClass.next()"),
+            "STATIC: String SimpleWithParams.bar(String)",
+            "STATIC: String SimpleWithParams.foo()",
+            "STATIC: String TrivialTestClass.next()"),
         references(clazz, "testSimpleWithParams", "void"));
 
     ClassSubject simpleWithParams = inspector.clazz(SimpleWithParams.class);
@@ -138,11 +142,11 @@
 
     assertEquals(
         Lists.newArrayList(
-            "STATIC: String trivial.SimpleWithSideEffects.bar(String)",
-            "STATIC: String trivial.SimpleWithSideEffects.foo()",
-            "STATIC: String trivial.TrivialTestClass.next()",
-            "trivial.SimpleWithSideEffects trivial.SimpleWithSideEffects.INSTANCE",
-            "trivial.SimpleWithSideEffects trivial.SimpleWithSideEffects.INSTANCE"),
+            "STATIC: String SimpleWithSideEffects.bar(String)",
+            "STATIC: String SimpleWithSideEffects.foo()",
+            "STATIC: String TrivialTestClass.next()",
+            "SimpleWithSideEffects SimpleWithSideEffects.INSTANCE",
+            "SimpleWithSideEffects SimpleWithSideEffects.INSTANCE"),
         references(clazz, "testSimpleWithSideEffects", "void"));
 
     ClassSubject simpleWithSideEffects = inspector.clazz(SimpleWithSideEffects.class);
@@ -150,19 +154,49 @@
     // As its name implies, its clinit has side effects.
     assertThat(simpleWithSideEffects.clinit(), isPresent());
 
-    // TODO(b/111832046): add support for singleton instance getters.
     assertEquals(
         Lists.newArrayList(
-            "STATIC: String trivial.TrivialTestClass.next()",
-            "VIRTUAL: String trivial.SimpleWithGetter.bar(String)",
-            "VIRTUAL: String trivial.SimpleWithGetter.foo()",
-            "trivial.SimpleWithGetter trivial.SimpleWithGetter.INSTANCE",
-            "trivial.SimpleWithGetter trivial.SimpleWithGetter.INSTANCE"),
+            "STATIC: String SimpleWithGetter.bar(String)",
+            "STATIC: String SimpleWithGetter.foo()",
+            "STATIC: String TrivialTestClass.next()"),
         references(clazz, "testSimpleWithGetter", "void"));
 
     ClassSubject simpleWithGetter = inspector.clazz(SimpleWithGetter.class);
-    assertFalse(instanceMethods(simpleWithGetter).isEmpty());
-    assertThat(simpleWithGetter.clinit(), isPresent());
+    assertTrue(instanceMethods(simpleWithGetter).isEmpty());
+    assertThat(simpleWithGetter.clinit(), not(isPresent()));
+
+    assertEquals(
+        Lists.newArrayList(
+            "STATIC: SimpleWithThrowingGetter SimpleWithThrowingGetter.getInstance()",
+            "STATIC: SimpleWithThrowingGetter SimpleWithThrowingGetter.getInstance()",
+            "STATIC: String TrivialTestClass.next()",
+            "VIRTUAL: String SimpleWithThrowingGetter.bar(String)",
+            "VIRTUAL: String SimpleWithThrowingGetter.foo()"),
+        references(clazz, "testSimpleWithThrowingGetter", "void"));
+
+    ClassSubject simpleWithThrowingGetter = inspector.clazz(SimpleWithThrowingGetter.class);
+    assertFalse(instanceMethods(simpleWithThrowingGetter).isEmpty());
+    assertThat(simpleWithThrowingGetter.clinit(), isPresent());
+
+    // TODO(b/143389508): add support for lazy init in singleton instance getter.
+    assertEquals(
+        Lists.newArrayList(
+            "DIRECT: void SimpleWithLazyInit.<init>()",
+            "DIRECT: void SimpleWithLazyInit.<init>()",
+            "STATIC: String TrivialTestClass.next()",
+            "SimpleWithLazyInit SimpleWithLazyInit.INSTANCE",
+            "SimpleWithLazyInit SimpleWithLazyInit.INSTANCE",
+            "SimpleWithLazyInit SimpleWithLazyInit.INSTANCE",
+            "SimpleWithLazyInit SimpleWithLazyInit.INSTANCE",
+            "SimpleWithLazyInit SimpleWithLazyInit.INSTANCE",
+            "SimpleWithLazyInit SimpleWithLazyInit.INSTANCE",
+            "VIRTUAL: String SimpleWithLazyInit.bar(String)",
+            "VIRTUAL: String SimpleWithLazyInit.foo()"),
+        references(clazz, "testSimpleWithLazyInit", "void"));
+
+    ClassSubject simpleWithLazyInit = inspector.clazz(SimpleWithLazyInit.class);
+    assertFalse(instanceMethods(simpleWithLazyInit).isEmpty());
+    assertThat(simpleWithLazyInit.clinit(), not(isPresent()));
   }
 
   @Test
@@ -316,6 +350,7 @@
             .filter(method -> isTypeOfInterest(method.holder))
             .map(method -> "DIRECT: " + method.toSourceString()))
         .map(txt -> txt.replace("java.lang.", ""))
+        .map(txt -> txt.replace("com.android.tools.r8.ir.optimize.staticizer.trivial.", ""))
         .map(txt -> txt.replace("com.android.tools.r8.ir.optimize.staticizer.", ""))
         .sorted()
         .collect(Collectors.toList());
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/staticizer/trivial/SimpleWithLazyInit.java b/src/test/java/com/android/tools/r8/ir/optimize/staticizer/trivial/SimpleWithLazyInit.java
new file mode 100644
index 0000000..056f452
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/staticizer/trivial/SimpleWithLazyInit.java
@@ -0,0 +1,27 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.ir.optimize.staticizer.trivial;
+
+import com.android.tools.r8.NeverInline;
+
+public class SimpleWithLazyInit {
+  private static SimpleWithLazyInit INSTANCE = null;
+
+  static SimpleWithLazyInit getInstance() {
+    if (INSTANCE == null) {
+      INSTANCE = new SimpleWithLazyInit();
+    }
+    return INSTANCE;
+  }
+
+  @NeverInline
+  String foo() {
+    return bar("Simple::foo()");
+  }
+
+  @NeverInline
+  String bar(String other) {
+    return "Simple::bar(" + other + ")";
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/staticizer/trivial/SimpleWithThrowingGetter.java b/src/test/java/com/android/tools/r8/ir/optimize/staticizer/trivial/SimpleWithThrowingGetter.java
new file mode 100644
index 0000000..099d692
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/ir/optimize/staticizer/trivial/SimpleWithThrowingGetter.java
@@ -0,0 +1,27 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+package com.android.tools.r8.ir.optimize.staticizer.trivial;
+
+import com.android.tools.r8.NeverInline;
+
+public class SimpleWithThrowingGetter {
+  private static SimpleWithThrowingGetter INSTANCE = new SimpleWithThrowingGetter();
+
+  static SimpleWithThrowingGetter getInstance() {
+    if (System.currentTimeMillis() < 0) {
+      throw new AssertionError("This should not happen!");
+    }
+    return INSTANCE;
+  }
+
+  @NeverInline
+  String foo() {
+    return bar("Simple::foo()");
+  }
+
+  @NeverInline
+  String bar(String other) {
+    return "Simple::bar(" + other + ")";
+  }
+}
diff --git a/src/test/java/com/android/tools/r8/ir/optimize/staticizer/trivial/TrivialTestClass.java b/src/test/java/com/android/tools/r8/ir/optimize/staticizer/trivial/TrivialTestClass.java
index 964567c..1dca1e0 100644
--- a/src/test/java/com/android/tools/r8/ir/optimize/staticizer/trivial/TrivialTestClass.java
+++ b/src/test/java/com/android/tools/r8/ir/optimize/staticizer/trivial/TrivialTestClass.java
@@ -20,6 +20,8 @@
     test.testSimpleWithSideEffects();
     test.testSimpleWithParams();
     test.testSimpleWithGetter();
+    test.testSimpleWithThrowingGetter();
+    test.testSimpleWithLazyInit();
   }
 
   @NeverInline
@@ -60,5 +62,17 @@
     System.out.println(SimpleWithGetter.getInstance().foo());
     System.out.println(SimpleWithGetter.getInstance().bar(next()));
   }
+
+  @NeverInline
+  private void testSimpleWithThrowingGetter() {
+    System.out.println(SimpleWithThrowingGetter.getInstance().foo());
+    System.out.println(SimpleWithThrowingGetter.getInstance().bar(next()));
+  }
+
+  @NeverInline
+  private void testSimpleWithLazyInit() {
+    System.out.println(SimpleWithLazyInit.getInstance().foo());
+    System.out.println(SimpleWithLazyInit.getInstance().bar(next()));
+  }
 }