Add memory-tracking to timing output

Bug: 131858815
Bug: 128350147
Change-Id: I1ce4bcef2dccc7dcff96351b3d61ac226b223c25
diff --git a/src/main/java/com/android/tools/r8/R8.java b/src/main/java/com/android/tools/r8/R8.java
index 2924ba8..47fe32e 100644
--- a/src/main/java/com/android/tools/r8/R8.java
+++ b/src/main/java/com/android/tools/r8/R8.java
@@ -130,11 +130,15 @@
 @Keep
 public class R8 {
 
-  private final Timing timing = new Timing("R8");
+  private final Timing timing;
   private final InternalOptions options;
 
   private R8(InternalOptions options) {
     this.options = options;
+    if (options.printMemory) {
+      System.gc();
+    }
+    this.timing = new Timing("R8", options.printMemory);
     options.itemFactory.resetSortedIndices();
   }
 
diff --git a/src/main/java/com/android/tools/r8/utils/InternalOptions.java b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
index 4d0de15..e93de6a 100644
--- a/src/main/java/com/android/tools/r8/utils/InternalOptions.java
+++ b/src/main/java/com/android/tools/r8/utils/InternalOptions.java
@@ -137,6 +137,8 @@
   }
 
   public boolean printTimes = System.getProperty("com.android.tools.r8.printtimes") != null;
+  // To print memory one also have to enable printtimes.
+  public boolean printMemory = System.getProperty("com.android.tools.r8.printmemory") != null;
 
   // Flag to toggle if DEX code objects should pass-through without IR processing.
   public boolean passthroughDexCode = false;
diff --git a/src/main/java/com/android/tools/r8/utils/Timing.java b/src/main/java/com/android/tools/r8/utils/Timing.java
index cbb5f6a..df575a6 100644
--- a/src/main/java/com/android/tools/r8/utils/Timing.java
+++ b/src/main/java/com/android/tools/r8/utils/Timing.java
@@ -13,44 +13,67 @@
 // Finally a report is printed by:
 //     t.report();
 
+import java.lang.management.ManagementFactory;
+import java.lang.management.MemoryMXBean;
+import java.lang.management.MemoryPoolMXBean;
+import java.util.ArrayList;
 import java.util.LinkedHashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.Stack;
 
 public class Timing {
 
   private final Stack<Node> stack;
+  private final boolean trackMemory;
 
   public Timing() {
     this("<no title>");
   }
 
   public Timing(String title) {
+    this(title, false);
+  }
+
+  public Timing(String title, boolean trackMemory) {
+    this.trackMemory = trackMemory;
     stack = new Stack<>();
     stack.push(new Node("Recorded timings for " + title));
   }
 
-  static class Node {
+  class Node {
     final String title;
 
     final Map<String, Node> children = new LinkedHashMap<>();
     long duration = 0;
     long start_time;
+    List<String> startMemory;
+    List<String> endMemory;
 
     Node(String title) {
       this.title = title;
       this.start_time = System.nanoTime();
+      if (trackMemory) {
+        startMemory = computeMemoryInformation();
+      }
     }
 
     void restart() {
       assert start_time == -1;
       start_time = System.nanoTime();
+      if (trackMemory) {
+        startMemory = computeMemoryInformation();
+      }
     }
 
     void end() {
       duration += System.nanoTime() - start_time;
       start_time = -1;
       assert duration() >= 0;
+      if (trackMemory) {
+        System.gc();
+        endMemory = computeMemoryInformation();
+      }
     }
 
     long duration() {
@@ -77,7 +100,41 @@
         System.out.print("- ");
       }
       System.out.println(toString(top));
+      System.out.println();
+      if (trackMemory) {
+        printMemoryStart(depth);
+        System.out.println();
+      }
       children.values().forEach(p -> p.report(depth + 1, top));
+      if (trackMemory) {
+        printMemoryEnd(depth);
+        System.out.println();
+      }
+    }
+
+    private void printMemoryStart(int depth) {
+      if (startMemory != null) {
+        printMemory(depth, title + "(Memory) Start: ", startMemory);
+      }
+    }
+
+    private void printMemoryEnd(int depth) {
+      if (endMemory != null) {
+        printMemory(depth, title + "(Memory) End: ", endMemory);
+      }
+    }
+
+    private void printMemory(int depth, String header, List<String> strings) {
+      for (int i = 0; i <= depth; i++) {
+        System.out.print("  ");
+      }
+      System.out.println(header);
+      for (String memoryInfo : strings) {
+        for (int i = 0; i <= depth; i++) {
+          System.out.print("  ");
+        }
+        System.out.println(memoryInfo);
+      }
     }
   }
 
@@ -119,4 +176,23 @@
   public interface TimingScope {
     void apply();
   }
+
+  private List<String> computeMemoryInformation() {
+    List<String> strings = new ArrayList<>();
+    strings.add(
+        "Free memory: "
+            + Runtime.getRuntime().freeMemory()
+            + "\tTotal memory: "
+            + Runtime.getRuntime().totalMemory()
+            + "\tMax memory: "
+            + Runtime.getRuntime().maxMemory());
+    MemoryMXBean memoryMXBean = ManagementFactory.getMemoryMXBean();
+    strings.add("Heap summary: " + memoryMXBean.getHeapMemoryUsage().toString());
+    strings.add("Non-heap summary: " + memoryMXBean.getNonHeapMemoryUsage().toString());
+    // Print out the memory information for all managed memory pools.
+    for (MemoryPoolMXBean memoryPoolMXBean : ManagementFactory.getMemoryPoolMXBeans()) {
+      strings.add(memoryPoolMXBean.getName() + ": " + memoryPoolMXBean.getUsage().toString());
+    }
+    return strings;
+  }
 }