Spaces:
Sleeping
Sleeping
# Copyright (c) ONNX Project Contributors | |
# SPDX-License-Identifier: Apache-2.0 | |
# A library and utility for drawing ONNX nets. Most of this implementation has | |
# been borrowed from the caffe2 implementation | |
# https://github.com/pytorch/pytorch/blob/master/caffe2/python/net_drawer.py | |
# | |
# The script takes two required arguments: | |
# -input: a path to a serialized ModelProto .pb file. | |
# -output: a path to write a dot file representation of the graph | |
# | |
# Given this dot file representation, you can-for example-export this to svg | |
# with the graphviz `dot` utility, like so: | |
# | |
# $ dot -Tsvg my_output.dot -o my_output.svg | |
import argparse | |
import json | |
from collections import defaultdict | |
from typing import Any, Callable, Dict, Optional | |
import pydot | |
from onnx import GraphProto, ModelProto, NodeProto | |
OP_STYLE = { | |
"shape": "box", | |
"color": "#0F9D58", | |
"style": "filled", | |
"fontcolor": "#FFFFFF", | |
} | |
BLOB_STYLE = {"shape": "octagon"} | |
_NodeProducer = Callable[[NodeProto, int], pydot.Node] | |
def _escape_label(name: str) -> str: | |
# json.dumps is poor man's escaping | |
return json.dumps(name) | |
def _form_and_sanitize_docstring(s: str) -> str: | |
url = "javascript:alert(" | |
url += _escape_label(s).replace('"', "'").replace("<", "").replace(">", "") | |
url += ")" | |
return url | |
def GetOpNodeProducer( # noqa: N802 | |
embed_docstring: bool = False, **kwargs: Any | |
) -> _NodeProducer: | |
def really_get_op_node(op: NodeProto, op_id: int) -> pydot.Node: | |
if op.name: | |
node_name = f"{op.name}/{op.op_type} (op#{op_id})" | |
else: | |
node_name = f"{op.op_type} (op#{op_id})" | |
for i, input_ in enumerate(op.input): | |
node_name += "\n input" + str(i) + " " + input_ | |
for i, output in enumerate(op.output): | |
node_name += "\n output" + str(i) + " " + output | |
node = pydot.Node(node_name, **kwargs) | |
if embed_docstring: | |
url = _form_and_sanitize_docstring(op.doc_string) | |
node.set_URL(url) | |
return node | |
return really_get_op_node | |
def GetPydotGraph( # noqa: N802 | |
graph: GraphProto, | |
name: Optional[str] = None, | |
rankdir: str = "LR", | |
node_producer: Optional[_NodeProducer] = None, | |
embed_docstring: bool = False, | |
) -> pydot.Dot: | |
if node_producer is None: | |
node_producer = GetOpNodeProducer(embed_docstring=embed_docstring, **OP_STYLE) | |
pydot_graph = pydot.Dot(name, rankdir=rankdir) | |
pydot_nodes: Dict[str, pydot.Node] = {} | |
pydot_node_counts: Dict[str, int] = defaultdict(int) | |
for op_id, op in enumerate(graph.node): | |
op_node = node_producer(op, op_id) | |
pydot_graph.add_node(op_node) | |
for input_name in op.input: | |
if input_name not in pydot_nodes: | |
input_node = pydot.Node( | |
_escape_label(input_name + str(pydot_node_counts[input_name])), | |
label=_escape_label(input_name), | |
**BLOB_STYLE, | |
) | |
pydot_nodes[input_name] = input_node | |
else: | |
input_node = pydot_nodes[input_name] | |
pydot_graph.add_node(input_node) | |
pydot_graph.add_edge(pydot.Edge(input_node, op_node)) | |
for output_name in op.output: | |
if output_name in pydot_nodes: | |
pydot_node_counts[output_name] += 1 | |
output_node = pydot.Node( | |
_escape_label(output_name + str(pydot_node_counts[output_name])), | |
label=_escape_label(output_name), | |
**BLOB_STYLE, | |
) | |
pydot_nodes[output_name] = output_node | |
pydot_graph.add_node(output_node) | |
pydot_graph.add_edge(pydot.Edge(op_node, output_node)) | |
return pydot_graph | |
def main() -> None: | |
parser = argparse.ArgumentParser(description="ONNX net drawer") | |
parser.add_argument( | |
"--input", | |
type=str, | |
required=True, | |
help="The input protobuf file.", | |
) | |
parser.add_argument( | |
"--output", | |
type=str, | |
required=True, | |
help="The output protobuf file.", | |
) | |
parser.add_argument( | |
"--rankdir", | |
type=str, | |
default="LR", | |
help="The rank direction of the pydot graph.", | |
) | |
parser.add_argument( | |
"--embed_docstring", | |
action="store_true", | |
help="Embed docstring as javascript alert. Useful for SVG format.", | |
) | |
args = parser.parse_args() | |
model = ModelProto() | |
with open(args.input, "rb") as fid: | |
content = fid.read() | |
model.ParseFromString(content) | |
pydot_graph = GetPydotGraph( | |
model.graph, | |
name=model.graph.name, | |
rankdir=args.rankdir, | |
node_producer=GetOpNodeProducer( | |
embed_docstring=args.embed_docstring, **OP_STYLE | |
), | |
) | |
pydot_graph.write_dot(args.output) | |
if __name__ == "__main__": | |
main() | |