File size: 11,910 Bytes
dc2106c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
import os
import re
import sys
import uuid
from itertools import chain
from typing import Callable, Iterable, Optional

import onnx.onnx_cpp2py_export.checker as c_checker
from onnx.onnx_pb import AttributeProto, GraphProto, ModelProto, TensorProto


class ExternalDataInfo:
    def __init__(self, tensor: TensorProto) -> None:
        self.location = ""
        self.offset = None
        self.length = None
        self.checksum = None
        self.basepath = ""

        for entry in tensor.external_data:
            setattr(self, entry.key, entry.value)

        if self.offset:
            self.offset = int(self.offset)

        if self.length:
            self.length = int(self.length)


def load_external_data_for_tensor(tensor: TensorProto, base_dir: str) -> None:
    """Loads data from an external file for tensor.

    Ideally TensorProto should not hold any raw data but if it does it will be ignored.



    Arguments:

        tensor: a TensorProto object.

        base_dir: directory that contains the external data.

    """
    info = ExternalDataInfo(tensor)
    external_data_file_path = c_checker._resolve_external_data_location(  # type: ignore[attr-defined]
        base_dir, info.location, tensor.name
    )
    with open(external_data_file_path, "rb") as data_file:
        if info.offset:
            data_file.seek(info.offset)

        if info.length:
            tensor.raw_data = data_file.read(info.length)
        else:
            tensor.raw_data = data_file.read()


def load_external_data_for_model(model: ModelProto, base_dir: str) -> None:
    """Loads external tensors into model



    Arguments:

        model: ModelProto to load external data to

        base_dir: directory that contains external data

    """
    for tensor in _get_all_tensors(model):
        if uses_external_data(tensor):
            load_external_data_for_tensor(tensor, base_dir)
            # After loading raw_data from external_data, change the state of tensors
            tensor.data_location = TensorProto.DEFAULT
            # and remove external data
            del tensor.external_data[:]


def set_external_data(

    tensor: TensorProto,

    location: str,

    offset: Optional[int] = None,

    length: Optional[int] = None,

    checksum: Optional[str] = None,

    basepath: Optional[str] = None,

) -> None:
    if not tensor.HasField("raw_data"):
        raise ValueError(
            "Tensor "
            + tensor.name
            + "does not have raw_data field. Cannot set external data for this tensor."
        )

    del tensor.external_data[:]
    tensor.data_location = TensorProto.EXTERNAL
    for k, v in {
        "location": location,
        "offset": int(offset) if offset is not None else None,
        "length": int(length) if length is not None else None,
        "checksum": checksum,
        "basepath": basepath,
    }.items():
        if v is not None:
            entry = tensor.external_data.add()
            entry.key = k
            entry.value = str(v)


def convert_model_to_external_data(

    model: ModelProto,

    all_tensors_to_one_file: bool = True,

    location: Optional[str] = None,

    size_threshold: int = 1024,

    convert_attribute: bool = False,

) -> None:
    """Call to set all tensors with raw data as external data. This call should precede 'save_model'.

    'save_model' saves all the tensors data as external data after calling this function.



    Arguments:

        model (ModelProto): Model to be converted.

        all_tensors_to_one_file (bool): If true, save all tensors to one external file specified by location.

            If false, save each tensor to a file named with the tensor name.

        location: specify the external file relative to the model that all tensors to save to.

            Path is relative to the model path.

            If not specified, will use the model name.

        size_threshold: Threshold for size of data. Only when tensor's data is >= the size_threshold

            it will be converted to external data. To convert every tensor with raw data to external data set size_threshold=0.

        convert_attribute (bool): If true, convert all tensors to external data

                       If false, convert only non-attribute tensors to external data

    """
    tensors = _get_initializer_tensors(model)
    if convert_attribute:
        tensors = _get_all_tensors(model)

    if all_tensors_to_one_file:
        file_name = str(uuid.uuid1())
        if location:
            if os.path.isabs(location):
                raise ValueError(
                    "location must be a relative path that is relative to the model path."
                )
            file_name = location
        for tensor in tensors:
            if (
                tensor.HasField("raw_data")
                and sys.getsizeof(tensor.raw_data) >= size_threshold
            ):
                set_external_data(tensor, file_name)
    else:
        for tensor in tensors:
            if (
                tensor.HasField("raw_data")
                and sys.getsizeof(tensor.raw_data) >= size_threshold
            ):
                tensor_location = tensor.name
                if not _is_valid_filename(tensor_location):
                    tensor_location = str(uuid.uuid1())
                set_external_data(tensor, tensor_location)


def convert_model_from_external_data(model: ModelProto) -> None:
    """Call to set all tensors which use external data as embedded data.

    save_model saves all the tensors data as embedded data after

    calling this function.



    Arguments:

        model (ModelProto): Model to be converted.

    """
    for tensor in _get_all_tensors(model):
        if uses_external_data(tensor):
            if not tensor.HasField("raw_data"):
                raise ValueError("raw_data field doesn't exist.")
            del tensor.external_data[:]
            tensor.data_location = TensorProto.DEFAULT


