Spaces:
Running
Running
# Copyright (c) ONNX Project Contributors | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
import os | |
import tempfile | |
import unittest | |
import numpy as np | |
import numpy.testing as npt | |
import onnx | |
import onnx.helper | |
import onnx.model_container | |
import onnx.numpy_helper | |
import onnx.reference | |
def _linear_regression(): | |
X = onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [None, None]) | |
Y = onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [None]) | |
graph = onnx.helper.make_graph( | |
[ | |
onnx.helper.make_node("MatMul", ["X", "A"], ["XA"]), | |
onnx.helper.make_node("MatMul", ["XA", "B"], ["XB"]), | |
onnx.helper.make_node("MatMul", ["XB", "C"], ["Y"]), | |
], | |
"mm", | |
[X], | |
[Y], | |
[ | |
onnx.numpy_helper.from_array( | |
np.arange(9).astype(np.float32).reshape((-1, 3)), name="A" | |
), | |
onnx.numpy_helper.from_array( | |
(np.arange(9) * 100).astype(np.float32).reshape((-1, 3)), | |
name="B", | |
), | |
onnx.numpy_helper.from_array( | |
(np.arange(9) + 10).astype(np.float32).reshape((-1, 3)), | |
name="C", | |
), | |
], | |
) | |
onnx_model = onnx.helper.make_model(graph) | |
onnx.checker.check_model(onnx_model) | |
return onnx_model | |
def _large_linear_regression(): | |
X = onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [None, None]) | |
Y = onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [None]) | |
graph = onnx.helper.make_graph( | |
[ | |
onnx.helper.make_node("MatMul", ["X", "A"], ["XA"]), | |
onnx.helper.make_node("MatMul", ["XA", "B"], ["XB"]), | |
onnx.helper.make_node("MatMul", ["XB", "C"], ["Y"]), | |
], | |
"mm", | |
[X], | |
[Y], | |
[ | |
onnx.model_container.make_large_tensor_proto( | |
"#loc0", "A", onnx.TensorProto.FLOAT, (3, 3) | |
), | |
onnx.numpy_helper.from_array( | |
np.arange(9).astype(np.float32).reshape((-1, 3)), name="B" | |
), | |
onnx.model_container.make_large_tensor_proto( | |
"#loc1", "C", onnx.TensorProto.FLOAT, (3, 3) | |
), | |
], | |
) | |
onnx_model = onnx.helper.make_model(graph) | |
large_model = onnx.model_container.make_large_model( | |
onnx_model.graph, | |
{ | |
"#loc0": (np.arange(9) * 100).astype(np.float32).reshape((-1, 3)), | |
"#loc1": (np.arange(9) + 10).astype(np.float32).reshape((-1, 3)), | |
}, | |
) | |
large_model.check_model() | |
return large_model | |
class TestLargeOnnxReferenceEvaluator(unittest.TestCase): | |
def common_check_reference_evaluator(self, container): | |
X = np.arange(9).astype(np.float32).reshape((-1, 3)) | |
ref = onnx.reference.ReferenceEvaluator(container) | |
got = ref.run(None, {"X": X}) | |
expected = np.array( | |
[ | |
[945000, 1015200, 1085400], | |
[2905200, 3121200, 3337200], | |
[4865400, 5227200, 5589000], | |
], | |
dtype=np.float32, | |
) | |
npt.assert_allclose(expected, got[0]) | |
def test_large_onnx_no_large_initializer(self): | |
model_proto = _linear_regression() | |
large_model = onnx.model_container.make_large_model(model_proto.graph) | |
self.common_check_reference_evaluator(large_model) | |
with self.assertRaises(ValueError): | |
large_model["#anymissingkey"] | |
with tempfile.TemporaryDirectory() as temp: | |
filename = os.path.join(temp, "model.onnx") | |
large_model.save(filename) | |
copy = onnx.model_container.ModelContainer() | |
copy.load(filename) | |
self.common_check_reference_evaluator(copy) | |
def test_large_one_weight_file(self): | |
large_model = _large_linear_regression() | |
self.common_check_reference_evaluator(large_model) | |
with tempfile.TemporaryDirectory() as temp: | |
filename = os.path.join(temp, "model.onnx") | |
large_model.save(filename, True) | |
copy = onnx.model_container.ModelContainer() | |
copy.load(filename) | |
loaded_model = onnx.load_model(filename, load_external_data=True) | |
self.common_check_reference_evaluator(loaded_model) | |
def test_large_multi_files(self): | |
large_model = _large_linear_regression() | |
self.common_check_reference_evaluator(large_model) | |
with tempfile.TemporaryDirectory() as temp: | |
filename = os.path.join(temp, "model.onnx") | |
large_model.save(filename, False) | |
copy = onnx.load_model(filename) | |
self.common_check_reference_evaluator(copy) | |
loaded_model = onnx.load_model(filename, load_external_data=True) | |
self.common_check_reference_evaluator(loaded_model) | |
if __name__ == "__main__": | |
unittest.main(verbosity=2) | |