Spaces:
Sleeping
Sleeping
File size: 4,084 Bytes
dc2106c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
import unittest
from onnx import inliner, parser
class InlinerTest(unittest.TestCase):
def test_basic(self):
model = parser.parse_model(
"""
<ir_version: 8, opset_import: [ "" : 17, "local" : 1 ]>
agraph (float[N] X) => (float[N] Y)
{
Y = local.foo (X)
}
<opset_import: [ "" : 17, "local" : 1 ], domain: "local">
foo (x) => (y) {
temp = Add(x, x)
y = local.bar(temp)
}
<opset_import: [ "" : 17 ], domain: "local">
bar (x) => (y) {
y = Mul (x, x)
}
"""
)
inlined = inliner.inline_local_functions(model)
inlined_nodes = inlined.graph.node
# function-call should be replaced by Add, followed by Mul
self.assertEqual(len(inlined_nodes), 2)
self.assertEqual(inlined_nodes[0].op_type, "Add")
self.assertEqual(inlined_nodes[1].op_type, "Mul")
def test_selective_inlining(self):
model = parser.parse_model(
"""
<ir_version: 8, opset_import: [ "" : 17, "local" : 1 ]>
agraph (float[N] X) => (float[N] Y)
{
T = local.square (X)
Y = local.double_and_square (T)
}
<opset_import: [ "" : 17, "local" : 1 ], domain: "local">
double_and_square (x) => (y) {
double = Add(x, x)
y = local.square(double)
}
<opset_import: [ "" : 17 ], domain: "local">
square (x) => (y) {
y = Mul (x, x)
}
"""
)
inlined = inliner.inline_selected_functions(
model, [("local", "square")], exclude=False
)
inlined_nodes = inlined.graph.node
# function-call to square should be replaced by Add, but not the one to double_and_square
self.assertEqual(len(inlined_nodes), 2)
self.assertEqual(inlined_nodes[0].op_type, "Mul")
self.assertEqual(inlined_nodes[1].op_type, "double_and_square")
# check call to square inside double_and_square was inlined:
function_nodes = inlined.functions[0].node
self.assertEqual(len(function_nodes), 2)
self.assertEqual(function_nodes[0].op_type, "Add")
self.assertEqual(function_nodes[1].op_type, "Mul")
def test_selective_exclusion(self):
model = parser.parse_model(
"""
<ir_version: 8, opset_import: [ "" : 17, "local" : 1 ]>
agraph (float[N] X) => (float[N] Y)
{
T = local.square (X)
Y = local.double_and_square (T)
}
<opset_import: [ "" : 17, "local" : 1 ], domain: "local">
double_and_square (x) => (y) {
double = Add(x, x)
y = local.square(double)
}
<opset_import: [ "" : 17 ], domain: "local">
square (x) => (y) {
y = Mul (x, x)
}
"""
)
inlined = inliner.inline_selected_functions(
model, [("local", "double_and_square")], exclude=True
)
inlined_nodes = inlined.graph.node
# function-call to square should be replaced by Add, but not the one to double_and_square
self.assertEqual(len(inlined_nodes), 2)
self.assertEqual(inlined_nodes[0].op_type, "Mul")
self.assertEqual(inlined_nodes[1].op_type, "double_and_square")
# check call to square inside double_and_square was inlined:
function_nodes = inlined.functions[0].node
self.assertEqual(len(function_nodes), 2)
self.assertEqual(function_nodes[0].op_type, "Add")
self.assertEqual(function_nodes[1].op_type, "Mul")
if __name__ == "__main__":
unittest.main()
|