Spaces:
Running
Running
import hashlib | |
import torch | |
import torch.fx | |
from typing import Any, Dict, Optional, TYPE_CHECKING | |
from torch.fx.node import _get_qualified_name, _format_arg | |
from torch.fx.graph import _parse_stack_trace | |
from torch.fx.passes.shape_prop import TensorMetadata | |
from torch.fx._compatibility import compatibility | |
from itertools import chain | |
__all__ = ['FxGraphDrawer'] | |
try: | |
import pydot | |
HAS_PYDOT = True | |
except ImportError: | |
HAS_PYDOT = False | |
_COLOR_MAP = { | |
"placeholder": '"AliceBlue"', | |
"call_module": "LemonChiffon1", | |
"get_param": "Yellow2", | |
"get_attr": "LightGrey", | |
"output": "PowderBlue", | |
} | |
_HASH_COLOR_MAP = [ | |
"CadetBlue1", | |
"Coral", | |
"DarkOliveGreen1", | |
"DarkSeaGreen1", | |
"GhostWhite", | |
"Khaki1", | |
"LavenderBlush1", | |
"LightSkyBlue", | |
"MistyRose1", | |
"MistyRose2", | |
"PaleTurquoise2", | |
"PeachPuff1", | |
"Salmon", | |
"Thistle1", | |
"Thistle3", | |
"Wheat1", | |
] | |
_WEIGHT_TEMPLATE = { | |
"fillcolor": "Salmon", | |
"style": '"filled,rounded"', | |
"fontcolor": "#000000", | |
} | |
if HAS_PYDOT: | |
class FxGraphDrawer: | |
""" | |
Visualize a torch.fx.Graph with graphviz | |
Basic usage: | |
g = FxGraphDrawer(symbolic_traced, "resnet18") | |
g.get_dot_graph().write_svg("a.svg") | |
""" | |
def __init__( | |
self, | |
graph_module: torch.fx.GraphModule, | |
name: str, | |
ignore_getattr: bool = False, | |
ignore_parameters_and_buffers: bool = False, | |
skip_node_names_in_args: bool = True, | |
parse_stack_trace: bool = False, | |
dot_graph_shape: Optional[str] = None, | |
): | |
self._name = name | |
self.dot_graph_shape = ( | |
dot_graph_shape if dot_graph_shape is not None else "record" | |
) | |
_WEIGHT_TEMPLATE["shape"] = self.dot_graph_shape | |
self._dot_graphs = { | |
name: self._to_dot( | |
graph_module, name, ignore_getattr, ignore_parameters_and_buffers, skip_node_names_in_args, parse_stack_trace | |
) | |
} | |
for node in graph_module.graph.nodes: | |
if node.op != "call_module": | |
continue | |
leaf_node = self._get_leaf_node(graph_module, node) | |
if not isinstance(leaf_node, torch.fx.GraphModule): | |
continue | |
self._dot_graphs[f"{name}_{node.target}"] = self._to_dot( | |
leaf_node, | |
f"{name}_{node.target}", | |
ignore_getattr, | |
ignore_parameters_and_buffers, | |
skip_node_names_in_args, | |
parse_stack_trace, | |
) | |
def get_dot_graph(self, submod_name=None) -> pydot.Dot: | |
""" | |
Visualize a torch.fx.Graph with graphviz | |
Example: | |
>>> # xdoctest: +REQUIRES(module:pydot) | |
>>> # define module | |
>>> class MyModule(torch.nn.Module): | |
>>> def __init__(self): | |
>>> super().__init__() | |
>>> self.linear = torch.nn.Linear(4, 5) | |
>>> def forward(self, x): | |
>>> return self.linear(x).clamp(min=0.0, max=1.0) | |
>>> module = MyModule() | |
>>> # trace the module | |
>>> symbolic_traced = torch.fx.symbolic_trace(module) | |
>>> # setup output file | |
>>> import ubelt as ub | |
>>> dpath = ub.Path.appdir('torch/tests/FxGraphDrawer').ensuredir() | |
>>> fpath = dpath / 'linear.svg' | |
>>> # draw the graph | |
>>> g = FxGraphDrawer(symbolic_traced, "linear") | |
>>> g.get_dot_graph().write_svg(fpath) | |
""" | |
if submod_name is None: | |
return self.get_main_dot_graph() | |
else: | |
return self.get_submod_dot_graph(submod_name) | |
def get_main_dot_graph(self) -> pydot.Dot: | |
return self._dot_graphs[self._name] | |
def get_submod_dot_graph(self, submod_name) -> pydot.Dot: | |
return self._dot_graphs[f"{self._name}_{submod_name}"] | |
def get_all_dot_graphs(self) -> Dict[str, pydot.Dot]: | |
return self._dot_graphs | |
def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]: | |
template = { | |
"shape": self.dot_graph_shape, | |
"fillcolor": "#CAFFE3", | |
"style": '"filled,rounded"', | |
"fontcolor": "#000000", | |
} | |
if node.op in _COLOR_MAP: | |
template["fillcolor"] = _COLOR_MAP[node.op] | |
else: | |
# Use a random color for each node; based on its name so it's stable. | |
target_name = node._pretty_print_target(node.target) | |
target_hash = int(hashlib.md5(target_name.encode()).hexdigest()[:8], 16) | |
template["fillcolor"] = _HASH_COLOR_MAP[target_hash % len(_HASH_COLOR_MAP)] | |
return template | |
def _get_leaf_node( | |
self, module: torch.nn.Module, node: torch.fx.Node | |
) -> torch.nn.Module: | |
py_obj = module | |
assert isinstance(node.target, str) | |
atoms = node.target.split(".") | |
for atom in atoms: | |
if not hasattr(py_obj, atom): | |
raise RuntimeError( | |
str(py_obj) + " does not have attribute " + atom + "!" | |
) | |
py_obj = getattr(py_obj, atom) | |
return py_obj | |
def _typename(self, target: Any) -> str: | |
if isinstance(target, torch.nn.Module): | |
ret = torch.typename(target) | |
elif isinstance(target, str): | |
ret = target | |
else: | |
ret = _get_qualified_name(target) | |
# Escape "{" and "}" to prevent dot files like: | |
# https://gist.github.com/SungMinCho/1a017aab662c75d805c5954d62c5aabc | |
# which triggers `Error: bad label format (...)` from dot | |
return ret.replace("{", r"\{").replace("}", r"\}") | |
# shorten path to avoid drawing long boxes | |
# for full path = '/home/weif/pytorch/test.py' | |
# return short path = 'pytorch/test.py' | |
def _shorten_file_name( | |
self, | |
full_file_name: str, | |
truncate_to_last_n: int = 2, | |
): | |
splits = full_file_name.split('/') | |
if len(splits) >= truncate_to_last_n: | |
return '/'.join(splits[-truncate_to_last_n:]) | |
return full_file_name | |
def _get_node_label( | |
self, | |
module: torch.fx.GraphModule, | |
node: torch.fx.Node, | |
skip_node_names_in_args: bool, | |
parse_stack_trace: bool, | |
) -> str: | |
def _get_str_for_args_kwargs(arg): | |
if isinstance(arg, tuple): | |
prefix, suffix = r"|args=(\l", r",\n)\l" | |
arg_strs_list = [_format_arg(a, max_list_len=8) for a in arg] | |
elif isinstance(arg, dict): | |
prefix, suffix = r"|kwargs={\l", r",\n}\l" | |
arg_strs_list = [ | |
f"{k}: {_format_arg(v, max_list_len=8)}" | |
for k, v in arg.items() | |
] | |
else: # Fall back to nothing in unexpected case. | |
return "" | |
# Strip out node names if requested. | |
if skip_node_names_in_args: | |
arg_strs_list = [a for a in arg_strs_list if "%" not in a] | |
if len(arg_strs_list) == 0: | |
return "" | |
arg_strs = prefix + r",\n".join(arg_strs_list) + suffix | |
if len(arg_strs_list) == 1: | |
arg_strs = arg_strs.replace(r"\l", "").replace(r"\n", "") | |
return arg_strs.replace("{", r"\{").replace("}", r"\}") | |
label = "{" + f"name=%{node.name}|op_code={node.op}\n" | |
if node.op == "call_module": | |
leaf_module = self._get_leaf_node(module, node) | |
label += r"\n" + self._typename(leaf_module) + r"\n|" | |
extra = "" | |
if hasattr(leaf_module, "__constants__"): | |
extra = r"\n".join( | |
[f"{c}: {getattr(leaf_module, c)}" for c in leaf_module.__constants__] # type: ignore[union-attr] | |
) | |
label += extra + r"\n" | |
else: | |
label += f"|target={self._typename(node.target)}" + r"\n" | |
if len(node.args) > 0: | |
label += _get_str_for_args_kwargs(node.args) | |
if len(node.kwargs) > 0: | |
label += _get_str_for_args_kwargs(node.kwargs) | |
label += f"|num_users={len(node.users)}" + r"\n" | |
tensor_meta = node.meta.get('tensor_meta') | |
label += self._tensor_meta_to_label(tensor_meta) | |
# for original fx graph | |
# print buf=buf0, n_origin=6 | |
buf_meta = node.meta.get('buf_meta', None) | |
if buf_meta is not None: | |
label += f"|buf={buf_meta.name}" + r"\n" | |
label += f"|n_origin={buf_meta.n_origin}" + r"\n" | |
# for original fx graph | |
# print file:lineno code | |
if parse_stack_trace and node.stack_trace is not None: | |
parsed_stack_trace = _parse_stack_trace(node.stack_trace) | |
fname = self._shorten_file_name(parsed_stack_trace.file) | |
label += f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}" + r"\n" | |
return label + "}" | |
def _tensor_meta_to_label(self, tm) -> str: | |
if tm is None: | |
return "" | |
elif isinstance(tm, TensorMetadata): | |
return self._stringify_tensor_meta(tm) | |
elif isinstance(tm, list): | |
result = "" | |
for item in tm: | |
result += self._tensor_meta_to_label(item) | |
return result | |
elif isinstance(tm, dict): | |
result = "" | |
for v in tm.values(): | |
result += self._tensor_meta_to_label(v) | |
return result | |
elif isinstance(tm, tuple): | |
result = "" | |
for item in tm: | |
result += self._tensor_meta_to_label(item) | |
return result | |
else: | |
raise RuntimeError(f"Unsupported tensor meta type {type(tm)}") | |
def _stringify_tensor_meta(self, tm: TensorMetadata) -> str: | |
result = "" | |
if not hasattr(tm, "dtype"): | |
print("tm", tm) | |
result += "|" + "dtype" + "=" + str(tm.dtype) + r"\n" | |
result += "|" + "shape" + "=" + str(tuple(tm.shape)) + r"\n" | |
result += "|" + "requires_grad" + "=" + str(tm.requires_grad) + r"\n" | |
result += "|" + "stride" + "=" + str(tm.stride) + r"\n" | |
if tm.is_quantized: | |
assert tm.qparams is not None | |
assert "qscheme" in tm.qparams | |
qscheme = tm.qparams["qscheme"] | |
if qscheme in { | |
torch.per_tensor_affine, | |
torch.per_tensor_symmetric, | |
}: | |
result += "|" + "q_scale" + "=" + str(tm.qparams["scale"]) + r"\n" | |
result += "|" + "q_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n" | |
elif qscheme in { | |
torch.per_channel_affine, | |
torch.per_channel_symmetric, | |
torch.per_channel_affine_float_qparams, | |
}: | |
result += "|" + "q_per_channel_scale" + "=" + str(tm.qparams["scale"]) + r"\n" | |
result += "|" + "q_per_channel_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n" | |
result += "|" + "q_per_channel_axis" + "=" + str(tm.qparams["axis"]) + r"\n" | |
else: | |
raise RuntimeError(f"Unsupported qscheme: {qscheme}") | |
result += "|" + "qscheme" + "=" + str(tm.qparams["qscheme"]) + r"\n" | |
return result | |
def _get_tensor_label(self, t: torch.Tensor) -> str: | |
return str(t.dtype) + str(list(t.shape)) + r"\n" | |
# when parse_stack_trace=True | |
# print file:lineno code | |
def _to_dot( | |
self, | |
graph_module: torch.fx.GraphModule, | |
name: str, | |
ignore_getattr: bool, | |
ignore_parameters_and_buffers: bool, | |
skip_node_names_in_args: bool, | |
parse_stack_trace: bool, | |
) -> pydot.Dot: | |
""" | |
Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph. | |
If ignore_parameters_and_buffers is True, the parameters and buffers | |
created with the module will not be added as nodes and edges. | |
""" | |
# "TB" means top-to-bottom rank direction in layout | |
dot_graph = pydot.Dot(name, rankdir="TB") | |
buf_name_to_subgraph = {} | |
for node in graph_module.graph.nodes: | |
if ignore_getattr and node.op == "get_attr": | |
continue | |
style = self._get_node_style(node) | |
dot_node = pydot.Node( | |
node.name, label=self._get_node_label(graph_module, node, skip_node_names_in_args, parse_stack_trace), **style | |
) | |
current_graph = dot_graph | |
buf_meta = node.meta.get('buf_meta', None) | |
if buf_meta is not None and buf_meta.n_origin > 1: | |
buf_name = buf_meta.name | |
if buf_name not in buf_name_to_subgraph: | |
buf_name_to_subgraph[buf_name] = pydot.Cluster(buf_name, label=buf_name) | |
current_graph = buf_name_to_subgraph.get(buf_name) | |
current_graph.add_node(dot_node) | |
def get_module_params_or_buffers(): | |
for pname, ptensor in chain( | |
leaf_module.named_parameters(), leaf_module.named_buffers() | |
): | |
pname1 = node.name + "." + pname | |
label1 = ( | |
pname1 + "|op_code=get_" + "parameter" | |
if isinstance(ptensor, torch.nn.Parameter) | |
else "buffer" + r"\l" | |
) | |
dot_w_node = pydot.Node( | |
pname1, | |
label="{" + label1 + self._get_tensor_label(ptensor) + "}", | |
**_WEIGHT_TEMPLATE, | |
) | |
dot_graph.add_node(dot_w_node) | |
dot_graph.add_edge(pydot.Edge(pname1, node.name)) | |
if node.op == "call_module": | |
leaf_module = self._get_leaf_node(graph_module, node) | |
if not ignore_parameters_and_buffers and not isinstance(leaf_module, torch.fx.GraphModule): | |
get_module_params_or_buffers() | |
for subgraph in buf_name_to_subgraph.values(): | |
subgraph.set('color', 'royalblue') | |
subgraph.set('penwidth', '2') | |
dot_graph.add_subgraph(subgraph) | |
for node in graph_module.graph.nodes: | |
if ignore_getattr and node.op == "get_attr": | |
continue | |
for user in node.users: | |
dot_graph.add_edge(pydot.Edge(node.name, user.name)) | |
return dot_graph | |
else: | |
if not TYPE_CHECKING: | |
class FxGraphDrawer: | |
def __init__( | |
self, | |
graph_module: torch.fx.GraphModule, | |
name: str, | |
ignore_getattr: bool = False, | |
ignore_parameters_and_buffers: bool = False, | |
skip_node_names_in_args: bool = True, | |
parse_stack_trace: bool = False, | |
dot_graph_shape: Optional[str] = None, | |
): | |
raise RuntimeError('FXGraphDrawer requires the pydot package to be installed. Please install ' | |
'pydot through your favorite Python package manager.') | |