GameServerO / MLPY /Lib /site-packages /onnx /test /serialization_test.py
Kano001's picture
Upload 2707 files
dc2106c verified
raw
history blame
3.29 kB
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
import os
import tempfile
import unittest
import onnx
_TEST_MODEL = """\
<
ir_version: 8,
opset_import: ["" : 17, "local" : 1]
>
agraph (float[N] X) => (float[N] Y) {
Y = local.foo (X)
}
<opset_import: ["" : 17, "local" : 1], domain: "local">
foo (x) => (y) {
temp = Add(x, x)
y = local.bar(temp)
}
<opset_import: ["" : 17], domain: "local">
bar (x) => (y) {
y = Mul (x, x)
}"""
class _OnnxTestTextualSerializer(onnx.serialization.ProtoSerializer):
"""Serialize and deserialize the ONNX textual representation."""
supported_format = "onnxtext"
file_extensions = frozenset({".onnxtext"})
def serialize_proto(self, proto) -> bytes:
text = onnx.printer.to_text(proto)
return text.encode("utf-8")
def deserialize_proto(self, serialized: bytes, proto):
text = serialized.decode("utf-8")
if isinstance(proto, onnx.ModelProto):
return onnx.parser.parse_model(text)
if isinstance(proto, onnx.GraphProto):
return onnx.parser.parse_graph(text)
if isinstance(proto, onnx.FunctionProto):
return onnx.parser.parse_function(text)
if isinstance(proto, onnx.NodeProto):
return onnx.parser.parse_node(text)
raise ValueError(f"Unsupported proto type: {type(proto)}")
class TestRegistry(unittest.TestCase):
def setUp(self) -> None:
self.serializer = _OnnxTestTextualSerializer()
onnx.serialization.registry.register(self.serializer)
def test_get_returns_the_registered_instance(self) -> None:
serializer = onnx.serialization.registry.get("onnxtext")
self.assertIs(serializer, self.serializer)
def test_get_raises_for_unsupported_format(self) -> None:
with self.assertRaises(ValueError):
onnx.serialization.registry.get("unsupported")
def test_onnx_save_load_model_uses_the_custom_serializer(self) -> None:
model = onnx.parser.parse_model(_TEST_MODEL)
with tempfile.TemporaryDirectory() as tmpdir:
model_path = os.path.join(tmpdir, "model.onnx")
onnx.save_model(model, model_path, format="onnxtext")
# Check the file content
with open(model_path, encoding="utf-8") as f:
content = f.read()
self.assertEqual(content, onnx.printer.to_text(model))
loaded_model = onnx.load_model(model_path, format="onnxtext")
self.assertEqual(
model.SerializeToString(deterministic=True),
loaded_model.SerializeToString(deterministic=True),
)
class TestCustomSerializer(unittest.TestCase):
def test_serialize_deserialize_model(self) -> None:
serializer = _OnnxTestTextualSerializer()
model = onnx.parser.parse_model(_TEST_MODEL)
serialized = serializer.serialize_proto(model)
deserialized = serializer.deserialize_proto(serialized, onnx.ModelProto())
self.assertEqual(
model.SerializeToString(deterministic=True),
deserialized.SerializeToString(deterministic=True),
)