Spaces:
Running
Running
File size: 7,902 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 |
# mypy: ignore-errors
import functools
import operator
from collections import defaultdict
from typing import Dict, List, Optional
import torch
from torch._dynamo.backends.debugging import boxed_nop
from torch._inductor.cudagraph_trees import cudagraphify_impl
from torch._inductor.cudagraph_utils import (
BoxedDeviceIndex,
check_multiple_devices_or_any_cpu_nodes,
get_mutation_stack_trace,
)
from torch._inductor.utils import (
BoxedBool,
count_tangents,
has_incompatible_cudagraph_ops,
num_fw_fixed_arguments,
output_node,
)
from torch.multiprocessing.reductions import StorageWeakRef
from .common import aot_autograd
from .registry import register_backend
perf_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
def find_input_mutations(g):
def meta_fk(meta):
return meta["val"] if "val" in meta else meta["fake_result"]
inputs = defaultdict(set)
input_idx = 0
mutated_inputs = set()
for n in g.nodes:
if n.op == "placeholder":
if isinstance(meta_fk(n.meta), torch.Tensor):
inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx)
input_idx += 1
elif n.op == "call_function":
if n.target is operator.getitem:
continue
schema = n.target._schema
for i, arg in enumerate(schema.arguments):
if i < len(n.args):
argument = n.args[i]
else:
if arg.name not in n.kwargs:
continue
argument = n.kwargs[arg.name]
mut_arg = False
if arg.alias_info:
if arg.alias_info.is_write:
mut_arg = True
if mut_arg:
# TODO: not correct for args that contain tensors in a struct
# like list
mutated_inputs |= inputs[
StorageWeakRef(meta_fk(argument.meta)._typed_storage())
]
# TODO: error on unrecognized nodes
return mutated_inputs
def get_device_node_mapping(gm: torch.fx.GraphModule):
device_node_mapping: Dict[torch.device, torch.fx.Node] = {}
for n in gm.graph.nodes:
t = n.meta.get("val", None)
if isinstance(t, torch.Tensor) and t.device not in device_node_mapping:
device_node_mapping[t.device] = n
return device_node_mapping
def check_for_mutation(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]:
mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed))
if not mutation_indices:
return None
return get_mutation_stack_trace(aot_model, mutation_indices)
def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]:
if mut_skip := check_for_mutation(aot_model, num_fixed):
return mut_skip
if skip := check_multiple_devices_or_any_cpu_nodes(
get_device_node_mapping(aot_model)
):
return skip
if has_incompatible_cudagraph_ops(aot_model):
return "skipping cudagraphs due to incompatible op"
return None
def get_device_index(gm) -> int:
device = next(iter(get_device_node_mapping(gm)))
assert device.type == "cuda"
return device.index
def get_stack_traces(gm) -> List[Optional[str]]:
output = output_node(gm)
assert len(output.args) == 1
return [
(arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
for arg in output.args[0]
]
def cudagraphs(dynamo_model, dynamo_inputs):
do_cudagraphs = BoxedBool(True)
boxed_device_index = BoxedDeviceIndex(None)
def forward_cudagraphs(aot_model, aot_inputs, is_inference=False):
interp = boxed_nop(aot_model, aot_inputs)
fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs))
if skip_msg := check_for_skip(aot_model, fixed):
BoxedBool.disable(do_cudagraphs)
perf_log.warning("skipping cudagraphs due to %s", skip_msg)
return interp
boxed_device_index.set(get_device_index(aot_model))
out = cudagraphify_impl(
interp,
aot_inputs,
range(fixed),
device_index=boxed_device_index.value,
is_backward=False,
is_inference=False,
stack_traces=get_stack_traces(aot_model),
)
out._boxed_call = True
return out
def backward_cudagraphs(aot_model, aot_inputs):
interp = boxed_nop(aot_model, aot_inputs)
if not do_cudagraphs:
return aot_model
fixed = count_tangents(aot_model)
if skip_msg := check_for_skip(aot_model, fixed):
perf_log.warning("skipping cudagraphs due to %s", skip_msg)
# See [Backward Generation Handling]
manager = torch._inductor.cudagraph_trees.get_manager(
boxed_device_index.value, create_if_none_exists=False
)
assert manager is not None
def fn(inputs):
manager.set_to_running_backward()
return aot_model(inputs)
fn._boxed_call = True
return fn
out = cudagraphify_impl(
interp,
aot_inputs,
range(fixed),
device_index=get_device_index(aot_model),
is_backward=True,
is_inference=False,
stack_traces=get_stack_traces(aot_model),
)
out._boxed_call = True
return out
aot_cudagraphs = aot_autograd(
fw_compiler=forward_cudagraphs,
bw_compiler=backward_cudagraphs,
inference_compiler=functools.partial(forward_cudagraphs, is_inference=True),
keep_inference_input_mutations=torch._dynamo.config.cudagraph_backend_keep_input_mutation,
)
return aot_cudagraphs(dynamo_model, dynamo_inputs)
class CudagraphsBackend:
compiler_name = "cudagraphs"
@staticmethod
def reset():
from torch._inductor.cudagraph_trees import reset_cudagraph_trees
reset_cudagraph_trees()
@staticmethod
def __call__(model, inputs):
return cudagraphs(model, inputs)
# aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful
# for debugging and can serve as a perf baseline.
register_backend(name="cudagraphs", compiler_fn=CudagraphsBackend())
def cudagraphs_inner(model, inputs, copy_outputs=True, copy_inputs=True):
"""This isn't registered as a backend, but is used in some benchmarks"""
assert isinstance(inputs, (list, tuple))
if copy_inputs:
static_inputs = [torch.zeros_like(x) for x in inputs]
else:
static_inputs = list(inputs)
# warmup
torch.cuda.synchronize()
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
model(*inputs)
stream.synchronize()
torch.cuda.current_stream().wait_stream(stream)
torch.cuda.synchronize()
# record
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
static_outputs = model(*static_inputs)
if not isinstance(static_outputs, (list, tuple)):
static_outputs = (static_outputs,)
def run(*new_inputs):
assert len(static_inputs) == len(new_inputs)
if copy_inputs:
for dst, src in zip(static_inputs, new_inputs):
dst.copy_(src)
graph.replay()
if copy_outputs:
return [x.clone() for x in static_outputs]
else:
return static_outputs
return run
|