Spaces:
Sleeping
Sleeping
File size: 5,284 Bytes
dc2106c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
import os
import tempfile
import unittest
import numpy as np
import onnx
import onnx.external_data_helper as ext_data
import onnx.helper
import onnx.model_container
import onnx.numpy_helper
def _linear_regression():
X = onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [None, None])
Y = onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [None])
graph = onnx.helper.make_graph(
[
onnx.helper.make_node("MatMul", ["X", "A"], ["XA"]),
onnx.helper.make_node("MatMul", ["XA", "B"], ["XB"]),
onnx.helper.make_node("MatMul", ["XB", "C"], ["Y"]),
],
"mm",
[X],
[Y],
[
onnx.numpy_helper.from_array(
np.arange(9).astype(np.float32).reshape((-1, 3)), name="A"
),
onnx.numpy_helper.from_array(
(np.arange(9) * 10).astype(np.float32).reshape((-1, 3)),
name="B",
),
onnx.numpy_helper.from_array(
(np.arange(9) * 10).astype(np.float32).reshape((-1, 3)),
name="C",
),
],
)
onnx_model = onnx.helper.make_model(graph)
onnx.checker.check_model(onnx_model)
return onnx_model
def _large_linear_regression():
X = onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [None, None])
Y = onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [None])
graph = onnx.helper.make_graph(
[
onnx.helper.make_node("MatMul", ["X", "A"], ["XA"]),
onnx.helper.make_node("MatMul", ["XA", "B"], ["XB"]),
onnx.helper.make_node("MatMul", ["XB", "C"], ["Y"]),
],
"mm",
[X],
[Y],
[
onnx.model_container.make_large_tensor_proto(
"#loc0", "A", onnx.TensorProto.FLOAT, (3, 3)
),
onnx.numpy_helper.from_array(
np.arange(9).astype(np.float32).reshape((-1, 3)), name="B"
),
onnx.model_container.make_large_tensor_proto(
"#loc1", "C", onnx.TensorProto.FLOAT, (3, 3)
),
],
)
onnx_model = onnx.helper.make_model(graph)
large_model = onnx.model_container.make_large_model(
onnx_model.graph,
{
"#loc0": (np.arange(9) * 100).astype(np.float32).reshape((-1, 3)),
"#loc1": (np.arange(9) + 10).astype(np.float32).reshape((-1, 3)),
},
)
large_model.check_model()
return large_model
class TestLargeOnnx(unittest.TestCase):
def test_large_onnx_no_large_initializer(self):
model_proto = _linear_regression()
assert isinstance(model_proto, onnx.ModelProto)
large_model = onnx.model_container.make_large_model(model_proto.graph)
assert isinstance(large_model, onnx.model_container.ModelContainer)
with tempfile.TemporaryDirectory() as temp:
filename = os.path.join(temp, "model.onnx")
large_model.save(filename)
copy = onnx.model_container.ModelContainer()
with self.assertRaises(RuntimeError):
assert copy.model_proto
copy.load(filename)
assert copy.model_proto is not None
onnx.checker.check_model(copy.model_proto)
def test_large_one_weight_file(self):
large_model = _large_linear_regression()
assert isinstance(large_model, onnx.model_container.ModelContainer)
with tempfile.TemporaryDirectory() as temp:
filename = os.path.join(temp, "model.onnx")
saved_proto = large_model.save(filename, True)
assert isinstance(saved_proto, onnx.ModelProto)
copy = onnx.model_container.ModelContainer()
copy.load(filename)
copy.check_model()
loaded_model = onnx.load_model(filename, load_external_data=True)
onnx.checker.check_model(loaded_model)
def test_large_multi_files(self):
large_model = _large_linear_regression()
assert isinstance(large_model, onnx.model_container.ModelContainer)
with tempfile.TemporaryDirectory() as temp:
filename = os.path.join(temp, "model.onnx")
saved_proto = large_model.save(filename, False)
assert isinstance(saved_proto, onnx.ModelProto)
copy = onnx.load_model(filename)
onnx.checker.check_model(copy)
for tensor in ext_data._get_all_tensors(copy):
if ext_data.uses_external_data(tensor):
tested = 0
for ext in tensor.external_data:
if ext.key == "location": # type: ignore[attr-defined]
assert os.path.exists(ext.value)
tested += 1
self.assertEqual(tested, 1)
loaded_model = onnx.load_model(filename, load_external_data=True)
onnx.checker.check_model(loaded_model)
if __name__ == "__main__":
unittest.main(verbosity=2)
|