Use valid target for field rebinding

Bug: 171413213
Change-Id: Ib8f2fb364203a0874dd3044f0c4c355624357345
diff --git a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
index 0512e69..8e1ee32 100644
--- a/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
+++ b/src/main/java/com/android/tools/r8/ir/conversion/LensCodeRewriter.java
@@ -89,6 +89,7 @@
 import com.android.tools.r8.ir.code.ValueType;
 import com.android.tools.r8.ir.optimize.enums.EnumUnboxer;
 import com.android.tools.r8.logging.Log;
+import com.android.tools.r8.optimize.MemberRebindingAnalysis;
 import com.google.common.collect.Sets;
 import java.util.ArrayList;
 import java.util.HashSet;
@@ -594,7 +595,7 @@
       if (definition != null) {
         DexClassAndField field = DexClassAndField.create(holder, definition);
         if (AccessControl.isMemberAccessible(field, holder, context, appView).isTrue()) {
-          return lookup.getReboundReference();
+          return MemberRebindingAnalysis.validMemberRebindingTargetFor(appView, field, reference);
         }
       }
     }
diff --git a/src/main/java/com/android/tools/r8/optimize/MemberRebindingAnalysis.java b/src/main/java/com/android/tools/r8/optimize/MemberRebindingAnalysis.java
index 9d4121c..fb120e0 100644
--- a/src/main/java/com/android/tools/r8/optimize/MemberRebindingAnalysis.java
+++ b/src/main/java/com/android/tools/r8/optimize/MemberRebindingAnalysis.java
@@ -5,7 +5,8 @@
 
 import com.android.tools.r8.graph.AppView;
 import com.android.tools.r8.graph.DexClass;
-import com.android.tools.r8.graph.DexEncodedField;
+import com.android.tools.r8.graph.DexClassAndField;
+import com.android.tools.r8.graph.DexDefinitionSupplier;
 import com.android.tools.r8.graph.DexEncodedMethod;
 import com.android.tools.r8.graph.DexField;
 import com.android.tools.r8.graph.DexMethod;
