Kano001's picture
Upload 2707 files
dc2106c verified
raw
history blame
35.9 kB
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
import unittest
from typing import Callable, List, Optional, Sequence, Tuple
import numpy as np
from onnx import (
FunctionProto,
GraphProto,
ModelProto,
NodeProto,
SparseTensorProto,
TensorProto,
ValueInfoProto,
checker,
compose,
helper,
parser,
version_converter,
)
def _load_model(m_def: str) -> ModelProto:
"""Parses a model from a string representation, including checking the model for correctness"""
m = parser.parse_model(m_def)
checker.check_model(m)
return m
def _prefixed(prefix: str, s: str) -> str:
"""Prefixes a string (if not empty)"""
return prefix + s if len(s) > 0 else s
def _get_shape(value_info: ValueInfoProto) -> List[int]:
"""Returns a list of integers representing the shape of the provided ValueInfoProto"""
return [
value_info.type.tensor_type.shape.dim[d].dim_value
for d in range(len(value_info.type.tensor_type.shape.dim))
]
def _make_sparse_tensor(name: str) -> SparseTensorProto:
dense_shape = [3, 3]
linear_indices = [2, 3, 5]
sparse_values = [1.7, 0.4, 0.9]
values_tensor = helper.make_tensor(
name=name + "_values",
data_type=TensorProto.FLOAT,
dims=[len(sparse_values)],
vals=np.array(sparse_values).astype(np.float32),
raw=False,
)
indices_tensor = helper.make_tensor(
name=name + "_idx",
data_type=TensorProto.INT64,
dims=[len(linear_indices)],
vals=np.array(linear_indices).astype(np.int64),
raw=False,
)
return helper.make_sparse_tensor(values_tensor, indices_tensor, dense_shape)
M1_DEF = """
<
ir_version: 7,
opset_import: [ "": 10, "com.microsoft": 1]
>
agraph (float[N, M] A0, float[N, M] A1, float[N, M] _A) => (float[N, M] B00, float[N, M] B10, float[N, M] B20)
{
B00 = Add(A0, A1)
B10 = Sub(A0, A1)
B20 = Mul(A0, A1)
}
"""
M2_DEF = """
<
ir_version: 7,
opset_import: [ "": 10, "com.microsoft": 1]
>
agraph (float[N, M] B01, float[N, M] B11, float[N, M] B21) => (float[N, M] D0)
{
C0 = Add(B01, B11)
C1 = Sub(B11, B21)
M1 = Mul(C0, C1)
}
"""
class TestComposeFunctions(unittest.TestCase):
def _test_merge_models(
self,
m1def: str,
m2def: str,
io_map: List[Tuple[str, str]],
check_expectations: Callable[[GraphProto, GraphProto, GraphProto], None],
inputs: Optional[List[str]] = None,
outputs: Optional[List[str]] = None,
prefix1: Optional[str] = None,
prefix2: Optional[str] = None,
) -> None:
m1, m2 = _load_model(m1def), _load_model(m2def)
g3 = compose.merge_graphs(
m1.graph,
m2.graph,
io_map=io_map,
inputs=inputs,
outputs=outputs,
prefix1=prefix1,
prefix2=prefix2,
)
checker.check_graph(g3)
check_expectations(m1.graph, m2.graph, g3)
m3 = compose.merge_models(
m1,
m2,
io_map=io_map,
inputs=inputs,
outputs=outputs,
prefix1=prefix1,
prefix2=prefix2,
)
checker.check_model(m3)
check_expectations(m1.graph, m2.graph, m3.graph)
def test_case_connect_all_no_name_collision(self) -> None:
"""Tests a simple scenario where two models without overlapping names are merged by
connecting all the outputs in the first models to all the inputs in the second model
"""
def check_expectations(g1: GraphProto, g2: GraphProto, g3: GraphProto) -> None:
self.assertEqual(g3.input, g1.input)
self.assertEqual(g3.output, g2.output)
self.assertEqual(
["Add", "Sub", "Mul", "Add", "Sub", "Mul"],
[item.op_type for item in g3.node],
)
io_map = [("B00", "B01"), ("B10", "B11"), ("B20", "B21")]
self._test_merge_models(M1_DEF, M2_DEF, io_map, check_expectations)
def test_case_connect_same_output_twice(self) -> None:
"""Tests a scenario where we merge two models by connecting a single output in the first model
to all the inputs in the second
"""
def check_expectations(g1: GraphProto, g2: GraphProto, g3: GraphProto) -> None:
del g2 # Unused
self.assertEqual(g3.input, g1.input)
self.assertEqual(["B10", "B20", "D0"], [elem.name for elem in g3.output])
self.assertEqual(
["Add", "Sub", "Mul", "Add", "Sub", "Mul"],
[item.op_type for item in g3.node],
)
io_map = [("B00", "B01"), ("B00", "B11"), ("B00", "B21")]
self._test_merge_models(M1_DEF, M2_DEF, io_map, check_expectations)
def test_case_connect_same_output_drop_outputs(self) -> None:
"""Tests a scenario where we merge two models by connecting a single output in the first model
to all the inputs in the second, while dropping the rest of the outputs in the first model
"""
def check_expectations(g1: GraphProto, g2: GraphProto, g3: GraphProto) -> None:
del g2 # Unused
self.assertEqual(g3.input, g1.input)
self.assertEqual(["D0"], [elem.name for elem in g3.output])
self.assertEqual(
["Add", "Add", "Sub", "Mul"], [item.op_type for item in g3.node]
)
io_map = [("B00", "B01"), ("B00", "B11"), ("B00", "B21")]
outputs = ["D0"]
self._test_merge_models(
M1_DEF, M2_DEF, io_map, check_expectations, outputs=outputs
)
def test_case_connect_same_input_output_name(self) -> None:
"""Tests a scenario where we merge two models, where the inputs/outputs connected
are named exactly the same
"""
m1_def = """
<
ir_version: 7,
opset_import: [ "": 10]
>
agraph (float[N, M] A) => (float[N, M] B)
{
B = Add(A, A)
}
"""
m2_def = """
<
ir_version: 7,
opset_import: [ "": 10]
>
agraph (float[N, M] B) => (float[N, M] C)
{
C = Add(B, B)
}
"""
io_map = [("B", "B")]
def check_expectations(g1: GraphProto, g2: GraphProto, g3: GraphProto) -> None:
del g1, g2 # Unused
self.assertEqual(["A"], [elem.name for elem in g3.input])
self.assertEqual(["C"], [elem.name for elem in g3.output])
self._test_merge_models(m1_def, m2_def, io_map, check_expectations)
def test_case_drop_inputs_outputs(self) -> None:
"""Tests a scenario where we merge two models, not including some of the inputs/outputs"""
m1_def = """
<
ir_version: 7,
opset_import: [ "": 10]
>
agraph (float[N] A0, float[N] B0) => (float[N] A1, float[N] B1)
{
A1 = Add(A0, A0)
B1 = Sub(B0, B0)
}
"""
m2_def = """
<
ir_version: 7,
opset_import: [ "": 10]
>
agraph (float[N] A2, float[N] B2) => (float[N] A3, float[N] B3)
{
A3 = Add(A2, A2)
B3 = Sub(B2, B2)
}
"""
io_map = [("A1", "B2")]
def check_expectations(g1: GraphProto, g2: GraphProto, g3: GraphProto) -> None:
del g1, g2 # Unused
self.assertEqual(["A0"], [elem.name for elem in g3.input])
self.assertEqual(["B3"], [elem.name for elem in g3.output])
self.assertEqual(["Add", "Sub"], [elem.op_type for elem in g3.node])
inputs = ["A0"]
outputs = ["B3"]
self._test_merge_models(
m1_def, m2_def, io_map, check_expectations, inputs=inputs, outputs=outputs
)
def test_case_name_collision_prefix(self) -> None:
"""Tests a scenario where we merge two models that have name collisions, but they
are avoided by prefixing the models model.
"""
m1_def = """
<
ir_version: 7,
opset_import: [ "": 10]
>
agraph (float[N] A, float[N] B) => (float[N] C)
{
C = Add(A, B)
}
"""
io_map = [("C", "A")]
def check_expectations(g1: GraphProto, g2: GraphProto, g3: GraphProto) -> None:
del g1, g2 # Unused
self.assertEqual(["m1/A", "m1/B", "m2/B"], [elem.name for elem in g3.input])
self.assertEqual(["m2/C"], [elem.name for elem in g3.output])
self.assertEqual(["Add", "Add"], [elem.op_type for elem in g3.node])
self._test_merge_models(
m1_def, m1_def, io_map, check_expectations, prefix1="m1/", prefix2="m2/"
)
def test_case_connect_partially_no_name_collision(self) -> None:
"""Tests a scenario where two models without overlapping names are merged by
connecting some outputs from the first model to some inputs in the second.
The remaining inputs/outputs should be present in the combined model
"""
def check_expectations(g1: GraphProto, g2: GraphProto, g4: GraphProto) -> None:
del g1, g2 # Unused
# B20 <-> B21 not connected. They should still be present
# in the inputs and outputs of the combined graph
self.assertEqual(
["A0", "A1", "_A", "B21"], [elem.name for elem in g4.input]
)
self.assertEqual(["B20", "D0"], [elem.name for elem in g4.output])
io_map = [("B00", "B01"), ("B10", "B11")]
self._test_merge_models(M1_DEF, M2_DEF, io_map, check_expectations)
def test_merge_models_with_metadata_props(self) -> None:
m1 = _load_model(M1_DEF)
helper.set_model_props(m1, {"p1": "v1", "p2": "v2"})
m2 = _load_model(M2_DEF)
helper.set_model_props(m2, {"p3": "v3", "p4": "v4"})
io_map = [("B00", "B01")]
m3 = compose.merge_models(m1, m2, io_map=io_map)
assert len(m3.metadata_props) == 4
# Overlap, but same value
helper.set_model_props(m2, {"p1": "v1", "p4": "v4"})
m3 = compose.merge_models(m1, m2, io_map=io_map)
assert len(m3.metadata_props) == 3
# Same keys but not same value. Error
helper.set_model_props(m2, {"p1": "v5", "p4": "v4"})
self.assertRaises(ValueError, compose.merge_models, m1, m2, io_map=io_map)
def test_error_wrong_input_output_name(self) -> None:
"""Tests that providing a non existing output/input name in the io_map argument produces an error."""
m1, m2 = _load_model(M1_DEF), _load_model(M2_DEF)
self.assertRaises(
ValueError,
compose.merge_models,
m1,
m2,
io_map=[("wrong_outname", "B01"), ("B10", "B11"), ("B20", "B21")],
)
# Wrong output name
self.assertRaises(
ValueError,
compose.merge_models,
m1,
m2,
io_map=[("B00", "wrong_input"), ("B10", "B11"), ("B20", "B21")],
)
def test_error_ir_version_mismatch(self) -> None:
m1 = _load_model(
"""
<
ir_version: 7,
opset_import: [ "": 13]
>
agraph (float[N, M] X0) => (float[N, M] Y0)
{
Y0 = Add(X0, X0)
}
"""
)
m2 = _load_model(
"""
<
ir_version: 6,
opset_import: [ "": 13]
>
agraph (float[N, M] X1) => (float[N, M] Y1)
{
Y1 = Add(X1, X1)
}
"""
)
# Wrong IR version name
self.assertRaises(
ValueError, compose.merge_models, m1, m2, io_map=[("Y0", "X1")]
)
def test_error_opset_import_mismatch(self) -> None:
"""Tests that providing models with different operator set imported produces an error."""
m1, m2 = _load_model(M1_DEF), _load_model(M2_DEF)
m1 = helper.make_model(
m1.graph, producer_name="test", opset_imports=[helper.make_opsetid("", 10)]
)
m2 = helper.make_model(
m2.graph, producer_name="test", opset_imports=[helper.make_opsetid("", 15)]
)
io_map = [("B00", "B01"), ("B10", "B11"), ("B20", "B21")]
self.assertRaises(ValueError, compose.merge_models, m1, m2, io_map)
# Converting to the same Operator set version, should work
m1 = version_converter.convert_version(m1, 15)
m3 = compose.merge_models(m1, m2, io_map=io_map)
checker.check_model(m3)
# FIXME: This function should be removed, as tests should not contain a copy of the tested logic.
def _test_add_prefix(
self,
rename_nodes: bool = False,
rename_edges: bool = False,
rename_inputs: bool = False,
rename_outputs: bool = False,
rename_initializers: bool = False,
rename_value_infos: bool = False,
inplace: bool = False,
) -> None:
m1 = _load_model(M1_DEF)
prefix = "pre/"
if inplace:
m2 = ModelProto()
m2.CopyFrom(m1)
compose.add_prefix(
m2,
prefix,
rename_nodes=rename_nodes,
rename_edges=rename_edges,
rename_inputs=rename_inputs,
rename_outputs=rename_outputs,
rename_initializers=rename_initializers,
rename_value_infos=rename_value_infos,
inplace=True,
)
else:
m2 = compose.add_prefix(
m1,
prefix,
rename_nodes=rename_nodes,
rename_edges=rename_edges,
rename_inputs=rename_inputs,
rename_outputs=rename_outputs,
rename_initializers=rename_initializers,
rename_value_infos=rename_value_infos,
)
g_in = m1.graph
g_out = m2.graph
if (
rename_edges
or rename_inputs
or rename_outputs
or rename_initializers
or rename_value_infos
):
name_mapping = {}
# Rename inputs/outputs/edges. Propagate name changes from and to edges
if rename_edges:
for n in g_in.node:
for e in n.input:
name_mapping[e] = _prefixed(prefix, e)
for e in n.output:
name_mapping[e] = _prefixed(prefix, e)
if rename_inputs:
for elem in g_in.input:
name_mapping[elem.name] = _prefixed(prefix, elem.name)
if rename_outputs:
for elem in g_in.output:
name_mapping[elem.name] = _prefixed(prefix, elem.name)
if rename_initializers:
for init in g_in.initializer:
name_mapping[init.name] = _prefixed(prefix, init.name)
for sparse_init in g_in.sparse_initializer:
name_mapping[sparse_init.values.name] = _prefixed(
prefix, sparse_init.values.name
)
name_mapping[sparse_init.indices.name] = _prefixed(
prefix, sparse_init.indices.name
)
if rename_value_infos:
for value_info in g_in.output:
name_mapping[value_info.name] = _prefixed(prefix, value_info.name)
for n1, n0 in zip(g_out.node, g_in.node):
for e1, e0 in zip(n1.input, n0.input):
self.assertEqual(name_mapping.get(e0, e0), e1)
for e1, e0 in zip(n1.output, n0.output):
self.assertEqual(name_mapping.get(e0, e0), e1)
for i1, i0 in zip(g_out.input, g_in.input):
self.assertEqual(name_mapping.get(i0.name, i0.name), i1.name)
for o1, o0 in zip(g_out.output, g_in.output):
self.assertEqual(name_mapping.get(o0.name, o0.name), o1.name)
for init1, init0 in zip(g_out.initializer, g_in.initializer):
self.assertEqual(name_mapping.get(init0.name, init0.name), init1.name)
for sparse_init1, sparse_init0 in zip(
g_out.sparse_initializer, g_in.sparse_initializer
):
self.assertEqual(
name_mapping.get(
sparse_init0.values.name, sparse_init0.values.name
),
sparse_init1.values.name,
)
self.assertEqual(
name_mapping.get(
sparse_init0.indices.name, sparse_init0.indices.name
),
sparse_init1.indices.name,
)
for vi1, vi0 in zip(g_out.value_info, g_in.value_info):
self.assertEqual(name_mapping.get(vi0.name, vi0.name), vi1.name)
if rename_nodes:
for n1, n0 in zip(g_out.node, g_in.node):
self.assertEqual(_prefixed(prefix, n0.name), n1.name)
def test_add_prefix_nodes(self) -> None:
"""Tests renaming nodes only"""
self._test_add_prefix(rename_nodes=True)
def test_add_prefix_edges(self) -> None:
"""Tests prefixing nodes edges. This will also rename inputs/outputs, since the names are shared"""
self._test_add_prefix(rename_edges=True)
def test_add_prefix_inputs(self) -> None:
"""Tests prefixing graph inputs only. Relevant node edges should be renamed as well"""
self._test_add_prefix(rename_inputs=True)
def test_add_prefix_outputs(self) -> None:
"""Tests prefixing graph outputs only. Relevant node edges should be renamed as well"""
self._test_add_prefix(rename_outputs=True)
def test_add_prefix_attribute_subgraph(self) -> None:
"""Tests prefixing attribute's subgraph. Relevant subgraph should be renamed as well"""
C = helper.make_tensor_value_info("C", TensorProto.BOOL, [1])
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [None, 1])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [None, 1])
Z = helper.make_tensor_value_info("Z", TensorProto.FLOAT, [None, 1])
Out = helper.make_tensor_value_info("Out", TensorProto.FLOAT, [None, 1])
XY = helper.make_node("Mul", inputs=["X", "Y"], outputs=["XY"])
add = helper.make_node("Add", inputs=["XY", "Z"], outputs=["Out"])
sub = helper.make_node("Sub", inputs=["XY", "Z"], outputs=["Out"])
cond = helper.make_node(
"If",
inputs=["C"],
outputs=["Out"],
then_branch=helper.make_graph(
nodes=[add], name="then", inputs=[], outputs=[Out]
),
else_branch=helper.make_graph(
nodes=[sub], name="else", inputs=[], outputs=[Out]
),
)
graph = helper.make_graph(
nodes=[XY, cond], name="graph", inputs=[C, X, Y, Z], outputs=[Out]
)
prefix = "prefix."
prefixed_graph = compose.add_prefix_graph(graph, prefix)
checker.check_graph(prefixed_graph)
for n1, n0 in zip(prefixed_graph.node, graph.node):
self.assertEqual(_prefixed(prefix, n0.name), n1.name)
for attribute1, attribute0 in zip(n1.attribute, n0.attribute):
if attribute1.g:
for subgraph_n1, subgraph_n0 in zip(
attribute1.g.node, attribute0.g.node
):
for input_n1, input_n0 in zip(
subgraph_n1.input, subgraph_n0.input
):
self.assertEqual(_prefixed(prefix, input_n0), input_n1)
for output_n1, output_n0 in zip(
subgraph_n1.output, subgraph_n0.output
):
self.assertEqual(_prefixed(prefix, output_n0), output_n1)
def test_add_prefix_all(self) -> None:
"""Tests prefixing all names in the graph"""
self._test_add_prefix(True, True, True, True, True, True)
def test_add_prefix_inplace(self) -> None:
"""Tests prefixing inplace"""
self._test_add_prefix(inplace=True)
def test_expand_out_dim(self) -> None:
"""Tests expanding output dimensions. The resulting graph should have the same output names,
but with one more dimension at the specified index.
"""
m1 = _load_model(M1_DEF)
def _check_model(m1: ModelProto, m2: ModelProto, dim_idx: int) -> None:
for out_g2, out_g1 in zip(m2.graph.output, m1.graph.output):
self.assertEqual(out_g2.name, out_g1.name)
self.assertEqual(
out_g2.type.tensor_type.elem_type, out_g1.type.tensor_type.elem_type
)
expected_out_shape = _get_shape(out_g1)
expected_out_shape.insert(dim_idx, 1)
self.assertEqual(_get_shape(out_g2), expected_out_shape)
for dim_idx in [0, 2, -1, -3]:
m2 = compose.expand_out_dim(m1, dim_idx)
_check_model(m1, m2, dim_idx)
# Test inplace
m2 = ModelProto()
m2.CopyFrom(m1)
dim_idx = 0
compose.expand_out_dim(m2, dim_idx, inplace=True)
_check_model(m1, m2, dim_idx)
def _test_overlapping_names(
self,
inputs0: Sequence[str] = ("i0", "i1"),
inputs1: Sequence[str] = ("i2", "i3"),
outputs0: Sequence[str] = ("o0", "o1"),
outputs1: Sequence[str] = ("o2", "o3"),
value_info0: Sequence[str] = ("v0", "v1"),
value_info1: Sequence[str] = ("v2", "v3"),
initializer0: Sequence[str] = ("init0", "init1"),
initializer1: Sequence[str] = ("init2", "init3"),
sparse_initializer0: Sequence[str] = ("sparse_init0", "sparse_init1"),
sparse_initializer1: Sequence[str] = ("sparse_init2", "sparse_init3"),
) -> None:
n0 = [
helper.make_node("Identity", inputs=[inputs0[i]], outputs=[outputs0[i]])
for i in range(len(inputs0))
]
i0 = [
helper.make_tensor_value_info(inputs0[i], TensorProto.FLOAT, [])
for i in range(len(inputs0))
]
o0 = [
helper.make_tensor_value_info(outputs0[i], TensorProto.FLOAT, [])
for i in range(len(outputs0))
]
vi0 = [
helper.make_tensor_value_info(value_info0[i], TensorProto.FLOAT, [])
for i in range(len(value_info0))
]
init0 = [
helper.make_tensor(
name=initializer0[i], data_type=TensorProto.INT64, dims=(), vals=[1]
)
for i in range(len(initializer0))
]
sparse_init0 = [
_make_sparse_tensor(sparse_initializer0[i])
for i in range(len(sparse_initializer0))
]
n1 = [
helper.make_node("Identity", inputs=[inputs1[i]], outputs=[outputs1[i]])
for i in range(len(inputs1))
]
i1 = [
helper.make_tensor_value_info(inputs1[i], TensorProto.FLOAT, [])
for i in range(len(inputs1))
]
o1 = [
helper.make_tensor_value_info(outputs1[i], TensorProto.FLOAT, [])
for i in range(len(outputs1))
]
vi1 = [
helper.make_tensor_value_info(value_info1[i], TensorProto.FLOAT, [])
for i in range(len(value_info1))
]
init1 = [
helper.make_tensor(
name=initializer1[i], data_type=TensorProto.INT64, dims=(), vals=[1]
)
for i in range(len(initializer1))
]
sparse_init1 = [
_make_sparse_tensor(sparse_initializer1[i])
for i in range(len(sparse_initializer1))
]
ops = [helper.make_opsetid("", 10)]
m0 = helper.make_model(
helper.make_graph(
nodes=n0,
name="g0",
inputs=i0,
outputs=o0,
value_info=vi0,
initializer=init0,
sparse_initializer=sparse_init0,
),
producer_name="test",
opset_imports=ops,
)
m1 = helper.make_model(
helper.make_graph(
nodes=n1,
name="g1",
inputs=i1,
outputs=o1,
value_info=vi1,
initializer=init1,
sparse_initializer=sparse_init1,
),
producer_name="test",
opset_imports=ops,
)
overlap = compose.check_overlapping_names(m0.graph, m1.graph)
i = 0
overlapping_inputs = list(set(inputs0) & set(inputs1))
overlapping_outputs = list(set(outputs0) & set(outputs1))
overlapping_edges = list(set(overlapping_inputs + overlapping_outputs))
if overlapping_edges:
self.assertEqual(overlap[i], ("edge", overlapping_edges))
i += 1
overlapping_vis = list(set(value_info0) & set(value_info1))
if overlapping_vis:
self.assertEqual(overlap[i], ("value_info", overlapping_vis))
i += 1
overlapping_init = list(set(initializer0) & set(initializer1))
if overlapping_init:
self.assertEqual(overlap[i], ("initializer", overlapping_init))
i += 1
overlapping_sparse_init = list(
set(sparse_initializer0) & set(sparse_initializer1)
)
if overlapping_sparse_init:
expected_overlap = []
for overlapping_name in overlapping_sparse_init:
expected_overlap.append(overlapping_name + "_values")
expected_overlap.append(overlapping_name + "_idx")
self.assertEqual(overlap[i], ("sparse_initializer", expected_overlap))
i += 1
m0_new = compose.add_prefix(m0, prefix="g0/")
overlap = compose.check_overlapping_names(m0_new.graph, m1.graph)
self.assertEqual(0, len(overlap))
def test_overlapping_input_names(self) -> None:
"""Tests error checking when the name of the inputs overlaps"""
self._test_overlapping_names(inputs0=["i0", "i1"], inputs1=["i1", "i2"])
def test_overlapping_output_names(self) -> None:
"""Tests error checking when the name of the output overlaps"""
self._test_overlapping_names(outputs0=["o0", "o1"], outputs1=["o1", "o2"])
def test_overlapping_value_info_names(self) -> None:
"""Tests error checking when the name of value_info entries overlaps"""
self._test_overlapping_names(
value_info0=["vi0", "vi1"], value_info1=["vi1", "vi2"]
)
def test_overlapping_initializer_names(self) -> None:
"""Tests error checking when the name of initializer entries overlaps"""
self._test_overlapping_names(
initializer0=["init0", "init1"], initializer1=["init1", "init2"]
)
def test_overlapping_sparse_initializer_names(self) -> None:
"""Tests error checking when the name of sparse_initializer entries overlaps"""
self._test_overlapping_names(
sparse_initializer0=["sparse_init0", "sparse_init1"],
sparse_initializer1=["sparse_init1", "sparse_init2"],
)
def test_overlapping_function_names(self) -> None:
"""Tests error checking when the name of local function entries overlaps"""
ops = [helper.make_opsetid("", 10), helper.make_opsetid("local", 10)]
def _make_function(
domain: str,
fname: str,
inputs: List[str],
outputs: List[str],
nodes: List[NodeProto],
) -> FunctionProto:
f = FunctionProto()
f.domain = domain
f.name = fname
f.input.extend(inputs)
f.output.extend(outputs)
f.node.extend(nodes)
f.opset_import.extend(ops)
return f
ops = [helper.make_opsetid("", 10), helper.make_opsetid("local", 10)]
g = GraphProto()
g.input.extend(
[
helper.make_tensor_value_info("x0", TensorProto.FLOAT, []),
helper.make_tensor_value_info("x1", TensorProto.FLOAT, []),
]
)
g.output.extend(
[
helper.make_tensor_value_info("y", TensorProto.FLOAT, []),
]
)
g.node.extend(
[helper.make_node("f1", domain="local", inputs=["x0", "x1"], outputs=["y"])]
)
g1 = GraphProto()
g1.CopyFrom(g)
g1.name = "g1"
m1 = helper.make_model(g1, producer_name="test", opset_imports=ops)
m1.functions.extend(
[
_make_function(
"local",
"f1",
["x0", "x1"],
["y"],
[helper.make_node("Add", inputs=["x0", "x1"], outputs=["y"])],
)
]
)
checker.check_model(m1)
g2 = GraphProto()
g2.CopyFrom(g)
g2.name = "g2"
m2 = helper.make_model(g2, producer_name="test", opset_imports=ops)
m2.functions.extend(
[
_make_function(
"local",
"f1",
["x0", "x1"],
["y"],
[helper.make_node("Mul", inputs=["x0", "x1"], outputs=["y"])],
)
]
)
checker.check_model(m2)
m = compose.merge_models(
m1, m2, io_map=[("y", "x0"), ("y", "x1")], prefix1="m1/", prefix2="m2/"
)
checker.check_model(m)
nodes = [n.op_type for n in m.graph.node]
self.assertEqual(["m1/f1", "m2/f1"], nodes)
functions = [f.name for f in m.functions]
self.assertEqual(["m1/f1", "m2/f1"], functions)
g3 = GraphProto()
g3.CopyFrom(g)
g3.name = "g3"
g3.node[0].op_type = "f2"
m3 = helper.make_model(g3, producer_name="test", opset_imports=ops)
m3.functions.extend(
[
_make_function(
"local",
"f1",
["x0", "x1"],
["y"],
[
helper.make_node("Add", inputs=["x0", "x1"], outputs=["y0"]),
helper.make_node("Mul", inputs=["x0", "x1"], outputs=["y1"]),
helper.make_node("Add", inputs=["y0", "y1"], outputs=["y"]),
],
),
_make_function(
"local",
"f2",
["x0", "x1"],
["y"],
[
helper.make_node(
"f1", domain="local", inputs=["x0", "x1"], outputs=["y0"]
),
helper.make_node("Mul", inputs=["x0", "x1"], outputs=["y1"]),
helper.make_node("Add", inputs=["y0", "y1"], outputs=["y"]),
],
),
]
)
checker.check_model(m3)
m = compose.merge_models(
m1, m3, io_map=[("y", "x0"), ("y", "x1")], prefix1="m1/", prefix2="m3/"
)
checker.check_model(m)
nodes = [n.op_type for n in m.graph.node]
self.assertEqual(["m1/f1", "m3/f2"], nodes)
functions = [f.name for f in m.functions]
self.assertEqual(["m1/f1", "m3/f1", "m3/f2"], functions)
self.assertEqual(["Add"], [n.op_type for n in m.functions[0].node])
self.assertEqual(
["Add", "Mul", "Add"], [n.op_type for n in m.functions[1].node]
)
self.assertEqual(
["m3/f1", "Mul", "Add"], [n.op_type for n in m.functions[2].node]
)
def test_merge_drop_unnecessary_initializers_and_value_info(self) -> None:
"""Tests automatic removal of initializers when merging graphs"""
ops = [helper.make_opsetid("", 10)]
g = GraphProto()
g.input.extend([helper.make_tensor_value_info("x", TensorProto.FLOAT, [])])
g.output.extend([helper.make_tensor_value_info("y", TensorProto.FLOAT, [])])
g.node.extend([helper.make_node("Identity", inputs=["x"], outputs=["y"])])
g1 = GraphProto()
g1.CopyFrom(g)
g1.name = "g1"
m1 = helper.make_model(g1, producer_name="test", opset_imports=ops)
checker.check_model(m1)
g2 = GraphProto()
g2.CopyFrom(g)
g2.name = "g2"
g2.initializer.extend(
[
helper.make_tensor(
name="x", data_type=TensorProto.FLOAT, dims=(), vals=[0]
)
]
)
m2 = helper.make_model(g2, producer_name="test", opset_imports=ops)
checker.check_model(m2)
g3 = GraphProto()
g3.CopyFrom(g)
g3.name = "g3"
g3.sparse_initializer.extend([_make_sparse_tensor("x")])
m3 = helper.make_model(g3, producer_name="test", opset_imports=ops)
checker.check_model(m3)
g4 = GraphProto()
g4.CopyFrom(g)
g4.name = "g3"
g4.value_info.extend(
[helper.make_tensor_value_info("x", TensorProto.FLOAT, [])]
)
m4 = helper.make_model(g4, producer_name="test", opset_imports=ops)
checker.check_model(m4)
# Initializer 'x' from m1 is removed, because there is no longer an input with that name
out_m1 = compose.merge_models(m1, m2, prefix1="m1/", io_map=[("y", "x")])
self.assertEqual(0, len(out_m1.graph.initializer))
# Sparse initializer 'x' from m1 is removed, because there is no longer an input with that name
out_m2 = compose.merge_models(m1, m3, prefix1="m1/", io_map=[("y", "x")])
self.assertEqual(0, len(out_m2.graph.initializer))
# Value info 'x' from m1 is removed, because there is no longer an input with that name
out_m3 = compose.merge_models(m1, m4, prefix1="m1/", io_map=[("y", "x")])
self.assertEqual(0, len(out_m3.graph.value_info))
if __name__ == "__main__":
unittest.main()