Towards handling -if rules in presence of class merging

Bug: 110141157
Change-Id: Icc54f842d3e4143ec55b0b0cf89521b4de2fda32
diff --git a/src/main/java/com/android/tools/r8/graph/DexClass.java b/src/main/java/com/android/tools/r8/graph/DexClass.java
index 75ce373..7538bb1 100644
--- a/src/main/java/com/android/tools/r8/graph/DexClass.java
+++ b/src/main/java/com/android/tools/r8/graph/DexClass.java
@@ -9,7 +9,8 @@
 import com.android.tools.r8.kotlin.KotlinInfo;
 import com.android.tools.r8.origin.Origin;
 import com.google.common.base.MoreObjects;
-import com.google.common.collect.Iterators;
+import com.google.common.base.Predicates;
+import com.google.common.collect.Iterables;
 import java.util.Arrays;
 import java.util.List;
 import java.util.Set;
@@ -99,13 +100,23 @@
   }
 
   public Iterable<DexEncodedField> fields() {
-    return () ->
-        Iterators.concat(Iterators.forArray(instanceFields), Iterators.forArray(staticFields));
+    return fields(Predicates.alwaysTrue());
+  }
+
+  public Iterable<DexEncodedField> fields(final Predicate<? super DexEncodedField> predicate) {
+    return Iterables.concat(
+        Iterables.filter(Arrays.asList(instanceFields), predicate::test),
+        Iterables.filter(Arrays.asList(staticFields), predicate::test));
   }
 
   public Iterable<DexEncodedMethod> methods() {
-    return () ->
-        Iterators.concat(Iterators.forArray(directMethods), Iterators.forArray(virtualMethods));
+    return methods(Predicates.alwaysTrue());
+  }
+
+  public Iterable<DexEncodedMethod> methods(Predicate<? super DexEncodedMethod> predicate) {
+    return Iterables.concat(
+        Iterables.filter(Arrays.asList(directMethods), predicate::test),
+        Iterables.filter(Arrays.asList(virtualMethods), predicate::test));
   }
 
   @Override
diff --git a/src/main/java/com/android/tools/r8/shaking/ProguardMemberRule.java b/src/main/java/com/android/tools/r8/shaking/ProguardMemberRule.java
index 39737ea..d1bfc9a 100644
--- a/src/main/java/com/android/tools/r8/shaking/ProguardMemberRule.java
+++ b/src/main/java/com/android/tools/r8/shaking/ProguardMemberRule.java
@@ -4,8 +4,12 @@
 package com.android.tools.r8.shaking;
 
 import com.android.tools.r8.errors.Unreachable;
+import com.android.tools.r8.graph.AppInfo;
+import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexEncodedField;
 import com.android.tools.r8.graph.DexEncodedMethod;
+import com.android.tools.r8.graph.DexField;
+import com.android.tools.r8.graph.DexMethod;
 import com.android.tools.r8.graph.DexType;
 import com.android.tools.r8.shaking.ProguardConfigurationParser.IdentifierPatternWithWildcards;
 import com.android.tools.r8.utils.StringUtils;
@@ -160,11 +164,14 @@
     return type;
   }
 
