Spaces:
Sleeping
Sleeping
# Copyright (c) ONNX Project Contributors | |
# SPDX-License-Identifier: Apache-2.0 | |
import io | |
import os | |
import pathlib | |
import tempfile | |
import unittest | |
import google.protobuf.message | |
import google.protobuf.text_format | |
import parameterized | |
import onnx | |
from onnx import serialization | |
def _simple_model() -> onnx.ModelProto: | |
model = onnx.ModelProto() | |
model.ir_version = onnx.IR_VERSION | |
model.producer_name = "onnx-test" | |
model.graph.name = "test" | |
return model | |
def _simple_tensor() -> onnx.TensorProto: | |
tensor = onnx.helper.make_tensor( | |
name="test-tensor", | |
data_type=onnx.TensorProto.FLOAT, | |
dims=(2, 3, 4), | |
vals=[x + 0.5 for x in range(24)], | |
) | |
return tensor | |
class TestIO(unittest.TestCase): | |
format: str | |
def test_load_model_when_input_is_bytes(self) -> None: | |
proto = _simple_model() | |
proto_string = serialization.registry.get(self.format).serialize_proto(proto) | |
loaded_proto = onnx.load_model_from_string(proto_string, format=self.format) | |
self.assertEqual(proto, loaded_proto) | |
def test_save_and_load_model_when_input_has_read_function(self) -> None: | |
proto = _simple_model() | |
# When the proto is a bytes representation provided to `save_model`, | |
# it should always be a serialized binary protobuf representation. Aka. format="protobuf" | |
# The saved file format is specified by the `format` argument. | |
proto_string = serialization.registry.get("protobuf").serialize_proto(proto) | |
f = io.BytesIO() | |
onnx.save_model(proto_string, f, format=self.format) | |
loaded_proto = onnx.load_model(io.BytesIO(f.getvalue()), format=self.format) | |
self.assertEqual(proto, loaded_proto) | |
def test_save_and_load_model_when_input_is_file_name(self) -> None: | |
proto = _simple_model() | |
with tempfile.TemporaryDirectory() as temp_dir: | |
model_path = os.path.join(temp_dir, "model.onnx") | |
onnx.save_model(proto, model_path, format=self.format) | |
loaded_proto = onnx.load_model(model_path, format=self.format) | |
self.assertEqual(proto, loaded_proto) | |
def test_save_and_load_model_when_input_is_pathlike(self) -> None: | |
proto = _simple_model() | |
with tempfile.TemporaryDirectory() as temp_dir: | |
model_path = pathlib.Path(temp_dir, "model.onnx") | |
onnx.save_model(proto, model_path, format=self.format) | |
loaded_proto = onnx.load_model(model_path, format=self.format) | |
self.assertEqual(proto, loaded_proto) | |
class TestIOTensor(unittest.TestCase): | |
"""Test loading and saving of TensorProto.""" | |
format: str | |
def test_load_tensor_when_input_is_bytes(self) -> None: | |
proto = _simple_tensor() | |
proto_string = serialization.registry.get(self.format).serialize_proto(proto) | |
loaded_proto = onnx.load_tensor_from_string(proto_string, format=self.format) | |
self.assertEqual(proto, loaded_proto) | |
def test_save_and_load_tensor_when_input_has_read_function(self) -> None: | |
# Test if input has a read function | |
proto = _simple_tensor() | |
f = io.BytesIO() | |
onnx.save_tensor(proto, f, format=self.format) | |
loaded_proto = onnx.load_tensor(io.BytesIO(f.getvalue()), format=self.format) | |
self.assertEqual(proto, loaded_proto) | |
def test_save_and_load_tensor_when_input_is_file_name(self) -> None: | |
# Test if input is a file name | |
proto = _simple_tensor() | |
with tempfile.TemporaryDirectory() as temp_dir: | |
model_path = os.path.join(temp_dir, "model.onnx") | |
onnx.save_tensor(proto, model_path, format=self.format) | |
loaded_proto = onnx.load_tensor(model_path, format=self.format) | |
self.assertEqual(proto, loaded_proto) | |
def test_save_and_load_tensor_when_input_is_pathlike(self) -> None: | |
# Test if input is a file name | |
proto = _simple_tensor() | |
with tempfile.TemporaryDirectory() as temp_dir: | |
model_path = pathlib.Path(temp_dir, "model.onnx") | |
onnx.save_tensor(proto, model_path, format=self.format) | |
loaded_proto = onnx.load_tensor(model_path, format=self.format) | |
self.assertEqual(proto, loaded_proto) | |
class TestSaveAndLoadFileExtensions(unittest.TestCase): | |
def test_save_model_picks_correct_format_from_extension(self) -> None: | |
proto = _simple_model() | |
with tempfile.TemporaryDirectory() as temp_dir: | |
model_path = os.path.join(temp_dir, "model.textproto") | |
# No format is specified, so the extension should be used to determine the format | |
onnx.save_model(proto, model_path) | |
loaded_proto = onnx.load_model(model_path, format="textproto") | |
self.assertEqual(proto, loaded_proto) | |
def test_load_model_picks_correct_format_from_extension(self) -> None: | |
proto = _simple_model() | |
with tempfile.TemporaryDirectory() as temp_dir: | |
model_path = os.path.join(temp_dir, "model.textproto") | |
onnx.save_model(proto, model_path, format="textproto") | |
# No format is specified, so the extension should be used to determine the format | |
loaded_proto = onnx.load_model(model_path) | |
self.assertEqual(proto, loaded_proto) | |
def test_save_model_uses_format_when_it_is_specified(self) -> None: | |
proto = _simple_model() | |
with tempfile.TemporaryDirectory() as temp_dir: | |
model_path = os.path.join(temp_dir, "model.textproto") | |
# `format` is specified. It should take precedence over the extension | |
onnx.save_model(proto, model_path, format="protobuf") | |
loaded_proto = onnx.load_model(model_path, format="protobuf") | |
self.assertEqual(proto, loaded_proto) | |
with self.assertRaises(google.protobuf.text_format.ParseError): | |
# Loading it as textproto (by file extension) should fail | |
onnx.load_model(model_path) | |
def test_load_model_uses_format_when_it_is_specified(self) -> None: | |
proto = _simple_model() | |
with tempfile.TemporaryDirectory() as temp_dir: | |
model_path = os.path.join(temp_dir, "model.protobuf") | |
onnx.save_model(proto, model_path) | |
with self.assertRaises(google.protobuf.text_format.ParseError): | |
# `format` is specified. It should take precedence over the extension | |
# Loading it as textproto should fail | |
onnx.load_model(model_path, format="textproto") | |
loaded_proto = onnx.load_model(model_path, format="protobuf") | |
self.assertEqual(proto, loaded_proto) | |
def test_load_and_save_model_to_path_without_specifying_extension_succeeds( | |
self, | |
) -> None: | |
proto = _simple_model() | |
with tempfile.TemporaryDirectory() as temp_dir: | |
# No extension is specified | |
model_path = os.path.join(temp_dir, "model") | |
onnx.save_model(proto, model_path, format="textproto") | |
with self.assertRaises(google.protobuf.message.DecodeError): | |
# `format` is not specified. load_model should assume protobuf | |
# and fail to load it | |
onnx.load_model(model_path) | |
loaded_proto = onnx.load_model(model_path, format="textproto") | |
self.assertEqual(proto, loaded_proto) | |
def test_load_and_save_model_without_specifying_extension_or_format_defaults_to_protobuf( | |
self, | |
) -> None: | |
proto = _simple_model() | |
with tempfile.TemporaryDirectory() as temp_dir: | |
# No extension is specified | |
model_path = os.path.join(temp_dir, "model") | |
onnx.save_model(proto, model_path) | |
with self.assertRaises(google.protobuf.text_format.ParseError): | |
# The model is saved as protobuf, so loading it as textproto should fail | |
onnx.load_model(model_path, format="textproto") | |
loaded_proto = onnx.load_model(model_path) | |
self.assertEqual(proto, loaded_proto) | |
loaded_proto_as_explicitly_protobuf = onnx.load_model( | |
model_path, format="protobuf" | |
) | |
self.assertEqual(proto, loaded_proto_as_explicitly_protobuf) | |
class TestBasicFunctions(unittest.TestCase): | |
def test_protos_exist(self) -> None: | |
# The proto classes should exist | |
_ = onnx.AttributeProto | |
_ = onnx.NodeProto | |
_ = onnx.GraphProto | |
_ = onnx.ModelProto | |
def test_version_exists(self) -> None: | |
model = onnx.ModelProto() | |
# When we create it, graph should not have a version string. | |
self.assertFalse(model.HasField("ir_version")) | |
# We should touch the version so it is annotated with the current | |
# ir version of the running ONNX | |
model.ir_version = onnx.IR_VERSION | |
model_string = model.SerializeToString() | |
model.ParseFromString(model_string) | |
self.assertTrue(model.HasField("ir_version")) | |
# Check if the version is correct. | |
self.assertEqual(model.ir_version, onnx.IR_VERSION) | |
if __name__ == "__main__": | |
unittest.main() | |