blob: 63475b7c82ee2df54110bc1b66be6348387e3b0d [file] [log] [blame]
// Copyright (c) 2017, the R8 project authors. Please see the AUTHORS file
// for details. All rights reserved. Use of this source code is governed by a
// BSD-style license that can be found in the LICENSE file.
package com.android.tools.r8.shaking;
import com.android.tools.r8.errors.CompilationError;
import com.android.tools.r8.graph.DexApplication;
import com.android.tools.r8.graph.DexClass;
import com.android.tools.r8.graph.DexEncodedField;
import com.android.tools.r8.graph.DexEncodedMethod;
import com.android.tools.r8.graph.DexField;
import com.android.tools.r8.graph.DexMethod;
import com.android.tools.r8.graph.DexProgramClass;
import com.android.tools.r8.graph.DexProto;
import com.android.tools.r8.graph.DexString;
import com.android.tools.r8.graph.DexType;
import com.android.tools.r8.graph.DexTypeList;
import com.android.tools.r8.graph.GraphLense;
import com.android.tools.r8.graph.GraphLense.Builder;
import com.android.tools.r8.graph.KeyedDexItem;
import com.android.tools.r8.graph.PresortedComparable;
import com.android.tools.r8.logging.Log;
import com.android.tools.r8.optimize.InvokeSingleTargetExtractor;
import com.android.tools.r8.shaking.Enqueuer.AppInfoWithLiveness;
import com.android.tools.r8.utils.FieldSignatureEquivalence;
import com.android.tools.r8.utils.MethodSignatureEquivalence;
import com.android.tools.r8.utils.Timing;
import com.google.common.base.Equivalence;
import com.google.common.base.Equivalence.Wrapper;
import com.google.common.collect.Iterators;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
import it.unimi.dsi.fastutil.objects.Reference2IntMap;
import it.unimi.dsi.fastutil.objects.Reference2IntOpenHashMap;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
/**
* Merges Supertypes with a single implementation into their single subtype.
* <p>
* A common use-case for this is to merge an interface into its single implementation.
* <p>
* The class merger only fixes the structure of the graph but leaves the actual instructions
* untouched. Fixup of instructions is deferred via a {@link GraphLense} to the Ir building phase.
*/
public class SimpleClassMerger {
private final DexApplication application;
private final AppInfoWithLiveness appInfo;
private final GraphLense graphLense;
private final GraphLense.Builder renamedMembersLense = GraphLense.builder();
private final Map<DexType, DexType> mergedClasses = new IdentityHashMap<>();
private final Timing timing;
private Collection<DexMethod> invokes;
private int numberOfMerges = 0;
public SimpleClassMerger(DexApplication application, AppInfoWithLiveness appInfo,
GraphLense graphLense, Timing timing) {
this.application = application;
this.appInfo = appInfo;
this.graphLense = graphLense;
this.timing = timing;
}
private boolean isMergeCandidate(DexProgramClass clazz) {
// We can merge program classes if they are not instantiated, have a single subtype
// and we do not have to keep them.
return !clazz.isLibraryClass()
&& !appInfo.instantiatedTypes.contains(clazz.type)
&& !appInfo.isPinned(clazz.type)
&& clazz.type.getSingleSubtype() != null;
}
private void addProgramMethods(Set<Wrapper<DexMethod>> set, DexMethod method,
Equivalence<DexMethod> equivalence) {
DexClass definition = appInfo.definitionFor(method.holder);
if (definition != null && definition.isProgramClass()) {
set.add(equivalence.wrap(method));
}
}
private Collection<DexMethod> getInvokes() {
if (invokes == null) {
// Collect all reachable methods that are not within a library class. Those defined on
// library classes are known not to have program classes in their signature.
// Also filter methods that only use types from library classes in their signatures. We
// know that those won't conflict.
Set<Wrapper<DexMethod>> filteredInvokes = new HashSet<>();
Equivalence<DexMethod> equivalence = MethodSignatureEquivalence.get();
appInfo.targetedMethods.forEach(m -> addProgramMethods(filteredInvokes, m, equivalence));
invokes = filteredInvokes.stream().map(Wrapper::get).filter(this::removeNonProgram)
.collect(Collectors.toList());
}
return invokes;
}
private boolean isProgramClass(DexType type) {
if (type.isArrayType()) {
type = type.toBaseType(appInfo.dexItemFactory);
}
if (type.isClassType()) {
DexClass clazz = appInfo.definitionFor(type);
if (clazz != null && clazz.isProgramClass()) {
return true;
}
}
return false;
}
private boolean removeNonProgram(DexMethod dexMethod) {
for (DexType type : dexMethod.proto.parameters.values) {
if (isProgramClass(type)) {
return true;
}
}
return isProgramClass(dexMethod.proto.returnType);
}
public GraphLense run() {
timing.begin("merge");
GraphLense mergingGraphLense = mergeClasses(graphLense);
timing.end();
timing.begin("fixup");
GraphLense result = new TreeFixer().fixupTypeReferences(mergingGraphLense);
timing.end();
return result;
}
private GraphLense mergeClasses(GraphLense graphLense) {
for (DexProgramClass clazz : application.classes()) {
if (isMergeCandidate(clazz)) {
DexClass targetClass = appInfo.definitionFor(clazz.type.getSingleSubtype());
if (appInfo.isPinned(targetClass.type)) {
// We have to keep the target class intact, so we cannot merge it.
continue;
}
if (mergedClasses.containsKey(targetClass.type)) {
// TODO(herhut): Traverse top-down.
continue;
}
if (clazz.hasClassInitializer() && targetClass.hasClassInitializer()) {
// TODO(herhut): Handle class initializers.
if (Log.ENABLED) {
Log.info(getClass(), "Cannot merge %s into %s due to static initializers.",
clazz.toSourceString(), targetClass.toSourceString());
}
continue;
}
// Guard against the case where we have two methods that may get the same signature
// if we replace types. This is rare, so we approximate and err on the safe side here.
if (new CollisionDetector(clazz.type, targetClass.type, getInvokes(), mergedClasses)
.mayCollide()) {
if (Log.ENABLED) {
Log.info(getClass(), "Cannot merge %s into %s due to conflict.", clazz.toSourceString(),
targetClass.toSourceString());
}
continue;
}
boolean merged = new ClassMerger(clazz, targetClass).merge();
if (Log.ENABLED) {
if (merged) {
numberOfMerges++;
Log.info(getClass(), "Merged class %s into %s.", clazz.toSourceString(),
targetClass.toSourceString());
} else {
Log.info(getClass(), "Aborted merge for class %s into %s.",
clazz.toSourceString(), targetClass.toSourceString());
}
}
}
}
if (Log.ENABLED) {
Log.debug(getClass(), "Merged %d classes.", numberOfMerges);
}
return renamedMembersLense.build(application.dexItemFactory, graphLense);
}
private class ClassMerger {
private static final String CONSTRUCTOR_NAME = "constructor";
private final DexClass source;
private final DexClass target;
private final Map<DexEncodedMethod, DexEncodedMethod> deferredRenamings = new HashMap<>();
private boolean abortMerge = false;
private ClassMerger(DexClass source, DexClass target) {
this.source = source;
this.target = target;
}
public boolean merge() {
if (source.getEnclosingMethod() != null || !source.getInnerClasses().isEmpty()
|| target.getEnclosingMethod() != null || !target.getInnerClasses().isEmpty()) {
// TODO(herhut): Consider supporting merging of inner-class attributes.
return false;
}
// Merge the class [clazz] into [targetClass] by adding all methods to
// targetClass that are not currently contained.
// Step 1: Merge methods
Set<Wrapper<DexMethod>> existingMethods = new HashSet<>();
addAll(existingMethods, target.directMethods(), MethodSignatureEquivalence.get());
addAll(existingMethods, target.virtualMethods(), MethodSignatureEquivalence.get());
Collection<DexEncodedMethod> mergedDirectMethods = mergeItems(
Iterators.transform(Iterators.forArray(source.directMethods()), this::renameConstructors),
target.directMethods(),
MethodSignatureEquivalence.get(),
existingMethods,
this::renameMethod
);
Iterator<DexEncodedMethod> methods = Iterators.forArray(source.virtualMethods());
if (source.accessFlags.isInterface()) {
// If merging an interface, only merge methods that are not otherwise defined in the
// target class.
methods = Iterators.transform(methods, this::filterShadowedInterfaceMethods);
}
Collection<DexEncodedMethod> mergedVirtualMethods = mergeItems(
methods,
target.virtualMethods(),
MethodSignatureEquivalence.get(),
existingMethods,
this::abortOnNonAbstract);
if (abortMerge) {
return false;
}
// Step 2: Merge fields
Set<Wrapper<DexField>> existingFields = new HashSet<>();
addAll(existingFields, target.instanceFields(), FieldSignatureEquivalence.get());
addAll(existingFields, target.staticFields(), FieldSignatureEquivalence.get());
Collection<DexEncodedField> mergedStaticFields = mergeItems(
Iterators.forArray(source.staticFields()),
target.staticFields(),
FieldSignatureEquivalence.get(),
existingFields,
this::renameField);
Collection<DexEncodedField> mergedInstanceFields = mergeItems(
Iterators.forArray(source.instanceFields()),
target.instanceFields(),
FieldSignatureEquivalence.get(),
existingFields,
this::renameField);
// Step 3: Merge interfaces
Set<DexType> interfaces = mergeArrays(target.interfaces.values, source.interfaces.values);
// Now destructively update the class.
// Step 1: Update supertype or fix interfaces.
if (source.isInterface()) {
interfaces.remove(source.type);
} else {
assert !target.isInterface();
target.superType = source.superType;
}
target.interfaces = interfaces.isEmpty()
? DexTypeList.empty()
: new DexTypeList(interfaces.toArray(new DexType[interfaces.size()]));
// Step 2: replace fields and methods.
target.setDirectMethods(mergedDirectMethods
.toArray(new DexEncodedMethod[mergedDirectMethods.size()]));
target.setVirtualMethods(mergedVirtualMethods
.toArray(new DexEncodedMethod[mergedVirtualMethods.size()]));
target.setStaticFields(mergedStaticFields
.toArray(new DexEncodedField[mergedStaticFields.size()]));
target.setInstanceFields(mergedInstanceFields
.toArray(new DexEncodedField[mergedInstanceFields.size()]));
// Step 3: Unlink old class to ease tree shaking.
source.superType = application.dexItemFactory.objectType;
source.setDirectMethods(null);
source.setVirtualMethods(null);
source.setInstanceFields(null);
source.setStaticFields(null);
source.interfaces = DexTypeList.empty();
// Step 4: Record merging.
mergedClasses.put(source.type, target.type);
// Step 5: Make deferred renamings final.
deferredRenamings.forEach((from, to) -> renamedMembersLense.map(from.method, to.method));
return true;
}
private DexEncodedMethod filterShadowedInterfaceMethods(DexEncodedMethod m) {
DexEncodedMethod actual = appInfo.resolveMethod(target.type, m.method).asSingleTarget();
assert actual != null;
if (actual != m) {
// We will drop a method here, so record it as a potential renaming.
deferredRenamings.put(m, actual);
return null;
}
// We will keep the method, so the class better be abstract.
assert target.accessFlags.isAbstract();
return m;
}
private <T extends KeyedDexItem<S>, S extends PresortedComparable<S>> void addAll(
Collection<Wrapper<S>> collection, T[] items, Equivalence<S> equivalence) {
for (T item : items) {
collection.add(equivalence.wrap(item.getKey()));
}
}
private <T> Set<T> mergeArrays(T[] one, T[] other) {
Set<T> merged = new LinkedHashSet<>();
Collections.addAll(merged, one);
Collections.addAll(merged, other);
return merged;
}
private <T extends PresortedComparable<T>, S extends KeyedDexItem<T>> Collection<S> mergeItems(
Iterator<S> fromItems,
S[] toItems,
Equivalence<T> equivalence,
Set<Wrapper<T>> existing,
BiFunction<S, S, S> onConflict) {
HashMap<Wrapper<T>, S> methods = new HashMap<>();
// First add everything from the target class. These items are not preprocessed.
for (S item : toItems) {
methods.put(equivalence.wrap(item.getKey()), item);
}
// Now add the new methods, resolving shadowing.
addNonShadowed(fromItems, methods, equivalence, existing, onConflict);
return methods.values();
}
private <T extends PresortedComparable<T>, S extends KeyedDexItem<T>> void addNonShadowed(
Iterator<S> items,
HashMap<Wrapper<T>, S> map,
Equivalence<T> equivalence,
Set<Wrapper<T>> existing,
BiFunction<S, S, S> onConflict) {
while (items.hasNext()) {
S item = items.next();
if (item == null) {
// This item was filtered out by a preprocessing.
continue;
}
Wrapper<T> wrapped = equivalence.wrap(item.getKey());
if (existing.contains(wrapped)) {
S resolved = onConflict.apply(map.get(wrapped), item);
wrapped = equivalence.wrap(resolved.getKey());
map.put(wrapped, resolved);
} else {
map.put(wrapped, item);
}
}
}
private DexString makeMergedName(String nameString, DexType holder) {
return application.dexItemFactory
.createString(nameString + "$" + holder.toSourceString().replace('.', '$'));
}
private DexEncodedMethod abortOnNonAbstract(DexEncodedMethod existing,
DexEncodedMethod method) {
if (existing == null) {
// This is a conflict between a static and virtual method. Abort.
abortMerge = true;
return method;
}
// Ignore if we merge in an abstract method or if we override a bridge method that would
// bridge to the superclasses method.
if (method.accessFlags.isAbstract()) {
// We make a method disappear here, so record the renaming so that calls to the previous
// target get forwarded properly.
deferredRenamings.put(method, existing);
return existing;
} else if (existing.accessFlags.isBridge()) {
InvokeSingleTargetExtractor extractor = new InvokeSingleTargetExtractor();
existing.getCode().registerInstructionsReferences(extractor);
if (extractor.getTarget() != method.method) {
abortMerge = true;
}
return method;
} else {
abortMerge = true;
return existing;
}
}
private DexEncodedMethod renameConstructors(DexEncodedMethod method) {
// Only rename instance initializers.
if (!method.isInstanceInitializer()) {
return method;
}
DexType holder = method.method.holder;
DexEncodedMethod result = method
.toRenamedMethod(makeMergedName(CONSTRUCTOR_NAME, holder), application.dexItemFactory);
result.markForceInline();
deferredRenamings.put(method, result);
// Renamed constructors turn into ordinary private functions. They can be private, as
// they are only references from their direct subclass, which they were merged into.
result.accessFlags.unsetConstructor();
result.accessFlags.unsetPublic();
result.accessFlags.unsetProtected();
result.accessFlags.setPrivate();
return result;
}
private DexEncodedMethod renameMethod(DexEncodedMethod existing, DexEncodedMethod method) {
// We cannot handle renaming static initializers yet and constructors should have been
// renamed already.
assert !method.accessFlags.isConstructor();
DexType holder = method.method.holder;
String name = method.method.name.toSourceString();
DexEncodedMethod result = method
.toRenamedMethod(makeMergedName(name, holder), application.dexItemFactory);
renamedMembersLense.map(method.method, result.method);
return result;
}
private DexEncodedField renameField(DexEncodedField existing, DexEncodedField field) {
DexString oldName = field.field.name;
DexType holder = field.field.clazz;
DexEncodedField result = field
.toRenamedField(makeMergedName(oldName.toSourceString(), holder),
application.dexItemFactory);
renamedMembersLense.map(field.field, result.field);
return result;
}
}
private class TreeFixer {
private final Builder lense = GraphLense.builder();
Map<DexProto, DexProto> protoFixupCache = new IdentityHashMap<>();
private GraphLense fixupTypeReferences(GraphLense graphLense) {
// Globally substitute merged class types in protos and holders.
for (DexProgramClass clazz : appInfo.classes()) {
clazz.setDirectMethods(substituteTypesIn(clazz.directMethods()));
clazz.setVirtualMethods(substituteTypesIn(clazz.virtualMethods()));
clazz.setVirtualMethods(removeDupes(clazz.virtualMethods()));
clazz.setStaticFields(substituteTypesIn(clazz.staticFields()));
clazz.setInstanceFields(substituteTypesIn(clazz.instanceFields()));
}
// Record type renamings so instanceof and checkcast checks are also fixed.
for (DexType type : mergedClasses.keySet()) {
DexType fixed = fixupType(type);
lense.map(type, fixed);
}
return lense.build(application.dexItemFactory, graphLense);
}
private DexEncodedMethod[] removeDupes(DexEncodedMethod[] methods) {
if (methods == null) {
return null;
}
Map<DexMethod, DexEncodedMethod> filtered = new IdentityHashMap<>();
for (DexEncodedMethod method : methods) {
DexEncodedMethod previous = filtered.put(method.method, method);
if (previous != null) {
if (!previous.accessFlags.isBridge()) {
if (!method.accessFlags.isBridge()) {
throw new CompilationError(
"Class merging produced invalid result on: " + previous.toSourceString());
} else {
filtered.put(previous.method, previous);
}
}
}
}
if (filtered.size() == methods.length) {
return methods;
}
return filtered.values().toArray(new DexEncodedMethod[filtered.size()]);
}
private DexEncodedMethod[] substituteTypesIn(DexEncodedMethod[] methods) {
if (methods == null) {
return null;
}
for (int i = 0; i < methods.length; i++) {
DexEncodedMethod encodedMethod = methods[i];
DexMethod method = encodedMethod.method;
DexProto newProto = getUpdatedProto(method.proto);
DexType newHolder = fixupType(method.holder);
DexMethod newMethod = application.dexItemFactory.createMethod(newHolder, newProto,
method.name);
if (newMethod != encodedMethod.method) {
lense.map(encodedMethod.method, newMethod);
methods[i] = encodedMethod.toTypeSubstitutedMethod(newMethod);
}
}
return methods;
}
private DexEncodedField[] substituteTypesIn(DexEncodedField[] fields) {
if (fields == null) {
return null;
}
for (int i = 0; i < fields.length; i++) {
DexEncodedField encodedField = fields[i];
DexField field = encodedField.field;
DexType newType = fixupType(field.type);
DexType newHolder = fixupType(field.clazz);
DexField newField = application.dexItemFactory.createField(newHolder, newType, field.name);
if (newField != encodedField.field) {
lense.map(encodedField.field, newField);
fields[i] = encodedField.toTypeSubstitutedField(newField);
}
}
return fields;
}
private DexProto getUpdatedProto(DexProto proto) {
DexProto result = protoFixupCache.get(proto);
if (result == null) {
DexType returnType = fixupType(proto.returnType);
DexType[] arguments = fixupTypes(proto.parameters.values);
result = application.dexItemFactory.createProto(returnType, arguments);
protoFixupCache.put(proto, result);
}
return result;
}
private DexType fixupType(DexType type) {
if (type.isArrayType()) {
DexType base = type.toBaseType(application.dexItemFactory);
DexType fixed = fixupType(base);
if (base == fixed) {
return type;
} else {
return type.replaceBaseType(fixed, application.dexItemFactory);
}
}
while (mergedClasses.containsKey(type)) {
type = mergedClasses.get(type);
}
return type;
}
private DexType[] fixupTypes(DexType[] types) {
DexType[] result = new DexType[types.length];
for (int i = 0; i < result.length; i++) {
result[i] = fixupType(types[i]);
}
return result;
}
}
private static class CollisionDetector {
private static final int NOT_FOUND = 1 << (Integer.SIZE - 1);
// TODO(herhut): Maybe cache seenPositions for target classes.
private final Map<DexString, Int2IntMap> seenPositions = new IdentityHashMap<>();
private final Reference2IntMap<DexProto> targetProtoCache;
private final Reference2IntMap<DexProto> sourceProtoCache;
private final DexType source, target;
private final Collection<DexMethod> invokes;
private final Map<DexType, DexType> substituions;
private CollisionDetector(DexType source, DexType target, Collection<DexMethod> invokes,
Map<DexType, DexType> substitutions) {
this.source = source;
this.target = target;
this.invokes = invokes;
this.substituions = substitutions;
this.targetProtoCache = new Reference2IntOpenHashMap<>(invokes.size() / 2);
this.targetProtoCache.defaultReturnValue(NOT_FOUND);
this.sourceProtoCache = new Reference2IntOpenHashMap<>(invokes.size() / 2);
this.sourceProtoCache.defaultReturnValue(NOT_FOUND);
}
boolean mayCollide() {
fillSeenPositions(invokes);
// If the type is not used in methods at all, there cannot be any conflict.
if (seenPositions.isEmpty()) {
return false;
}
for (DexMethod method : invokes) {
Int2IntMap positionsMap = seenPositions.get(method.name);
if (positionsMap != null) {
int arity = method.getArity();
int previous = positionsMap.get(arity);
if (previous != NOT_FOUND) {
assert previous != 0;
int positions = computePositionsFor(method.proto, source, sourceProtoCache,
substituions);
if ((positions & previous) != 0) {
return true;
}
}
}
}
return false;
}
private void fillSeenPositions(Collection<DexMethod> invokes) {
for (DexMethod method : invokes) {
DexType[] parameters = method.proto.parameters.values;
int arity = parameters.length;
int positions = computePositionsFor(method.proto, target, targetProtoCache, substituions);
if (positions != 0) {
Int2IntMap positionsMap =
seenPositions.computeIfAbsent(method.name, k -> {
Int2IntMap result = new Int2IntOpenHashMap();
result.defaultReturnValue(NOT_FOUND);
return result;
});
int value = 0;
int previous = positionsMap.get(arity);
if (previous != NOT_FOUND) {
value = previous;
}
value |= positions;
positionsMap.put(arity, value);
}
}
}
private int computePositionsFor(DexProto proto, DexType type,
Reference2IntMap<DexProto> cache, Map<DexType, DexType> substitutions) {
int result = cache.getInt(proto);
if (result != NOT_FOUND) {
return result;
}
result = 0;
int bitsUsed = 0;
int accumulator = 0;
for (DexType aType : proto.parameters.values) {
if (substitutions != null) {
// Substitute the type with the already merged class to estimate what it will
// look like.
while (substitutions.containsKey(aType)) {
aType = substitutions.get(aType);
}
}
accumulator <<= 1;
bitsUsed++;
if (aType == type) {
accumulator |= 1;
}
// Handle overflow on 31 bit boundary.
if (bitsUsed == Integer.SIZE - 1) {
result |= accumulator;
accumulator = 0;
bitsUsed = 0;
}
}
// We also take the return type into account for potential conflicts.
DexType returnType = proto.returnType;
if (substitutions != null) {
while (substitutions.containsKey(returnType)) {
returnType = substitutions.get(returnType);
}
}
accumulator <<= 1;
if (returnType == type) {
accumulator |= 1;
}
result |= accumulator;
cache.put(proto, result);
return result;
}
}
public Collection<DexType> getRemovedClasses() {
return Collections.unmodifiableCollection(mergedClasses.keySet());
}
}