@@ -48,7 +49,7 @@
   }
 
   private DexMethod validTargetFor(DexMethod target, DexMethod original) {
-    DexClass clazz = appView.definitionFor(target.holder);
+    DexClass clazz = appView.definitionFor(target.getHolderType());
     assert clazz != null;
     if (clazz.isProgramClass()) {
       return target;
@@ -56,36 +57,38 @@
     DexType newHolder;
     if (clazz.isInterface()) {
       newHolder =
-          firstLibraryClassForInterfaceTarget(target, original.holder, DexClass::lookupMethod);
+          firstLibraryClassForInterfaceTarget(
+              appView, target, original.holder, DexClass::lookupMethod);
     } else {
-      newHolder = firstLibraryClass(target.holder, original.holder);
+      newHolder = firstLibraryClass(appView, target.getHolderType(), original.getHolderType());
     }
     return newHolder == null
         ? original
         : appView.dexItemFactory().createMethod(newHolder, original.proto, original.name);
   }
 
-  private DexField validTargetFor(DexField target, DexField original,
-      BiFunction<DexClass, DexField, DexEncodedField> lookup) {
-    DexClass clazz = appView.definitionFor(target.holder);
-    assert clazz != null;
-    if (clazz.isProgramClass()) {
-      return target;
+  public static DexField validMemberRebindingTargetFor(
+      DexDefinitionSupplier definitions, DexClassAndField field, DexField original) {
+    DexClass clazz = field.getHolder();
+    if (field.isProgramField()) {
+      return field.getReference();
     }
-    DexType newHolder;
-    if (clazz.isInterface()) {
-      newHolder = firstLibraryClassForInterfaceTarget(target, original.holder, lookup);
-    } else {
-      newHolder = firstLibraryClass(target.holder, original.holder);
-    }
-    return newHolder == null
-        ? original
-        : appView.dexItemFactory().createField(newHolder, original.type, original.name);
+    DexType newHolder =
+        clazz.isInterface()
+            ? firstLibraryClassForInterfaceTarget(
+                definitions, field.getReference(), original.getHolderType(), DexClass::lookupField)
+            : firstLibraryClass(definitions, field.getHolderType(), original.getHolderType());
+    return newHolder != null
+        ? field.getReference().withHolder(newHolder, definitions.dexItemFactory())
+        : original;
   }
 
-  private <T> DexType firstLibraryClassForInterfaceTarget(T target, DexType current,
+  private static <T> DexType firstLibraryClassForInterfaceTarget(
+      DexDefinitionSupplier definitions,
+      T target,
+      DexType current,
       BiFunction<DexClass, T, ?> lookup) {
-    DexClass clazz = appView.definitionFor(current);
+    DexClass clazz = definitions.definitionFor(current);
     if (clazz == null) {
       return null;
     }
@@ -95,14 +98,16 @@
       return current;
     }
     if (clazz.superType != null) {
-      DexType matchingSuper = firstLibraryClassForInterfaceTarget(target, clazz.superType, lookup);
+      DexType matchingSuper =
+          firstLibraryClassForInterfaceTarget(definitions, target, clazz.superType, lookup);
       if (matchingSuper != null) {
         // Found in supertype, return first library class.
         return clazz.isNotProgramClass() ? current : matchingSuper;
       }
     }
     for (DexType iface : clazz.interfaces.values) {
-      DexType matchingIface = firstLibraryClassForInterfaceTarget(target, iface, lookup);
+      DexType matchingIface =
+          firstLibraryClassForInterfaceTarget(definitions, target, iface, lookup);
       if (matchingIface != null) {
         // Found in interface, return first library class.
         return clazz.isNotProgramClass() ? current : matchingIface;
@@ -111,11 +116,12 @@
     return null;
   }
 
-  private DexType firstLibraryClass(DexType top, DexType bottom) {
-    assert appView.definitionFor(top).isNotProgramClass();
-    DexClass searchClass = appView.definitionFor(bottom);
+  private static DexType firstLibraryClass(
+      DexDefinitionSupplier definitions, DexType top, DexType bottom) {
+    assert definitions.definitionFor(top).isNotProgramClass();
+    DexClass searchClass = definitions.definitionFor(bottom);
     while (searchClass.isProgramClass()) {
-      searchClass = appView.definitionFor(searchClass.superType);
+      searchClass = definitions.definitionFor(searchClass.superType);
     }
     return searchClass.type;
   }
diff --git a/src/test/java/com/android/tools/r8/memberrebinding/LibraryMemberRebindingTest.java b/src/test/java/com/android/tools/r8/memberrebinding/LibraryMemberRebindingTest.java
new file mode 100644
index 0000000..ed90756
--- /dev/null
+++ b/src/test/java/com/android/tools/r8/memberrebinding/LibraryMemberRebindingTest.java
@@ -0,0 +1,123 @@
+// Copyright (c) 2020, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.memberrebinding;
+
+import com.android.tools.r8.R8TestCompileResult;
+import com.android.tools.r8.TestBase;
+import com.android.tools.r8.TestParameters;
+import com.android.tools.r8.TestParametersCollection;
+import java.nio.file.Path;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class LibraryMemberRebindingTest extends TestBase {
+
+  private final TestParameters parameters;
+
+  @Parameters(name = "{0}")
+  public static TestParametersCollection data() {
+    return TestBase.getTestParameters().withAllRuntimesAndApiLevels().build();
+  }
+
+  public LibraryMemberRebindingTest(TestParameters parameters) {
+    this.parameters = parameters;
+  }
+
+  @Test
+  public void testWithEmptyA() throws Exception {
+    Path compileTimeLibrary = compileLibraryForAppCompilation(LibraryB.class, LibraryA.class);
+    test(compileTimeLibrary, compileTimeLibrary);
+  }
+
+  @Test
+  public void testWithEmptyB() throws Exception {
+    Path compileTimeLibrary = compileLibraryForAppCompilation(LibraryA.class, LibraryB.class);
+    test(compileTimeLibrary, compileTimeLibrary);
+  }
+
+  @Test
+  public void testWithEmptyBOnlyAtCompileTime() throws Exception {
+    Path compileTimeLibrary = compileLibraryForAppCompilation(LibraryA.class, LibraryB.class);
+    Path runtimeLibrary = compileLibraryForAppCompilation(LibraryB.class, LibraryA.class);
+    test(compileTimeLibrary, runtimeLibrary);
+  }
+
+  private void test(Path compileTimeLibrary, Path runtimeLibrary) throws Exception {
+    testForR8(parameters.getBackend())
+        .addProgramClasses(TestClass.class)
+        .addKeepMainRule(TestClass.class)
+        .addLibraryFiles(compileTimeLibrary)
+        .addDefaultRuntimeLibrary(parameters)
+        .setMinApi(parameters.getApiLevel())
+        .compile()
+        .apply(compileResult -> configureRunClasspath(compileResult, runtimeLibrary))
+        .run(parameters.getRuntime(), TestClass.class)
+        .assertSuccessWithOutputLines("42");
+  }
+
+  private Path compileLibraryForAppCompilation(
+      Class<?> nonEmptyLibraryClass, Class<?> emptyLibraryClass) throws Exception {
+    return testForR8(Backend.CF)
+        .addProgramClasses(nonEmptyLibraryClass)
+        .addProgramClassFileData(
+            transformer(emptyLibraryClass)
+                .removeFields(
+                    (int access, String name, String descriptor, String signature, Object value) ->
+                        true)
+                .removeMethods(
+                    (int access,
+                        String name,
+                        String descriptor,
+                        String signature,
+                        String[] exceptions) -> !name.equals("<init>"))
+                .transform())
+        .addKeepAllClassesRule()
+        .setMinApi(parameters.getApiLevel())
+        .compile()
+        .writeToZip();
+  }
+
+  private void configureRunClasspath(R8TestCompileResult compileResult, Path library)
+      throws Exception {
+    if (parameters.isCfRuntime()) {
+      compileResult.addRunClasspathFiles(library);
+    } else {
+      compileResult.addRunClasspathFiles(
+          testForD8()
+              .addProgramFiles(library)
+              .setMinApi(parameters.getApiLevel())
+              .compile()
+              .writeToZip());
+    }
+  }
+
+  static class TestClass {
+
+    public static void main(String[] args) {
+      System.out.println(LibraryB.f + LibraryB.m());
+    }
+  }
+
+  static class LibraryA {
+
+    public static int f = 21;
+
+    public static int m() {
+      return 21;
+    }
+  }
+
+  static class LibraryB extends LibraryA {
+
+    public static int f = 21;
+
+    public static int m() {
+      return 21;
+    }
+  }
+}