File size: 9,831 Bytes
dc2106c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0

import io
import os
import pathlib
import tempfile
import unittest

import google.protobuf.message
import google.protobuf.text_format
import parameterized

import onnx
from onnx import serialization


def _simple_model() -> onnx.ModelProto:
    model = onnx.ModelProto()
    model.ir_version = onnx.IR_VERSION
    model.producer_name = "onnx-test"
    model.graph.name = "test"
    return model


def _simple_tensor() -> onnx.TensorProto:
    tensor = onnx.helper.make_tensor(
        name="test-tensor",
        data_type=onnx.TensorProto.FLOAT,
        dims=(2, 3, 4),
        vals=[x + 0.5 for x in range(24)],
    )
    return tensor


@parameterized.parameterized_class(

    [

        {"format": "protobuf"},

        {"format": "textproto"},

        {"format": "json"},

        {"format": "onnxtxt"},

    ]

)
class TestIO(unittest.TestCase):
    format: str

    def test_load_model_when_input_is_bytes(self) -> None:
        proto = _simple_model()
        proto_string = serialization.registry.get(self.format).serialize_proto(proto)
        loaded_proto = onnx.load_model_from_string(proto_string, format=self.format)
        self.assertEqual(proto, loaded_proto)

    def test_save_and_load_model_when_input_has_read_function(self) -> None:
        proto = _simple_model()
        # When the proto is a bytes representation provided to `save_model`,
        # it should always be a serialized binary protobuf representation. Aka. format="protobuf"
        # The saved file format is specified by the `format` argument.
        proto_string = serialization.registry.get("protobuf").serialize_proto(proto)
        f = io.BytesIO()
        onnx.save_model(proto_string, f, format=self.format)
        loaded_proto = onnx.load_model(io.BytesIO(f.getvalue()), format=self.format)
        self.assertEqual(proto, loaded_proto)

    def test_save_and_load_model_when_input_is_file_name(self) -> None:
        proto = _simple_model()
        with tempfile.TemporaryDirectory() as temp_dir:
            model_path = os.path.join(temp_dir, "model.onnx")
            onnx.save_model(proto, model_path, format=self.format)
            loaded_proto = onnx.load_model(model_path, format=self.format)
            self.assertEqual(proto, loaded_proto)

    def test_save_and_load_model_when_input_is_pathlike(self) -> None:
        proto = _simple_model()
        with tempfile.TemporaryDirectory() as temp_dir:
            model_path = pathlib.Path(temp_dir, "model.onnx")
            onnx.save_model(proto, model_path, format=self.format)
            loaded_proto = onnx.load_model(model_path, format=self.format)
            self.assertEqual(proto, loaded_proto)


@parameterized.parameterized_class(

    [

        {"format": "protobuf"},

        {"format": "textproto"},

        {"format": "json"},

        # The onnxtxt format does not support saving/loading tensors yet

    ]

)
class TestIOTensor(unittest.TestCase):
    """Test loading and saving of TensorProto."""

    format: str

    def test_load_tensor_when_input_is_bytes(self) -> None:
        proto = _simple_tensor()
        proto_string = serialization.registry.get(self.format).serialize_proto(proto)
        loaded_proto = onnx.load_tensor_from_string(proto_string, format=self.format)
        self.assertEqual(proto, loaded_proto)

    def test_save_and_load_tensor_when_input_has_read_function(self) -> None:
        # Test if input has a read function
        proto = _simple_tensor()
        f = io.BytesIO()
        onnx.save_tensor(proto, f, format=self.format)
        loaded_proto = onnx.load_tensor(io.BytesIO(f.getvalue()), format=self.format)
        self.assertEqual(proto, loaded_proto)

    def test_save_and_load_tensor_when_input_is_file_name(self) -> None:
        # Test if input is a file name
        proto = _simple_tensor()
        with tempfile.TemporaryDirectory() as temp_dir:
            model_path = os.path.join(temp_dir, "model.onnx")
            onnx.save_tensor(proto, model_path, format=self.format)
            loaded_proto = onnx.load_tensor(model_path, format=self.format)
            self.assertEqual(proto, loaded_proto)

    def test_save_and_load_tensor_when_input_is_pathlike(self) -> None:
        # Test if input is a file name
        proto = _simple_tensor()
        with tempfile.TemporaryDirectory() as temp_dir:
            model_path = pathlib.Path(temp_dir, "model.onnx")
            onnx.save_tensor(proto, model_path, format=self.format)
            loaded_proto = onnx.load_tensor(model_path, format=self.format)
            self.assertEqual(proto, loaded_proto)


