File size: 3,288 Bytes
dc2106c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0
import os
import tempfile
import unittest

import onnx

_TEST_MODEL = """\

<

    ir_version: 8,

    opset_import: ["" : 17, "local" : 1]

>

agraph (float[N] X) => (float[N] Y) {

    Y = local.foo (X)

}



<opset_import: ["" : 17, "local" : 1], domain: "local">

foo (x) => (y) {

    temp = Add(x, x)

    y = local.bar(temp)

}



<opset_import: ["" : 17], domain: "local">

bar (x) => (y) {

    y = Mul (x, x)

}"""


class _OnnxTestTextualSerializer(onnx.serialization.ProtoSerializer):
    """Serialize and deserialize the ONNX textual representation."""

    supported_format = "onnxtext"
    file_extensions = frozenset({".onnxtext"})

    def serialize_proto(self, proto) -> bytes:
        text = onnx.printer.to_text(proto)
        return text.encode("utf-8")

    def deserialize_proto(self, serialized: bytes, proto):
        text = serialized.decode("utf-8")
        if isinstance(proto, onnx.ModelProto):
            return onnx.parser.parse_model(text)
        if isinstance(proto, onnx.GraphProto):
            return onnx.parser.parse_graph(text)
        if isinstance(proto, onnx.FunctionProto):
            return onnx.parser.parse_function(text)
        if isinstance(proto, onnx.NodeProto):
            return onnx.parser.parse_node(text)
        raise ValueError(f"Unsupported proto type: {type(proto)}")


class TestRegistry(unittest.TestCase):
    def setUp(self) -> None:
        self.serializer = _OnnxTestTextualSerializer()
        onnx.serialization.registry.register(self.serializer)

    def test_get_returns_the_registered_instance(self) -> None:
        serializer = onnx.serialization.registry.get("onnxtext")
        self.assertIs(serializer, self.serializer)

    def test_get_raises_for_unsupported_format(self) -> None:
        with self.assertRaises(ValueError):
            onnx.serialization.registry.get("unsupported")

    def test_onnx_save_load_model_uses_the_custom_serializer(self) -> None:
        model = onnx.parser.parse_model(_TEST_MODEL)
        with tempfile.TemporaryDirectory() as tmpdir:
            model_path = os.path.join(tmpdir, "model.onnx")
            onnx.save_model(model, model_path, format="onnxtext")

            # Check the file content
            with open(model_path, encoding="utf-8") as f:
                content = f.read()
                self.assertEqual(content, onnx.printer.to_text(model))

            loaded_model = onnx.load_model(model_path, format="onnxtext")

            self.assertEqual(
                model.SerializeToString(deterministic=True),
                loaded_model.SerializeToString(deterministic=True),
            )


class TestCustomSerializer(unittest.TestCase):
    def test_serialize_deserialize_model(self) -> None:
        serializer = _OnnxTestTextualSerializer()
        model = onnx.parser.parse_model(_TEST_MODEL)
        serialized = serializer.serialize_proto(model)
        deserialized = serializer.deserialize_proto(serialized, onnx.ModelProto())
        self.assertEqual(
            model.SerializeToString(deterministic=True),
            deserialized.SerializeToString(deterministic=True),
        )