Merge "Restrict vertical class merging"
diff --git a/src/main/java/com/android/tools/r8/graph/DexClass.java b/src/main/java/com/android/tools/r8/graph/DexClass.java
index 23139c5..2c188da 100644
--- a/src/main/java/com/android/tools/r8/graph/DexClass.java
+++ b/src/main/java/com/android/tools/r8/graph/DexClass.java
@@ -10,6 +10,7 @@
 import com.android.tools.r8.origin.Origin;
 import com.android.tools.r8.utils.ThrowingConsumer;
 import com.google.common.base.MoreObjects;
+import com.google.common.collect.Iterators;
 import java.util.Arrays;
 import java.util.List;
 import java.util.function.Consumer;
@@ -96,6 +97,16 @@
     }
   }
 
+  public Iterable<DexEncodedField> fields() {
+    return () ->
+        Iterators.concat(Iterators.forArray(instanceFields), Iterators.forArray(staticFields));
+  }
+
+  public Iterable<DexEncodedMethod> methods() {
+    return () ->
+        Iterators.concat(Iterators.forArray(directMethods), Iterators.forArray(virtualMethods));
+  }
+
   @Override
   void collectMixedSectionItems(MixedSectionCollection mixedItems) {
     throw new Unreachable();
diff --git a/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java b/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
index d54c95c..2d80d83 100644
--- a/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
+++ b/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
@@ -22,6 +22,7 @@
 import com.android.tools.r8.graph.MethodAccessFlags;
 import com.android.tools.r8.graph.ParameterAnnotationsList;
 import com.android.tools.r8.graph.PresortedComparable;
+import com.android.tools.r8.graph.UseRegistry;
 import com.android.tools.r8.ir.code.Invoke.Type;
 import com.android.tools.r8.ir.synthetic.ForwardMethodSourceCode;
 import com.android.tools.r8.ir.synthetic.SynthesizedCode;
@@ -84,19 +85,56 @@
 
   // Returns a set of types that must not be merged into other types.
   private Set<DexType> getPinnedTypes(Iterable<DexProgramClass> classes) {
-    // TODO(christofferqa): Compute types from [classes] that must be pinned in order for vertical
-    // class merging to work.
-    return new HashSet<>();
+    Set<DexType> pinnedTypes = new HashSet<>();
+    for (DexProgramClass clazz : classes) {
+      for (DexEncodedMethod method : clazz.methods()) {
+        // TODO(christofferqa): Remove the invariant that the graph lense should not modify any
+        // methods from the sets alwaysInline and noSideEffects (see use of assertNotModifiedBy-
+        // Lense).
+        if (appInfo.alwaysInline.contains(method) || appInfo.noSideEffects.containsKey(method)) {
+          DexClass other = appInfo.definitionFor(method.method.proto.returnType);
+          if (other != null && other.isProgramClass()) {
+            // If we were to merge [other] into its sub class, then we would implicitly change the
+            // signature of this method, and therefore break the invariant.
+            pinnedTypes.add(other.type);
+          }
+          for (DexType parameterType : method.method.proto.parameters.values) {
+            other = appInfo.definitionFor(parameterType);
+            if (other != null && other.isProgramClass()) {
+              // If we were to merge [other] into its sub class, then we would implicitly change the
+              // signature of this method, and therefore break the invariant.
+              pinnedTypes.add(other.type);
+            }
+          }
+        }
+      }
+    }
+    return pinnedTypes;
   }
 
   private boolean isMergeCandidate(DexProgramClass clazz, Set<DexType> pinnedTypes) {
-    // We can merge program classes if they are not instantiated, have a single subtype
-    // and we do not have to keep them.
-    return !clazz.isLibraryClass()
-        && !appInfo.instantiatedTypes.contains(clazz.type)
-        && !appInfo.isPinned(clazz.type)
-        && !pinnedTypes.contains(clazz.type)
-        && clazz.type.getSingleSubtype() != null;
+    if (clazz.isLibraryClass()
+        || appInfo.instantiatedTypes.contains(clazz.type)
+        || appInfo.isPinned(clazz.type)
+        || pinnedTypes.contains(clazz.type)
+        || clazz.type.getSingleSubtype() == null) {
+      return false;
+    }
+    for (DexEncodedField field : clazz.fields()) {
+      if (appInfo.isPinned(field.field)) {
+        return false;
+      }
+    }
+    for (DexEncodedMethod method : clazz.methods()) {
+      if (appInfo.isPinned(method.method)) {
+        return false;
+      }
+      if (method.isInstanceInitializer() && disallowInlining(method)) {
+        // Cannot guarantee that markForceInline() will work.
+        return false;
+      }
+    }
+    return true;
   }
 
   private void addProgramMethods(Set<Wrapper<DexMethod>> set, DexMethod method,
@@ -298,8 +336,7 @@
       // targetClass that are not currently contained.
       // Step 1: Merge methods
       Set<Wrapper<DexMethod>> existingMethods = new HashSet<>();
-      addAll(existingMethods, target.directMethods(), MethodSignatureEquivalence.get());
-      addAll(existingMethods, target.virtualMethods(), MethodSignatureEquivalence.get());
+      addAll(existingMethods, target.methods(), MethodSignatureEquivalence.get());
 
       List<DexEncodedMethod> directMethods = new ArrayList<>();
       for (DexEncodedMethod directMethod : source.directMethods()) {
@@ -380,8 +417,7 @@
       }
       // Step 2: Merge fields
       Set<Wrapper<DexField>> existingFields = new HashSet<>();
-      addAll(existingFields, target.instanceFields(), FieldSignatureEquivalence.get());
-      addAll(existingFields, target.staticFields(), FieldSignatureEquivalence.get());
+      addAll(existingFields, target.fields(), FieldSignatureEquivalence.get());
       Collection<DexEncodedField> mergedStaticFields = mergeItems(
           Iterators.forArray(source.staticFields()),
           target.staticFields(),
@@ -468,7 +504,7 @@
     }
 
     private <T extends KeyedDexItem<S>, S extends PresortedComparable<S>> void addAll(
-        Collection<Wrapper<S>> collection, T[] items, Equivalence<S> equivalence) {
+        Collection<Wrapper<S>> collection, Iterable<T> items, Equivalence<S> equivalence) {
       for (T item : items) {
         collection.add(equivalence.wrap(item.getKey()));
       }
@@ -836,6 +872,86 @@
     }
   }
 
+  private static boolean disallowInlining(DexEncodedMethod method) {
+    // TODO(christofferqa): Determine the situations where markForceInline() may fail, and ensure
+    // that we always return true here in these cases.
+    MethodInlineDecision registry = new MethodInlineDecision();
+    method.getCode().registerCodeReferences(registry);
+    return registry.isInliningDisallowed();
+  }
+
+  private static class MethodInlineDecision extends UseRegistry {
+    private boolean disallowInlining = false;
+
+    public boolean isInliningDisallowed() {
+      return disallowInlining;
+    }
+
+    private boolean allowInlining() {
+      return true;
+    }
+
+    private boolean disallowInlining() {
+      disallowInlining = true;
+      return true;
+    }
+
+    @Override
+    public boolean registerInvokeInterface(DexMethod method) {
+      return disallowInlining();
+    }
+
+    @Override
+    public boolean registerInvokeVirtual(DexMethod method) {
+      return disallowInlining();
+    }
+
+    @Override
+    public boolean registerInvokeDirect(DexMethod method) {
+      return allowInlining();
+    }
+
+    @Override
+    public boolean registerInvokeStatic(DexMethod method) {
+      return allowInlining();
+    }
+
+    @Override
+    public boolean registerInvokeSuper(DexMethod method) {
+      return allowInlining();
+    }
+
+    @Override
+    public boolean registerInstanceFieldWrite(DexField field) {
+      return allowInlining();
+    }
+
+    @Override
+    public boolean registerInstanceFieldRead(DexField field) {
+      return allowInlining();
+    }
+
+    @Override
+    public boolean registerNewInstance(DexType type) {
+      return allowInlining();
+    }
+
+    @Override
+    public boolean registerStaticFieldRead(DexField field) {
+      return allowInlining();
+    }
+
+    @Override
+    public boolean registerStaticFieldWrite(DexField field) {
+      return allowInlining();
+    }
+
+    @Override
+    public boolean registerTypeReference(DexType type) {
+      return allowInlining();
+    }
+  }
+
   public Collection<DexType> getRemovedClasses() {
     return Collections.unmodifiableCollection(mergedClasses.keySet());
   }