Spaces:
Sleeping
Sleeping
# Copyright (c) ONNX Project Contributors | |
# SPDX-License-Identifier: Apache-2.0 | |
import unittest | |
from onnx import inliner, parser | |
class InlinerTest(unittest.TestCase): | |
def test_basic(self): | |
model = parser.parse_model( | |
""" | |
<ir_version: 8, opset_import: [ "" : 17, "local" : 1 ]> | |
agraph (float[N] X) => (float[N] Y) | |
{ | |
Y = local.foo (X) | |
} | |
<opset_import: [ "" : 17, "local" : 1 ], domain: "local"> | |
foo (x) => (y) { | |
temp = Add(x, x) | |
y = local.bar(temp) | |
} | |
<opset_import: [ "" : 17 ], domain: "local"> | |
bar (x) => (y) { | |
y = Mul (x, x) | |
} | |
""" | |
) | |
inlined = inliner.inline_local_functions(model) | |
inlined_nodes = inlined.graph.node | |
# function-call should be replaced by Add, followed by Mul | |
self.assertEqual(len(inlined_nodes), 2) | |
self.assertEqual(inlined_nodes[0].op_type, "Add") | |
self.assertEqual(inlined_nodes[1].op_type, "Mul") | |
def test_selective_inlining(self): | |
model = parser.parse_model( | |
""" | |
<ir_version: 8, opset_import: [ "" : 17, "local" : 1 ]> | |
agraph (float[N] X) => (float[N] Y) | |
{ | |
T = local.square (X) | |
Y = local.double_and_square (T) | |
} | |
<opset_import: [ "" : 17, "local" : 1 ], domain: "local"> | |
double_and_square (x) => (y) { | |
double = Add(x, x) | |
y = local.square(double) | |
} | |
<opset_import: [ "" : 17 ], domain: "local"> | |
square (x) => (y) { | |
y = Mul (x, x) | |
} | |
""" | |
) | |
inlined = inliner.inline_selected_functions( | |
model, [("local", "square")], exclude=False | |
) | |
inlined_nodes = inlined.graph.node | |
# function-call to square should be replaced by Add, but not the one to double_and_square | |
self.assertEqual(len(inlined_nodes), 2) | |
self.assertEqual(inlined_nodes[0].op_type, "Mul") | |
self.assertEqual(inlined_nodes[1].op_type, "double_and_square") | |
# check call to square inside double_and_square was inlined: | |
function_nodes = inlined.functions[0].node | |
self.assertEqual(len(function_nodes), 2) | |
self.assertEqual(function_nodes[0].op_type, "Add") | |
self.assertEqual(function_nodes[1].op_type, "Mul") | |
def test_selective_exclusion(self): | |
model = parser.parse_model( | |
""" | |
<ir_version: 8, opset_import: [ "" : 17, "local" : 1 ]> | |
agraph (float[N] X) => (float[N] Y) | |
{ | |
T = local.square (X) | |
Y = local.double_and_square (T) | |
} | |
<opset_import: [ "" : 17, "local" : 1 ], domain: "local"> | |
double_and_square (x) => (y) { | |
double = Add(x, x) | |
y = local.square(double) | |
} | |
<opset_import: [ "" : 17 ], domain: "local"> | |
square (x) => (y) { | |
y = Mul (x, x) | |
} | |
""" | |
) | |
inlined = inliner.inline_selected_functions( | |
model, [("local", "double_and_square")], exclude=True | |
) | |
inlined_nodes = inlined.graph.node | |
# function-call to square should be replaced by Add, but not the one to double_and_square | |
self.assertEqual(len(inlined_nodes), 2) | |
self.assertEqual(inlined_nodes[0].op_type, "Mul") | |
self.assertEqual(inlined_nodes[1].op_type, "double_and_square") | |
# check call to square inside double_and_square was inlined: | |
function_nodes = inlined.functions[0].node | |
self.assertEqual(len(function_nodes), 2) | |
self.assertEqual(function_nodes[0].op_type, "Add") | |
self.assertEqual(function_nodes[1].op_type, "Mul") | |
if __name__ == "__main__": | |
unittest.main() | |