Spaces:
Sleeping
Sleeping
File size: 11,910 Bytes
dc2106c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 |
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
import os
import re
import sys
import uuid
from itertools import chain
from typing import Callable, Iterable, Optional
import onnx.onnx_cpp2py_export.checker as c_checker
from onnx.onnx_pb import AttributeProto, GraphProto, ModelProto, TensorProto
class ExternalDataInfo:
def __init__(self, tensor: TensorProto) -> None:
self.location = ""
self.offset = None
self.length = None
self.checksum = None
self.basepath = ""
for entry in tensor.external_data:
setattr(self, entry.key, entry.value)
if self.offset:
self.offset = int(self.offset)
if self.length:
self.length = int(self.length)
def load_external_data_for_tensor(tensor: TensorProto, base_dir: str) -> None:
"""Loads data from an external file for tensor.
Ideally TensorProto should not hold any raw data but if it does it will be ignored.
Arguments:
tensor: a TensorProto object.
base_dir: directory that contains the external data.
"""
info = ExternalDataInfo(tensor)
external_data_file_path = c_checker._resolve_external_data_location( # type: ignore[attr-defined]
base_dir, info.location, tensor.name
)
with open(external_data_file_path, "rb") as data_file:
if info.offset:
data_file.seek(info.offset)
if info.length:
tensor.raw_data = data_file.read(info.length)
else:
tensor.raw_data = data_file.read()
def load_external_data_for_model(model: ModelProto, base_dir: str) -> None:
"""Loads external tensors into model
Arguments:
model: ModelProto to load external data to
base_dir: directory that contains external data
"""
for tensor in _get_all_tensors(model):
if uses_external_data(tensor):
load_external_data_for_tensor(tensor, base_dir)
# After loading raw_data from external_data, change the state of tensors
tensor.data_location = TensorProto.DEFAULT
# and remove external data
del tensor.external_data[:]
def set_external_data(
tensor: TensorProto,
location: str,
offset: Optional[int] = None,
length: Optional[int] = None,
checksum: Optional[str] = None,
basepath: Optional[str] = None,
) -> None:
if not tensor.HasField("raw_data"):
raise ValueError(
"Tensor "
+ tensor.name
+ "does not have raw_data field. Cannot set external data for this tensor."
)
del tensor.external_data[:]
tensor.data_location = TensorProto.EXTERNAL
for k, v in {
"location": location,
"offset": int(offset) if offset is not None else None,
"length": int(length) if length is not None else None,
"checksum": checksum,
"basepath": basepath,
}.items():
if v is not None:
entry = tensor.external_data.add()
entry.key = k
entry.value = str(v)
def convert_model_to_external_data(
model: ModelProto,
all_tensors_to_one_file: bool = True,
location: Optional[str] = None,
size_threshold: int = 1024,
convert_attribute: bool = False,
) -> None:
"""Call to set all tensors with raw data as external data. This call should precede 'save_model'.
'save_model' saves all the tensors data as external data after calling this function.
Arguments:
model (ModelProto): Model to be converted.
all_tensors_to_one_file (bool): If true, save all tensors to one external file specified by location.
If false, save each tensor to a file named with the tensor name.
location: specify the external file relative to the model that all tensors to save to.
Path is relative to the model path.
If not specified, will use the model name.
size_threshold: Threshold for size of data. Only when tensor's data is >= the size_threshold
it will be converted to external data. To convert every tensor with raw data to external data set size_threshold=0.
convert_attribute (bool): If true, convert all tensors to external data
If false, convert only non-attribute tensors to external data
"""
tensors = _get_initializer_tensors(model)
if convert_attribute:
tensors = _get_all_tensors(model)
if all_tensors_to_one_file:
file_name = str(uuid.uuid1())
if location:
if os.path.isabs(location):
raise ValueError(
"location must be a relative path that is relative to the model path."
)
file_name = location
for tensor in tensors:
if (
tensor.HasField("raw_data")
and sys.getsizeof(tensor.raw_data) >= size_threshold
):
set_external_data(tensor, file_name)
else:
for tensor in tensors:
if (
tensor.HasField("raw_data")
and sys.getsizeof(tensor.raw_data) >= size_threshold
):
tensor_location = tensor.name
if not _is_valid_filename(tensor_location):
tensor_location = str(uuid.uuid1())
set_external_data(tensor, tensor_location)
def convert_model_from_external_data(model: ModelProto) -> None:
"""Call to set all tensors which use external data as embedded data.
save_model saves all the tensors data as embedded data after
calling this function.
Arguments:
model (ModelProto): Model to be converted.
"""
for tensor in _get_all_tensors(model):
if uses_external_data(tensor):
if not tensor.HasField("raw_data"):
raise ValueError("raw_data field doesn't exist.")
del tensor.external_data[:]
tensor.data_location = TensorProto.DEFAULT
def save_external_data(tensor: TensorProto, base_path: str) -> None:
"""Writes tensor data to an external file according to information in the `external_data` field.
Arguments:
tensor (TensorProto): Tensor object to be serialized
base_path: System path of a folder where tensor data is to be stored
"""
info = ExternalDataInfo(tensor)
external_data_file_path = os.path.join(base_path, info.location)
# Retrieve the tensor's data from raw_data or load external file
if not tensor.HasField("raw_data"):
raise ValueError("raw_data field doesn't exist.")
# Create file if it doesn't exist
if not os.path.isfile(external_data_file_path):
with open(external_data_file_path, "ab"):
pass
# Open file for reading and writing at random locations ('r+b')
with open(external_data_file_path, "r+b") as data_file:
data_file.seek(0, 2)
if info.offset is not None:
# Pad file to required offset if needed
file_size = data_file.tell()
if info.offset > file_size:
data_file.write(b"\0" * (info.offset - file_size))
data_file.seek(info.offset)
offset = data_file.tell()
data_file.write(tensor.raw_data)
set_external_data(tensor, info.location, offset, data_file.tell() - offset)
def _get_all_tensors(onnx_model_proto: ModelProto) -> Iterable[TensorProto]:
"""Scan an ONNX model for all tensors and return as an iterator."""
return chain(
_get_initializer_tensors(onnx_model_proto),
_get_attribute_tensors(onnx_model_proto),
)
def _recursive_attribute_processor(
attribute: AttributeProto, func: Callable[[GraphProto], Iterable[TensorProto]]
) -> Iterable[TensorProto]:
"""Create an iterator through processing ONNX model attributes with functor."""
if attribute.type == AttributeProto.GRAPH:
yield from func(attribute.g)
if attribute.type == AttributeProto.GRAPHS:
for graph in attribute.graphs:
yield from func(graph)
def _get_initializer_tensors_from_graph(
onnx_model_proto_graph: GraphProto,
) -> Iterable[TensorProto]:
"""Create an iterator of initializer tensors from ONNX model graph."""
yield from onnx_model_proto_graph.initializer
for node in onnx_model_proto_graph.node:
for attribute in node.attribute:
yield from _recursive_attribute_processor(
attribute, _get_initializer_tensors_from_graph
)
def _get_initializer_tensors(onnx_model_proto: ModelProto) -> Iterable[TensorProto]:
"""Create an iterator of initializer tensors from ONNX model."""
yield from _get_initializer_tensors_from_graph(onnx_model_proto.graph)
def _get_attribute_tensors_from_graph(
onnx_model_proto_graph: GraphProto,
) -> Iterable[TensorProto]:
"""Create an iterator of tensors from node attributes of an ONNX model graph."""
for node in onnx_model_proto_graph.node:
for attribute in node.attribute:
if attribute.HasField("t"):
yield attribute.t
yield from attribute.tensors
yield from _recursive_attribute_processor(
attribute, _get_attribute_tensors_from_graph
)
def _get_attribute_tensors(onnx_model_proto: ModelProto) -> Iterable[TensorProto]:
"""Create an iterator of tensors from node attributes of an ONNX model."""
yield from _get_attribute_tensors_from_graph(onnx_model_proto.graph)
def _is_valid_filename(filename: str) -> bool:
"""Utility to check whether the provided filename is valid."""
exp = re.compile('^[^<>:;,?"*|/]+$')
match = exp.match(filename)
return bool(match)
def uses_external_data(tensor: TensorProto) -> bool:
"""Returns true if the tensor stores data in an external location."""
return ( # type: ignore[no-any-return]
tensor.HasField("data_location")
and tensor.data_location == TensorProto.EXTERNAL
)
def remove_external_data_field(tensor: TensorProto, field_key: str) -> None:
"""Removes a field from a Tensor's external_data key-value store.
Modifies tensor object in place.
Arguments:
tensor (TensorProto): Tensor object from which value will be removed
field_key (string): The key of the field to be removed
"""
for i, field in enumerate(tensor.external_data):
if field.key == field_key:
del tensor.external_data[i]
def write_external_data_tensors(model: ModelProto, filepath: str) -> ModelProto:
"""Serializes data for all the tensors which have data location set to TensorProto.External.
Note: This function also strips basepath information from all tensors' external_data fields.
Arguments:
model (ModelProto): Model object which is the source of tensors to serialize.
filepath: System path to the directory which should be treated as base path for external data.
Returns:
ModelProto: The modified model object.
"""
for tensor in _get_all_tensors(model):
# Writing to external data happens in 2 passes:
# 1. Tensors with raw data which pass the necessary conditions (size threshold etc) are marked for serialization
# 2. The raw data in these tensors is serialized to a file
# Thus serialize only if tensor has raw data and it was marked for serialization
if uses_external_data(tensor) and tensor.HasField("raw_data"):
save_external_data(tensor, filepath)
tensor.ClearField("raw_data")
return model
|