Relax assertions in proto optimization
Change-Id: I1e3a70782f5d016a54a99bd00e786e595cb75cbd
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoReferences.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoReferences.java
index 932e81d..812c0ee 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoReferences.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/ProtoReferences.java
@@ -196,11 +196,16 @@
return isDynamicMethodBridge(method.getReference());
}
- @SuppressWarnings("ReferenceEquality")
public boolean isFindLiteExtensionByNumberMethod(DexMethod method) {
- return method.proto == findLiteExtensionByNumberProto
- && method.name.startsWith(findLiteExtensionByNumberName)
- && method.holder != extensionRegistryLiteType;
+ return method.getProto().isIdenticalTo(findLiteExtensionByNumberProto)
+ && method.getName().startsWith(findLiteExtensionByNumberName)
+ && method.getHolderType().isNotIdenticalTo(extensionRegistryLiteType);
+ }
+
+ public boolean isFindLiteExtensionByNumberBridgeMethod(DexMethod method) {
+ return method.getProto().isIdenticalTo(findLiteExtensionByNumberProto)
+ && method.getName().startsWith(findLiteExtensionByNumberName)
+ && method.getHolderType().isIdenticalTo(extensionRegistryLiteType);
}
public boolean isFindLiteExtensionByNumberMethod(ProgramMethod method) {
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoEnqueuerExtension.java b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoEnqueuerExtension.java
index c81b099..c9a6021 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoEnqueuerExtension.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/proto/schema/ProtoEnqueuerExtension.java
@@ -5,11 +5,11 @@
package com.android.tools.r8.ir.analysis.proto.schema;
import static com.android.tools.r8.graph.DexProgramClass.asProgramClassOrNull;
+import static com.android.tools.r8.utils.MapUtils.ignoreKey;
import com.android.tools.r8.graph.AppInfoWithClassHierarchy;
import com.android.tools.r8.graph.AppView;
import com.android.tools.r8.graph.DexClassAndField;
-import com.android.tools.r8.graph.DexEncodedField;
import com.android.tools.r8.graph.DexEncodedMethod;
import com.android.tools.r8.graph.DexField;
import com.android.tools.r8.graph.DexMethod;
@@ -45,6 +45,8 @@
import com.android.tools.r8.utils.BitUtils;
import com.android.tools.r8.utils.OptionalBool;
import com.android.tools.r8.utils.Timing;
+import com.android.tools.r8.utils.collections.ProgramFieldMap;
+import com.android.tools.r8.utils.collections.ProgramFieldSet;
import com.android.tools.r8.utils.collections.ProgramMethodSet;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
@@ -52,7 +54,6 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
-import java.util.function.Consumer;
import java.util.function.Predicate;
// TODO(b/112437944): Handle cycles in the graph + add a test that fails with the current
@@ -224,7 +225,7 @@
* inlining.
*/
private void populateExtensionGraph(Enqueuer enqueuer) {
- collectExtensionFields()
+ collectProgramExtensionFields()
.forEach(
(clazz, extensionFields) -> {
ProgramMethod clinit = clazz.getProgramClassInitializer();
@@ -234,9 +235,9 @@
}
IRCode code = clinit.buildIR(appView, MethodConversionOptions.nonConverting());
- Map<DexEncodedField, StaticPut> uniqueStaticPuts =
+ ProgramFieldMap<StaticPut> uniqueStaticPuts =
IRCodeUtils.findUniqueStaticPuts(appView, code, extensionFields);
- for (DexEncodedField extensionField : extensionFields) {
+ for (ProgramField extensionField : extensionFields) {
StaticPut staticPut = uniqueStaticPuts.get(extensionField);
if (staticPut == null) {
// Could happen after we have optimized the code.
@@ -252,37 +253,33 @@
}
/**
- * Finds the extension fields referenced in the methods in {@link
+ * Finds the program extension fields referenced in the methods in {@link
* #findLiteExtensionByNumberMethods}.
*/
- private Map<DexProgramClass, Set<DexEncodedField>> collectExtensionFields() {
- Map<DexProgramClass, Set<DexEncodedField>> extensionFieldsByClass = new IdentityHashMap<>();
+ private Map<DexProgramClass, ProgramFieldSet> collectProgramExtensionFields() {
+ Map<DexProgramClass, ProgramFieldSet> extensionFieldsByClass = new IdentityHashMap<>();
for (ProgramMethod findLiteExtensionByNumberMethod : findLiteExtensionByNumberMethods) {
IRCode code =
findLiteExtensionByNumberMethod.buildIR(appView, MethodConversionOptions.nonConverting());
Set<Phi> seenPhis = Sets.newIdentityHashSet();
for (BasicBlock block : code.blocks(BasicBlock::isReturnBlock)) {
Value returnValue = block.exit().asReturn().returnValue();
- collectExtensionFieldsFromValue(
- returnValue,
- seenPhis,
- field ->
- extensionFieldsByClass
- .computeIfAbsent(field.getHolder(), ignore -> Sets.newIdentityHashSet())
- .add(field.getDefinition()));
+ collectExtensionFieldsFromValue(returnValue, seenPhis, extensionFieldsByClass);
}
}
return extensionFieldsByClass;
}
private void collectExtensionFieldsFromValue(
- Value returnValue, Set<Phi> seenPhis, Consumer<ProgramField> consumer) {
+ Value returnValue,
+ Set<Phi> seenPhis,
+ Map<DexProgramClass, ProgramFieldSet> extensionFieldsByClass) {
Value root = returnValue.getAliasedValue();
if (root.isPhi()) {
Phi phi = root.asPhi();
if (seenPhis.add(phi)) {
for (Value operand : phi.getOperands()) {
- collectExtensionFieldsFromValue(operand, seenPhis, consumer);
+ collectExtensionFieldsFromValue(operand, seenPhis, extensionFieldsByClass);
}
}
return;
@@ -292,22 +289,34 @@
return;
}
- Instruction definition = root.definition;
+ Instruction definition = root.getDefinition();
if (definition.isStaticGet()) {
StaticGet staticGet = definition.asStaticGet();
DexClassAndField field =
appView.appInfo().resolveField(staticGet.getField()).getResolutionPair();
- if (field == null || !field.isProgramField()) {
- assert false;
+ if (field == null) {
+ assert staticGet.getField().getType().isIdenticalTo(references.generatedExtensionType);
return;
}
- consumer.accept(field.asProgramField());
- return;
+ if (field.isProgramField()) {
+ ProgramField programField = field.asProgramField();
+ extensionFieldsByClass
+ .computeIfAbsent(programField.getHolder(), ignoreKey(ProgramFieldSet::create))
+ .add(programField);
+ } else {
+ assert false;
+ }
+ } else {
+ assert verifyIsCallToFindLiteExtensionByNumberMethod(definition);
}
+ }
- assert definition.isInvokeMethod()
- && references.isFindLiteExtensionByNumberMethod(
- definition.asInvokeMethod().getInvokedMethod());
+ private boolean verifyIsCallToFindLiteExtensionByNumberMethod(Instruction instruction) {
+ assert instruction.isInvokeMethod();
+ DexMethod invokedMethod = instruction.asInvokeMethod().getInvokedMethod();
+ assert references.isFindLiteExtensionByNumberMethod(invokedMethod)
+ || references.isFindLiteExtensionByNumberBridgeMethod(invokedMethod);
+ return true;
}
/**
diff --git a/src/main/java/com/android/tools/r8/ir/code/IRCodeUtils.java b/src/main/java/com/android/tools/r8/ir/code/IRCodeUtils.java
index 7caf4ec..e918ad4 100644
--- a/src/main/java/com/android/tools/r8/ir/code/IRCodeUtils.java
+++ b/src/main/java/com/android/tools/r8/ir/code/IRCodeUtils.java
@@ -6,14 +6,14 @@
import com.android.tools.r8.graph.AppInfoWithClassHierarchy;
import com.android.tools.r8.graph.AppView;
-import com.android.tools.r8.graph.DexEncodedField;
import com.android.tools.r8.graph.DexItemFactory;
+import com.android.tools.r8.graph.ProgramField;
import com.android.tools.r8.utils.DequeUtils;
+import com.android.tools.r8.utils.collections.ProgramFieldMap;
+import com.android.tools.r8.utils.collections.ProgramFieldSet;
import com.google.common.collect.Sets;
import java.util.ArrayDeque;
import java.util.Deque;
-import java.util.IdentityHashMap;
-import java.util.Map;
import java.util.Set;
public class IRCodeUtils {
@@ -44,15 +44,12 @@
* Finds the single assignment to the fields in {@param fields} in {@param code}. Note that this
* does not guarantee that the assignments found dominate all the normal exits.
*/
- public static Map<DexEncodedField, StaticPut> findUniqueStaticPuts(
- AppView<? extends AppInfoWithClassHierarchy> appView,
- IRCode code,
- Set<DexEncodedField> fields) {
- Set<DexEncodedField> writtenMoreThanOnce = Sets.newIdentityHashSet();
- Map<DexEncodedField, StaticPut> uniqueStaticPuts = new IdentityHashMap<>();
+ public static ProgramFieldMap<StaticPut> findUniqueStaticPuts(
+ AppView<? extends AppInfoWithClassHierarchy> appView, IRCode code, ProgramFieldSet fields) {
+ ProgramFieldSet writtenMoreThanOnce = ProgramFieldSet.create();
+ ProgramFieldMap<StaticPut> uniqueStaticPuts = ProgramFieldMap.create();
for (StaticPut staticPut : code.<StaticPut>instructions(Instruction::isStaticPut)) {
- DexEncodedField field =
- appView.appInfo().resolveField(staticPut.getField()).getResolvedField();
+ ProgramField field = appView.appInfo().resolveField(staticPut.getField()).getProgramField();
if (field == null || !fields.contains(field) || writtenMoreThanOnce.contains(field)) {
continue;
}