# Copyright (c) ONNX Project Contributors # SPDX-License-Identifier: Apache-2.0 import unittest import numpy as np from numpy.testing import assert_allclose import onnx from onnx import TensorProto, helper, numpy_helper from onnx.defs import onnx_opset_version from onnx.reference import ReferenceEvaluator from onnx.tools import update_model_dims from onnx.tools.replace_constants import replace_initializer_by_constant_of_shape class TestToolsFunctions(unittest.TestCase): def test_update_inputs_outputs_dim(self) -> None: node_def = helper.make_node( "Conv", inputs=["x", "W"], outputs=["y"], kernel_shape=[3, 3], strides=[2, 2], ) graph_def = helper.make_graph( [node_def], "test", [ helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 1, 5, 5]), helper.make_tensor_value_info("W", TensorProto.FLOAT, [1, 1, 3, 3]), ], [helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 1, 2, 2])], ) model_def = helper.make_model(graph_def, producer_name="test") updated_def = update_model_dims.update_inputs_outputs_dims( model_def, { "x": [1, 1, "x1", -1], "W": [1, 1, 3, 3], }, { "y": [1, 1, -1, -1], }, ) onnx.checker.check_model(updated_def) self.assertEqual( updated_def.graph.input[0].type.tensor_type.shape.dim[2].dim_param, "x1" ) self.assertEqual( updated_def.graph.input[0].type.tensor_type.shape.dim[3].dim_param, "x_3" ) self.assertEqual( updated_def.graph.output[0].type.tensor_type.shape.dim[2].dim_param, "y_2" ) self.assertEqual( updated_def.graph.output[0].type.tensor_type.shape.dim[3].dim_param, "y_3" ) def test_replace_initializer(self): dtype = np.float32 value = np.random.randn(2, 100).astype(dtype) A = numpy_helper.from_array(value, name="A") value = np.array([1], dtype=dtype) C = numpy_helper.from_array(value, name="C") X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [None, None]) Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [None]) node1 = helper.make_node("MatMul", ["X", "A"], ["AX"]) node2 = helper.make_node("Sub", ["AX", "C"], ["Y"]) graph = helper.make_graph([node1, node2], "lr", [X], [Y], [A, C]) model_def = helper.make_model(graph) x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2)) oinf1 = ReferenceEvaluator(model_def) y1 = oinf1.run(None, {"X": x})[0] repl = replace_initializer_by_constant_of_shape(model_def) node_types = {n.op_type for n in repl.graph.node} self.assertIn("ConstantOfShape", node_types) oinf2 = ReferenceEvaluator(repl) y1[:, :] = 3.5 y1[0, :] = 0.5 y2 = oinf2.run(None, {"X": x})[0] assert_allclose(y1, y2) def test_replace_constant(self): dtype = np.float32 value = np.random.randn(2, 100).astype(dtype) A = numpy_helper.from_array(value, name="A") value = np.array([1], dtype=dtype) C = numpy_helper.from_array(value, name="C") X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [None, None]) Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [None]) node0 = helper.make_node("Constant", [], ["A"], value=A) node1 = helper.make_node("MatMul", ["X", "A"], ["AX"]) node2 = helper.make_node("Sub", ["AX", "C"], ["Y"]) graph = helper.make_graph([node0, node1, node2], "lr", [X], [Y], [C]) model_def = helper.make_model(graph) x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2)) oinf1 = ReferenceEvaluator(model_def) y1 = oinf1.run(None, {"X": x})[0] repl = replace_initializer_by_constant_of_shape(model_def) node_types = {n.op_type for n in repl.graph.node} self.assertIn("ConstantOfShape", node_types) oinf2 = ReferenceEvaluator(repl) y1[:, :] = 3.5 y1[0, :] = 0.5 y2 = oinf2.run(None, {"X": x})[0] assert_allclose(y1, y2) def test_replace_range(self): dtype = np.float32 value = np.random.randn(2, 100).astype(dtype) A = numpy_helper.from_array(value, name="A") value = np.array([1], dtype=dtype) C = numpy_helper.from_array(value, name="C") X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [None, None]) Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [None]) node0 = helper.make_node("Constant", [], ["A"], value=A) node1 = helper.make_node("MatMul", ["X", "A"], ["AX"]) node2 = helper.make_node("Sub", ["AX", "C"], ["Y"]) graph = helper.make_graph([node0, node1, node2], "lr", [X], [Y], [C]) model_def = helper.make_model(graph) x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2)) oinf1 = ReferenceEvaluator(model_def) y1 = oinf1.run(None, {"X": x})[0] repl = replace_initializer_by_constant_of_shape(model_def, use_range=True) node_types = {n.op_type for n in repl.graph.node} self.assertIn("Range", node_types) self.assertNotIn("ConstantOfShape", node_types) oinf2 = ReferenceEvaluator(repl) y2 = oinf2.run(None, {"X": x})[0] assert_allclose(y1.shape, y2.shape) def test_replace_constant_function(self): dtype = np.float32 value = np.random.randn(2, 100).astype(dtype) A = numpy_helper.from_array(value, name="A") value = np.array([1], dtype=dtype) C = numpy_helper.from_array(value, name="C") X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [None, None]) Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [None]) nodeC = helper.make_node("Constant", [], ["C"], value=C) node0 = helper.make_node("Constant", [], ["A"], value=A) node1 = helper.make_node("MatMul", ["X", "A"], ["AX"]) node2 = helper.make_node("Sub", ["AX", "C"], ["Y"]) opset_imports = [ helper.make_opsetid("", onnx_opset_version()), helper.make_opsetid("custom", 1), ] fct = helper.make_function( "custom", "unittest", ["X"], ["Y"], [nodeC, node0, node1, node2], opset_imports, ) node = helper.make_node("unittest", ["X"], ["Y"], domain="custom") graph = helper.make_graph([node], "lr", [X], [Y], [C]) model_def = helper.make_model( graph, functions=[fct], opset_imports=opset_imports ) x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2)) oinf1 = ReferenceEvaluator(model_def) y1 = oinf1.run(None, {"X": x})[0] repl = replace_initializer_by_constant_of_shape(model_def) node_types = {n.op_type for n in repl.functions[0].node} self.assertIn("ConstantOfShape", node_types) oinf2 = ReferenceEvaluator(repl) y1[:, :] = 3.5 y1[0, :] = 0.5 y2 = oinf2.run(None, {"X": x})[0] assert_allclose(y1, y2) def test_replace_range_function(self): dtype = np.float32 value = np.random.randn(2, 100).astype(dtype) A = numpy_helper.from_array(value, name="A") value = np.array([1], dtype=dtype) C = numpy_helper.from_array(value, name="C") X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [None, None]) Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [None]) nodeC = helper.make_node("Constant", [], ["C"], value=C) node0 = helper.make_node("Constant", [], ["A"], value=A) node1 = helper.make_node("MatMul", ["X", "A"], ["AX"]) node2 = helper.make_node("Sub", ["AX", "C"], ["Y"]) opset_imports = [ helper.make_opsetid("", onnx_opset_version()), helper.make_opsetid("custom", 1), ] fct = helper.make_function( "custom", "unittest", ["X"], ["Y"], [nodeC, node0, node1, node2], opset_imports, ) node = helper.make_node("unittest", ["X"], ["Y"], domain="custom") graph = helper.make_graph([node], "lr", [X], [Y], [C]) model_def = helper.make_model( graph, functions=[fct], opset_imports=opset_imports ) x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2)) oinf1 = ReferenceEvaluator(model_def) y1 = oinf1.run(None, {"X": x})[0] repl = replace_initializer_by_constant_of_shape(model_def, use_range=True) node_types = {n.op_type for n in repl.functions[0].node} self.assertIn("Range", node_types) self.assertNotIn("ConstantOfShape", node_types) oinf2 = ReferenceEvaluator(repl) y2 = oinf2.run(None, {"X": x})[0] assert_allclose(y1.shape, y2.shape) def test_replace_constant_graph(self): value = np.array([0], dtype=np.float32) zero = numpy_helper.from_array(value, name="zero") X = helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [None, None]) Y = helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [None]) rsum = helper.make_node("ReduceSum", ["X"], ["rsum"]) cond = helper.make_node("Greater", ["rsum", "zero"], ["cond"]) then_out = helper.make_tensor_value_info( "then_out", onnx.TensorProto.FLOAT, None ) then_cst = numpy_helper.from_array(np.array([1] * 129).astype(np.float32)) then_const_node = helper.make_node( "Constant", inputs=[], outputs=["then_out"], value=then_cst, name="cst1" ) then_body = helper.make_graph([then_const_node], "then_body", [], [then_out]) else_out = helper.make_tensor_value_info( "else_out", onnx.TensorProto.FLOAT, None ) else_cst = numpy_helper.from_array(np.array([-1] * 129).astype(np.float32)) else_const_node = helper.make_node( "Constant", inputs=[], outputs=["else_out"], value=else_cst, name="cst2" ) else_body = helper.make_graph([else_const_node], "else_body", [], [else_out]) if_node = onnx.helper.make_node( "If", ["cond"], ["Y"], then_branch=then_body, else_branch=else_body ) graph = helper.make_graph([rsum, cond, if_node], "if", [X], [Y], [zero]) onnx_model = helper.make_model( graph, opset_imports=[helper.make_opsetid("", onnx_opset_version())] ) self.assertNotIn("ConstantOfShape", str(onnx_model)) x = np.ones((3, 2), dtype=np.float32) oinf1 = ReferenceEvaluator(onnx_model) y1 = oinf1.run(None, {"X": x})[0] repl = replace_initializer_by_constant_of_shape(onnx_model) self.assertIn("ConstantOfShape", str(repl)) oinf2 = ReferenceEvaluator(repl) y2 = oinf2.run(None, {"X": x})[0] y1 = y1.copy() y1[:] = 0.5 assert_allclose(y1, y2) def test_replace_range_graph(self): value = np.array([0], dtype=np.float32) zero = numpy_helper.from_array(value, name="zero") X = helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [None, None]) Y = helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [None]) rsum = helper.make_node("ReduceSum", ["X"], ["rsum"]) cond = helper.make_node("Greater", ["rsum", "zero"], ["cond"]) then_out = helper.make_tensor_value_info( "then_out", onnx.TensorProto.FLOAT, None ) then_cst = numpy_helper.from_array(np.array([1] * 129).astype(np.float32)) then_const_node = helper.make_node( "Constant", inputs=[], outputs=["then_out"], value=then_cst, name="cst1" ) then_body = helper.make_graph([then_const_node], "then_body", [], [then_out]) else_out = helper.make_tensor_value_info( "else_out", onnx.TensorProto.FLOAT, None ) else_cst = numpy_helper.from_array(np.array([-1] * 129).astype(np.float32)) else_const_node = helper.make_node( "Constant", inputs=[], outputs=["else_out"], value=else_cst, name="cst2" ) else_body = helper.make_graph([else_const_node], "else_body", [], [else_out]) if_node = onnx.helper.make_node( "If", ["cond"], ["Y"], then_branch=then_body, else_branch=else_body ) graph = helper.make_graph([rsum, cond, if_node], "if", [X], [Y], [zero]) onnx_model = helper.make_model( graph, opset_imports=[helper.make_opsetid("", onnx_opset_version())] ) self.assertNotIn("ConstantOfShape", str(onnx_model)) x = np.ones((3, 2), dtype=np.float32) oinf1 = ReferenceEvaluator(onnx_model) y1 = oinf1.run(None, {"X": x})[0] repl = replace_initializer_by_constant_of_shape(onnx_model, use_range=True) self.assertNotIn("ConstantOfShape", str(repl)) self.assertIn("Range", str(repl)) oinf2 = ReferenceEvaluator(repl) y2 = oinf2.run(None, {"X": x})[0] assert_allclose(y1.shape, y2.shape) if __name__ == "__main__": unittest.main(verbosity=2)