Spaces:
Running
Running
File size: 5,006 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import dataclasses
import os
from typing import Any, List
import torch
from .utils import print_once
@dataclasses.dataclass
class ProfileMetrics:
microseconds: float = 0.0
operators: int = 0
fusions: int = 0
graphs: int = 0
def __iadd__(self, other: "ProfileMetrics"):
self.microseconds += other.microseconds
self.operators += other.operators
self.fusions += other.fusions
return self
def __add__(self, other: "ProfileMetrics"):
assert isinstance(other, ProfileMetrics)
return ProfileMetrics(
self.microseconds + other.microseconds,
self.operators + other.operators,
self.fusions + other.fusions,
)
def __truediv__(self, other):
if isinstance(other, int):
other = ProfileMetrics(other, other, other)
return ProfileMetrics(
self.microseconds / max(1, other.microseconds),
self.operators / max(1, other.operators),
self.fusions / max(1, other.fusions),
)
def __str__(self):
return f"{self.operators:4.0%} ops {self.microseconds:4.0%} time"
def tocsv(self):
return [self.operators, self.microseconds]
class ProfileResult:
def __init__(self, captured, total, unique_graphs):
self.captured: ProfileMetrics = captured or ProfileMetrics()
self.total: ProfileMetrics = total or ProfileMetrics()
self.unique_graphs: int = unique_graphs
def __iadd__(self, other: "ProfileResult"):
self.captured += other.captured
self.total += other.total
self.unique_graphs += other.unique_graphs
return self
def percent(self):
return self.captured / self.total
def __str__(self):
return (
f"{self.unique_graphs:2} graphs {self.captured.graphs:2} graph calls "
f"{self.captured.operators:4}/{self.total.operators:4} = "
+ str(self.percent())
)
def tocsv(self):
return [
self.unique_graphs,
self.captured.graphs,
self.captured.operators,
self.total.operators,
] + self.percent().tocsv()
def should_print_missing():
return os.environ.get("TORCHDYNAMO_PRINT_MISSING") == "1"
def print_missing(stack):
if any("/torch/autograd/profiler.py" in x for x in stack):
return
stack = [
x for x in stack if ("<built-in" not in x and "site-packages/torch/" not in x)
]
print_once("MISSING", " >> ".join(stack[-3:]))
class Profiler:
unique_graphs = 0
def __init__(self):
self.prof = torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU],
with_stack=should_print_missing(),
)
def results(self):
captured_regions = 0
captured_ops = 0
captured_microseconds = 0
total_ops = 0
total_microseconds = 0
last_op_end_time = -1
captured_region_end_time = -1
events = sorted(self.prof.events(), key=lambda x: x.time_range.start)
for e in events:
if e.name == "TORCHDYNAMO":
captured_region_end_time = e.time_range.end
captured_regions += 1
# ignore `handle = torch.zeros(1)` in record_function.__init__()
total_ops -= 1
elif e.time_range.start >= last_op_end_time:
last_op_end_time = e.time_range.end
if e.time_range.end <= captured_region_end_time:
captured_ops += 1
captured_microseconds += e.time_range.elapsed_us()
elif should_print_missing():
print_missing(e.stack)
total_ops += 1
total_microseconds += e.time_range.elapsed_us()
else:
pass # ops recursively called from other ops (ignored)
unique_graphs = Profiler.unique_graphs
Profiler.unique_graphs = 0
# we counted one extra op that is part of the profiler setup code
total_ops -= 1
return ProfileResult(
captured=ProfileMetrics(
microseconds=captured_microseconds,
operators=captured_ops,
fusions=captured_ops - captured_regions,
graphs=captured_regions,
),
total=ProfileMetrics(
microseconds=total_microseconds,
operators=total_ops,
fusions=total_ops - 1,
),
unique_graphs=unique_graphs,
)
def fx_insert_profiling(gm: torch.fx.GraphModule, example_inputs: List[Any]):
def _wrapped(*args):
with torch.profiler.record_function("TORCHDYNAMO"):
return gm.forward(*args)
Profiler.unique_graphs += 1
return _wrapped
|