class TestSaveAndLoadFileExtensions(unittest.TestCase):
    def test_save_model_picks_correct_format_from_extension(self) -> None:
        proto = _simple_model()
        with tempfile.TemporaryDirectory() as temp_dir:
            model_path = os.path.join(temp_dir, "model.textproto")
            # No format is specified, so the extension should be used to determine the format
            onnx.save_model(proto, model_path)
            loaded_proto = onnx.load_model(model_path, format="textproto")
            self.assertEqual(proto, loaded_proto)

    def test_load_model_picks_correct_format_from_extension(self) -> None:
        proto = _simple_model()
        with tempfile.TemporaryDirectory() as temp_dir:
            model_path = os.path.join(temp_dir, "model.textproto")
            onnx.save_model(proto, model_path, format="textproto")
            # No format is specified, so the extension should be used to determine the format
            loaded_proto = onnx.load_model(model_path)
            self.assertEqual(proto, loaded_proto)

    def test_save_model_uses_format_when_it_is_specified(self) -> None:
        proto = _simple_model()
        with tempfile.TemporaryDirectory() as temp_dir:
            model_path = os.path.join(temp_dir, "model.textproto")
            # `format` is specified. It should take precedence over the extension
            onnx.save_model(proto, model_path, format="protobuf")
            loaded_proto = onnx.load_model(model_path, format="protobuf")
            self.assertEqual(proto, loaded_proto)
            with self.assertRaises(google.protobuf.text_format.ParseError):
                # Loading it as textproto (by file extension) should fail
                onnx.load_model(model_path)

    def test_load_model_uses_format_when_it_is_specified(self) -> None:
        proto = _simple_model()
        with tempfile.TemporaryDirectory() as temp_dir:
            model_path = os.path.join(temp_dir, "model.protobuf")
            onnx.save_model(proto, model_path)
            with self.assertRaises(google.protobuf.text_format.ParseError):
                # `format` is specified. It should take precedence over the extension
                # Loading it as textproto should fail
                onnx.load_model(model_path, format="textproto")

            loaded_proto = onnx.load_model(model_path, format="protobuf")
            self.assertEqual(proto, loaded_proto)

    def test_load_and_save_model_to_path_without_specifying_extension_succeeds(

        self,

    ) -> None:
        proto = _simple_model()
        with tempfile.TemporaryDirectory() as temp_dir:
            # No extension is specified
            model_path = os.path.join(temp_dir, "model")
            onnx.save_model(proto, model_path, format="textproto")
            with self.assertRaises(google.protobuf.message.DecodeError):
                # `format` is not specified. load_model should assume protobuf
                # and fail to load it
                onnx.load_model(model_path)

            loaded_proto = onnx.load_model(model_path, format="textproto")
            self.assertEqual(proto, loaded_proto)

    def test_load_and_save_model_without_specifying_extension_or_format_defaults_to_protobuf(

        self,

    ) -> None:
        proto = _simple_model()
        with tempfile.TemporaryDirectory() as temp_dir:
            # No extension is specified
            model_path = os.path.join(temp_dir, "model")
            onnx.save_model(proto, model_path)
            with self.assertRaises(google.protobuf.text_format.ParseError):
                # The model is saved as protobuf, so loading it as textproto should fail
                onnx.load_model(model_path, format="textproto")

            loaded_proto = onnx.load_model(model_path)
            self.assertEqual(proto, loaded_proto)
            loaded_proto_as_explicitly_protobuf = onnx.load_model(
                model_path, format="protobuf"
            )
            self.assertEqual(proto, loaded_proto_as_explicitly_protobuf)


class TestBasicFunctions(unittest.TestCase):
    def test_protos_exist(self) -> None:
        # The proto classes should exist
        _ = onnx.AttributeProto
        _ = onnx.NodeProto
        _ = onnx.GraphProto
        _ = onnx.ModelProto

    def test_version_exists(self) -> None:
        model = onnx.ModelProto()
        # When we create it, graph should not have a version string.
        self.assertFalse(model.HasField("ir_version"))
        # We should touch the version so it is annotated with the current
        # ir version of the running ONNX
        model.ir_version = onnx.IR_VERSION
        model_string = model.SerializeToString()
        model.ParseFromString(model_string)
        self.assertTrue(model.HasField("ir_version"))
        # Check if the version is correct.
        self.assertEqual(model.ir_version, onnx.IR_VERSION)


if __name__ == "__main__":
    unittest.main()