Spaces:
Running
Running
File size: 9,688 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 |
# mypy: ignore-errors
import dataclasses
import functools
from importlib import import_module
from typing import Any, List, Optional
from functorch.compile import min_cut_rematerialization_partition
import torch
from torch import _guards
from torch._functorch.compilers import ts_compile
from .common import aot_autograd
from .registry import register_debug_backend as register_backend
"""
This file contains TorchDynamo backends intended for debugging uses.
"""
@register_backend
def eager(gm, fake_tensor_inputs):
return gm
@register_backend
def pre_dispatch_eager(gm, fake_tensor_inputs):
from torch.fx.experimental.proxy_tensor import make_fx
def runnable_gm(*args):
return torch.fx.Interpreter(gm).run(*args)
pre_dispatch_gm = make_fx(runnable_gm, pre_dispatch=True)(*fake_tensor_inputs)
pre_dispatch_gm.print_readable()
return pre_dispatch_gm
@register_backend
def eager_debug(gm, fake_tensor_inputs):
from torch._subclasses.schema_check_mode import SchemaCheckMode
# We could add more debugging bits here.
# Right now, this backend can be used to check for and error on
# custom dispatcher ops that have incorrect schemas.
def inner(*args):
with SchemaCheckMode():
return torch.fx.Interpreter(gm).run(*args)
return inner
@register_backend(name="ts")
def torchscript(gm, fake_tensor_inputs):
return torch.jit.script(gm)
# used boxed call to discard inputs when they are no longer needed
def boxed_nop(fx_g, example_inputs):
def run(args):
return torch.fx.Interpreter(fx_g).boxed_run(args)
run._boxed_call = True
return run
# Useful for debugging purpose
# aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging.
aot_eager = aot_autograd(
fw_compiler=boxed_nop, partition_fn=min_cut_rematerialization_partition
)
register_backend(name="aot_eager", compiler_fn=aot_eager)
aot_eager_default_partitioner = aot_autograd(fw_compiler=boxed_nop)
register_backend(
name="aot_eager_default_partitioner", compiler_fn=aot_eager_default_partitioner
)
# Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs
# inductor problems.
# aot_eager_decomp_partition just replaces the inductor compiler with nop to help
# isolate inductor vs aot_eager errors
aot_eager_decomp_partition = aot_autograd(
# these are taken from memory_efficient_fusion()
fw_compiler=boxed_nop,
bw_compiler=boxed_nop,
# NB: lambda here is to delay import of inductor
decompositions=lambda: import_module(
"torch._inductor.compile_fx"
).select_decomp_table(),
partition_fn=functools.partial(
min_cut_rematerialization_partition, compiler="inductor"
),
)
register_backend(
name="aot_eager_decomp_partition", compiler_fn=aot_eager_decomp_partition
)
# AOT Autograd with torchscript backend. Default partitioner.
# aot_ts uses torchscript backend. We can use this with both nnc and nvfuser
# by using the relevant fuser with torch.jit.fuser(...)
aot_ts = aot_autograd(fw_compiler=ts_compile)
register_backend(name="aot_ts", compiler_fn=aot_ts)
# These buggy backends are used for inducing bugs so that we can test
# our repro extraction / minifier scripts
class ReluCompileError(Exception):
pass
class TestingOnlyCompileError(Exception):
pass
@register_backend
def relu_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
for node in gm.graph.nodes:
if node.target == torch.relu:
raise ReluCompileError()
return gm
@register_backend
def relu_runtime_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
for node in gm.graph.nodes:
if node.target == torch.relu:
node.target = torch._assert
node.args = (False, "ReluRuntimeError")
gm.recompile()
return gm
@register_backend
def relu_accuracy_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
for node in gm.graph.nodes:
if node.target == torch.relu:
node.target = torch.add
node.args = (node.args[0], 1)
gm.recompile()
return gm
@register_backend
def non_leaf_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
# Require at least one non-trivial thing in the graph,
# see https://github.com/pytorch/pytorch/issues/102898
for node in gm.graph.nodes:
if node.op == "call_function":
break
else:
return gm
for t in example_inputs:
if not t.is_leaf:
raise TestingOnlyCompileError()
return gm
@dataclasses.dataclass
class ExplainOutput:
"""
This is the output of :func:`torch._dynamo.explain()`
There is no reason to create this class directly.
"""
graphs: List[torch.fx.GraphModule]
graph_count: int
graph_break_count: int
break_reasons: List[
Any
] # Type is GraphCompileReason but doesn't matter for this purpose
op_count: int
ops_per_graph: Optional[List[torch.fx.Node]] = None
out_guards: Optional[List[_guards.Guard]] = None
compile_times: Optional[str] = None
def __str__(self):
output = f"Graph Count: {self.graph_count}\n"
output += f"Graph Break Count: {self.graph_break_count}\n"
output += f"Op Count: {self.op_count}\n"
output += "Break Reasons:\n"
for idx, break_reason in enumerate(self.break_reasons):
output += f" Break Reason {idx+1}:\n"
output += f" Reason: {break_reason.reason}\n"
output += " User Stack:\n"
for frame_summary in break_reason.user_stack:
output += f" {frame_summary}\n"
if self.ops_per_graph is not None:
output += "Ops per Graph:\n"
for idx, ops in enumerate(self.ops_per_graph):
output += f" Ops {idx+1}:\n"
for op in ops:
output += f" {op}\n"
if self.out_guards is not None:
output += "Out Guards:\n"
for i, guard in enumerate(self.out_guards):
output += f" Guard {i+1}:\n"
output += f" {str(guard)}"
if self.compile_times is not None:
output += f"Compile Times: {self.compile_times}\n"
return output
def _explain_graph_detail(
gm: torch.fx.GraphModule, graphs, op_count, ops_per_graph, break_reasons
):
"""
This function is a utility which processes a torch.fx.GraphModule and
accumulates information about its ops, graph breaks, and other details. It
is intended to be used by the ExplainWithBackend class and
`torch._dynamo.explain()` to provide details from Dynamo's graph capture.
Parameters:
gm (torch.fx.GraphModule): The GraphModule to be processed.
graphs (list): A list that accumulates all the GraphModules processed.
op_count (int): The total count of operations in all GraphModules processed so far.
ops_per_graph (list): A list that accumulates the operations of each GraphModule.
break_reasons (list): A list that accumulates the reasons for breaks in each GraphModule.
Returns:
tuple: A tuple containing the processed GraphModule, the updated lists of graphs,
operations per graph, and break reasons, and the updated operation count.
"""
graphs.append(gm)
ops = [node.target for node in gm.graph.nodes if node.op == "call_function"]
op_count += len(ops)
ops_per_graph.append(ops)
if gm.compile_subgraph_reason.graph_break:
break_reasons.append(gm.compile_subgraph_reason)
return gm, graphs, op_count, ops_per_graph, break_reasons
class ExplainWithBackend:
"""
This class is intended to be used as a backend for `torch.compile`. It is
composable with other backends. When used in this way, it accumulates
information about graph breaks, ops, and other info and provides a string
representation summarizing this information.
Attributes:
backend (str): The name of the backend to use for optimization.
graphs (list): A list of the graphs captured by TorchDynamo.
op_count (int): The total number of operations in all optimized graphs.
break_reasons (list): A list of graph break reasons with stack traces.
Example Usage:
def fn(x):
x = torch.sigmoid(x)
return x
torch._dynamo.reset()
eb = ExplainWithBackend("inductor")
optimized_fn = torch.compile(fn, backend=eb)
result = optimized_fn(torch.randn(5))
print(eb.output())
"""
def __init__(self, backend):
from .registry import lookup_backend
self.backend = lookup_backend(backend)
self.graphs = []
self.op_count = 0
self.break_reasons = []
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
gm, self.graphs, self.op_count, _, self.break_reasons = _explain_graph_detail(
gm, self.graphs, self.op_count, [], self.break_reasons
)
return self.backend(gm, example_inputs)
def output(self) -> ExplainOutput:
graph_count = len(self.graphs)
output = ExplainOutput(
self.graphs,
graph_count,
graph_count - 1,
self.break_reasons,
self.op_count,
)
return output
|