Spaces:
Running
Running
from __future__ import annotations | |
import itertools | |
import logging | |
import weakref | |
from typing import Any, List, Optional, Tuple | |
import torch | |
import torch.utils._pytree as pytree | |
from torch._dynamo.utils import dynamo_timed, lazy_format_graph_code | |
from torch._functorch.aot_autograd import MutationType | |
from torch._functorch.compile_utils import fx_graph_cse | |
from torch._inductor.constant_folding import constant_fold, replace_node_with_constant | |
from torch._inductor.fx_passes.freezing_patterns import freezing_passes | |
from torch._inductor.fx_passes.post_grad import view_to_reshape | |
from . import config | |
aten = torch.ops.aten | |
prims = torch.ops.prims | |
log = logging.getLogger(__name__) | |
def replace_params_with_constants( | |
gm: torch.fx.GraphModule, | |
flat_params: list[Any], | |
fw_metadata: torch._functorch.aot_autograd.ViewAndMutationMeta, | |
) -> List[int]: | |
""" | |
Replaces the parameters of a PyTorch GraphModule with constants wherever possible. | |
Returns a list of indices representing the input parameters that were not converted to constants. | |
""" | |
params = [node for node in gm.graph.nodes if node.op == "placeholder"] | |
fake_inp_nodes = params[: len(params)] | |
preserved_arg_indices = [] | |
aliased_input_args = [ | |
out_info.base_idx | |
for out_info in fw_metadata.output_info | |
if out_info.base_idx is not None | |
] | |
# TODO (tmanlaibaatar) figure out why this is different | |
# from mutated_inp_runtime_indices | |
mutated_inps = [ | |
i | |
for i, m in enumerate(fw_metadata.input_info) | |
if m.mutation_type | |
in (MutationType.MUTATED_IN_GRAPH, MutationType.MUTATED_OUT_GRAPH) | |
] | |
for i, (real_input, node) in enumerate(zip(flat_params, fake_inp_nodes)): | |
if i in mutated_inps or i in aliased_input_args: | |
preserved_arg_indices.append(i) | |
continue | |
replace_node_with_constant(gm, node, real_input) | |
# add on non param inputs | |
preserved_arg_indices.extend(range(len(flat_params), len(params))) | |
# is this necessary ? | |
gm.recompile() | |
return preserved_arg_indices | |
def freeze( | |
dynamo_gm: torch.fx.GraphModule, | |
aot_autograd_gm: torch.fx.GraphModule, | |
example_inputs: List[torch._subclasses.FakeTensor], | |
) -> Tuple[torch.fx.GraphModule, List[int]]: | |
""" | |
Inlines parameters that are not mutated into constants and optimizes the graph through constant propagation | |
and other techniques. If enabled, the function also discards the original parameters of the module for memory efficiency. | |
Assumes that this function is run in dynamo tracing post aot_autograd. | |
Args: | |
dynamo_gm (torch.fx.GraphModule): The Dynamo constructed GraphModule. | |
aot_autograd_gm (torch.fx.GraphModule): The aot_autograd constructed GraphModule to be frozen. | |
example_inputs (List[torch.Tensor]): A list of example input tensors to be used in the freezing process. | |
Returns: | |
Tuple[torch.fx.GraphModule, List[int]]: A tuple containing the frozen GraphModule and a list of indices | |
of the inputs that were preserved (not turned into constants). | |
""" | |
# We have convert conv's weight to channels last which may meet error for .view | |
# when doing fake_tensor_prop. So we need to convert view to reshape first. | |
# See the details in fx_codegen_and_compile of compile_fx.py. | |
view_to_reshape(aot_autograd_gm) | |
if tracing_context := torch._guards.TracingContext.try_get(): | |
fw_metadata = tracing_context.fw_metadata | |
params_flat = tracing_context.params_flat | |
assert fw_metadata is not None and params_flat is not None | |
preserved_arg_indices = replace_params_with_constants( | |
aot_autograd_gm, params_flat, fw_metadata | |
) | |
else: | |
inputs = [ | |
node for node in aot_autograd_gm.graph.nodes if node.op == "placeholder" | |
] | |
preserved_arg_indices = list(range(len(inputs))) | |
# TODO - further restrict cse ? right now needed to dedup aliasing ops | |
cse_graph = fx_graph_cse(aot_autograd_gm.graph) | |
aot_autograd_gm.graph = cse_graph | |
aot_autograd_gm.recompile() | |
aot_example_inputs = [example_inputs[ind] for ind in preserved_arg_indices] | |
freezing_passes(aot_autograd_gm, aot_example_inputs) | |
constant_fold(aot_autograd_gm) | |
# invalidate nn Modules | |
if config.freezing_discard_parameters: | |
invalidate_eager_modules() | |
discard_traced_gm_params(dynamo_gm) | |
log.debug("%s", lazy_format_graph_code("FROZEN GRAPH", aot_autograd_gm)) | |
return aot_autograd_gm, preserved_arg_indices | |
class ErasedTensor(torch.Tensor): | |
def __new__(cls, elem, name, owning_mod): | |
return super().__new__(cls, elem.to(device="meta")) | |
def __init__(self, elem, name: Optional[str], mod): | |
self.erased_name = name | |
self.owning_mod_ref = weakref.ref(mod) | |
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): | |
erased_tensors = [ | |
e | |
for e in pytree.arg_tree_leaves(*args, **kwargs) | |
if isinstance(e, ErasedTensor) | |
] | |
assert len(erased_tensors) > 0 | |
e = erased_tensors[0] | |
raise RuntimeError( | |
f"Trying to run Pytorch Eager Module after Dynamo Freezing. " | |
"The original parameters have been discarded for memory efficiency. " | |
f"Found in op {func} for erased parameter {e.erased_name} of {e.owning_mod_ref()}" | |
) | |
def invalidate_eager_modules(): | |
for mod in torch._guards.TracingContext.get().module_context.nn_modules.values(): | |
if not isinstance(mod, torch.nn.Module): | |
continue | |
for attr_name, tensor in list( | |
itertools.chain( | |
mod.named_parameters(recurse=False), mod.named_buffers(recurse=False) | |
) | |
): | |
with torch._dispatch.python.no_python_dispatcher(): | |
e_t = ErasedTensor(tensor, attr_name, mod) | |
if isinstance(tensor, torch.nn.Parameter): | |
e_t.requires_grad_(True) | |
e_t._is_param = True # type: ignore[attr-defined] | |
setattr(mod, attr_name, e_t) | |
def discard_traced_gm_params(mod: torch.fx.GraphModule): | |
for attr_name, tensor in list( | |
itertools.chain( | |
mod.named_parameters(recurse=False), mod.named_buffers(recurse=False) | |
) | |
): | |
with torch._dispatch.python.no_python_dispatcher(): | |
e_t = ErasedTensor(tensor, attr_name, mod) | |
if isinstance(tensor, torch.nn.Parameter): | |
e_t.requires_grad_(True) | |
e_t._is_param = True # type: ignore[attr-defined] | |
setattr(mod, attr_name, e_t) | |
def enforce_output_layout(gm: torch.fx.GraphModule): | |
""" | |
Make sure the output node's layout does not change due to compiler optimizations | |
by adding aten.as_strided nodes with the expected strides. | |
Only used for inference so we can assume all graph outputs are model outputs. | |
""" | |
*_, output_node = gm.graph.nodes | |
out_list = output_node.args[0] | |
with gm.graph.inserting_before(output_node): | |
for n in out_list: | |
if not isinstance( | |
n.meta["val"], torch.Tensor | |
) or not torch._prims_common.is_non_overlapping_and_dense(n.meta["val"]): | |
continue | |
# add a node to enforce eager layout | |
ft = n.meta["val"] | |
new_node = gm.graph.call_function( | |
prims.inductor_force_stride_order.default, (n, ft.stride()) | |
) | |
# can not call | |
# n.replace_all_uses_with(new_node) | |
# since it will replace the usage of n in new_node itself. | |
output_node.replace_input_with(n, new_node) | |
gm.graph.lint() | |
gm.recompile() | |
def enforce_as_strided_input_layout(gm: torch.fx.GraphModule): | |
""" | |
Make sure the as_strided node's input's layout does not change due to compiler | |
optimizations, because the as_strided strides info depends on input tensor stride info. | |
""" | |
as_strided_ops = [ | |
torch.ops.aten.as_strided.default, | |
torch.ops.aten.as_strided_.default, | |
torch.ops.aten.as_strided_scatter.default, | |
] | |
strided_nodes = [n for n in gm.graph.nodes if n.target in as_strided_ops] | |
for n in strided_nodes: | |
with gm.graph.inserting_before(n): | |
# add a node to enforce eager layout | |
ft = n.args[0].meta["val"] | |
new_node = gm.graph.call_function( | |
prims.inductor_force_stride_order.default, (n.args[0], ft.stride()) | |
) | |
n.replace_input_with(n.args[0], new_node) | |
gm.graph.lint() | |
gm.recompile() | |
def convert_conv_weights_to_channels_last(gm: torch.fx.GraphModule): | |
""" | |
Convert 4d convolution weight tensor to channels last format. | |
This pass is performed before freezing so the added nodes can be constant | |
folded by freezing. | |
""" | |
convs = [n for n in gm.graph.nodes if n.target == aten.convolution.default] | |
for conv in convs: | |
weight_node = conv.args[1] | |
if len(weight_node.meta["val"].size()) != 4 or weight_node.meta[ | |
"val" | |
].is_contiguous(memory_format=torch.channels_last): | |
# not a 4d tensor or already channels last, skip | |
continue | |
with gm.graph.inserting_before(conv): | |
new_node = gm.graph.call_function( | |
aten.clone.default, | |
(weight_node,), | |
{"memory_format": torch.channels_last}, | |
) | |
conv.replace_input_with(weight_node, new_node) | |
enforce_as_strided_input_layout(gm) | |
enforce_output_layout(gm) | |