Refactor on-demand link on ReferenceTypeLatticeElement

This is a precursor to fix not widening when narrowing.

The current implementation bakes in duality by having primary and
variant as on-demand fields. The current nullability lattice have
three values and we will add Bottom shortly, so it will be hard to
maintain the fields and have on-demand links back and forth.

The on-demand links is changed to be a shared pointer struct. This
encapsulates creating and linking all similar reference lattice
elements. This makes the synchronicity at bit easier to digest as well.

Bug: 125492155
Change-Id: Iada100551f2865ee23684e9d56a775e75a049e81
diff --git a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
index 243557c..963d503 100644
--- a/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
+++ b/src/main/java/com/android/tools/r8/graph/DexItemFactory.java
@@ -3,8 +3,6 @@
 // BSD-style license that can be found in the LICENSE file.
 package com.android.tools.r8.graph;
 
-import static com.android.tools.r8.ir.analysis.type.Nullability.maybeNull;
-
 import com.android.tools.r8.dex.Constants;
 import com.android.tools.r8.dex.Marker;
 import com.android.tools.r8.graph.DexDebugEvent.AdvanceLine;
@@ -1074,42 +1072,53 @@
 
   public ReferenceTypeLatticeElement createReferenceTypeLatticeElement(
       DexType type, Nullability nullability, DexDefinitionSupplier definitions) {
-    ReferenceTypeLatticeElement primary = referenceTypeLatticeElements.get(type);
-    if (primary != null) {
-      return nullability == primary.nullability()
-          ? primary
-          : primary.getOrCreateVariant(nullability);
-    }
-    synchronized (type) {
-      primary = referenceTypeLatticeElements.get(type);
-      if (primary == null) {
-        if (type.isClassType()) {
-          if (!type.isUnknown() && type.isInterface()) {
-            primary = new ClassTypeLatticeElement(objectType, maybeNull(), ImmutableSet.of(type));
-          } else {
-            // In theory, `interfaces` is the least upper bound of implemented interfaces.
-            // It is expensive to walk through type hierarchy; collect implemented interfaces; and
-            // compute the least upper bound of two interface sets. Hence, lazy computations.
-            // Most likely during lattice join. See {@link ClassTypeLatticeElement#getInterfaces}.
-            primary = new ClassTypeLatticeElement(type, maybeNull(), definitions);
-          }
-        } else {
-          assert type.isArrayType();
-          DexType elementType = type.toArrayElementType(this);
-          TypeLatticeElement elementTypeLattice =
-              TypeLatticeElement.fromDexType(elementType, maybeNull(), definitions, true);
-          primary = new ArrayTypeLatticeElement(elementTypeLattice, maybeNull());
-        }
-        referenceTypeLatticeElements.put(type, primary);
+    // Class case:
+    // If two concurrent threads will try to create the same class-type the concurrent hash map will
+    // synchronize on the type in .computeIfAbsent and only a single class type is created.
+    //
+    // Array case:
+    // Arrays will create a lattice element for its base type thus we take special care here.
+    // Multiple threads may race recursively to create a base type. We have two cases:
+    // (i)  If base type is class type and the threads will race to create the class type but only a
+    //      single one will be created (Class case).
+    // (ii) If base is ArrayLattice case we can use our induction hypothesis to get that only one
+    //      element is created for us up to this case. Threads will now race to return from the
+    //      latest recursive call and fight to get access to .computeIfAbsent to add the
+    //      ArrayTypeLatticeElement but only one will enter. The property that only one
+    //      ArrayTypeLatticeElement is created per level therefore holds inductively.
+    TypeLatticeElement memberType = null;
+    if (type.isArrayType()) {
+      ReferenceTypeLatticeElement existing = referenceTypeLatticeElements.get(type);
+      if (existing != null) {
+        return existing.getOrCreateVariant(nullability);
       }
-      // Make sure that canonicalized version is MAYBE_NULL variant.
-      assert primary.nullability().isMaybeNull();
+      memberType =
+          TypeLatticeElement.fromDexType(
+              type.toArrayElementType(this), Nullability.maybeNull(), definitions, true);
     }
-    // The call to getOrCreateVariant can't be under the DexType synchronized block, since that
-    // can create deadlocks with ClassTypeLatticeElement::getInterfaces (both lock on the lattice).
-    return nullability == primary.nullability()
-        ? primary
-        : primary.getOrCreateVariant(nullability);
+    TypeLatticeElement finalMemberType = memberType;
+    return referenceTypeLatticeElements
+        .computeIfAbsent(
+            type,
+            t -> {
+              if (type.isClassType()) {
+                if (!type.isUnknown() && type.isInterface()) {
+                  return ClassTypeLatticeElement.create(
+                      objectType, nullability, ImmutableSet.of(type));
+                } else {
+                  // In theory, `interfaces` is the least upper bound of implemented interfaces.
+                  // It is expensive to walk through type hierarchy; collect implemented interfaces;
+                  // and compute the least upper bound of two interface sets. Hence, lazy
+                  // computations. Most likely during lattice join. See {@link
+                  // ClassTypeLatticeElement#getInterfaces}.
+                  return ClassTypeLatticeElement.create(type, nullability, definitions);
+                }
+              } else {
+                assert type.isArrayType();
+                return ArrayTypeLatticeElement.create(finalMemberType, nullability);
+              }
+            })
+        .getOrCreateVariant(nullability);
   }
 
   private static <S extends PresortedComparable<S>> void assignSortedIndices(Collection<S> items,
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/type/ArrayTypeLatticeElement.java b/src/main/java/com/android/tools/r8/ir/analysis/type/ArrayTypeLatticeElement.java
index 1c16837..b610850 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/type/ArrayTypeLatticeElement.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/type/ArrayTypeLatticeElement.java
@@ -14,10 +14,24 @@
 
   private final TypeLatticeElement memberTypeLattice;
 
-  public ArrayTypeLatticeElement(
+  // On-demand link between other nullability-variants.
+  private final NullabilityVariants<ArrayTypeLatticeElement> variants;
+
+  public static ArrayTypeLatticeElement create(
       TypeLatticeElement memberTypeLattice, Nullability nullability) {
-    super(nullability, null);
+    return NullabilityVariants.create(
+        nullability,
+        (variants) -> new ArrayTypeLatticeElement(memberTypeLattice, nullability, variants));
+  }
+
+  private ArrayTypeLatticeElement(
+      TypeLatticeElement memberTypeLattice,
+      Nullability nullability,
+      NullabilityVariants<ArrayTypeLatticeElement> variants) {
+    super(nullability);
+    assert memberTypeLattice.isPrimitive() || memberTypeLattice.nullability().isMaybeNull();
     this.memberTypeLattice = memberTypeLattice;
+    this.variants = variants;
   }
 
   public DexType getArrayType(DexItemFactory factory) {
@@ -58,22 +72,29 @@
     return base;
   }
 
-  @Override
-  ReferenceTypeLatticeElement createVariant(Nullability nullability) {
-    if (this.nullability == nullability) {
-      return this;
-    }
-    return new ArrayTypeLatticeElement(memberTypeLattice, nullability);
+  private ArrayTypeLatticeElement createVariant(
+      Nullability nullability, NullabilityVariants<ArrayTypeLatticeElement> variants) {
+    assert this.nullability != nullability;
+    return new ArrayTypeLatticeElement(memberTypeLattice, nullability, variants);
   }
 
   @Override
   public TypeLatticeElement asNullable() {
-    return nullability.isNullable() ? this : getOrCreateVariant(maybeNull());
+    return getOrCreateVariant(maybeNull());
+  }
+
+  @Override
+  public ReferenceTypeLatticeElement getOrCreateVariant(Nullability nullability) {
+    ArrayTypeLatticeElement variant = variants.get(nullability);
+    if (variant != null) {
+      return variant;
+    }
+    return variants.getOrCreateElement(nullability, this::createVariant);
   }
 
   @Override
   public TypeLatticeElement asNonNullable() {
-    return nullability.isDefinitelyNotNull() ? this : getOrCreateVariant(definitelyNotNull());
+    return getOrCreateVariant(definitelyNotNull());
   }
 
   @Override
@@ -93,7 +114,7 @@
 
   @Override
   public String toString() {
-    return memberTypeLattice.toString() + "[]";
+    return nullability.toString() + " (" + memberTypeLattice.toString() + "[])";
   }
 
   @Override
@@ -108,9 +129,6 @@
     if (nullability() != other.nullability()) {
       return false;
     }
-    if (type != null && other.type != null && !type.equals(other.type)) {
-      return false;
-    }
     return memberTypeLattice.equals(other.memberTypeLattice);
   }
 
@@ -121,31 +139,50 @@
 
   ReferenceTypeLatticeElement join(
       ArrayTypeLatticeElement other, DexDefinitionSupplier definitions) {
-    TypeLatticeElement aMember = getArrayMemberTypeAsMemberType();
-    TypeLatticeElement bMember = other.getArrayMemberTypeAsMemberType();
+    Nullability nullability = nullability().join(other.nullability());
+    ReferenceTypeLatticeElement join =
+        joinMember(this.memberTypeLattice, other.memberTypeLattice, definitions, nullability);
+    if (join == null) {
+      // Check if other has the right nullability before creating it.
+      if (other.nullability == nullability) {
+        return other;
+      } else {
+        return getOrCreateVariant(nullability);
+      }
+    } else {
+      assert join.nullability == nullability;
+      return join;
+    }
+  }
+
+  private static ReferenceTypeLatticeElement joinMember(
+      TypeLatticeElement aMember,
+      TypeLatticeElement bMember,
+      DexDefinitionSupplier definitions,
+      Nullability nullability) {
     if (aMember.equals(bMember)) {
       // Return null indicating the join is the same as the member to avoid object allocation.
       return null;
     }
-    Nullability nullability = nullability().join(other.nullability());
     if (aMember.isArrayType() && bMember.isArrayType()) {
-      ReferenceTypeLatticeElement join =
-          aMember
-              .asArrayTypeLatticeElement()
-              .join(bMember.asArrayTypeLatticeElement(), definitions);
-      return join == null ? null : new ArrayTypeLatticeElement(join, nullability);
+      TypeLatticeElement join =
+          joinMember(
+              aMember.asArrayTypeLatticeElement().memberTypeLattice,
+              bMember.asArrayTypeLatticeElement().memberTypeLattice,
+              definitions,
+              maybeNull());
+      return join == null ? null : ArrayTypeLatticeElement.create(join, nullability);
     }
     if (aMember.isClassType() && bMember.isClassType()) {
-      ClassTypeLatticeElement join =
+      ReferenceTypeLatticeElement join =
           aMember
               .asClassTypeLatticeElement()
               .join(bMember.asClassTypeLatticeElement(), definitions);
-      return join == null ? null : new ArrayTypeLatticeElement(join, nullability);
+      return ArrayTypeLatticeElement.create(join, nullability);
     }
     if (aMember.isPrimitive() || bMember.isPrimitive()) {
-      return objectClassType(definitions, nullability);
+      return aMember.objectClassType(definitions, nullability);
     }
-    return objectArrayType(definitions, nullability);
+    return aMember.objectArrayType(definitions, nullability);
   }
-
 }
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/type/ClassTypeLatticeElement.java b/src/main/java/com/android/tools/r8/ir/analysis/type/ClassTypeLatticeElement.java
index 09dda42..257c664 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/type/ClassTypeLatticeElement.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/type/ClassTypeLatticeElement.java
@@ -22,33 +22,44 @@
 
   private Set<DexType> lazyInterfaces;
   private DexDefinitionSupplier definitionsForLazyInterfacesComputation;
+  // On-demand link between other nullability-variants.
+  private final NullabilityVariants<ClassTypeLatticeElement> variants;
+  private final DexType type;
 
-  public ClassTypeLatticeElement(
+  public static ClassTypeLatticeElement create(
       DexType classType, Nullability nullability, Set<DexType> interfaces) {
-    this(classType, nullability, interfaces, null);
+    return NullabilityVariants.create(
+        nullability,
+        (variants) ->
+            new ClassTypeLatticeElement(classType, nullability, interfaces, variants, null));
   }
 
-  public ClassTypeLatticeElement(
+  public static ClassTypeLatticeElement create(
       DexType classType, Nullability nullability, DexDefinitionSupplier definitions) {
-    this(classType, nullability, null, definitions);
+    return NullabilityVariants.create(
+        nullability,
+        (variants) ->
+            new ClassTypeLatticeElement(classType, nullability, null, variants, definitions));
   }
 
   private ClassTypeLatticeElement(
       DexType classType,
       Nullability nullability,
       Set<DexType> interfaces,
+      NullabilityVariants<ClassTypeLatticeElement> variants,
       DexDefinitionSupplier definitions) {
-    super(nullability, classType);
+    super(nullability);
     assert classType.isClassType();
+    type = classType;
     definitionsForLazyInterfacesComputation = definitions;
     lazyInterfaces = interfaces;
+    this.variants = variants;
   }
 
   public DexType getClassType() {
     return type;
   }
 
-  @Override
   public Set<DexType> getInterfaces() {
     if (lazyInterfaces != null) {
       return lazyInterfaces;
@@ -64,23 +75,30 @@
     return lazyInterfaces;
   }
 
-  @Override
-  ReferenceTypeLatticeElement createVariant(Nullability nullability) {
-    if (this.nullability == nullability) {
-      return this;
-    }
+  private ClassTypeLatticeElement createVariant(
+      Nullability nullability, NullabilityVariants<ClassTypeLatticeElement> variants) {
+    assert this.nullability != nullability;
     return new ClassTypeLatticeElement(
-        type, nullability, lazyInterfaces, definitionsForLazyInterfacesComputation);
+        type, nullability, lazyInterfaces, variants, definitionsForLazyInterfacesComputation);
   }
 
   @Override
   public TypeLatticeElement asNullable() {
-    return nullability.isNullable() ? this : getOrCreateVariant(maybeNull());
+    return getOrCreateVariant(maybeNull());
+  }
+
+  @Override
+  public ReferenceTypeLatticeElement getOrCreateVariant(Nullability nullability) {
+    ClassTypeLatticeElement variant = variants.get(nullability);
+    if (variant != null) {
+      return variant;
+    }
+    return variants.getOrCreateElement(nullability, this::createVariant);
   }
 
   @Override
   public TypeLatticeElement asNonNullable() {
-    return nullability.isDefinitelyNotNull() ? this : getOrCreateVariant(definitelyNotNull());
+    return getOrCreateVariant(definitelyNotNull());
   }
 
   @Override
@@ -103,7 +121,9 @@
   @Override
   public String toString() {
     StringBuilder builder = new StringBuilder();
-    builder.append(super.toString());
+    builder.append(nullability);
+    builder.append(" ");
+    builder.append(type);
     builder.append(" {");
     builder.append(
         getInterfaces().stream().map(DexType::toString).collect(Collectors.joining(", ")));
@@ -130,7 +150,7 @@
       lubItfs = computeLeastUpperBoundOfInterfaces(definitions, c1lubItfs, c2lubItfs);
     }
     Nullability nullability = nullability().join(other.nullability());
-    return new ClassTypeLatticeElement(lubType, nullability, lubItfs);
+    return ClassTypeLatticeElement.create(lubType, nullability, lubItfs);
   }
 
   private enum InterfaceMarker {
@@ -228,4 +248,27 @@
     }
     return lub;
   }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) {
+      return true;
+    }
+    if (!(o instanceof ClassTypeLatticeElement)) {
+      return false;
+    }
+    ClassTypeLatticeElement other = (ClassTypeLatticeElement) o;
+    if (nullability() != other.nullability()) {
+      return false;
+    }
+    if (!type.equals(other.type)) {
+      return false;
+    }
+    Set<DexType> thisInterfaces = getInterfaces();
+    Set<DexType> otherInterfaces = other.getInterfaces();
+    if (thisInterfaces.size() != otherInterfaces.size()) {
+      return false;
+    }
+    return thisInterfaces.containsAll(otherInterfaces);
+  }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/type/NullabilityVariants.java b/src/main/java/com/android/tools/r8/ir/analysis/type/NullabilityVariants.java
new file mode 100644
index 0000000..e9baec4
--- /dev/null
+++ b/src/main/java/com/android/tools/r8/ir/analysis/type/NullabilityVariants.java
@@ -0,0 +1,63 @@
+// Copyright (c) 2019, the R8 project authors. Please see the AUTHORS file
+// for details. All rights reserved. Use of this source code is governed by a
+// BSD-style license that can be found in the LICENSE file.
+
+package com.android.tools.r8.ir.analysis.type;
+
+import java.util.function.BiFunction;
+import java.util.function.Function;
+
+public class NullabilityVariants<T extends ReferenceTypeLatticeElement> {
+
+  private T maybeNullVariant;
+  private T definitelyNullVariant;
+  private T definitelyNotNullVariant;
+
+  public static <T extends ReferenceTypeLatticeElement> T create(
+      Nullability nullability, Function<NullabilityVariants<T>, T> callback) {
+    NullabilityVariants<T> variants = new NullabilityVariants<>();
+    T newElement = callback.apply(variants);
+    variants.set(nullability, newElement);
+    return newElement;
+  }
+
+  private void set(Nullability nullability, T element) {
+    if (nullability == Nullability.maybeNull()) {
+      maybeNullVariant = element;
+    } else if (nullability == Nullability.definitelyNull()) {
+      definitelyNullVariant = element;
+    } else {
+      assert nullability == Nullability.definitelyNotNull();
+      definitelyNotNullVariant = element;
+    }
+  }
+
+  T get(Nullability nullability) {
+    if (nullability == Nullability.maybeNull()) {
+      return maybeNullVariant;
+    } else if (nullability == Nullability.definitelyNull()) {
+      return definitelyNullVariant;
+    } else {
+      assert nullability == Nullability.definitelyNotNull();
+      return definitelyNotNullVariant;
+    }
+  }
+
+  T getOrCreateElement(
+      Nullability nullability, BiFunction<Nullability, NullabilityVariants<T>, T> creator) {
+    T element = get(nullability);
+    if (element != null) {
+      return element;
+    }
+    synchronized (this) {
+      element = get(nullability);
+      if (element != null) {
+        return element;
+      }
+      element = creator.apply(nullability, this);
+      assert element != null;
+      set(nullability, element);
+      return element;
+    }
+  }
+}
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/type/ReferenceTypeLatticeElement.java b/src/main/java/com/android/tools/r8/ir/analysis/type/ReferenceTypeLatticeElement.java
index 2baaa52..a7ae56c 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/type/ReferenceTypeLatticeElement.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/type/ReferenceTypeLatticeElement.java
@@ -5,76 +5,57 @@
 
 import com.android.tools.r8.errors.Unreachable;
 import com.android.tools.r8.graph.DexItemFactory;
-import com.android.tools.r8.graph.DexType;
-import java.util.Collections;
-import java.util.Set;
 
-public class ReferenceTypeLatticeElement extends TypeLatticeElement {
-  private static final ReferenceTypeLatticeElement NULL_INSTANCE =
-      new ReferenceTypeLatticeElement(
-          Nullability.definitelyNull(), DexItemFactory.nullValueType);
+public abstract class ReferenceTypeLatticeElement extends TypeLatticeElement {
 
-  // TODO(b/72693244): Consider moving this to ClassTypeLatticeElement.
-  final DexType type;
+  private static class NullLatticeElement extends ReferenceTypeLatticeElement {
+
+    NullLatticeElement(Nullability nullability) {
+      super(nullability);
+    }
+
+    @Override
+    public ReferenceTypeLatticeElement getOrCreateVariant(Nullability nullability) {
+      throw new Unreachable("This should not be called on NullLaticeElement");
+    }
+
+    static NullLatticeElement create() {
+      return new NullLatticeElement(Nullability.definitelyNull());
+    }
+
+    @Override
+    public boolean isNullType() {
+      return true;
+    }
+
+    @Override
+    public String toString() {
+      return nullability.toString() + " " + DexItemFactory.nullValueType.toString();
+    }
+
+    @Override
+    public int hashCode() {
+      return System.identityHashCode(this);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (!(o instanceof NullLatticeElement)) {
+        return false;
+      }
+      return true;
+    }
+  }
+
+  private static final ReferenceTypeLatticeElement NULL_INSTANCE = NullLatticeElement.create();
 
   final Nullability nullability;
-  // On-demand link between maybe-null (primary) and definitely-null reference type lattices.
-  private ReferenceTypeLatticeElement primaryOrNullVariant;
-  // On-demand link between maybe-null (primary) and definitely-not-null reference type lattices.
-  // This link will be null for non-primary variants.
-  private ReferenceTypeLatticeElement nonNullVariant;
 
-  ReferenceTypeLatticeElement(Nullability nullability, DexType type) {
+  ReferenceTypeLatticeElement(Nullability nullability) {
     this.nullability = nullability;
-    this.type = type;
-  }
-
-  public ReferenceTypeLatticeElement getOrCreateVariant(Nullability variantNullability) {
-    if (nullability == variantNullability) {
-      return this;
-    }
-    ReferenceTypeLatticeElement primary = nullability.isMaybeNull() ? this : primaryOrNullVariant;
-    synchronized (this) {
-      // If the link towards the factory-created, canonicalized MAYBE_NULL variant doesn't exist,
-      // we are in the middle of join() computation.
-      if (primary == null) {
-        primary = createVariant(Nullability.maybeNull());
-        linkVariant(primary, this);
-      }
-    }
-    if (variantNullability.isMaybeNull()) {
-      return primary;
-    }
-    synchronized (primary) {
-      ReferenceTypeLatticeElement variant =
-          variantNullability.isDefinitelyNull()
-              ? primary.primaryOrNullVariant
-              : primary.nonNullVariant;
-      if (variant == null) {
-        variant = createVariant(variantNullability);
-        linkVariant(primary, variant);
-      }
-      return variant;
-    }
-  }
-
-  ReferenceTypeLatticeElement createVariant(Nullability nullability) {
-    throw new Unreachable("Should be defined by class/array type lattice element");
-  }
-
-  private static void linkVariant(
-      ReferenceTypeLatticeElement primary, ReferenceTypeLatticeElement variant) {
-    assert primary.nullability().isMaybeNull();
-    assert variant.primaryOrNullVariant == null && variant.nonNullVariant == null;
-    variant.primaryOrNullVariant = primary;
-    if (variant.nullability().isDefinitelyNotNull()) {
-      assert primary.nonNullVariant == null;
-      primary.nonNullVariant = variant;
-    } else {
-      assert variant.nullability().isDefinitelyNull();
-      assert primary.primaryOrNullVariant == null;
-      primary.primaryOrNullVariant = variant;
-    }
   }
 
   @Override
@@ -86,20 +67,7 @@
     return NULL_INSTANCE;
   }
 
-  public Set<DexType> getInterfaces() {
-    return Collections.emptySet();
-  }
-
-  @Override
-  public boolean isNullType() {
-    return type == DexItemFactory.nullValueType;
-  }
-
-  @Override
-  public TypeLatticeElement asNullable() {
-    assert isNullType();
-    return this;
-  }
+  public abstract ReferenceTypeLatticeElement getOrCreateVariant(Nullability nullability);
 
   @Override
   public boolean isReference() {
@@ -107,36 +75,12 @@
   }
 
   @Override
-  public String toString() {
-    return nullability.toString() + " " + type.toString();
+  public ReferenceTypeLatticeElement asReferenceTypeLatticeElement() {
+    return this;
   }
 
   @Override
   public boolean equals(Object o) {
-    if (this == o) {
-      return true;
-    }
-    if (!(o instanceof ReferenceTypeLatticeElement)) {
-      return false;
-    }
-    ReferenceTypeLatticeElement other = (ReferenceTypeLatticeElement) o;
-    if (nullability() != other.nullability()) {
-      return false;
-    }
-    if (!type.equals(other.type)) {
-      return false;
-    }
-    Set<DexType> thisInterfaces = getInterfaces();
-    Set<DexType> otherInterfaces = other.getInterfaces();
-    if (thisInterfaces.size() != otherInterfaces.size()) {
-      return false;
-    }
-    return thisInterfaces.containsAll(otherInterfaces);
-  }
-
-  @Override
-  public int hashCode() {
-    assert isNullType();
-    return System.identityHashCode(this);
+    throw new Unreachable("Should be implemented on each sub type");
   }
 }
diff --git a/src/main/java/com/android/tools/r8/ir/analysis/type/TypeLatticeElement.java b/src/main/java/com/android/tools/r8/ir/analysis/type/TypeLatticeElement.java
index 6384088..a438264 100644
--- a/src/main/java/com/android/tools/r8/ir/analysis/type/TypeLatticeElement.java
+++ b/src/main/java/com/android/tools/r8/ir/analysis/type/TypeLatticeElement.java
@@ -72,12 +72,6 @@
     if (isTop() || other.isTop()) {
       return TOP;
     }
-    if (isNullType()) {
-      return other.asNullable();
-    }
-    if (other.isNullType()) {
-      return asNullable();
-    }
     if (isPrimitive()) {
       return other.isPrimitive()
           ? asPrimitiveTypeLatticeElement().join(other.asPrimitiveTypeLatticeElement())
@@ -90,15 +84,20 @@
     // From now on, this and other are precise reference types, i.e., either ArrayType or ClassType.
     assert isReference() && other.isReference();
     assert isPreciseType() && other.isPreciseType();
+    Nullability nullabilityJoin = nullability().join(other.nullability());
+    if (isNullType()) {
+      return other.asReferenceTypeLatticeElement().getOrCreateVariant(nullabilityJoin);
+    }
+    if (other.isNullType()) {
+      return this.asReferenceTypeLatticeElement().getOrCreateVariant(nullabilityJoin);
+    }
     if (getClass() != other.getClass()) {
-      return objectClassType(definitions, nullability().join(other.nullability()));
+      return objectClassType(definitions, nullabilityJoin);
     }
     // From now on, getClass() == other.getClass()
     if (isArrayType()) {
       assert other.isArrayType();
-      TypeLatticeElement join =
-          asArrayTypeLatticeElement().join(other.asArrayTypeLatticeElement(), definitions);
-      return join != null ? join : (isNullable() ? this : other);
+      return asArrayTypeLatticeElement().join(other.asArrayTypeLatticeElement(), definitions);
     }
     if (isClassType()) {
       assert other.isClassType();
@@ -175,6 +174,10 @@
     return false;
   }
 
+  public ReferenceTypeLatticeElement asReferenceTypeLatticeElement() {
+    return null;
+  }
+
   public boolean isArrayType() {
     return false;
   }