# Copyright (c) ONNX Project Contributors # SPDX-License-Identifier: Apache-2.0 import unittest from onnx import inliner, parser class InlinerTest(unittest.TestCase): def test_basic(self): model = parser.parse_model( """ agraph (float[N] X) => (float[N] Y) { Y = local.foo (X) } foo (x) => (y) { temp = Add(x, x) y = local.bar(temp) } bar (x) => (y) { y = Mul (x, x) } """ ) inlined = inliner.inline_local_functions(model) inlined_nodes = inlined.graph.node # function-call should be replaced by Add, followed by Mul self.assertEqual(len(inlined_nodes), 2) self.assertEqual(inlined_nodes[0].op_type, "Add") self.assertEqual(inlined_nodes[1].op_type, "Mul") def test_selective_inlining(self): model = parser.parse_model( """ agraph (float[N] X) => (float[N] Y) { T = local.square (X) Y = local.double_and_square (T) } double_and_square (x) => (y) { double = Add(x, x) y = local.square(double) } square (x) => (y) { y = Mul (x, x) } """ ) inlined = inliner.inline_selected_functions( model, [("local", "square")], exclude=False ) inlined_nodes = inlined.graph.node # function-call to square should be replaced by Add, but not the one to double_and_square self.assertEqual(len(inlined_nodes), 2) self.assertEqual(inlined_nodes[0].op_type, "Mul") self.assertEqual(inlined_nodes[1].op_type, "double_and_square") # check call to square inside double_and_square was inlined: function_nodes = inlined.functions[0].node self.assertEqual(len(function_nodes), 2) self.assertEqual(function_nodes[0].op_type, "Add") self.assertEqual(function_nodes[1].op_type, "Mul") def test_selective_exclusion(self): model = parser.parse_model( """ agraph (float[N] X) => (float[N] Y) { T = local.square (X) Y = local.double_and_square (T) } double_and_square (x) => (y) { double = Add(x, x) y = local.square(double) } square (x) => (y) { y = Mul (x, x) } """ ) inlined = inliner.inline_selected_functions( model, [("local", "double_and_square")], exclude=True ) inlined_nodes = inlined.graph.node # function-call to square should be replaced by Add, but not the one to double_and_square self.assertEqual(len(inlined_nodes), 2) self.assertEqual(inlined_nodes[0].op_type, "Mul") self.assertEqual(inlined_nodes[1].op_type, "double_and_square") # check call to square inside double_and_square was inlined: function_nodes = inlined.functions[0].node self.assertEqual(len(function_nodes), 2) self.assertEqual(function_nodes[0].op_type, "Add") self.assertEqual(function_nodes[1].op_type, "Mul") if __name__ == "__main__": unittest.main()