-  public boolean matches(DexEncodedField field, DexStringCache stringCache) {
+  public boolean matches(
+      DexEncodedField field, AppView<? extends AppInfo> appView, DexStringCache stringCache) {
+    DexField originalSignature = appView.graphLense().getOriginalFieldSignature(field.field);
     switch (getRuleType()) {
       case ALL:
       case ALL_FIELDS:
         // Access flags check.
+        // TODO(b/117330692): The access flags may have changed as a result of access relaxation.
         if (!getAccessFlags().containsAll(field.accessFlags)
             || !getNegatedAccessFlags().containsNone(field.accessFlags)) {
           break;
@@ -173,17 +180,18 @@
         return RootSetBuilder.containsAnnotation(annotation, field.annotations);
       case FIELD:
         // Name check.
-        String name = stringCache.lookupString(field.field.name);
+        String name = stringCache.lookupString(originalSignature.name);
         if (!getName().matches(name)) {
           break;
         }
         // Access flags check.
+        // TODO(b/117330692): The access flags may have changed as a result of access relaxation.
         if (!getAccessFlags().containsAll(field.accessFlags)
             || !getNegatedAccessFlags().containsNone(field.accessFlags)) {
           break;
         }
         // Type check.
-        if (!this.type.matches(field.field.type)) {
+        if (!getType().matches(originalSignature.type)) {
           break;
         }
         // Annotations check
@@ -200,7 +208,9 @@
     return false;
   }
 
-  public boolean matches(DexEncodedMethod method, DexStringCache stringCache) {
+  public boolean matches(
+      DexEncodedMethod method, AppView<? extends AppInfo> appView, DexStringCache stringCache) {
+    DexMethod originalSignature = appView.graphLense().getOriginalMethodSignature(method.method);
     switch (getRuleType()) {
       case ALL_METHODS:
         if (method.isClassInitializer()) {
@@ -209,6 +219,7 @@
         // Fall through for all other methods.
       case ALL:
         // Access flags check.
+        // TODO(b/117330692): The access flags may have changed as a result of access relaxation.
         if (!getAccessFlags().containsAll(method.accessFlags)
             || !getNegatedAccessFlags().containsNone(method.accessFlags)) {
           break;
@@ -217,18 +228,21 @@
         return RootSetBuilder.containsAnnotation(annotation, method.annotations);
       case METHOD:
         // Check return type.
-        if (!type.matches(method.method.proto.returnType)) {
+        // TODO(b/110141157): The name of the return type may have changed as a result of vertical
+        // class merging. We should use the original type name.
+        if (!type.matches(originalSignature.proto.returnType)) {
           break;
         }
         // Fall through for access flags, name and arguments.
       case CONSTRUCTOR:
       case INIT:
         // Name check.
-        String name = stringCache.lookupString(method.method.name);
+        String name = stringCache.lookupString(originalSignature.name);
         if (!getName().matches(name)) {
           break;
         }
         // Access flags check.
+        // TODO(b/117330692): The access flags may have changed as a result of access relaxation.
         if (!getAccessFlags().containsAll(method.accessFlags)
             || !getNegatedAccessFlags().containsNone(method.accessFlags)) {
           break;
@@ -241,23 +255,20 @@
         List<ProguardTypeMatcher> arguments = getArguments();
         if (arguments.size() == 1 && arguments.get(0).isTripleDotPattern()) {
           return true;
-        } else {
-          DexType[] parameters = method.method.proto.parameters.values;
-          if (parameters.length != arguments.size()) {
-            break;
-          }
-          int i = 0;
-          for (; i < parameters.length; i++) {
-            if (!arguments.get(i).matches(parameters[i])) {
-              break;
-            }
-          }
-          if (i == parameters.length) {
-            // All parameters matched.
-            return true;
+        }
+        DexType[] parameters = originalSignature.proto.parameters.values;
+        if (parameters.length != arguments.size()) {
+          break;
+        }
+        for (int i = 0; i < parameters.length; i++) {
+          // TODO(b/110141157): The names of the parameter types may have changed as a result of
+          // vertical class merging. We should use the original type names.
+          if (!arguments.get(i).matches(parameters[i])) {
+            return false;
           }
         }
-        break;
+        // All parameters matched.
+        return true;
       case ALL_FIELDS:
       case FIELD:
         break;
diff --git a/src/main/java/com/android/tools/r8/shaking/RootSetBuilder.java b/src/main/java/com/android/tools/r8/shaking/RootSetBuilder.java
index e6f6a17..d5d5422 100644
--- a/src/main/java/com/android/tools/r8/shaking/RootSetBuilder.java
+++ b/src/main/java/com/android/tools/r8/shaking/RootSetBuilder.java
@@ -29,6 +29,7 @@
 import com.android.tools.r8.utils.ThreadUtils;
 import com.google.common.base.Equivalence.Wrapper;
 import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Iterables;
 import com.google.common.collect.Sets;
 import java.io.PrintStream;
 import java.util.ArrayList;
@@ -164,39 +165,9 @@
     // seems not to care, so users have started to use this inconsistently. We are thus
     // inconsistent, as well, but tell them.
     // TODO(herhut): One day make this do what it says.
-    if (rule.hasInheritanceClassName()) {
-      boolean extendsExpected =
-          anySuperTypeMatches(
-              clazz.superType,
-              application::definitionFor,
-              rule.getInheritanceClassName(),
-              rule.getInheritanceAnnotation());
-      boolean implementsExpected = false;
-      if (!extendsExpected) {
-        implementsExpected =
-            anyImplementedInterfaceMatches(
-                clazz,
-                application::definitionFor,
-                rule.getInheritanceClassName(),
-                rule.getInheritanceAnnotation());
-      }
-      if (!extendsExpected && !implementsExpected) {
-        return;
-      }
-      // Warn if users got it wrong, but only warn once.
-      if (extendsExpected && !rule.getInheritanceIsExtends()) {
-        if (rulesThatUseExtendsOrImplementsWrong.add(rule)) {
-          options.reporter.warning(
-              new StringDiagnostic(
-                  "The rule `" + rule + "` uses implements but actually matches extends."));
-        }
-      } else if (implementsExpected && rule.getInheritanceIsExtends()) {
-        if (rulesThatUseExtendsOrImplementsWrong.add(rule)) {
-          options.reporter.warning(
-              new StringDiagnostic(
-                  "The rule `" + rule + "` uses extends but actually matches implements."));
-        }
-      }
+    if (rule.hasInheritanceClassName()
+        && !satisfyInheritanceRule(clazz, application::definitionFor, rule)) {
+      return;
     }
 
     if (rule.getClassNames().matches(clazz.type)) {
@@ -340,95 +311,39 @@
       Set<DexEncodedMethod> liveMethods,
       Set<DexEncodedField> liveFields) throws ExecutionException {
     application.timing.begin("Find consequent items for -if rules...");
-    Function<DexType, DexClass> definitionForWithLiveTypes =
-        type -> {
-          DexClass clazz = appView.appInfo().definitionFor(type);
-          if (clazz != null && liveTypes.contains(clazz.type)) {
-            return clazz;
-          }
-          return null;
-        };
     try {
-      List<Future<?>> futures = new ArrayList<>();
       if (rules != null) {
+        IfRuleEvaluator evaluator =
+            new IfRuleEvaluator(liveTypes, liveMethods, liveFields, executorService);
         for (ProguardConfigurationRule rule : rules) {
           assert rule instanceof ProguardIfRule;
           ProguardIfRule ifRule = (ProguardIfRule) rule;
-          // Depending on which types trigger the -if rule, the application of the subsequent
+          // Depending on which types that trigger the -if rule, the application of the subsequent
           // -keep rule may vary (due to back references). So, we need to try all pairs of -if rule
           // and live types.
-          for (DexType currentLiveType : liveTypes) {
-            DexClass currentLiveClass = appView.appInfo().definitionFor(currentLiveType);
-            if (currentLiveClass == null) {
+          for (DexType type : liveTypes) {
+            DexClass clazz = appView.appInfo().definitionFor(type);
+            if (clazz == null) {
               continue;
             }
-            if (!satisfyClassType(rule, currentLiveClass)) {
-              continue;
-            }
-            if (!satisfyAccessFlag(rule, currentLiveClass)) {
-              continue;
-            }
-            if (!satisfyAnnotation(rule, currentLiveClass)) {
-              continue;
-            }
-            if (ifRule.hasInheritanceClassName()) {
-              if (!satisfyInheritanceRule(currentLiveType, definitionForWithLiveTypes, ifRule)) {
-                // Try another live type since the current one doesn't satisfy the inheritance rule.
-                continue;
-              }
-            }
-            if (ifRule.getClassNames().matches(currentLiveType)) {
-              Collection<ProguardMemberRule> memberKeepRules = ifRule.getMemberRules();
-              if (memberKeepRules.isEmpty()) {
-                ProguardIfRule materializedRule = ifRule.materialize();
-                runPerRule(
-                    executorService, futures, materializedRule.subsequentRule, materializedRule);
-                // No member rule to satisfy. Move on to the next live type.
-                continue;
-              }
-              Set<DexDefinition> filteredFields = liveFields.stream()
-                  .filter(f -> f.field.getHolder() == currentLiveType)
-                  .collect(Collectors.toSet());
-              Set<DexDefinition> filteredMethods = liveMethods.stream()
-                  .filter(m -> m.method.getHolder() == currentLiveType)
-                  .collect(Collectors.toSet());
-              // If the number of member rules to hold is more than live members, we can't make it.
-              if (filteredFields.size() + filteredMethods.size() < memberKeepRules.size()) {
-                continue;
-              }
-              // Depending on which members trigger the -if rule, the application of the subsequent
-              // -keep rule may vary (due to back references). So, we need to try literally all
-              // combinations of live members.
-              // TODO(b/79486261): Some of those are equivalent from the point of view of -if rule.
-              Set<Set<DexDefinition>> combinationsOfMembers = Sets.combinations(
-                  Sets.union(filteredFields, filteredMethods), memberKeepRules.size());
-              for (Set<DexDefinition> combination : combinationsOfMembers) {
-                Set<DexEncodedField> fieldsInCombination =
-                    DexDefinition.filterDexEncodedField(combination.stream())
-                        .collect(Collectors.toSet());
-                Set<DexEncodedMethod> methodsInCombination =
-                    DexDefinition.filterDexEncodedMethod(combination.stream())
-                        .collect(Collectors.toSet());
-                // Member rules are combined as AND logic: if found unsatisfied member rule, this
-                // combination of live members is not a good fit.
-                boolean satisfied = true;
-                for (ProguardMemberRule memberRule : memberKeepRules) {
-                  if (!ruleSatisfiedByFields(memberRule, fieldsInCombination)
-                      && !ruleSatisfiedByMethods(memberRule, methodsInCombination)) {
-                    satisfied = false;
-                    break;
-                  }
-                }
-                if (satisfied) {
-                  ProguardIfRule materializedRule = ifRule.materialize();
-                  runPerRule(
-                      executorService, futures, materializedRule.subsequentRule, materializedRule);
-                }
+
+            // Check if the class matches the if-rule.
+            evaluator.evaluateIfRule(ifRule, clazz, clazz);
+
+            // Check if one of the types that have been merged into `clazz` satisfies the if-rule.
+            if (options.enableVerticalClassMerging && appView.verticallyMergedClasses() != null) {
+              for (DexType sourceType : appView.verticallyMergedClasses().getSourcesFor(type)) {
+                // Note that, although `sourceType` has been merged into `type`, the dex class for
+                // `sourceType` is still available until the second round of tree shaking. This way
+                // we can still retrieve the access flags of `sourceType`.
+                DexClass sourceClass = appView.appInfo().definitionFor(sourceType);
+                assert sourceClass != null;
+                evaluator.evaluateIfRule(ifRule, sourceClass, clazz);
               }
             }
           }
         }
-        ThreadUtils.awaitFutures(futures);
+        ThreadUtils.awaitFutures(evaluator.futures);
       }
     } finally {
       application.timing.end();
@@ -436,6 +351,115 @@
     return new ConsequentRootSet(noShrinking, noOptimization, noObfuscation, dependentNoShrinking);
   }
 
+  private class IfRuleEvaluator {
+
+    private final Set<DexType> liveTypes;
+    private final Set<DexEncodedMethod> liveMethods;
+    private final Set<DexEncodedField> liveFields;
+    private final ExecutorService executorService;
+
+    private final List<Future<?>> futures = new ArrayList<>();
+
+    public IfRuleEvaluator(
+        Set<DexType> liveTypes,
+        Set<DexEncodedMethod> liveMethods,
+        Set<DexEncodedField> liveFields,
+        ExecutorService executorService) {
+      this.liveTypes = liveTypes;
+      this.liveMethods = liveMethods;
+      this.liveFields = liveFields;
+      this.executorService = executorService;
+    }
+
+    /**
+     * Determines if `sourceClass` satisfies the given if-rule. If `sourceClass` has not been merged
+     * into another class, then `targetClass` is the same as `sourceClass`. Otherwise, `targetClass`
+     * denotes the class that `sourceClass` has been merged into.
+     */
+    private void evaluateIfRule(ProguardIfRule rule, DexClass sourceClass, DexClass targetClass) {
+      if (!satisfyClassType(rule, sourceClass)) {
+        return;
+      }
+      if (!satisfyAccessFlag(rule, sourceClass)) {
+        return;
+      }
+      if (!satisfyAnnotation(rule, sourceClass)) {
+        return;
+      }
+      // TODO(b/110141157): Handle the situation where the class in the extends/implements clause
+      // has been merged.
+      if (rule.hasInheritanceClassName()
+          && !satisfyInheritanceRule(sourceClass, this::definitionForWithLiveTypes, rule)) {
+        // Try another live type since the current one doesn't satisfy the inheritance rule.
+        return;
+      }
+      if (!rule.getClassNames().matches(sourceClass.type)) {
+        return;
+      }
+      Collection<ProguardMemberRule> memberKeepRules = rule.getMemberRules();
+      if (memberKeepRules.isEmpty()) {
+        materializeIfRule(rule);
+        return;
+      }
+
+      Set<DexDefinition> filteredMembers = Sets.newIdentityHashSet();
+      Iterables.addAll(
+          filteredMembers,
+          targetClass.fields(
+              f ->
+                  liveFields.contains(f)
+                      && appView.graphLense().getOriginalFieldSignature(f.field).getHolder()
+                          == sourceClass.type));
+      Iterables.addAll(
+          filteredMembers,
+          targetClass.methods(
+              m ->
+                  liveMethods.contains(m)
+                      && appView.graphLense().getOriginalMethodSignature(m.method).getHolder()
+                          == sourceClass.type));
+
+      // If the number of member rules to hold is more than live members, we can't make it.
+      if (filteredMembers.size() < memberKeepRules.size()) {
+        return;
+      }
+
+      // Depending on which members trigger the -if rule, the application of the subsequent
+      // -keep rule may vary (due to back references). So, we need to try literally all
+      // combinations of live members.
+      // TODO(b/79486261): Some of those are equivalent from the point of view of -if rule.
+      Sets.combinations(filteredMembers, memberKeepRules.size())
+          .forEach(
+              combination -> {
+                Collection<DexEncodedField> fieldsInCombination =
+                    DexDefinition.filterDexEncodedField(combination.stream())
+                        .collect(Collectors.toList());
+                Collection<DexEncodedMethod> methodsInCombination =
+                    DexDefinition.filterDexEncodedMethod(combination.stream())
+                        .collect(Collectors.toList());
+                // Member rules are combined as AND logic: if found unsatisfied member rule, this
+                // combination of live members is not a good fit.
+                boolean satisfied =
+                    memberKeepRules.stream()
+                        .allMatch(
+                            memberRule ->
+                                ruleSatisfiedByFields(memberRule, fieldsInCombination)
+                                    || ruleSatisfiedByMethods(memberRule, methodsInCombination));
+                if (satisfied) {
+                  materializeIfRule(rule);
+                }
+              });
+    }
+
+    private void materializeIfRule(ProguardIfRule rule) {
+      ProguardIfRule materializedRule = rule.materialize();
+      runPerRule(executorService, futures, materializedRule.subsequentRule, materializedRule);
+    }
+
+    private DexClass definitionForWithLiveTypes(DexType type) {
+      return liveTypes.contains(type) ? appView.appInfo().definitionFor(type) : null;
+    }
+  }
+
   private static DexDefinition testAndGetPrecondition(
       DexDefinition definition, Map<Predicate<DexDefinition>, DexDefinition> preconditionSupplier) {
     if (preconditionSupplier == null) {
@@ -497,7 +521,7 @@
     });
   }
 
-  // TODO(67934426): Test this code.
+  // TODO(b/67934426): Test this code.
   public static void writeSeeds(
       AppInfoWithLiveness appInfo, PrintStream out, Predicate<DexType> include) {
     for (DexReference seed : appInfo.getPinnedItems()) {
@@ -564,24 +588,34 @@
   }
 
   private boolean satisfyInheritanceRule(
-      DexType type,
-      Function<DexType, DexClass> definitionFor,
-      ProguardConfigurationRule rule) {
-    DexClass clazz = definitionFor.apply(type);
-    if (clazz == null) {
-      return false;
-    }
-    return
+      DexClass clazz, Function<DexType, DexClass> definitionFor, ProguardConfigurationRule rule) {
+    ProguardTypeMatcher inheritanceClassName = rule.getInheritanceClassName();
+    ProguardTypeMatcher inheritanceAnnotation = rule.getInheritanceAnnotation();
+    boolean extendsExpected =
         anySuperTypeMatches(
-            clazz.superType,
-            definitionFor,
-            rule.getInheritanceClassName(),
-            rule.getInheritanceAnnotation())
-        || anyImplementedInterfaceMatches(
-            clazz,
-            definitionFor,
-            rule.getInheritanceClassName(),
-            rule.getInheritanceAnnotation());
+            clazz.superType, definitionFor, inheritanceClassName, inheritanceAnnotation);
+    boolean implementsExpected = false;
+    if (!extendsExpected) {
+      implementsExpected =
+          anyImplementedInterfaceMatches(
+              clazz, definitionFor, inheritanceClassName, inheritanceAnnotation);
+    }
+    if (extendsExpected || implementsExpected) {
+      // Warn if users got it wrong, but only warn once.
+      if (rule.getInheritanceIsExtends()) {
+        if (implementsExpected && rulesThatUseExtendsOrImplementsWrong.add(rule)) {
+          options.reporter.warning(
+              new StringDiagnostic(
+                  "The rule `" + rule + "` uses extends but actually matches implements."));
+        }
+      } else if (extendsExpected && rulesThatUseExtendsOrImplementsWrong.add(rule)) {
+        options.reporter.warning(
+            new StringDiagnostic(
+                "The rule `" + rule + "` uses implements but actually matches extends."));
+      }
+      return true;
+    }
+    return false;
   }
 
   private boolean allRulesSatisfied(Collection<ProguardMemberRule> memberKeepRules,
@@ -606,11 +640,10 @@
   }
 
   private boolean ruleSatisfiedByMethods(
-      ProguardMemberRule rule,
-      Iterable<DexEncodedMethod> methods) {
+      ProguardMemberRule rule, Iterable<DexEncodedMethod> methods) {
     if (rule.getRuleType().includesMethods()) {
       for (DexEncodedMethod method : methods) {
-        if (rule.matches(method, dexStringCache)) {
+        if (rule.matches(method, appView, dexStringCache)) {
           return true;
         }
       }
@@ -619,22 +652,13 @@
   }
 
   private boolean ruleSatisfiedByMethods(ProguardMemberRule rule, DexEncodedMethod[] methods) {
-    if (rule.getRuleType().includesMethods()) {
-      for (DexEncodedMethod method : methods) {
-        if (rule.matches(method, dexStringCache)) {
-          return true;
-        }
-      }
-    }
-    return false;
+    return ruleSatisfiedByMethods(rule, Arrays.asList(methods));
   }
 
-  private boolean ruleSatisfiedByFields(
-      ProguardMemberRule rule,
-      Iterable<DexEncodedField> fields) {
+  private boolean ruleSatisfiedByFields(ProguardMemberRule rule, Iterable<DexEncodedField> fields) {
     if (rule.getRuleType().includesFields()) {
       for (DexEncodedField field : fields) {
-        if (rule.matches(field, dexStringCache)) {
+        if (rule.matches(field, appView, dexStringCache)) {
           return true;
         }
       }
@@ -643,14 +667,7 @@
   }
 
   private boolean ruleSatisfiedByFields(ProguardMemberRule rule, DexEncodedField[] fields) {
-    if (rule.getRuleType().includesFields()) {
-      for (DexEncodedField field : fields) {
-        if (rule.matches(field, dexStringCache)) {
-          return true;
-        }
-      }
-    }
-    return false;
+    return ruleSatisfiedByFields(rule, Arrays.asList(fields));
   }
 
   static boolean containsAnnotation(ProguardTypeMatcher classAnnotation,
@@ -680,7 +697,7 @@
       return;
     }
     for (ProguardMemberRule rule : rules) {
-      if (rule.matches(method, dexStringCache)) {
+      if (rule.matches(method, appView, dexStringCache)) {
         if (Log.ENABLED) {
           Log.verbose(getClass(), "Marking method `%s` due to `%s { %s }`.", method, context,
               rule);
@@ -699,7 +716,7 @@
       ProguardConfigurationRule context,
       DexDefinition precondition) {
     for (ProguardMemberRule rule : rules) {
-      if (rule.matches(field, dexStringCache)) {
+      if (rule.matches(field, appView, dexStringCache)) {
         if (Log.ENABLED) {
           Log.verbose(getClass(), "Marking field `%s` due to `%s { %s }`.", field, context,
               rule);
diff --git a/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java b/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
index e592a9a..d41bd69 100644
--- a/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
+++ b/src/main/java/com/android/tools/r8/shaking/VerticalClassMerger.java
@@ -49,7 +49,9 @@
 import com.android.tools.r8.utils.Timing;
 import com.google.common.base.Equivalence;
 import com.google.common.base.Equivalence.Wrapper;
+import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Maps;
 import it.unimi.dsi.fastutil.ints.Int2IntMap;
 import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
 import it.unimi.dsi.fastutil.objects.Reference2BooleanOpenHashMap;
@@ -91,9 +93,19 @@
   public static class VerticallyMergedClasses {
 
     private final Map<DexType, DexType> mergedClasses;
+    private final Map<DexType, List<DexType>> sources;
 
     private VerticallyMergedClasses(Map<DexType, DexType> mergedClasses) {
+      Map<DexType, List<DexType>> sources = Maps.newIdentityHashMap();
+      mergedClasses.forEach(
+          (source, target) ->
+              sources.computeIfAbsent(target, key -> new ArrayList<>()).add(source));
       this.mergedClasses = mergedClasses;
+      this.sources = sources;
+    }
+
+    public List<DexType> getSourcesFor(DexType type) {
+      return sources.getOrDefault(type, ImmutableList.of());
     }
 
     public DexType getTargetFor(DexType type) {
diff --git a/src/test/java/com/android/tools/r8/TestBase.java b/src/test/java/com/android/tools/r8/TestBase.java
index 6f93df4..94cc54b 100644
--- a/src/test/java/com/android/tools/r8/TestBase.java
+++ b/src/test/java/com/android/tools/r8/TestBase.java
@@ -12,10 +12,8 @@
 import com.android.tools.r8.ToolHelper.ArtCommandBuilder;
 import com.android.tools.r8.ToolHelper.DexVm;
 import com.android.tools.r8.ToolHelper.ProcessResult;
-import com.android.tools.r8.cf.code.CfInstruction;
 import com.android.tools.r8.code.Instruction;
 import com.android.tools.r8.errors.Unreachable;
-import com.android.tools.r8.graph.CfCode;
 import com.android.tools.r8.graph.DexCode;
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.SmaliWriter;
@@ -443,7 +441,7 @@
       String proguardConfig,
       Consumer<InternalOptions> optionsConsumer,
       Backend backend)
-      throws IOException, CompilationFailedException {
+      throws CompilationFailedException {
     R8Command command =
         ToolHelper.prepareR8CommandBuilder(app, emptyConsumer(backend))
             .addProguardConfiguration(ImmutableList.of(proguardConfig), Origin.unknown())
@@ -455,7 +453,7 @@
   /** Compile an application with R8 using the supplied proguard configuration. */
   protected AndroidApp compileWithR8(
       AndroidApp app, Path proguardConfig, Consumer<InternalOptions> optionsConsumer)
-      throws IOException, CompilationFailedException {
+      throws CompilationFailedException {
     R8Command command =
         ToolHelper.prepareR8CommandBuilder(app)
             .addProguardConfigurationFiles(proguardConfig)
diff --git a/src/test/java/com/android/tools/r8/shaking/ifrule/verticalclassmerging/IfRuleWithVerticalClassMerging.java b/src/test/java/com/android/tools/r8/shaking/ifrule/verticalclassmerging/IfRuleWithVerticalClassMerging.java
index 9cada6f..9d2e352 100644
--- a/src/test/java/com/android/tools/r8/shaking/ifrule/verticalclassmerging/IfRuleWithVerticalClassMerging.java
+++ b/src/test/java/com/android/tools/r8/shaking/ifrule/verticalclassmerging/IfRuleWithVerticalClassMerging.java
@@ -8,17 +8,14 @@
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertThat;
 
-import com.android.tools.r8.OutputMode;
-import com.android.tools.r8.ToolHelper;
-import com.android.tools.r8.ToolHelper.ProcessResult;
-import com.android.tools.r8.shaking.forceproguardcompatibility.ProguardCompatibilityTestBase;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.ir.optimize.Inliner.Reason;
 import com.android.tools.r8.utils.AndroidApp;
 import com.android.tools.r8.utils.InternalOptions;
 import com.android.tools.r8.utils.codeinspector.ClassSubject;
 import com.android.tools.r8.utils.codeinspector.CodeInspector;
 import com.google.common.collect.ImmutableList;
-import java.io.File;
-import java.nio.file.Path;
+import com.google.common.collect.ImmutableSet;
 import java.util.Collection;
 import java.util.List;
 import org.junit.Test;
@@ -31,7 +28,7 @@
   int a() throws ClassNotFoundException {
     // Class D is expected to be kept - vertical class merging or not. The -if rule say that if
     // the method A.a is in the output, then class D is needed.
-    String p =getClass().getPackage().getName();
+    String p = getClass().getPackage().getName();
     Class.forName(p + ".D");
     return 4;
   }
@@ -61,72 +58,67 @@
   }
 }
 
+// TODO(b/110141157):
+// - Add tests where the return type of a kept method changes.
+// - Add tests where the parameter type of a kept method changes.
+// - Add tests where the type of a kept field changes.
+// - Add tests where fields and methods get renamed due to naming conflicts.
+// - Add tests where the type in a implements/extends clause has changed.
 @RunWith(Parameterized.class)
-public class IfRuleWithVerticalClassMerging extends ProguardCompatibilityTestBase {
-  private final static List<Class> CLASSES = ImmutableList.of(
-      A.class, B.class, C.class, D.class, Main.class);
+public class IfRuleWithVerticalClassMerging extends TestBase {
 
-  private final Shrinker shrinker;
-  private final boolean enableClassMerging;
+  private static final List<Class> CLASSES =
+      ImmutableList.of(A.class, B.class, C.class, D.class, Main.class);
 
-  public IfRuleWithVerticalClassMerging(Shrinker shrinker, boolean enableClassMerging) {
-    this.shrinker = shrinker;
-    this.enableClassMerging = enableClassMerging;
+  private final Backend backend;
+  private final boolean enableVerticalClassMerging;
+
+  public IfRuleWithVerticalClassMerging(Backend backend, boolean enableVerticalClassMerging) {
+    this.backend = backend;
+    this.enableVerticalClassMerging = enableVerticalClassMerging;
   }
 
-  @Parameters(name = "shrinker: {0} classMerging: {1}")
+  @Parameters(name = "Backend: {0}, vertical class merging: {1}")
   public static Collection<Object[]> data() {
     // We don't run this on Proguard, as Proguard does not merge A into B.
     return ImmutableList.of(
-        new Object[] {Shrinker.R8, true},
-        new Object[] {Shrinker.R8, false},
-        new Object[] {Shrinker.R8_CF, true},
-        new Object[] {Shrinker.R8_CF, false});
+        new Object[] {Backend.DEX, true},
+        new Object[] {Backend.DEX, false},
+        new Object[] {Backend.CF, true},
+        new Object[] {Backend.CF, false});
   }
 
   private void configure(InternalOptions options) {
-    options.enableVerticalClassMerging = enableClassMerging;
+    options.enableVerticalClassMerging = enableVerticalClassMerging;
+
+    // TODO(b/110148109): Allow ordinary method inlining when -if rules work with inlining.
+    options.testing.validInliningReasons = ImmutableSet.of(Reason.FORCE);
   }
 
   private void check(AndroidApp app) throws Exception {
     CodeInspector inspector = new CodeInspector(app);
     ClassSubject clazzA = inspector.clazz(A.class);
-    assertEquals(!enableClassMerging, clazzA.isPresent());
+    assertEquals(!enableVerticalClassMerging, clazzA.isPresent());
     ClassSubject clazzB = inspector.clazz(B.class);
     assertThat(clazzB, isPresent());
     ClassSubject clazzD = inspector.clazz(D.class);
-    // TODO(110141157): Class D should be kept - vertical class merging or not.
-    assertEquals(!enableClassMerging, clazzD.isPresent());
-
-    ProcessResult result;
-    if (shrinker == Shrinker.R8) {
-      result = runOnArtRaw(app, Main.class.getName());
-    } else {
-      assert shrinker == Shrinker.R8_CF;
-      Path file = File.createTempFile("junit", ".zip", temp.getRoot()).toPath();
-      app.writeToZip(file, OutputMode.ClassFile);
-      result = ToolHelper.runJava(file, Main.class.getName());
-    }
-    // TODO(110141157): The code should run - vertical class merging or not.
-    assertEquals(enableClassMerging ? 1 : 0, result.exitCode);
-    if (!enableClassMerging) {
-      assertEquals("123456", result.stdout);
-    }
+    assertThat(clazzD, isPresent());
+    assertEquals("123456", runOnVM(app, Main.class, backend));
   }
 
   @Test
   public void testMergedClassInIfRule() throws Exception {
     // Class C is kept, meaning that it will not be touched.
     // Class A will be merged into class B.
-    List<String> config = ImmutableList.of(
-        "-keep class **.Main { public static void main(java.lang.String[]); }",
-        "-keep class **.C",
-        "-if class **.A",
-        "-keep class **.D",
-        "-dontobfuscate"
-    );
-
-    check(runShrinker(shrinker, CLASSES, config, this::configure));
+    String config =
+        String.join(
+            System.lineSeparator(),
+            "-keep class **.Main { public static void main(java.lang.String[]); }",
+            "-keep class **.C",
+            "-if class **.A",
+            "-keep class **.D",
+            "-dontobfuscate");
+    check(compileWithR8(readClasses(CLASSES), config, this::configure, backend));
   }
 
   @Test
@@ -134,15 +126,15 @@
     // Class C is kept, meaning that it will not be touched.
     // Class A will be merged into class B.
     // Main.main access A.x, so that field exists satisfying the if rule.
-    List<String> config = ImmutableList.of(
-        "-keep class **.Main { public static void main(java.lang.String[]); }",
-        "-keep class **.C",
-        "-if class **.A { int x; }",
-        "-keep class **.D",
-        "-dontobfuscate"
-    );
-
-    check(runShrinker(shrinker, CLASSES, config, this::configure));
+    String config =
+        String.join(
+            System.lineSeparator(),
+            "-keep class **.Main { public static void main(java.lang.String[]); }",
+            "-keep class **.C",
+            "-if class **.A { int x; }",
+            "-keep class **.D",
+            "-dontobfuscate");
+    check(compileWithR8(readClasses(CLASSES), config, this::configure, backend));
   }
 
   @Test
@@ -150,14 +142,14 @@
     // Class C is kept, meaning that it will not be touched.
     // Class A will be merged into class B.
     // Main.main access A.a(), that method exists satisfying the if rule.
-    List<String> config = ImmutableList.of(
-        "-keep class **.Main { public static void main(java.lang.String[]); }",
-        "-keep class **.C",
-        "-if class **.A { int a(); }",
-        "-keep class **.D",
-        "-dontobfuscate"
-    );
-
-    check(runShrinker(shrinker, CLASSES, config, this::configure));
+    String config =
+        String.join(
+            System.lineSeparator(),
+            "-keep class **.Main { public static void main(java.lang.String[]); }",
+            "-keep class **.C",
+            "-if class **.A { int a(); }",
+            "-keep class **.D",
+            "-dontobfuscate");
+    check(compileWithR8(readClasses(CLASSES), config, this::configure, backend));
   }
 }