Spaces:
Running
Running
import dataclasses | |
import math | |
import operator | |
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type | |
import torch | |
from torch._subclasses.fake_tensor import FakeTensor | |
from torch.export import ExportedProgram | |
from torch.utils._pytree import ( | |
_register_pytree_node, | |
Context, | |
FlattenFunc, | |
FromDumpableContextFn, | |
KeyPath, | |
keystr, | |
MappingKey, | |
SequenceKey, | |
ToDumpableContextFn, | |
UnflattenFunc, | |
) | |
def _check_input_constraints_for_graph( | |
input_placeholders: List[torch.fx.Node], flat_args_with_path, range_constraints | |
): | |
def get_keystr(key_path: KeyPath) -> str: | |
"""For a given index into the flat_args, return a human readable string | |
describing how to access it, e.g. "*args["foo"][0].bar" | |
""" | |
# Prefix the keypath with "*args" or "**kwargs" to make it clearer where | |
# the arguments come from. Ultimately we ought to serialize the | |
# original arg names for the best error message here. | |
args_kwargs_key_path = key_path[0] | |
assert isinstance(args_kwargs_key_path, SequenceKey) | |
if args_kwargs_key_path.idx == 0: | |
return f"*args{keystr(key_path[1:])}" | |
else: | |
kwarg_key = key_path[1] | |
assert isinstance(kwarg_key, MappingKey) | |
name = str(kwarg_key)[1:-1] # get rid of the enclosed [] | |
return f"{name}{keystr(key_path[2:])}" | |
import sympy | |
from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( | |
_convert_range_to_int, | |
) | |
from torch.utils._sympy.solve import try_solve | |
if len(flat_args_with_path) != len(input_placeholders): | |
raise RuntimeError( | |
"Unexpected number of inputs " | |
f"(expected {len(input_placeholders)}, got {len(flat_args_with_path)})" | |
) | |
# NOTE: export already guarantees that the same symbol is used in metadata | |
# for all InputDims related by equality constraints, so we can just unify | |
# symbols with given input dimension values to check equality constraints. | |
unification_map: "Dict[sympy.Symbol, Any]" = {} | |
for (key_path, arg), node in zip(flat_args_with_path, input_placeholders): | |
node_val = node.meta.get("val") | |
if isinstance(node_val, FakeTensor): | |
if not isinstance(arg, torch.Tensor): | |
raise RuntimeError( | |
f"Expected input at {get_keystr(key_path)} to be a tensor, but got {type(arg)}", | |
) | |
if len(node_val.shape) != len(arg.shape): | |
raise RuntimeError( | |
f"Unexpected number of dimensions in input at {get_keystr(key_path)}.shape " | |
f"(expected {node_val.shape}, got {arg.shape})" | |
) | |
for j, (arg_dim, node_dim) in enumerate(zip(arg.shape, node_val.shape)): | |
# TODO(avik): Assert the following property in the IR verifier: | |
# node_dim is either an int or a SymInt containing an int or a unary sympy.Expr | |
if ( | |
isinstance(node_dim, torch.SymInt) | |
and len(node_dim.node.expr.free_symbols) == 1 | |
): | |
symbol = next(iter(node_dim.node.expr.free_symbols)) | |
if symbol in unification_map: | |
existing_dim = node_dim.node.expr.subs(unification_map) | |
if arg_dim != existing_dim: | |
raise RuntimeError( | |
f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to " | |
f"{existing_dim}, but got {arg_dim}", | |
) | |
else: | |
if ( | |
isinstance(arg_dim, torch.SymInt) | |
and not arg_dim.node.expr.is_number | |
): | |
# This can happen when, say, arg is a fake tensor. | |
# We do not run checks on symbolic shapes of fake inputs as | |
# such checks can affect the shape env. | |
pass | |
else: | |
solution = try_solve( | |
sympy.Eq(node_dim.node.expr, arg_dim), symbol | |
) | |
if solution is None: | |
raise RuntimeError( # noqa: TRY200 | |
f"Expected input {node.name}.shape[{j}] = {arg_dim} to be " | |
f"of the form {node_dim.node.expr}, where {symbol} is an integer" | |
) | |
else: | |
unification_map[symbol] = int(solution[1]) | |
if node_dim.node.expr in range_constraints: | |
min_val, max_val = _convert_range_to_int( | |
range_constraints[node_dim.node.expr] | |
) | |
# NOTE: we allow dimensions to be 0/1 at runtime | |
if min_val > 2: | |
if arg_dim < min_val: | |
raise RuntimeError( | |
f"Expected input at {get_keystr(key_path)}.shape[{j}] to be >= " | |
f"{min_val}, but got {arg_dim}", | |
) | |
if max_val < math.inf: | |
if arg_dim > max_val: | |
raise RuntimeError( | |
f"Expected input at {get_keystr(key_path)}.shape[{j}] to be <= " | |
f"{max_val}, but got {arg_dim}", | |
) | |
else: | |
if arg_dim != node_dim: | |
raise RuntimeError( | |
f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to " | |
f"{node_dim}, but got {arg_dim}", | |
) | |
elif isinstance(node_val, (int, float, str)): | |
if type(arg) != type(node_val) or arg != node_val: | |
raise RuntimeError( | |
f"Expected input at {get_keystr(key_path)} to be equal to {node_val}, but got {arg}", | |
) | |
def register_dataclass_as_pytree_node( | |
cls: Type[Any], | |
flatten_fn: Optional[FlattenFunc] = None, | |
unflatten_fn: Optional[UnflattenFunc] = None, | |
*, | |
serialized_type_name: Optional[str] = None, | |
to_dumpable_context: Optional[ToDumpableContextFn] = None, | |
from_dumpable_context: Optional[FromDumpableContextFn] = None, | |
return_none_fields: bool = False, | |
) -> None: | |
assert dataclasses.is_dataclass( | |
cls | |
), f"Only dataclasses can be registered with this function: {cls}" | |
def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]: | |
flattened = [] | |
flat_names = [] | |
none_names = [] | |
for f in dataclasses.fields(obj): | |
name, val = f.name, getattr(obj, f.name) | |
if val is not None or return_none_fields: | |
flattened.append(val) | |
flat_names.append(name) | |
else: | |
none_names.append(name) | |
return flattened, [flat_names, none_names] | |
def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any: | |
flat_names, none_names = context | |
return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names)) | |
flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn | |
unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn | |
if (to_dumpable_context is None) ^ (from_dumpable_context is None): | |
raise ValueError( | |
f"Both to_dumpable_context and from_dumpable_context for {cls} must " | |
"be None or registered." | |
) | |
_register_pytree_node( | |
cls, | |
flatten_fn, | |
unflatten_fn, | |
serialized_type_name=serialized_type_name, | |
to_dumpable_context=to_dumpable_context, | |
from_dumpable_context=from_dumpable_context, | |
) | |
def is_param(program: ExportedProgram, node: torch.fx.Node) -> bool: | |
""" | |
Checks if the given node is a parameter within the exported program | |
""" | |
return node.name in program.graph_signature.inputs_to_parameters | |
def get_param( | |
program: ExportedProgram, | |
node: torch.fx.Node, | |
) -> Optional[torch.nn.Parameter]: | |
""" | |
Returns the parameter associated with the given node in the exported program. | |
Returns None if the node is not a parameter within the exported program | |
""" | |
if is_param(program, node): | |
parameter_name = program.graph_signature.inputs_to_parameters[node.name] | |
return program.state_dict[parameter_name] | |
return None | |
def is_buffer(program: ExportedProgram, node: torch.fx.Node) -> bool: | |
""" | |
Checks if the given node is a buffer within the exported program | |
""" | |
return node.name in program.graph_signature.inputs_to_buffers | |
def get_buffer( | |
program: ExportedProgram, | |
node: torch.fx.Node, | |
) -> Optional[torch.Tensor]: | |
""" | |
Returns the buffer associated with the given node in the exported program. | |
Returns None if the node is not a buffer within the exported program | |
""" | |
if is_buffer(program, node): | |
buffer_name = program.graph_signature.inputs_to_buffers[node.name] | |
if buffer_name in program.graph_signature.non_persistent_buffers: | |
return program.constants[buffer_name] | |
else: | |
return program.state_dict[buffer_name] | |
return None | |
def is_lifted_tensor_constant( | |
program: ExportedProgram, | |
node: torch.fx.Node, | |
) -> bool: | |
""" | |
Checks if the given node is a lifted tensor constant within the exported program | |
""" | |
return node.name in program.graph_signature.inputs_to_lifted_tensor_constants | |
def get_lifted_tensor_constant( | |
program: ExportedProgram, | |
node: torch.fx.Node, | |
) -> Optional[torch.Tensor]: | |
""" | |
Returns the lifted tensor constant associated with the given node in the exported program. | |
Returns None if the node is not a lifted tensor constant within the exported program | |
""" | |
if is_lifted_tensor_constant(program, node): | |
lifted_tensor_name = program.graph_signature.inputs_to_lifted_tensor_constants[ | |
node.name | |
] | |
return program.constants[lifted_tensor_name] | |
return None | |
def sequential_split(gm: torch.fx.GraphModule, node_call_back) -> torch.fx.GraphModule: | |
""" | |
Splits the graph module into multiple submodules based on the node_call_back. | |
The node_call_back should return True if the node is a delimiter. Delimiter will be | |
the first node in the next submodule. | |
""" | |
from torch.fx.passes.split_module import split_module | |
split_map = {} | |
split_id = 0 | |
for node in gm.graph.nodes: | |
if node_call_back(node): | |
split_id += 1 | |
split_map[node] = split_id | |
new_gm = split_module( | |
gm, | |
gm, | |
lambda node: split_map[node], | |
keep_original_order=True, | |
keep_original_node_name=True, | |
) | |
# Keep the codegen from original graph module to preserve e.g. pytree info. | |
new_gm.graph._codegen = gm.graph._codegen | |
new_gm.recompile() | |
return new_gm | |
def nodes_filter(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]: | |
"""Returns the nodes that match the node_call_back as a list.""" | |
return [node for node in nodes if node_call_back(node)] | |
def nodes_first( | |
nodes: List[torch.fx.Node], node_call_back=None | |
) -> Optional[torch.fx.Node]: | |
""" | |
Returns the first node that matches the node_call_back. If no node matches, returns None. | |
When node_call_back is None, returns the first node in the node list. | |
""" | |
ret = nodes_filter(nodes, node_call_back if node_call_back else lambda node: True) | |
if len(ret) > 0: | |
return ret[0] | |
return None | |
def nodes_count(nodes: List[torch.fx.Node], node_call_back) -> int: | |
"""Returns the number of nodes that match the node_call_back.""" | |
return len(nodes_filter(nodes, node_call_back)) | |
def nodes_map(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]: | |
""" | |
Sequentially visit the nodes list and invoke node_call_back on each element. | |
Returns the nodes list after the node_call_back is invoked on each element. | |
""" | |
for node in nodes: | |
node_call_back(node) | |
return nodes | |
def node_replace_( | |
old_node: torch.fx.Node, new_node: torch.fx.Node, delete_old: bool = False | |
) -> None: | |
""" | |
Replace all uses of old_node with new_node. | |
""" | |
old_node.replace_all_uses_with(new_node) | |
if delete_old: | |
old_node.users.clear() | |
old_node.graph.erase_node(old_node) | |
def node_inline_(call_mod_node: torch.fx.Node) -> None: | |
""" | |
Inline the submodule of the given node into the parent module. | |
Note: we only support the case where submodule takes tensors inputs. | |
""" | |
assert call_mod_node.op == "call_module" | |
gm = call_mod_node.graph.owning_module | |
assert isinstance(call_mod_node.target, str) | |
sub_gm = getattr(gm, call_mod_node.target) | |
phs = (node for node in sub_gm.graph.nodes if node.op == "placeholder") | |
body = ( | |
node for node in sub_gm.graph.nodes if node.op not in ("placeholder", "output") | |
) | |
output = [node for node in sub_gm.graph.nodes if node.op == "output"] | |
for ph, arg in zip(phs, call_mod_node.args): | |
assert isinstance(arg, torch.fx.Node) | |
node_replace_(ph, arg, delete_old=True) | |
with gm.graph.inserting_before(call_mod_node): | |
for node in body: | |
new_node = gm.graph.node_copy(node) | |
node_replace_(node, new_node, delete_old=True) | |
if len(output) > 0: | |
assert len(output) == 1 and len(output[0].args) == 1 | |
new_output = output[0].args[0] | |
if isinstance(new_output, torch.fx.Node): | |
node_replace_(call_mod_node, new_output, delete_old=True) | |
elif isinstance(new_output, (list, tuple)): | |
# Inline the get_item calls for the output node. | |
get_item_users = nodes_filter( | |
list(call_mod_node.users.keys()), | |
lambda node: node.op == "call_function" | |
and node.target == operator.getitem, | |
) | |
# get_item_node.args[1] is the idx referring to new_output[idx] | |
nodes_map( | |
get_item_users, | |
lambda get_item_node: node_replace_( | |
get_item_node, | |
new_output[get_item_node.args[1]], | |
delete_old=True, | |
), | |
) | |
call_mod_node.graph.erase_node(call_mod_node) | |
else: | |
raise NotImplementedError( | |
f"Unsupported output type {type(new_output)}. Expect it to be a Node or a list/tuple of Nodes." | |
) | |
else: | |
call_mod_node.graph.erase_node(call_mod_node) | |
gm.delete_all_unused_submodules() | |
gm.recompile() | |
return gm | |