def save_external_data(tensor: TensorProto, base_path: str) -> None:
    """Writes tensor data to an external file according to information in the `external_data` field.



    Arguments:

        tensor (TensorProto): Tensor object to be serialized

        base_path: System path of a folder where tensor data is to be stored

    """
    info = ExternalDataInfo(tensor)
    external_data_file_path = os.path.join(base_path, info.location)

    # Retrieve the tensor's data from raw_data or load external file
    if not tensor.HasField("raw_data"):
        raise ValueError("raw_data field doesn't exist.")

    # Create file if it doesn't exist
    if not os.path.isfile(external_data_file_path):
        with open(external_data_file_path, "ab"):
            pass

    # Open file for reading and writing at random locations ('r+b')
    with open(external_data_file_path, "r+b") as data_file:
        data_file.seek(0, 2)
        if info.offset is not None:
            # Pad file to required offset if needed
            file_size = data_file.tell()
            if info.offset > file_size:
                data_file.write(b"\0" * (info.offset - file_size))

            data_file.seek(info.offset)
        offset = data_file.tell()
        data_file.write(tensor.raw_data)
        set_external_data(tensor, info.location, offset, data_file.tell() - offset)


def _get_all_tensors(onnx_model_proto: ModelProto) -> Iterable[TensorProto]:
    """Scan an ONNX model for all tensors and return as an iterator."""
    return chain(
        _get_initializer_tensors(onnx_model_proto),
        _get_attribute_tensors(onnx_model_proto),
    )


def _recursive_attribute_processor(

    attribute: AttributeProto, func: Callable[[GraphProto], Iterable[TensorProto]]

) -> Iterable[TensorProto]:
    """Create an iterator through processing ONNX model attributes with functor."""
    if attribute.type == AttributeProto.GRAPH:
        yield from func(attribute.g)
    if attribute.type == AttributeProto.GRAPHS:
        for graph in attribute.graphs:
            yield from func(graph)


def _get_initializer_tensors_from_graph(

    onnx_model_proto_graph: GraphProto,

) -> Iterable[TensorProto]:
    """Create an iterator of initializer tensors from ONNX model graph."""
    yield from onnx_model_proto_graph.initializer
    for node in onnx_model_proto_graph.node:
        for attribute in node.attribute:
            yield from _recursive_attribute_processor(
                attribute, _get_initializer_tensors_from_graph
            )


def _get_initializer_tensors(onnx_model_proto: ModelProto) -> Iterable[TensorProto]:
    """Create an iterator of initializer tensors from ONNX model."""
    yield from _get_initializer_tensors_from_graph(onnx_model_proto.graph)


def _get_attribute_tensors_from_graph(

    onnx_model_proto_graph: GraphProto,

) -> Iterable[TensorProto]:
    """Create an iterator of tensors from node attributes of an ONNX model graph."""
    for node in onnx_model_proto_graph.node:
        for attribute in node.attribute:
            if attribute.HasField("t"):
                yield attribute.t
            yield from attribute.tensors
            yield from _recursive_attribute_processor(
                attribute, _get_attribute_tensors_from_graph
            )


def _get_attribute_tensors(onnx_model_proto: ModelProto) -> Iterable[TensorProto]:
    """Create an iterator of tensors from node attributes of an ONNX model."""
    yield from _get_attribute_tensors_from_graph(onnx_model_proto.graph)


def _is_valid_filename(filename: str) -> bool:
    """Utility to check whether the provided filename is valid."""
    exp = re.compile('^[^<>:;,?"*|/]+$')
    match = exp.match(filename)
    return bool(match)


def uses_external_data(tensor: TensorProto) -> bool:
    """Returns true if the tensor stores data in an external location."""
    return (  # type: ignore[no-any-return]
        tensor.HasField("data_location")
        and tensor.data_location == TensorProto.EXTERNAL
    )


def remove_external_data_field(tensor: TensorProto, field_key: str) -> None:
    """Removes a field from a Tensor's external_data key-value store.



    Modifies tensor object in place.



    Arguments:

        tensor (TensorProto): Tensor object from which value will be removed

        field_key (string): The key of the field to be removed

    """
    for i, field in enumerate(tensor.external_data):
        if field.key == field_key:
            del tensor.external_data[i]


def write_external_data_tensors(model: ModelProto, filepath: str) -> ModelProto:
    """Serializes data for all the tensors which have data location set to TensorProto.External.



    Note: This function also strips basepath information from all tensors' external_data fields.



    Arguments:

        model (ModelProto): Model object which is the source of tensors to serialize.

        filepath: System path to the directory which should be treated as base path for external data.



    Returns:

        ModelProto: The modified model object.

    """
    for tensor in _get_all_tensors(model):
        # Writing to external data happens in 2 passes:
        # 1. Tensors with raw data which pass the necessary conditions (size threshold etc) are marked for serialization
        # 2. The raw data in these tensors is serialized to a file
        # Thus serialize only if tensor has raw data and it was marked for serialization
        if uses_external_data(tensor) and tensor.HasField("raw_data"):
            save_external_data(tensor, filepath)
            tensor.ClearField("raw_data")

    return model