Kano001's picture
Upload 2707 files
dc2106c verified
raw
history blame
9.83 kB
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
import io
import os
import pathlib
import tempfile
import unittest
import google.protobuf.message
import google.protobuf.text_format
import parameterized
import onnx
from onnx import serialization
def _simple_model() -> onnx.ModelProto:
model = onnx.ModelProto()
model.ir_version = onnx.IR_VERSION
model.producer_name = "onnx-test"
model.graph.name = "test"
return model
def _simple_tensor() -> onnx.TensorProto:
tensor = onnx.helper.make_tensor(
name="test-tensor",
data_type=onnx.TensorProto.FLOAT,
dims=(2, 3, 4),
vals=[x + 0.5 for x in range(24)],
)
return tensor
@parameterized.parameterized_class(
[
{"format": "protobuf"},
{"format": "textproto"},
{"format": "json"},
{"format": "onnxtxt"},
]
)
class TestIO(unittest.TestCase):
format: str
def test_load_model_when_input_is_bytes(self) -> None:
proto = _simple_model()
proto_string = serialization.registry.get(self.format).serialize_proto(proto)
loaded_proto = onnx.load_model_from_string(proto_string, format=self.format)
self.assertEqual(proto, loaded_proto)
def test_save_and_load_model_when_input_has_read_function(self) -> None:
proto = _simple_model()
# When the proto is a bytes representation provided to `save_model`,
# it should always be a serialized binary protobuf representation. Aka. format="protobuf"
# The saved file format is specified by the `format` argument.
proto_string = serialization.registry.get("protobuf").serialize_proto(proto)
f = io.BytesIO()
onnx.save_model(proto_string, f, format=self.format)
loaded_proto = onnx.load_model(io.BytesIO(f.getvalue()), format=self.format)
self.assertEqual(proto, loaded_proto)
def test_save_and_load_model_when_input_is_file_name(self) -> None:
proto = _simple_model()
with tempfile.TemporaryDirectory() as temp_dir:
model_path = os.path.join(temp_dir, "model.onnx")
onnx.save_model(proto, model_path, format=self.format)
loaded_proto = onnx.load_model(model_path, format=self.format)
self.assertEqual(proto, loaded_proto)
def test_save_and_load_model_when_input_is_pathlike(self) -> None:
proto = _simple_model()
with tempfile.TemporaryDirectory() as temp_dir:
model_path = pathlib.Path(temp_dir, "model.onnx")
onnx.save_model(proto, model_path, format=self.format)
loaded_proto = onnx.load_model(model_path, format=self.format)
self.assertEqual(proto, loaded_proto)
@parameterized.parameterized_class(
[
{"format": "protobuf"},
{"format": "textproto"},
{"format": "json"},
# The onnxtxt format does not support saving/loading tensors yet
]
)
class TestIOTensor(unittest.TestCase):
"""Test loading and saving of TensorProto."""
format: str
def test_load_tensor_when_input_is_bytes(self) -> None:
proto = _simple_tensor()
proto_string = serialization.registry.get(self.format).serialize_proto(proto)
loaded_proto = onnx.load_tensor_from_string(proto_string, format=self.format)
self.assertEqual(proto, loaded_proto)
def test_save_and_load_tensor_when_input_has_read_function(self) -> None:
# Test if input has a read function
proto = _simple_tensor()
f = io.BytesIO()
onnx.save_tensor(proto, f, format=self.format)
loaded_proto = onnx.load_tensor(io.BytesIO(f.getvalue()), format=self.format)
self.assertEqual(proto, loaded_proto)
def test_save_and_load_tensor_when_input_is_file_name(self) -> None:
# Test if input is a file name
proto = _simple_tensor()
with tempfile.TemporaryDirectory() as temp_dir:
model_path = os.path.join(temp_dir, "model.onnx")
onnx.save_tensor(proto, model_path, format=self.format)
loaded_proto = onnx.load_tensor(model_path, format=self.format)
self.assertEqual(proto, loaded_proto)
def test_save_and_load_tensor_when_input_is_pathlike(self) -> None:
# Test if input is a file name
proto = _simple_tensor()
with tempfile.TemporaryDirectory() as temp_dir:
model_path = pathlib.Path(temp_dir, "model.onnx")
onnx.save_tensor(proto, model_path, format=self.format)
loaded_proto = onnx.load_tensor(model_path, format=self.format)
self.assertEqual(proto, loaded_proto)
class TestSaveAndLoadFileExtensions(unittest.TestCase):
def test_save_model_picks_correct_format_from_extension(self) -> None:
proto = _simple_model()
with tempfile.TemporaryDirectory() as temp_dir:
model_path = os.path.join(temp_dir, "model.textproto")
# No format is specified, so the extension should be used to determine the format
onnx.save_model(proto, model_path)
loaded_proto = onnx.load_model(model_path, format="textproto")
self.assertEqual(proto, loaded_proto)
def test_load_model_picks_correct_format_from_extension(self) -> None:
proto = _simple_model()
with tempfile.TemporaryDirectory() as temp_dir:
model_path = os.path.join(temp_dir, "model.textproto")
onnx.save_model(proto, model_path, format="textproto")
# No format is specified, so the extension should be used to determine the format
loaded_proto = onnx.load_model(model_path)
self.assertEqual(proto, loaded_proto)
def test_save_model_uses_format_when_it_is_specified(self) -> None:
proto = _simple_model()
with tempfile.TemporaryDirectory() as temp_dir:
model_path = os.path.join(temp_dir, "model.textproto")
# `format` is specified. It should take precedence over the extension
onnx.save_model(proto, model_path, format="protobuf")
loaded_proto = onnx.load_model(model_path, format="protobuf")
self.assertEqual(proto, loaded_proto)
with self.assertRaises(google.protobuf.text_format.ParseError):
# Loading it as textproto (by file extension) should fail
onnx.load_model(model_path)
def test_load_model_uses_format_when_it_is_specified(self) -> None:
proto = _simple_model()
with tempfile.TemporaryDirectory() as temp_dir:
model_path = os.path.join(temp_dir, "model.protobuf")
onnx.save_model(proto, model_path)
with self.assertRaises(google.protobuf.text_format.ParseError):
# `format` is specified. It should take precedence over the extension
# Loading it as textproto should fail
onnx.load_model(model_path, format="textproto")
loaded_proto = onnx.load_model(model_path, format="protobuf")
self.assertEqual(proto, loaded_proto)
def test_load_and_save_model_to_path_without_specifying_extension_succeeds(
self,
) -> None:
proto = _simple_model()
with tempfile.TemporaryDirectory() as temp_dir:
# No extension is specified
model_path = os.path.join(temp_dir, "model")
onnx.save_model(proto, model_path, format="textproto")
with self.assertRaises(google.protobuf.message.DecodeError):
# `format` is not specified. load_model should assume protobuf
# and fail to load it
onnx.load_model(model_path)
loaded_proto = onnx.load_model(model_path, format="textproto")
self.assertEqual(proto, loaded_proto)
def test_load_and_save_model_without_specifying_extension_or_format_defaults_to_protobuf(
self,
) -> None:
proto = _simple_model()
with tempfile.TemporaryDirectory() as temp_dir:
# No extension is specified
model_path = os.path.join(temp_dir, "model")
onnx.save_model(proto, model_path)
with self.assertRaises(google.protobuf.text_format.ParseError):
# The model is saved as protobuf, so loading it as textproto should fail
onnx.load_model(model_path, format="textproto")
loaded_proto = onnx.load_model(model_path)
self.assertEqual(proto, loaded_proto)
loaded_proto_as_explicitly_protobuf = onnx.load_model(
model_path, format="protobuf"
)
self.assertEqual(proto, loaded_proto_as_explicitly_protobuf)
class TestBasicFunctions(unittest.TestCase):
def test_protos_exist(self) -> None:
# The proto classes should exist
_ = onnx.AttributeProto
_ = onnx.NodeProto
_ = onnx.GraphProto
_ = onnx.ModelProto
def test_version_exists(self) -> None:
model = onnx.ModelProto()
# When we create it, graph should not have a version string.
self.assertFalse(model.HasField("ir_version"))
# We should touch the version so it is annotated with the current
# ir version of the running ONNX
model.ir_version = onnx.IR_VERSION
model_string = model.SerializeToString()
model.ParseFromString(model_string)
self.assertTrue(model.HasField("ir_version"))
# Check if the version is correct.
self.assertEqual(model.ir_version, onnx.IR_VERSION)
if __name__ == "__main__":
unittest.main()