Spaces:
Sleeping
Sleeping
File size: 5,055 Bytes
dc2106c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
# A library and utility for drawing ONNX nets. Most of this implementation has
# been borrowed from the caffe2 implementation
# https://github.com/pytorch/pytorch/blob/master/caffe2/python/net_drawer.py
#
# The script takes two required arguments:
# -input: a path to a serialized ModelProto .pb file.
# -output: a path to write a dot file representation of the graph
#
# Given this dot file representation, you can-for example-export this to svg
# with the graphviz `dot` utility, like so:
#
# $ dot -Tsvg my_output.dot -o my_output.svg
import argparse
import json
from collections import defaultdict
from typing import Any, Callable, Dict, Optional
import pydot
from onnx import GraphProto, ModelProto, NodeProto
OP_STYLE = {
"shape": "box",
"color": "#0F9D58",
"style": "filled",
"fontcolor": "#FFFFFF",
}
BLOB_STYLE = {"shape": "octagon"}
_NodeProducer = Callable[[NodeProto, int], pydot.Node]
def _escape_label(name: str) -> str:
# json.dumps is poor man's escaping
return json.dumps(name)
def _form_and_sanitize_docstring(s: str) -> str:
url = "javascript:alert("
url += _escape_label(s).replace('"', "'").replace("<", "").replace(">", "")
url += ")"
return url
def GetOpNodeProducer( # noqa: N802
embed_docstring: bool = False, **kwargs: Any
) -> _NodeProducer:
def really_get_op_node(op: NodeProto, op_id: int) -> pydot.Node:
if op.name:
node_name = f"{op.name}/{op.op_type} (op#{op_id})"
else:
node_name = f"{op.op_type} (op#{op_id})"
for i, input_ in enumerate(op.input):
node_name += "\n input" + str(i) + " " + input_
for i, output in enumerate(op.output):
node_name += "\n output" + str(i) + " " + output
node = pydot.Node(node_name, **kwargs)
if embed_docstring:
url = _form_and_sanitize_docstring(op.doc_string)
node.set_URL(url)
return node
return really_get_op_node
def GetPydotGraph( # noqa: N802
graph: GraphProto,
name: Optional[str] = None,
rankdir: str = "LR",
node_producer: Optional[_NodeProducer] = None,
embed_docstring: bool = False,
) -> pydot.Dot:
if node_producer is None:
node_producer = GetOpNodeProducer(embed_docstring=embed_docstring, **OP_STYLE)
pydot_graph = pydot.Dot(name, rankdir=rankdir)
pydot_nodes: Dict[str, pydot.Node] = {}
pydot_node_counts: Dict[str, int] = defaultdict(int)
for op_id, op in enumerate(graph.node):
op_node = node_producer(op, op_id)
pydot_graph.add_node(op_node)
for input_name in op.input:
if input_name not in pydot_nodes:
input_node = pydot.Node(
_escape_label(input_name + str(pydot_node_counts[input_name])),
label=_escape_label(input_name),
**BLOB_STYLE,
)
pydot_nodes[input_name] = input_node
else:
input_node = pydot_nodes[input_name]
pydot_graph.add_node(input_node)
pydot_graph.add_edge(pydot.Edge(input_node, op_node))
for output_name in op.output:
if output_name in pydot_nodes:
pydot_node_counts[output_name] += 1
output_node = pydot.Node(
_escape_label(output_name + str(pydot_node_counts[output_name])),
label=_escape_label(output_name),
**BLOB_STYLE,
)
pydot_nodes[output_name] = output_node
pydot_graph.add_node(output_node)
pydot_graph.add_edge(pydot.Edge(op_node, output_node))
return pydot_graph
def main() -> None:
parser = argparse.ArgumentParser(description="ONNX net drawer")
parser.add_argument(
"--input",
type=str,
required=True,
help="The input protobuf file.",
)
parser.add_argument(
"--output",
type=str,
required=True,
help="The output protobuf file.",
)
parser.add_argument(
"--rankdir",
type=str,
default="LR",
help="The rank direction of the pydot graph.",
)
parser.add_argument(
"--embed_docstring",
action="store_true",
help="Embed docstring as javascript alert. Useful for SVG format.",
)
args = parser.parse_args()
model = ModelProto()
with open(args.input, "rb") as fid:
content = fid.read()
model.ParseFromString(content)
pydot_graph = GetPydotGraph(
model.graph,
name=model.graph.name,
rankdir=args.rankdir,
node_producer=GetOpNodeProducer(
embed_docstring=args.embed_docstring, **OP_STYLE
),
)
pydot_graph.write_dot(args.output)
if __name__ == "__main__":
main()
|