# Copyright (c) ONNX Project Contributors # SPDX-License-Identifier: Apache-2.0 import unittest from typing import Callable, List, Optional, Sequence, Tuple import numpy as np from onnx import ( FunctionProto, GraphProto, ModelProto, NodeProto, SparseTensorProto, TensorProto, ValueInfoProto, checker, compose, helper, parser, version_converter, ) def _load_model(m_def: str) -> ModelProto: """Parses a model from a string representation, including checking the model for correctness""" m = parser.parse_model(m_def) checker.check_model(m) return m def _prefixed(prefix: str, s: str) -> str: """Prefixes a string (if not empty)""" return prefix + s if len(s) > 0 else s def _get_shape(value_info: ValueInfoProto) -> List[int]: """Returns a list of integers representing the shape of the provided ValueInfoProto""" return [ value_info.type.tensor_type.shape.dim[d].dim_value for d in range(len(value_info.type.tensor_type.shape.dim)) ] def _make_sparse_tensor(name: str) -> SparseTensorProto: dense_shape = [3, 3] linear_indices = [2, 3, 5] sparse_values = [1.7, 0.4, 0.9] values_tensor = helper.make_tensor( name=name + "_values", data_type=TensorProto.FLOAT, dims=[len(sparse_values)], vals=np.array(sparse_values).astype(np.float32), raw=False, ) indices_tensor = helper.make_tensor( name=name + "_idx", data_type=TensorProto.INT64, dims=[len(linear_indices)], vals=np.array(linear_indices).astype(np.int64), raw=False, ) return helper.make_sparse_tensor(values_tensor, indices_tensor, dense_shape) M1_DEF = """ < ir_version: 7, opset_import: [ "": 10, "com.microsoft": 1] > agraph (float[N, M] A0, float[N, M] A1, float[N, M] _A) => (float[N, M] B00, float[N, M] B10, float[N, M] B20) { B00 = Add(A0, A1) B10 = Sub(A0, A1) B20 = Mul(A0, A1) } """ M2_DEF = """ < ir_version: 7, opset_import: [ "": 10, "com.microsoft": 1] > agraph (float[N, M] B01, float[N, M] B11, float[N, M] B21) => (float[N, M] D0) { C0 = Add(B01, B11) C1 = Sub(B11, B21) M1 = Mul(C0, C1) } """ class TestComposeFunctions(unittest.TestCase): def _test_merge_models( self, m1def: str, m2def: str, io_map: List[Tuple[str, str]], check_expectations: Callable[[GraphProto, GraphProto, GraphProto], None], inputs: Optional[List[str]] = None, outputs: Optional[List[str]] = None, prefix1: Optional[str] = None, prefix2: Optional[str] = None, ) -> None: m1, m2 = _load_model(m1def), _load_model(m2def) g3 = compose.merge_graphs( m1.graph, m2.graph, io_map=io_map, inputs=inputs, outputs=outputs, prefix1=prefix1, prefix2=prefix2, ) checker.check_graph(g3) check_expectations(m1.graph, m2.graph, g3) m3 = compose.merge_models( m1, m2, io_map=io_map, inputs=inputs, outputs=outputs, prefix1=prefix1, prefix2=prefix2, ) checker.check_model(m3) check_expectations(m1.graph, m2.graph, m3.graph) def test_case_connect_all_no_name_collision(self) -> None: """Tests a simple scenario where two models without overlapping names are merged by connecting all the outputs in the first models to all the inputs in the second model """ def check_expectations(g1: GraphProto, g2: GraphProto, g3: GraphProto) -> None: self.assertEqual(g3.input, g1.input) self.assertEqual(g3.output, g2.output) self.assertEqual( ["Add", "Sub", "Mul", "Add", "Sub", "Mul"], [item.op_type for item in g3.node], ) io_map = [("B00", "B01"), ("B10", "B11"), ("B20", "B21")] self._test_merge_models(M1_DEF, M2_DEF, io_map, check_expectations) def test_case_connect_same_output_twice(self) -> None: """Tests a scenario where we merge two models by connecting a single output in the first model to all the inputs in the second """ def check_expectations(g1: GraphProto, g2: GraphProto, g3: GraphProto) -> None: del g2 # Unused self.assertEqual(g3.input, g1.input) self.assertEqual(["B10", "B20", "D0"], [elem.name for elem in g3.output]) self.assertEqual( ["Add", "Sub", "Mul", "Add", "Sub", "Mul"], [item.op_type for item in g3.node], ) io_map = [("B00", "B01"), ("B00", "B11"), ("B00", "B21")] self._test_merge_models(M1_DEF, M2_DEF, io_map, check_expectations) def test_case_connect_same_output_drop_outputs(self) -> None: """Tests a scenario where we merge two models by connecting a single output in the first model to all the inputs in the second, while dropping the rest of the outputs in the first model """ def check_expectations(g1: GraphProto, g2: GraphProto, g3: GraphProto) -> None: del g2 # Unused self.assertEqual(g3.input, g1.input) self.assertEqual(["D0"], [elem.name for elem in g3.output]) self.assertEqual( ["Add", "Add", "Sub", "Mul"], [item.op_type for item in g3.node] ) io_map = [("B00", "B01"), ("B00", "B11"), ("B00", "B21")] outputs = ["D0"] self._test_merge_models( M1_DEF, M2_DEF, io_map, check_expectations, outputs=outputs ) def test_case_connect_same_input_output_name(self) -> None: """Tests a scenario where we merge two models, where the inputs/outputs connected are named exactly the same """ m1_def = """ < ir_version: 7, opset_import: [ "": 10] > agraph (float[N, M] A) => (float[N, M] B) { B = Add(A, A) } """ m2_def = """ < ir_version: 7, opset_import: [ "": 10] > agraph (float[N, M] B) => (float[N, M] C) { C = Add(B, B) } """ io_map = [("B", "B")] def check_expectations(g1: GraphProto, g2: GraphProto, g3: GraphProto) -> None: del g1, g2 # Unused self.assertEqual(["A"], [elem.name for elem in g3.input]) self.assertEqual(["C"], [elem.name for elem in g3.output]) self._test_merge_models(m1_def, m2_def, io_map, check_expectations) def test_case_drop_inputs_outputs(self) -> None: """Tests a scenario where we merge two models, not including some of the inputs/outputs""" m1_def = """ < ir_version: 7, opset_import: [ "": 10] > agraph (float[N] A0, float[N] B0) => (float[N] A1, float[N] B1) { A1 = Add(A0, A0) B1 = Sub(B0, B0) } """ m2_def = """ < ir_version: 7, opset_import: [ "": 10] > agraph (float[N] A2, float[N] B2) => (float[N] A3, float[N] B3) { A3 = Add(A2, A2) B3 = Sub(B2, B2) } """ io_map = [("A1", "B2")] def check_expectations(g1: GraphProto, g2: GraphProto, g3: GraphProto) -> None: del g1, g2 # Unused self.assertEqual(["A0"], [elem.name for elem in g3.input]) self.assertEqual(["B3"], [elem.name for elem in g3.output]) self.assertEqual(["Add", "Sub"], [elem.op_type for elem in g3.node]) inputs = ["A0"] outputs = ["B3"] self._test_merge_models( m1_def, m2_def, io_map, check_expectations, inputs=inputs, outputs=outputs ) def test_case_name_collision_prefix(self) -> None: """Tests a scenario where we merge two models that have name collisions, but they are avoided by prefixing the models model. """ m1_def = """ < ir_version: 7, opset_import: [ "": 10] > agraph (float[N] A, float[N] B) => (float[N] C) { C = Add(A, B) } """ io_map = [("C", "A")] def check_expectations(g1: GraphProto, g2: GraphProto, g3: GraphProto) -> None: del g1, g2 # Unused self.assertEqual(["m1/A", "m1/B", "m2/B"], [elem.name for elem in g3.input]) self.assertEqual(["m2/C"], [elem.name for elem in g3.output]) self.assertEqual(["Add", "Add"], [elem.op_type for elem in g3.node]) self._test_merge_models( m1_def, m1_def, io_map, check_expectations, prefix1="m1/", prefix2="m2/" ) def test_case_connect_partially_no_name_collision(self) -> None: """Tests a scenario where two models without overlapping names are merged by connecting some outputs from the first model to some inputs in the second. The remaining inputs/outputs should be present in the combined model """ def check_expectations(g1: GraphProto, g2: GraphProto, g4: GraphProto) -> None: del g1, g2 # Unused # B20 <-> B21 not connected. They should still be present # in the inputs and outputs of the combined graph self.assertEqual( ["A0", "A1", "_A", "B21"], [elem.name for elem in g4.input] ) self.assertEqual(["B20", "D0"], [elem.name for elem in g4.output]) io_map = [("B00", "B01"), ("B10", "B11")] self._test_merge_models(M1_DEF, M2_DEF, io_map, check_expectations) def test_merge_models_with_metadata_props(self) -> None: m1 = _load_model(M1_DEF) helper.set_model_props(m1, {"p1": "v1", "p2": "v2"}) m2 = _load_model(M2_DEF) helper.set_model_props(m2, {"p3": "v3", "p4": "v4"}) io_map = [("B00", "B01")] m3 = compose.merge_models(m1, m2, io_map=io_map) assert len(m3.metadata_props) == 4 # Overlap, but same value helper.set_model_props(m2, {"p1": "v1", "p4": "v4"}) m3 = compose.merge_models(m1, m2, io_map=io_map) assert len(m3.metadata_props) == 3 # Same keys but not same value. Error helper.set_model_props(m2, {"p1": "v5", "p4": "v4"}) self.assertRaises(ValueError, compose.merge_models, m1, m2, io_map=io_map) def test_error_wrong_input_output_name(self) -> None: """Tests that providing a non existing output/input name in the io_map argument produces an error.""" m1, m2 = _load_model(M1_DEF), _load_model(M2_DEF) self.assertRaises( ValueError, compose.merge_models, m1, m2, io_map=[("wrong_outname", "B01"), ("B10", "B11"), ("B20", "B21")], ) # Wrong output name self.assertRaises( ValueError, compose.merge_models, m1, m2, io_map=[("B00", "wrong_input"), ("B10", "B11"), ("B20", "B21")], ) def test_error_ir_version_mismatch(self) -> None: m1 = _load_model( """ < ir_version: 7, opset_import: [ "": 13] > agraph (float[N, M] X0) => (float[N, M] Y0) { Y0 = Add(X0, X0) } """ ) m2 = _load_model( """ < ir_version: 6, opset_import: [ "": 13] > agraph (float[N, M] X1) => (float[N, M] Y1) { Y1 = Add(X1, X1) } """ ) # Wrong IR version name self.assertRaises( ValueError, compose.merge_models, m1, m2, io_map=[("Y0", "X1")] ) def test_error_opset_import_mismatch(self) -> None: """Tests that providing models with different operator set imported produces an error.""" m1, m2 = _load_model(M1_DEF), _load_model(M2_DEF) m1 = helper.make_model( m1.graph, producer_name="test", opset_imports=[helper.make_opsetid("", 10)] ) m2 = helper.make_model( m2.graph, producer_name="test", opset_imports=[helper.make_opsetid("", 15)] ) io_map = [("B00", "B01"), ("B10", "B11"), ("B20", "B21")] self.assertRaises(ValueError, compose.merge_models, m1, m2, io_map) # Converting to the same Operator set version, should work m1 = version_converter.convert_version(m1, 15) m3 = compose.merge_models(m1, m2, io_map=io_map) checker.check_model(m3) # FIXME: This function should be removed, as tests should not contain a copy of the tested logic. def _test_add_prefix( self, rename_nodes: bool = False, rename_edges: bool = False, rename_inputs: bool = False, rename_outputs: bool = False, rename_initializers: bool = False, rename_value_infos: bool = False, inplace: bool = False, ) -> None: m1 = _load_model(M1_DEF) prefix = "pre/" if inplace: m2 = ModelProto() m2.CopyFrom(m1) compose.add_prefix( m2, prefix, rename_nodes=rename_nodes, rename_edges=rename_edges, rename_inputs=rename_inputs, rename_outputs=rename_outputs, rename_initializers=rename_initializers, rename_value_infos=rename_value_infos, inplace=True, ) else: m2 = compose.add_prefix( m1, prefix, rename_nodes=rename_nodes, rename_edges=rename_edges, rename_inputs=rename_inputs, rename_outputs=rename_outputs, rename_initializers=rename_initializers, rename_value_infos=rename_value_infos, ) g_in = m1.graph g_out = m2.graph if ( rename_edges or rename_inputs or rename_outputs or rename_initializers or rename_value_infos ): name_mapping = {} # Rename inputs/outputs/edges. Propagate name changes from and to edges if rename_edges: for n in g_in.node: for e in n.input: name_mapping[e] = _prefixed(prefix, e) for e in n.output: name_mapping[e] = _prefixed(prefix, e) if rename_inputs: for elem in g_in.input: name_mapping[elem.name] = _prefixed(prefix, elem.name) if rename_outputs: for elem in g_in.output: name_mapping[elem.name] = _prefixed(prefix, elem.name) if rename_initializers: for init in g_in.initializer: name_mapping[init.name] = _prefixed(prefix, init.name) for sparse_init in g_in.sparse_initializer: name_mapping[sparse_init.values.name] = _prefixed( prefix, sparse_init.values.name ) name_mapping[sparse_init.indices.name] = _prefixed( prefix, sparse_init.indices.name ) if rename_value_infos: for value_info in g_in.output: name_mapping[value_info.name] = _prefixed(prefix, value_info.name) for n1, n0 in zip(g_out.node, g_in.node): for e1, e0 in zip(n1.input, n0.input): self.assertEqual(name_mapping.get(e0, e0), e1) for e1, e0 in zip(n1.output, n0.output): self.assertEqual(name_mapping.get(e0, e0), e1) for i1, i0 in zip(g_out.input, g_in.input): self.assertEqual(name_mapping.get(i0.name, i0.name), i1.name) for o1, o0 in zip(g_out.output, g_in.output): self.assertEqual(name_mapping.get(o0.name, o0.name), o1.name) for init1, init0 in zip(g_out.initializer, g_in.initializer): self.assertEqual(name_mapping.get(init0.name, init0.name), init1.name) for sparse_init1, sparse_init0 in zip( g_out.sparse_initializer, g_in.sparse_initializer ): self.assertEqual( name_mapping.get( sparse_init0.values.name, sparse_init0.values.name ), sparse_init1.values.name, ) self.assertEqual( name_mapping.get( sparse_init0.indices.name, sparse_init0.indices.name ), sparse_init1.indices.name, ) for vi1, vi0 in zip(g_out.value_info, g_in.value_info): self.assertEqual(name_mapping.get(vi0.name, vi0.name), vi1.name) if rename_nodes: for n1, n0 in zip(g_out.node, g_in.node): self.assertEqual(_prefixed(prefix, n0.name), n1.name) def test_add_prefix_nodes(self) -> None: """Tests renaming nodes only""" self._test_add_prefix(rename_nodes=True) def test_add_prefix_edges(self) -> None: """Tests prefixing nodes edges. This will also rename inputs/outputs, since the names are shared""" self._test_add_prefix(rename_edges=True) def test_add_prefix_inputs(self) -> None: """Tests prefixing graph inputs only. Relevant node edges should be renamed as well""" self._test_add_prefix(rename_inputs=True) def test_add_prefix_outputs(self) -> None: """Tests prefixing graph outputs only. Relevant node edges should be renamed as well""" self._test_add_prefix(rename_outputs=True) def test_add_prefix_attribute_subgraph(self) -> None: """Tests prefixing attribute's subgraph. Relevant subgraph should be renamed as well""" C = helper.make_tensor_value_info("C", TensorProto.BOOL, [1]) X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [None, 1]) Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [None, 1]) Z = helper.make_tensor_value_info("Z", TensorProto.FLOAT, [None, 1]) Out = helper.make_tensor_value_info("Out", TensorProto.FLOAT, [None, 1]) XY = helper.make_node("Mul", inputs=["X", "Y"], outputs=["XY"]) add = helper.make_node("Add", inputs=["XY", "Z"], outputs=["Out"]) sub = helper.make_node("Sub", inputs=["XY", "Z"], outputs=["Out"]) cond = helper.make_node( "If", inputs=["C"], outputs=["Out"], then_branch=helper.make_graph( nodes=[add], name="then", inputs=[], outputs=[Out] ), else_branch=helper.make_graph( nodes=[sub], name="else", inputs=[], outputs=[Out] ), ) graph = helper.make_graph( nodes=[XY, cond], name="graph", inputs=[C, X, Y, Z], outputs=[Out] ) prefix = "prefix." prefixed_graph = compose.add_prefix_graph(graph, prefix) checker.check_graph(prefixed_graph) for n1, n0 in zip(prefixed_graph.node, graph.node): self.assertEqual(_prefixed(prefix, n0.name), n1.name) for attribute1, attribute0 in zip(n1.attribute, n0.attribute): if attribute1.g: for subgraph_n1, subgraph_n0 in zip( attribute1.g.node, attribute0.g.node ): for input_n1, input_n0 in zip( subgraph_n1.input, subgraph_n0.input ): self.assertEqual(_prefixed(prefix, input_n0), input_n1) for output_n1, output_n0 in zip( subgraph_n1.output, subgraph_n0.output ): self.assertEqual(_prefixed(prefix, output_n0), output_n1) def test_add_prefix_all(self) -> None: """Tests prefixing all names in the graph""" self._test_add_prefix(True, True, True, True, True, True) def test_add_prefix_inplace(self) -> None: """Tests prefixing inplace""" self._test_add_prefix(inplace=True) def test_expand_out_dim(self) -> None: """Tests expanding output dimensions. The resulting graph should have the same output names, but with one more dimension at the specified index. """ m1 = _load_model(M1_DEF) def _check_model(m1: ModelProto, m2: ModelProto, dim_idx: int) -> None: for out_g2, out_g1 in zip(m2.graph.output, m1.graph.output): self.assertEqual(out_g2.name, out_g1.name) self.assertEqual( out_g2.type.tensor_type.elem_type, out_g1.type.tensor_type.elem_type ) expected_out_shape = _get_shape(out_g1) expected_out_shape.insert(dim_idx, 1) self.assertEqual(_get_shape(out_g2), expected_out_shape) for dim_idx in [0, 2, -1, -3]: m2 = compose.expand_out_dim(m1, dim_idx) _check_model(m1, m2, dim_idx) # Test inplace m2 = ModelProto() m2.CopyFrom(m1) dim_idx = 0 compose.expand_out_dim(m2, dim_idx, inplace=True) _check_model(m1, m2, dim_idx) def _test_overlapping_names( self, inputs0: Sequence[str] = ("i0", "i1"), inputs1: Sequence[str] = ("i2", "i3"), outputs0: Sequence[str] = ("o0", "o1"), outputs1: Sequence[str] = ("o2", "o3"), value_info0: Sequence[str] = ("v0", "v1"), value_info1: Sequence[str] = ("v2", "v3"), initializer0: Sequence[str] = ("init0", "init1"), initializer1: Sequence[str] = ("init2", "init3"), sparse_initializer0: Sequence[str] = ("sparse_init0", "sparse_init1"), sparse_initializer1: Sequence[str] = ("sparse_init2", "sparse_init3"), ) -> None: n0 = [ helper.make_node("Identity", inputs=[inputs0[i]], outputs=[outputs0[i]]) for i in range(len(inputs0)) ] i0 = [ helper.make_tensor_value_info(inputs0[i], TensorProto.FLOAT, []) for i in range(len(inputs0)) ] o0 = [ helper.make_tensor_value_info(outputs0[i], TensorProto.FLOAT, []) for i in range(len(outputs0)) ] vi0 = [ helper.make_tensor_value_info(value_info0[i], TensorProto.FLOAT, []) for i in range(len(value_info0)) ] init0 = [ helper.make_tensor( name=initializer0[i], data_type=TensorProto.INT64, dims=(), vals=[1] ) for i in range(len(initializer0)) ] sparse_init0 = [ _make_sparse_tensor(sparse_initializer0[i]) for i in range(len(sparse_initializer0)) ] n1 = [ helper.make_node("Identity", inputs=[inputs1[i]], outputs=[outputs1[i]]) for i in range(len(inputs1)) ] i1 = [ helper.make_tensor_value_info(inputs1[i], TensorProto.FLOAT, []) for i in range(len(inputs1)) ] o1 = [ helper.make_tensor_value_info(outputs1[i], TensorProto.FLOAT, []) for i in range(len(outputs1)) ] vi1 = [ helper.make_tensor_value_info(value_info1[i], TensorProto.FLOAT, []) for i in range(len(value_info1)) ] init1 = [ helper.make_tensor( name=initializer1[i], data_type=TensorProto.INT64, dims=(), vals=[1] ) for i in range(len(initializer1)) ] sparse_init1 = [ _make_sparse_tensor(sparse_initializer1[i]) for i in range(len(sparse_initializer1)) ] ops = [helper.make_opsetid("", 10)] m0 = helper.make_model( helper.make_graph( nodes=n0, name="g0", inputs=i0, outputs=o0, value_info=vi0, initializer=init0, sparse_initializer=sparse_init0, ), producer_name="test", opset_imports=ops, ) m1 = helper.make_model( helper.make_graph( nodes=n1, name="g1", inputs=i1, outputs=o1, value_info=vi1, initializer=init1, sparse_initializer=sparse_init1, ), producer_name="test", opset_imports=ops, ) overlap = compose.check_overlapping_names(m0.graph, m1.graph) i = 0 overlapping_inputs = list(set(inputs0) & set(inputs1)) overlapping_outputs = list(set(outputs0) & set(outputs1)) overlapping_edges = list(set(overlapping_inputs + overlapping_outputs)) if overlapping_edges: self.assertEqual(overlap[i], ("edge", overlapping_edges)) i += 1 overlapping_vis = list(set(value_info0) & set(value_info1)) if overlapping_vis: self.assertEqual(overlap[i], ("value_info", overlapping_vis)) i += 1 overlapping_init = list(set(initializer0) & set(initializer1)) if overlapping_init: self.assertEqual(overlap[i], ("initializer", overlapping_init)) i += 1 overlapping_sparse_init = list( set(sparse_initializer0) & set(sparse_initializer1) ) if overlapping_sparse_init: expected_overlap = [] for overlapping_name in overlapping_sparse_init: expected_overlap.append(overlapping_name + "_values") expected_overlap.append(overlapping_name + "_idx") self.assertEqual(overlap[i], ("sparse_initializer", expected_overlap)) i += 1 m0_new = compose.add_prefix(m0, prefix="g0/") overlap = compose.check_overlapping_names(m0_new.graph, m1.graph) self.assertEqual(0, len(overlap)) def test_overlapping_input_names(self) -> None: """Tests error checking when the name of the inputs overlaps""" self._test_overlapping_names(inputs0=["i0", "i1"], inputs1=["i1", "i2"]) def test_overlapping_output_names(self) -> None: """Tests error checking when the name of the output overlaps""" self._test_overlapping_names(outputs0=["o0", "o1"], outputs1=["o1", "o2"]) def test_overlapping_value_info_names(self) -> None: """Tests error checking when the name of value_info entries overlaps""" self._test_overlapping_names( value_info0=["vi0", "vi1"], value_info1=["vi1", "vi2"] ) def test_overlapping_initializer_names(self) -> None: """Tests error checking when the name of initializer entries overlaps""" self._test_overlapping_names( initializer0=["init0", "init1"], initializer1=["init1", "init2"] ) def test_overlapping_sparse_initializer_names(self) -> None: """Tests error checking when the name of sparse_initializer entries overlaps""" self._test_overlapping_names( sparse_initializer0=["sparse_init0", "sparse_init1"], sparse_initializer1=["sparse_init1", "sparse_init2"], ) def test_overlapping_function_names(self) -> None: """Tests error checking when the name of local function entries overlaps""" ops = [helper.make_opsetid("", 10), helper.make_opsetid("local", 10)] def _make_function( domain: str, fname: str, inputs: List[str], outputs: List[str], nodes: List[NodeProto], ) -> FunctionProto: f = FunctionProto() f.domain = domain f.name = fname f.input.extend(inputs) f.output.extend(outputs) f.node.extend(nodes) f.opset_import.extend(ops) return f ops = [helper.make_opsetid("", 10), helper.make_opsetid("local", 10)] g = GraphProto() g.input.extend( [ helper.make_tensor_value_info("x0", TensorProto.FLOAT, []), helper.make_tensor_value_info("x1", TensorProto.FLOAT, []), ] ) g.output.extend( [ helper.make_tensor_value_info("y", TensorProto.FLOAT, []), ] ) g.node.extend( [helper.make_node("f1", domain="local", inputs=["x0", "x1"], outputs=["y"])] ) g1 = GraphProto() g1.CopyFrom(g) g1.name = "g1" m1 = helper.make_model(g1, producer_name="test", opset_imports=ops) m1.functions.extend( [ _make_function( "local", "f1", ["x0", "x1"], ["y"], [helper.make_node("Add", inputs=["x0", "x1"], outputs=["y"])], ) ] ) checker.check_model(m1) g2 = GraphProto() g2.CopyFrom(g) g2.name = "g2" m2 = helper.make_model(g2, producer_name="test", opset_imports=ops) m2.functions.extend( [ _make_function( "local", "f1", ["x0", "x1"], ["y"], [helper.make_node("Mul", inputs=["x0", "x1"], outputs=["y"])], ) ] ) checker.check_model(m2) m = compose.merge_models( m1, m2, io_map=[("y", "x0"), ("y", "x1")], prefix1="m1/", prefix2="m2/" ) checker.check_model(m) nodes = [n.op_type for n in m.graph.node] self.assertEqual(["m1/f1", "m2/f1"], nodes) functions = [f.name for f in m.functions] self.assertEqual(["m1/f1", "m2/f1"], functions) g3 = GraphProto() g3.CopyFrom(g) g3.name = "g3" g3.node[0].op_type = "f2" m3 = helper.make_model(g3, producer_name="test", opset_imports=ops) m3.functions.extend( [ _make_function( "local", "f1", ["x0", "x1"], ["y"], [ helper.make_node("Add", inputs=["x0", "x1"], outputs=["y0"]), helper.make_node("Mul", inputs=["x0", "x1"], outputs=["y1"]), helper.make_node("Add", inputs=["y0", "y1"], outputs=["y"]), ], ), _make_function( "local", "f2", ["x0", "x1"], ["y"], [ helper.make_node( "f1", domain="local", inputs=["x0", "x1"], outputs=["y0"] ), helper.make_node("Mul", inputs=["x0", "x1"], outputs=["y1"]), helper.make_node("Add", inputs=["y0", "y1"], outputs=["y"]), ], ), ] ) checker.check_model(m3) m = compose.merge_models( m1, m3, io_map=[("y", "x0"), ("y", "x1")], prefix1="m1/", prefix2="m3/" ) checker.check_model(m) nodes = [n.op_type for n in m.graph.node] self.assertEqual(["m1/f1", "m3/f2"], nodes) functions = [f.name for f in m.functions] self.assertEqual(["m1/f1", "m3/f1", "m3/f2"], functions) self.assertEqual(["Add"], [n.op_type for n in m.functions[0].node]) self.assertEqual( ["Add", "Mul", "Add"], [n.op_type for n in m.functions[1].node] ) self.assertEqual( ["m3/f1", "Mul", "Add"], [n.op_type for n in m.functions[2].node] ) def test_merge_drop_unnecessary_initializers_and_value_info(self) -> None: """Tests automatic removal of initializers when merging graphs""" ops = [helper.make_opsetid("", 10)] g = GraphProto() g.input.extend([helper.make_tensor_value_info("x", TensorProto.FLOAT, [])]) g.output.extend([helper.make_tensor_value_info("y", TensorProto.FLOAT, [])]) g.node.extend([helper.make_node("Identity", inputs=["x"], outputs=["y"])]) g1 = GraphProto() g1.CopyFrom(g) g1.name = "g1" m1 = helper.make_model(g1, producer_name="test", opset_imports=ops) checker.check_model(m1) g2 = GraphProto() g2.CopyFrom(g) g2.name = "g2" g2.initializer.extend( [ helper.make_tensor( name="x", data_type=TensorProto.FLOAT, dims=(), vals=[0] ) ] ) m2 = helper.make_model(g2, producer_name="test", opset_imports=ops) checker.check_model(m2) g3 = GraphProto() g3.CopyFrom(g) g3.name = "g3" g3.sparse_initializer.extend([_make_sparse_tensor("x")]) m3 = helper.make_model(g3, producer_name="test", opset_imports=ops) checker.check_model(m3) g4 = GraphProto() g4.CopyFrom(g) g4.name = "g3" g4.value_info.extend( [helper.make_tensor_value_info("x", TensorProto.FLOAT, [])] ) m4 = helper.make_model(g4, producer_name="test", opset_imports=ops) checker.check_model(m4) # Initializer 'x' from m1 is removed, because there is no longer an input with that name out_m1 = compose.merge_models(m1, m2, prefix1="m1/", io_map=[("y", "x")]) self.assertEqual(0, len(out_m1.graph.initializer)) # Sparse initializer 'x' from m1 is removed, because there is no longer an input with that name out_m2 = compose.merge_models(m1, m3, prefix1="m1/", io_map=[("y", "x")]) self.assertEqual(0, len(out_m2.graph.initializer)) # Value info 'x' from m1 is removed, because there is no longer an input with that name out_m3 = compose.merge_models(m1, m4, prefix1="m1/", io_map=[("y", "x")]) self.assertEqual(0, len(out_m3.graph.value_info)) if __name__ == "__main__": unittest.main()