Spaces:
Sleeping
Sleeping
# Copyright (c) ONNX Project Contributors | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
import collections.abc | |
import numbers | |
import struct | |
from cmath import isnan | |
from typing import ( | |
Any, | |
Callable, | |
Dict, | |
KeysView, | |
List, | |
Optional, | |
Sequence, | |
Tuple, | |
TypeVar, | |
Union, | |
cast, | |
) | |
import google.protobuf.message | |
import numpy as np | |
from onnx import ( | |
IR_VERSION, | |
AttributeProto, | |
FunctionProto, | |
GraphProto, | |
MapProto, | |
ModelProto, | |
NodeProto, | |
OperatorSetIdProto, | |
OptionalProto, | |
SequenceProto, | |
SparseTensorProto, | |
TensorProto, | |
TensorShapeProto, | |
TrainingInfoProto, | |
TypeProto, | |
ValueInfoProto, | |
defs, | |
mapping, | |
subbyte, | |
) | |
VersionRowType = Union[Tuple[str, int, int, int], Tuple[str, int, int, int, int]] | |
VersionTableType = List[VersionRowType] | |
AssignmentBindingType = List[Tuple[str, str]] | |
# This is a copy of the documented version in https://github.com/onnx/onnx/blob/main/docs/Versioning.md#released-versions | |
# Both must be updated whenever a new version of ONNX is released. | |
VERSION_TABLE: VersionTableType = [ | |
# Release-version, IR version, ai.onnx version, ai.onnx.ml version, (optional) ai.onnx.training version | |
("1.0", 3, 1, 1), | |
("1.1", 3, 5, 1), | |
("1.1.2", 3, 6, 1), | |
("1.2", 3, 7, 1), | |
("1.3", 3, 8, 1), | |
("1.4.1", 4, 9, 1), | |
("1.5.0", 5, 10, 1), | |
("1.6.0", 6, 11, 2), | |
("1.7.0", 7, 12, 2, 1), | |
("1.8.0", 7, 13, 2, 1), | |
("1.8.1", 7, 13, 2, 1), | |
("1.9.0", 7, 14, 2, 1), | |
("1.10.0", 8, 15, 2, 1), | |
("1.10.1", 8, 15, 2, 1), | |
("1.10.2", 8, 15, 2, 1), | |
("1.11.0", 8, 16, 3, 1), | |
("1.12.0", 8, 17, 3, 1), | |
("1.13.0", 8, 18, 3, 1), | |
("1.13.1", 8, 18, 3, 1), | |
("1.14.0", 9, 19, 3, 1), | |
("1.14.1", 9, 19, 3, 1), | |
("1.15.0", 9, 20, 4, 1), | |
("1.16.0", 10, 21, 5, 1), | |
] | |
VersionMapType = Dict[Tuple[str, int], int] | |
def create_op_set_id_version_map(table: VersionTableType) -> VersionMapType: | |
"""Create a map from (opset-domain, opset-version) to ir-version from above table.""" | |
result: VersionMapType = {} | |
def process(release_version: str, ir_version: int, *args: Any) -> None: | |
del release_version # Unused | |
for pair in zip(["ai.onnx", "ai.onnx.ml", "ai.onnx.training"], args): | |
if pair not in result: | |
result[pair] = ir_version | |
if pair[0] == "ai.onnx.training": | |
result["ai.onnx.preview.training", pair[1]] = ir_version | |
for row in table: | |
process(*row) | |
return result | |
OP_SET_ID_VERSION_MAP = create_op_set_id_version_map(VERSION_TABLE) | |
def find_min_ir_version_for( | |
opsetidlist: Sequence[OperatorSetIdProto], ignore_unknown: bool = False | |
) -> int: | |
"""Given list of opset ids, determine minimum IR version required. | |
Args: | |
opsetidlist: A sequence of OperatorSetIdProto. | |
ignore_unknown: If True, ignore unknown domain and return default minimum | |
version for that domain. | |
Returns: | |
The minimum IR version required (integer) | |
""" | |
default_min_version = 3 | |
def find_min(domain: Union[str, None], version: int) -> int: | |
key = (domain or "ai.onnx", version) | |
if key in OP_SET_ID_VERSION_MAP: | |
return OP_SET_ID_VERSION_MAP[key] | |
if ignore_unknown: | |
return default_min_version | |
raise ValueError("Unsupported opset-version.") | |
if opsetidlist: | |
return max(find_min(x.domain, x.version) for x in opsetidlist) | |
return default_min_version # if no opsets specified | |
def make_node( | |
op_type: str, | |
inputs: Sequence[str], | |
outputs: Sequence[str], | |
name: Optional[str] = None, | |
doc_string: Optional[str] = None, | |
domain: Optional[str] = None, | |
overload: Optional[str] = None, | |
**kwargs: Any, | |
) -> NodeProto: | |
"""Construct a NodeProto. | |
Args: | |
op_type (string): The name of the operator to construct | |
inputs (list of string): list of input names | |
outputs (list of string): list of output names | |
name (string, default None): optional unique identifier for NodeProto | |
doc_string (string, default None): optional documentation string for NodeProto | |
domain (string, default None): optional domain for NodeProto. | |
If it's None, we will just use default domain (which is empty) | |
overload (string, default None): optional field, used to | |
resolve calls to model-local functions | |
**kwargs (dict): the attributes of the node. The acceptable values | |
are documented in :func:`make_attribute`. | |
Returns: | |
NodeProto | |
""" | |
node = NodeProto() | |
node.op_type = op_type | |
node.input.extend(inputs) | |
node.output.extend(outputs) | |
if name: | |
node.name = name | |
if doc_string: | |
node.doc_string = doc_string | |
if domain is not None: | |
node.domain = domain | |
if overload is not None: | |
node.overload = overload | |
if kwargs: | |
node.attribute.extend( | |
make_attribute(key, value) | |
for key, value in sorted(kwargs.items()) | |
if value is not None | |
) | |
return node | |
def make_operatorsetid( | |
domain: str, | |
version: int, | |
) -> OperatorSetIdProto: | |
"""Construct an OperatorSetIdProto. | |
Args: | |
domain (string): The domain of the operator set id | |
version (integer): Version of operator set id | |
Returns: | |
OperatorSetIdProto | |
""" | |
operatorsetid = OperatorSetIdProto() | |
operatorsetid.domain = domain | |
operatorsetid.version = version | |
return operatorsetid | |
def make_graph( | |
nodes: Sequence[NodeProto], | |
name: str, | |
inputs: Sequence[ValueInfoProto], | |
outputs: Sequence[ValueInfoProto], | |
initializer: Optional[Sequence[TensorProto]] = None, | |
doc_string: Optional[str] = None, | |
value_info: Optional[Sequence[ValueInfoProto]] = None, | |
sparse_initializer: Optional[Sequence[SparseTensorProto]] = None, | |
) -> GraphProto: | |
"""Construct a GraphProto | |
Args: | |
nodes: list of NodeProto | |
name (string): graph name | |
inputs: list of ValueInfoProto | |
outputs: list of ValueInfoProto | |
initializer: list of TensorProto | |
doc_string (string): graph documentation | |
value_info: list of ValueInfoProto | |
sparse_initializer: list of SparseTensorProto | |
Returns: | |
GraphProto | |
""" | |
if initializer is None: | |
initializer = [] | |
if sparse_initializer is None: | |
sparse_initializer = [] | |
if value_info is None: | |
value_info = [] | |
graph = GraphProto() | |
graph.node.extend(nodes) | |
graph.name = name | |
graph.input.extend(inputs) | |
graph.output.extend(outputs) | |
graph.initializer.extend(initializer) | |
graph.sparse_initializer.extend(sparse_initializer) | |
graph.value_info.extend(value_info) | |
if doc_string: | |
graph.doc_string = doc_string | |
return graph | |
def make_opsetid(domain: str, version: int) -> OperatorSetIdProto: | |
"""Construct an OperatorSetIdProto. | |
Args: | |
domain (string): The domain of the operator set id | |
version (integer): Version of operator set id | |
Returns: | |
OperatorSetIdProto | |
""" | |
opsetid = OperatorSetIdProto() | |
opsetid.domain = domain | |
opsetid.version = version | |
return opsetid | |
def make_function( | |
domain: str, | |
fname: str, | |
inputs: Sequence[str], | |
outputs: Sequence[str], | |
nodes: Sequence[NodeProto], | |
opset_imports: Sequence[OperatorSetIdProto], | |
attributes: Optional[Sequence[str]] = None, | |
attribute_protos: Optional[Sequence[AttributeProto]] = None, | |
doc_string: Optional[str] = None, | |
overload: Optional[str] = None, | |
value_info: Optional[Sequence[ValueInfoProto]] = None, | |
) -> FunctionProto: | |
if attributes is None: | |
attributes = [] | |
if attribute_protos is None: | |
attribute_protos = [] | |
if value_info is None: | |
value_info = [] | |
f = FunctionProto() | |
f.domain = domain | |
f.name = fname | |
f.input.extend(inputs) | |
f.output.extend(outputs) | |
f.node.extend(nodes) | |
f.opset_import.extend(opset_imports) | |
f.attribute.extend(attributes) | |
f.attribute_proto.extend(attribute_protos) | |
if doc_string: | |
f.doc_string = doc_string | |
if overload is not None: | |
f.overload = overload | |
f.value_info.extend(value_info) | |
return f | |
def make_model(graph: GraphProto, **kwargs: Any) -> ModelProto: | |
"""Construct a ModelProto | |
Args: | |
graph (GraphProto): *make_graph* returns | |
**kwargs: any attribute to add to the returned instance | |
Returns: | |
ModelProto | |
""" | |
model = ModelProto() | |
# Touch model.ir_version so it is stored as the version from which it is | |
# generated. | |
model.ir_version = IR_VERSION | |
model.graph.CopyFrom(graph) | |
opset_imports: Optional[Sequence[OperatorSetIdProto]] = None | |
opset_imports = kwargs.pop("opset_imports", None) # type: ignore | |
if opset_imports is not None: | |
model.opset_import.extend(opset_imports) | |
else: | |
# Default import | |
imp = model.opset_import.add() | |
imp.version = defs.onnx_opset_version() | |
functions: Optional[Sequence[FunctionProto]] = None | |
functions = kwargs.pop("functions", None) # type: ignore | |
if functions is not None: | |
model.functions.extend(functions) | |
for k, v in kwargs.items(): | |
# TODO: Does this work with repeated fields? | |
setattr(model, k, v) | |
return model | |
# An extension of make_model that infers an IR_VERSION for the model, | |
# if not specified, using a best-effort-basis. | |
def make_model_gen_version(graph: GraphProto, **kwargs: Any) -> ModelProto: | |
ir_version_field = "ir_version" | |
if ir_version_field not in kwargs: | |
opset_imports_field = "opset_imports" | |
imports = kwargs.get(opset_imports_field, []) | |
kwargs[ir_version_field] = find_min_ir_version_for(imports) | |
return make_model(graph, **kwargs) | |
def set_metadata_props( | |
proto: Union[ | |
ModelProto, GraphProto, FunctionProto, NodeProto, TensorProto, ValueInfoProto | |
], | |
dict_value: Dict[str, str], | |
) -> None: | |
del proto.metadata_props[:] | |
for k, v in dict_value.items(): | |
entry = proto.metadata_props.add() | |
entry.key = k | |
entry.value = v | |
def set_model_props(model: ModelProto, dict_value: Dict[str, str]) -> None: | |
set_metadata_props(model, dict_value) | |
def split_complex_to_pairs(ca: Sequence[np.complex64]) -> Sequence[int]: | |
return [ | |
(ca[i // 2].real if (i % 2 == 0) else ca[i // 2].imag) # type: ignore[misc] | |
for i in range(len(ca) * 2) | |
] | |
# convert a float32 value to a bfloat16 (as int) | |
# By default, this conversion rounds-to-nearest-even and supports NaN | |
# Setting `truncate` to True enables a simpler conversion. In this mode the | |
# conversion is performed by simply dropping the 2 least significant bytes of | |
# the significand. In this mode an error of up to 1 bit may be introduced and | |
# preservation of NaN values is not be guaranteed. | |
def float32_to_bfloat16(fval: float, truncate: bool = False) -> int: | |
ival = int.from_bytes(struct.pack("<f", fval), "little") | |
if truncate: | |
return ival >> 16 | |
# NaN requires at least 1 significand bit set | |
if isnan(fval): | |
return 0x7FC0 # sign=0, exp=all-ones, sig=0b1000000 | |
# drop bottom 16-bits | |
# round remaining bits using round-to-nearest-even | |
rounded = ((ival >> 16) & 1) + 0x7FFF | |
return (ival + rounded) >> 16 | |
def float32_to_float8e4m3( # noqa: PLR0911 | |
fval: float, | |
scale: float = 1.0, | |
fn: bool = True, | |
uz: bool = False, | |
saturate: bool = True, | |
) -> int: | |
"""Convert a float32 value to a float8, e4m3 (as int). | |
See :ref:`onnx-detail-float8` for technical details. | |
Args: | |
fval: float to convert | |
scale: scale, divide *fval* by *scale* before casting it | |
fn: no infinite values | |
uz: no negative zero | |
saturate: if True, any value out of range included inf becomes | |
the maximum value, otherwise, it becomes NaN. The | |
description of operator Cast fully describes the | |
differences. | |
Returns: | |
converted float | |
""" | |
if not fn: | |
raise NotImplementedError( | |
"float32_to_float8e4m3 not implemented with fn=False." | |
) | |
x = fval / scale | |
b = int.from_bytes(struct.pack("<f", np.float32(x)), "little") | |
ret = (b & 0x80000000) >> 24 # sign | |
if uz: | |
if (b & 0x7FC00000) == 0x7FC00000: # noqa: PLR2004 | |
return 0x80 | |
if np.isinf(x): | |
if saturate: | |
return ret | 127 | |
return 0x80 | |
e = (b & 0x7F800000) >> 23 # exponent | |
m = b & 0x007FFFFF # mantissa | |
if e < 116: # noqa: PLR2004 | |
ret = 0 | |
elif e < 120: # noqa: PLR2004 | |
# denormalized number | |
ex = e - 119 | |
if ex >= -2: # noqa: PLR2004 | |
ret |= 1 << (2 + ex) | |
ret |= m >> (21 - ex) | |
elif m > 0: | |
ret |= 1 | |
else: | |
ret = 0 | |
mask = 1 << (20 - ex) | |
if m & mask and ( | |
ret & 1 | |
or m & (mask - 1) > 0 | |
or (m & mask and m & (mask << 1) and m & (mask - 1) == 0) | |
): | |
# rounding | |
ret += 1 | |
elif e < 135: # noqa: PLR2004 | |
# normalized number | |
ex = e - 119 # 127 - 8 | |
if ex == 0: | |
ret |= 0x4 | |
ret |= m >> 21 | |
else: | |
ret |= ex << 3 | |
ret |= m >> 20 | |
if m & 0x80000 and ((m & 0x100000) or (m & 0x7FFFF)): | |
if (ret & 0x7F) < 0x7F: # noqa: PLR2004 | |
# rounding | |
ret += 1 | |
elif not saturate: | |
return 0x80 | |
elif saturate: | |
ret |= 0x7F # 01111110 | |
else: | |
ret = 0x80 | |
return int(ret) | |
else: | |
if (b & 0x7FC00000) == 0x7FC00000: # noqa: PLR2004 | |
return 0x7F | ret | |
if np.isinf(x): | |
if saturate: | |
return ret | 126 | |
return 0x7F | ret | |
e = (b & 0x7F800000) >> 23 # exponent | |
m = b & 0x007FFFFF # mantissa | |
if e != 0: | |
if e < 117: # noqa: PLR2004 | |
pass | |
elif e < 121: # noqa: PLR2004 | |
# denormalized number | |
ex = e - 120 | |
if ex >= -2: # noqa: PLR2004 | |
ret |= 1 << (2 + ex) | |
ret |= m >> (21 - ex) | |
elif m > 0: | |
ret |= 1 | |
mask = 1 << (20 - ex) | |
if m & mask and ( | |
ret & 1 | |
or m & (mask - 1) > 0 | |
or (m & mask and m & (mask << 1) and m & (mask - 1) == 0) | |
): | |
# rounding | |
ret += 1 | |
elif e < 136: # noqa: PLR2004 | |
# normalized number | |
ex = e - 120 | |
if ex == 0: | |
ret |= 0x4 | |
ret |= m >> 21 | |
else: | |
ret |= ex << 3 | |
ret |= m >> 20 | |
if (ret & 0x7F) == 0x7F: # noqa: PLR2004 | |
ret &= 0xFE | |
if (m & 0x80000) and ((m & 0x100000) or (m & 0x7FFFF)): | |
if (ret & 0x7F) < 0x7E: # noqa: PLR2004 | |
# rounding | |
ret += 1 | |
elif not saturate: | |
ret |= 0x7F | |
elif saturate: | |
ret |= 126 # 01111110 | |
else: | |
ret |= 0x7F | |
return int(ret) | |
def float32_to_float8e5m2( # noqa: PLR0911 | |
fval: float, | |
scale: float = 1.0, | |
fn: bool = False, | |
uz: bool = False, | |
saturate: bool = True, | |
) -> int: | |
"""Convert a float32 value to a float8, e5m2 (as int). | |
Args: | |
fval: float to convert | |
scale: scale, divide *fval* by *scale* before casting it | |
fn: no infinite values | |
uz: no negative zero | |
saturate: if True, any value out of range included inf becomes | |
the maximum value, otherwise, it becomes NaN. The | |
description of operator Cast fully describes the | |
differences. | |
Returns: | |
converted float | |
""" | |
x = fval / scale | |
b = int.from_bytes(struct.pack("<f", np.float32(x)), "little") | |
ret = (b & 0x80000000) >> 24 # sign | |
if fn and uz: | |
if (b & 0x7FC00000) == 0x7FC00000: # noqa: PLR2004 | |
return 0x80 | |
if (b & 0x7FFFFFFF) == 0x7F800000: # noqa: PLR2004 | |
# inf | |
if saturate: | |
return ret | 0x7F | |
return 0x80 | |
e = (b & 0x7F800000) >> 23 # exponent | |
m = b & 0x007FFFFF # mantissa | |
if e < 109: # noqa: PLR2004 | |
ret = 0 | |
elif e < 112: # noqa: PLR2004 | |
# denormalized number | |
ex = e - 111 | |
if ex >= -1: | |
ret |= 1 << (1 + ex) | |
ret |= m >> (22 - ex) | |
elif m > 0: | |
ret |= 1 | |
else: | |
ret = 0 | |
mask = 1 << (21 - ex) | |
if m & mask and ( | |
ret & 1 | |
or m & (mask - 1) > 0 | |
or (m & mask and m & (mask << 1) and m & (mask - 1) == 0) | |
): | |
# rounding | |
ret += 1 | |
elif e < 143: # noqa: PLR2004 | |
# normalized number | |
ex = e - 111 | |
ret |= ex << 2 | |
ret |= m >> 21 | |
if m & 0x100000 and ((m & 0xFFFFF) or (m & 0x200000)): | |
if (ret & 0x7F) < 0x7F: # noqa: PLR2004 | |
# rounding | |
ret += 1 | |
elif not saturate: | |
ret = 0x80 | |
elif e == 255 and m == 0: # inf # noqa: PLR2004 | |
ret = 0x80 | |
elif saturate: | |
ret |= 0x7F # last possible number | |
else: | |
ret = 0x80 | |
return int(ret) | |
elif not fn and not uz: | |
if (b & 0x7FC00000) == 0x7FC00000: # noqa: PLR2004 | |
return 0x7F | ret | |
if np.isinf(x): | |
if saturate: | |
return 0x7B | ret | |
return 0x7C | ret | |
e = (b & 0x7F800000) >> 23 # exponent | |
m = b & 0x007FFFFF # mantissa | |
if e != 0: | |
if e < 110: # noqa: PLR2004 | |
pass | |
elif e < 113: # noqa: PLR2004 | |
# denormalized number | |
ex = e - 112 | |
if ex >= -1: | |
ret |= 1 << (1 + ex) | |
ret |= m >> (22 - ex) | |
elif m > 0: | |
ret |= 1 | |
mask = 1 << (21 - ex) | |
if m & mask and ( | |
ret & 1 | |
or m & (mask - 1) > 0 | |
or (m & mask and m & (mask << 1) and m & (mask - 1) == 0) | |
): | |
# rounding | |
ret += 1 | |
elif e < 143: # noqa: PLR2004 | |
# normalized number | |
ex = e - 112 | |
ret |= ex << 2 | |
ret |= m >> 21 | |
if m & 0x100000 and ((m & 0xFFFFF) or (m & 0x200000)): | |
if (ret & 0x7F) < 0x7B: # noqa: PLR2004 | |
# rounding | |
ret += 1 | |
elif saturate: | |
ret |= 0x7B | |
else: | |
ret |= 0x7C | |
elif saturate: | |
ret |= 0x7B | |
else: | |
ret |= 0x7C | |
return int(ret) | |
else: | |
raise NotImplementedError("fn and uz must be both False or True.") | |
def pack_float32_to_4bit( | |
array: Union[np.ndarray, Sequence], signed: bool | |
) -> np.ndarray: | |
"""Convert an array of float32 value to a 4bit data-type and pack every two concecutive elements in a byte. | |
See :ref:`onnx-detail-int4` for technical details. | |
Args: | |
array: array of float to convert and pack | |
signed: Whether the 4 bit variant is signed or unsigned | |
Returns: | |
Packed array with size `ceil(farray.size/2)` (single dimension). | |
""" | |
if not isinstance(array, np.ndarray): | |
array = np.asarray(array, dtype=np.float32) | |
array_flat = array.ravel() | |
is_odd_volume = np.prod(array.shape) % 2 == 1 | |
if is_odd_volume: | |
array_flat = np.append(array_flat, np.array([0])) | |
single_func = lambda x, y: subbyte.float32x2_to_4bitx2(x, y, signed) # noqa: E731 | |
func = np.frompyfunc(single_func, 2, 1) | |
arr = func(array_flat[0::2], array_flat[1::2]) | |
return arr.astype(np.uint8) # type: ignore[no-any-return] | |
def make_tensor( | |
name: str, data_type: int, dims: Sequence[int], vals: Any, raw: bool = False | |
) -> TensorProto: | |
"""Make a TensorProto with specified arguments. If raw is False, this | |
function will choose the corresponding proto field to store the | |
values based on data_type. If raw is True, use "raw_data" proto | |
field to store the values, and values should be of type bytes in | |
this case. | |
Args: | |
name (string): tensor name | |
data_type (int): a value such as onnx.TensorProto.FLOAT | |
dims (List[int]): shape | |
vals: values | |
raw (bool): if True, vals contains the serialized content of the tensor, | |
otherwise, vals should be a list of values of the type defined by *data_type* | |
Returns: | |
TensorProto | |
""" | |
tensor = TensorProto() | |
tensor.data_type = data_type | |
tensor.name = name | |
if data_type == TensorProto.STRING and raw: | |
raise TypeError("Can not use raw_data to store string type.") | |
np_dtype = tensor_dtype_to_np_dtype(data_type) | |
# Check number of vals specified equals tensor size | |
expected_size = 1 | |
if raw: | |
# NumPy doesn't have BFLOAT16. TENSOR_TYPE_MAP maps it to float32, which has the wrong itemsize. | |
if data_type == TensorProto.BFLOAT16: | |
expected_size = 2 | |
elif data_type in ( | |
TensorProto.FLOAT8E4M3FN, | |
TensorProto.FLOAT8E4M3FNUZ, | |
TensorProto.FLOAT8E5M2, | |
TensorProto.FLOAT8E5M2FNUZ, | |
): | |
expected_size = 1 | |
# NumPy doesn't have INT4. It is packed in couples to UINT8 buffers. | |
elif data_type in (TensorProto.UINT4, TensorProto.INT4): | |
expected_size = 0.5 # type: ignore[assignment] | |
else: | |
expected_size = np_dtype.itemsize | |
if type(vals) is np.ndarray and len(vals.shape) > 1: | |
vals = vals.flatten() | |
for d in dims: | |
expected_size *= d | |
if len(vals) != expected_size: | |
# padding of half a byte is acceptable for 4bit types | |
if not ( | |
data_type in (TensorProto.UINT4, TensorProto.INT4) | |
and len(vals) == expected_size + 0.5 | |
): | |
raise ValueError( | |
f"Number of values does not match tensor's size. Expected {expected_size}, but it is {len(vals)}. " | |
) | |
if raw: | |
tensor.raw_data = vals | |
else: | |
if data_type in (TensorProto.COMPLEX64, TensorProto.COMPLEX128): | |
vals = split_complex_to_pairs(vals) | |
elif data_type == TensorProto.FLOAT16: | |
vals = ( | |
np.array(vals).astype(np_dtype).view(dtype=np.uint16).flatten().tolist() | |
) | |
elif data_type in ( | |
TensorProto.BFLOAT16, | |
TensorProto.FLOAT8E4M3FN, | |
TensorProto.FLOAT8E4M3FNUZ, | |
TensorProto.FLOAT8E5M2, | |
TensorProto.FLOAT8E5M2FNUZ, | |
): | |
fcast = { | |
TensorProto.BFLOAT16: float32_to_bfloat16, | |
TensorProto.FLOAT8E4M3FN: float32_to_float8e4m3, | |
TensorProto.FLOAT8E4M3FNUZ: lambda *args: float32_to_float8e4m3( # type: ignore[misc] | |
*args, uz=True | |
), | |
TensorProto.FLOAT8E5M2: float32_to_float8e5m2, | |
TensorProto.FLOAT8E5M2FNUZ: lambda *args: float32_to_float8e5m2( # type: ignore[misc] | |
*args, fn=True, uz=True | |
), | |
}[ | |
data_type # type: ignore[index] | |
] | |
vals = list( | |
map( # type: ignore[call-overload] | |
fcast, | |
np.array(vals).astype(np_dtype).flatten().tolist(), | |
) | |
) | |
elif data_type in ( | |
TensorProto.UINT4, | |
TensorProto.INT4, | |
): | |
signed = data_type == TensorProto.INT4 | |
vals = ( | |
pack_float32_to_4bit(vals, signed=signed) | |
.astype(np_dtype) | |
.flatten() | |
.tolist() | |
) | |
elif data_type == TensorProto.BOOL: | |
vals = np.array(vals).astype(int) | |
elif data_type == TensorProto.STRING: | |
vals = np.array(vals).astype(bytes) | |
field = tensor_dtype_to_field(data_type) | |
getattr(tensor, field).extend(vals) | |
tensor.dims.extend(dims) | |
return tensor | |
def make_sparse_tensor( | |
values: TensorProto, indices: TensorProto, dims: Sequence[int] | |
) -> SparseTensorProto: | |
"""Construct a SparseTensorProto | |
Args: | |
values (TensorProto): the values | |
indices (TensorProto): the indices | |
dims: the shape | |
Returns: | |
SparseTensorProto | |
""" | |
sparse = SparseTensorProto() | |
sparse.values.CopyFrom(values) | |
sparse.indices.CopyFrom(indices) | |
sparse.dims.extend(dims) | |
return sparse | |
def make_sequence( | |
name: str, | |
elem_type: SequenceProto.DataType, | |
values: Sequence[Any], | |
) -> SequenceProto: | |
"""Make a Sequence with specified value arguments.""" | |
sequence = SequenceProto() | |
sequence.name = name | |
sequence.elem_type = elem_type | |
if elem_type == SequenceProto.UNDEFINED: | |
return sequence | |
if elem_type == SequenceProto.TENSOR: | |
attribute = sequence.tensor_values | |
elif elem_type == SequenceProto.SPARSE_TENSOR: | |
attribute = sequence.sparse_tensor_values # type: ignore[assignment] | |
elif elem_type == SequenceProto.SEQUENCE: | |
attribute = sequence.sequence_values # type: ignore[assignment] | |
elif elem_type == SequenceProto.MAP: | |
attribute = sequence.map_values # type: ignore[assignment] | |
elif elem_type == OptionalProto.OPTIONAL: | |
attribute = sequence.optional_values # type: ignore[assignment] | |
else: | |
raise TypeError("The element type in the input sequence is not supported.") | |
attribute.extend(values) | |
return sequence | |
def make_map( | |
name: str, key_type: int, keys: List[Any], values: SequenceProto | |
) -> MapProto: | |
"""Make a Map with specified key-value pair arguments. | |
Criteria for conversion: | |
- Keys and Values must have the same number of elements | |
- Every key in keys must be of the same type | |
- Every value in values must be of the same type | |
""" | |
map_proto = MapProto() | |
valid_key_int_types = [ | |
TensorProto.INT8, | |
TensorProto.INT16, | |
TensorProto.INT32, | |
TensorProto.INT64, | |
TensorProto.UINT8, | |
TensorProto.UINT16, | |
TensorProto.UINT32, | |
TensorProto.UINT64, | |
] | |
map_proto.name = name | |
map_proto.key_type = key_type | |
if key_type == TensorProto.STRING: | |
map_proto.string_keys.extend(keys) | |
elif key_type in valid_key_int_types: | |
map_proto.keys.extend(keys) | |
map_proto.values.CopyFrom(values) | |
return map_proto | |
def make_optional( | |
name: str, | |
elem_type: OptionalProto.DataType, | |
value: Optional[Any], | |
) -> OptionalProto: | |
"""Make an Optional with specified value arguments.""" | |
optional = OptionalProto() | |
optional.name = name | |
optional.elem_type = elem_type | |
if elem_type == OptionalProto.UNDEFINED: | |
return optional | |
if elem_type == OptionalProto.TENSOR: | |
attribute = optional.tensor_value | |
elif elem_type == OptionalProto.SPARSE_TENSOR: | |
attribute = optional.sparse_tensor_value # type: ignore[assignment] | |
elif elem_type == OptionalProto.SEQUENCE: | |
attribute = optional.sequence_value # type: ignore[assignment] | |
elif elem_type == OptionalProto.MAP: | |
attribute = optional.map_value # type: ignore[assignment] | |
elif elem_type == OptionalProto.OPTIONAL: | |
attribute = optional.optional_value # type: ignore[assignment] | |
else: | |
raise TypeError("The element type in the input optional is not supported.") | |
attribute.CopyFrom(value) # type: ignore[arg-type] | |
return optional | |
def _to_bytes(value: Union[str, bytes]) -> bytes: | |
"""Coerce a string (or bytes) value into UTF-8 bytes.""" | |
return value if isinstance(value, bytes) else value.encode("utf-8") | |
def make_attribute( | |
key: str, | |
value: Any, | |
doc_string: Optional[str] = None, | |
attr_type: Optional[int] = None, | |
) -> AttributeProto: | |
"""Makes an AttributeProto based on the value type.""" | |
attr = AttributeProto() | |
attr.name = key | |
if doc_string: | |
attr.doc_string = doc_string | |
# Singular cases | |
if isinstance(value, numbers.Integral): | |
attr.i = int(value) | |
attr.type = AttributeProto.INT | |
elif isinstance(value, numbers.Real): | |
attr.f = float(value) | |
attr.type = AttributeProto.FLOAT | |
elif isinstance(value, (str, bytes)): | |
# Encode strings into utf-8 | |
attr.s = _to_bytes(value) | |
attr.type = AttributeProto.STRING | |
elif isinstance(value, TensorProto): | |
attr.t.CopyFrom(value) | |
attr.type = AttributeProto.TENSOR | |
elif isinstance(value, SparseTensorProto): | |
attr.sparse_tensor.CopyFrom(value) | |
attr.type = AttributeProto.SPARSE_TENSOR | |
elif isinstance(value, GraphProto): | |
attr.g.CopyFrom(value) | |
attr.type = AttributeProto.GRAPH | |
elif isinstance(value, TypeProto): | |
attr.tp.CopyFrom(value) | |
attr.type = AttributeProto.TYPE_PROTO | |
# Iterable cases | |
elif isinstance(value, collections.abc.Iterable): | |
value = list(value) | |
if len(value) == 0 and attr_type is None: | |
raise ValueError( | |
f"Could not infer attribute `{key}` type from empty iterator" | |
) | |
if attr_type is None: | |
types = {type(v) for v in value} | |
for exp_t, exp_enum in ( | |
(numbers.Integral, AttributeProto.INTS), | |
(numbers.Real, AttributeProto.FLOATS), | |
((str, bytes), AttributeProto.STRINGS), | |
(TensorProto, AttributeProto.TENSORS), | |
(SparseTensorProto, AttributeProto.SPARSE_TENSORS), | |
(GraphProto, AttributeProto.GRAPHS), | |
(TypeProto, AttributeProto.TYPE_PROTOS), | |
): | |
if all(issubclass(t, exp_t) for t in types): # type: ignore[arg-type] | |
attr_type = exp_enum | |
break | |
if attr_type is None: | |
raise ValueError( | |
"Could not infer the attribute type from the elements of the passed Iterable value." | |
) | |
if attr_type == AttributeProto.INTS: | |
attr.ints.extend(value) | |
attr.type = AttributeProto.INTS | |
elif attr_type == AttributeProto.FLOATS: | |
attr.floats.extend(value) | |
attr.type = AttributeProto.FLOATS | |
elif attr_type == AttributeProto.STRINGS: | |
attr.strings.extend(_to_bytes(v) for v in value) | |
attr.type = AttributeProto.STRINGS | |
elif attr_type == AttributeProto.TENSORS: | |
attr.tensors.extend(value) | |
attr.type = AttributeProto.TENSORS | |
elif attr_type == AttributeProto.SPARSE_TENSORS: | |
attr.sparse_tensors.extend(value) | |
attr.type = AttributeProto.SPARSE_TENSORS | |
elif attr_type == AttributeProto.GRAPHS: | |
attr.graphs.extend(value) | |
attr.type = AttributeProto.GRAPHS | |
elif attr_type == AttributeProto.TYPE_PROTOS: | |
attr.type_protos.extend(value) | |
attr.type = AttributeProto.TYPE_PROTOS | |
else: | |
raise AssertionError() # Should not reach since `ValueError` must be raised in attr_type checking | |
else: | |
raise TypeError(f"'{value}' is not an accepted attribute value.") | |
if attr_type is not None and attr.type != attr_type: | |
raise TypeError( | |
f"Inferred attribute type '{_attr_type_to_str(attr.type)}'({attr.type}) mismatched with specified type '{_attr_type_to_str(attr_type)}'({attr_type})" | |
) | |
return attr | |
def make_attribute_ref( | |
name: str, attr_type: AttributeProto.AttributeType, doc_string: Optional[str] = None | |
) -> AttributeProto: | |
"""Make an AttributeProto holding a reference to the parent function's attribute of given name and type.""" | |
attr = AttributeProto() | |
attr.name = name | |
attr.type = attr_type | |
if doc_string: | |
attr.doc_string = doc_string | |
return attr | |
def get_attribute_value(attr: AttributeProto) -> Any: # noqa: PLR0911 | |
if attr.ref_attr_name: | |
raise ValueError(f"Cannot get value of reference attribute: {attr}") | |
if attr.type == AttributeProto.FLOAT: | |
return attr.f | |
if attr.type == AttributeProto.INT: | |
return attr.i | |
if attr.type == AttributeProto.STRING: | |
return attr.s | |
if attr.type == AttributeProto.TENSOR: | |
return attr.t | |
if attr.type == AttributeProto.SPARSE_TENSOR: | |
return attr.sparse_tensor | |
if attr.type == AttributeProto.GRAPH: | |
return attr.g | |
if attr.type == AttributeProto.TYPE_PROTO: | |
return attr.tp | |
if attr.type == AttributeProto.FLOATS: | |
return list(attr.floats) | |
if attr.type == AttributeProto.INTS: | |
return list(attr.ints) | |
if attr.type == AttributeProto.STRINGS: | |
return list(attr.strings) | |
if attr.type == AttributeProto.TENSORS: | |
return list(attr.tensors) | |
if attr.type == AttributeProto.SPARSE_TENSORS: | |
return list(attr.sparse_tensors) | |
if attr.type == AttributeProto.GRAPHS: | |
return list(attr.graphs) | |
if attr.type == AttributeProto.TYPE_PROTOS: | |
return list(attr.type_protos) | |
if attr.type == AttributeProto.UNDEFINED: | |
return None | |
raise ValueError(f"Unsupported ONNX attribute: {attr}") | |
def get_node_attr_value(node: NodeProto, attr_name: str) -> Any: | |
matching = [x for x in node.attribute if x.name == attr_name] | |
if len(matching) > 1: | |
raise ValueError(f"Node has multiple attributes with name {attr_name}") | |
if len(matching) < 1: | |
raise ValueError(f"Node has no attribute with name {attr_name}") | |
return get_attribute_value(matching[0]) | |
def make_empty_tensor_value_info(name: str) -> ValueInfoProto: | |
value_info_proto = ValueInfoProto() | |
value_info_proto.name = name | |
return value_info_proto | |
def make_tensor_type_proto( | |
elem_type: int, | |
shape: Optional[Sequence[Union[str, int, None]]], | |
shape_denotation: Optional[List[str]] = None, | |
) -> TypeProto: | |
"""Makes a Tensor TypeProto based on the data type and shape.""" | |
type_proto = TypeProto() | |
tensor_type_proto = type_proto.tensor_type | |
tensor_type_proto.elem_type = elem_type | |
tensor_shape_proto = tensor_type_proto.shape | |
if shape is not None: | |
# You might think this is a no-op (extending a normal Python | |
# list by [] certainly is), but protobuf lists work a little | |
# differently; if a field is never set, it is omitted from the | |
# resulting protobuf; a list that is explicitly set to be | |
# empty will get an (empty) entry in the protobuf. This | |
# difference is visible to our consumers, so make sure we emit | |
# an empty shape! | |
tensor_shape_proto.dim.extend([]) | |
if shape_denotation and len(shape_denotation) != len(shape): | |
raise ValueError( | |
"Invalid shape_denotation. Must be of the same length as shape." | |
) | |
for i, d in enumerate(shape): | |
dim = tensor_shape_proto.dim.add() | |
if d is None: | |
pass | |
elif isinstance(d, int): | |
dim.dim_value = d | |
elif isinstance(d, str): | |
dim.dim_param = d | |
else: | |
raise ValueError( | |
f"Invalid item in shape: {d}. Needs to be of int or str." | |
) | |
if shape_denotation: | |
dim.denotation = shape_denotation[i] | |
return type_proto | |
def make_tensor_value_info( | |
name: str, | |
elem_type: int, | |
shape: Optional[Sequence[Union[str, int, None]]], | |
doc_string: str = "", | |
shape_denotation: Optional[List[str]] = None, | |
) -> ValueInfoProto: | |
"""Makes a ValueInfoProto based on the data type and shape.""" | |
value_info_proto = ValueInfoProto() | |
value_info_proto.name = name | |
if doc_string: | |
value_info_proto.doc_string = doc_string | |
tensor_type_proto = make_tensor_type_proto(elem_type, shape, shape_denotation) | |
value_info_proto.type.CopyFrom(tensor_type_proto) | |
return value_info_proto | |
def make_sparse_tensor_type_proto( | |
elem_type: int, | |
shape: Optional[Sequence[Union[str, int, None]]], | |
shape_denotation: Optional[List[str]] = None, | |
) -> TypeProto: | |
"""Makes a SparseTensor TypeProto based on the data type and shape.""" | |
type_proto = TypeProto() | |
sparse_tensor_type_proto = type_proto.sparse_tensor_type | |
sparse_tensor_type_proto.elem_type = elem_type | |
sparse_tensor_shape_proto = sparse_tensor_type_proto.shape | |
if shape is not None: | |
# You might think this is a no-op (extending a normal Python | |
# list by [] certainly is), but protobuf lists work a little | |
# differently; if a field is never set, it is omitted from the | |
# resulting protobuf; a list that is explicitly set to be | |
# empty will get an (empty) entry in the protobuf. This | |
# difference is visible to our consumers, so make sure we emit | |
# an empty shape! | |
sparse_tensor_shape_proto.dim.extend([]) | |
if shape_denotation and len(shape_denotation) != len(shape): | |
raise ValueError( | |
"Invalid shape_denotation. Must be of the same length as shape." | |
) | |
for i, d in enumerate(shape): | |
dim = sparse_tensor_shape_proto.dim.add() | |
if d is None: | |
pass | |
elif isinstance(d, int): | |
dim.dim_value = d | |
elif isinstance(d, str): | |
dim.dim_param = d | |
else: | |
raise ValueError( | |
f"Invalid item in shape: {d}. Needs to be of int or text." | |
) | |
if shape_denotation: | |
dim.denotation = shape_denotation[i] | |
return type_proto | |
def make_sparse_tensor_value_info( | |
name: str, | |
elem_type: int, | |
shape: Optional[Sequence[Union[str, int, None]]], | |
doc_string: str = "", | |
shape_denotation: Optional[List[str]] = None, | |
) -> ValueInfoProto: | |
"""Makes a SparseTensor ValueInfoProto based on the data type and shape.""" | |
value_info_proto = ValueInfoProto() | |
value_info_proto.name = name | |
if doc_string: | |
value_info_proto.doc_string = doc_string | |
sparse_tensor_type_proto = make_sparse_tensor_type_proto( | |
elem_type, shape, shape_denotation | |
) | |
value_info_proto.type.sparse_tensor_type.CopyFrom( | |
sparse_tensor_type_proto.sparse_tensor_type | |
) | |
return value_info_proto | |
def make_sequence_type_proto( | |
inner_type_proto: TypeProto, | |
) -> TypeProto: | |
"""Makes a sequence TypeProto.""" | |
type_proto = TypeProto() | |
type_proto.sequence_type.elem_type.CopyFrom(inner_type_proto) | |
return type_proto | |
def make_optional_type_proto( | |
inner_type_proto: TypeProto, | |
) -> TypeProto: | |
"""Makes an optional TypeProto.""" | |
type_proto = TypeProto() | |
type_proto.optional_type.elem_type.CopyFrom(inner_type_proto) | |
return type_proto | |
def make_map_type_proto( | |
key_type: int, | |
value_type: TypeProto, | |
) -> TypeProto: | |
"""Makes a map TypeProto.""" | |
type_proto = TypeProto() | |
type_proto.map_type.key_type = key_type | |
type_proto.map_type.value_type.CopyFrom(value_type) | |
return type_proto | |
def make_value_info( | |
name: str, | |
type_proto: TypeProto, | |
doc_string: str = "", | |
) -> ValueInfoProto: | |
"""Makes a ValueInfoProto with the given type_proto.""" | |
value_info_proto = ValueInfoProto() | |
value_info_proto.name = name | |
if doc_string: | |
value_info_proto.doc_string = doc_string | |
value_info_proto.type.CopyFrom(type_proto) | |
return value_info_proto | |
def _sanitize_str(s: Union[str, bytes]) -> str: | |
if isinstance(s, str): | |
sanitized = s | |
elif isinstance(s, bytes): | |
sanitized = s.decode("utf-8", errors="ignore") | |
else: | |
sanitized = str(s) | |
if len(sanitized) < 64: # noqa: PLR2004 | |
return sanitized | |
return sanitized[:64] + f"...<+len={(len(sanitized) - 64)}>" | |
def make_tensor_sequence_value_info( | |
name: str, | |
elem_type: int, | |
shape: Optional[Sequence[Union[str, int, None]]], | |
doc_string: str = "", | |
elem_shape_denotation: Optional[List[str]] = None, | |
) -> ValueInfoProto: | |
"""Makes a Sequence[Tensors] ValueInfoProto based on the data type and shape.""" | |
value_info_proto = ValueInfoProto() | |
value_info_proto.name = name | |
if doc_string: | |
value_info_proto.doc_string = doc_string | |
tensor_type_proto = make_tensor_type_proto(elem_type, shape, elem_shape_denotation) | |
sequence_type_proto = make_sequence_type_proto(tensor_type_proto) | |
value_info_proto.type.sequence_type.CopyFrom(sequence_type_proto.sequence_type) | |
return value_info_proto | |
def printable_attribute( | |
attr: AttributeProto, subgraphs: bool = False | |
) -> Union[str, Tuple[str, List[GraphProto]]]: | |
content = [] | |
content.append(attr.name) | |
content.append("=") | |
def str_float(f: float) -> str: | |
# NB: Different Python versions print different numbers of trailing | |
# decimals, specifying this explicitly keeps it consistent for all | |
# versions | |
return f"{f:.15g}" | |
def str_int(i: int) -> str: | |
return str(i) | |
_T = TypeVar("_T") | |
def str_list(str_elem: Callable[[_T], str], xs: Sequence[_T]) -> str: | |
return "[" + ", ".join(map(str_elem, xs)) + "]" | |
# for now, this logic should continue to work as long as we are running on a proto3 | |
# implementation. If/when we switch to proto3, we will need to use attr.type | |
# To support printing subgraphs, if we find a graph attribute, print out | |
# its name here and pass the graph itself up to the caller for later | |
# printing. | |
graphs = [] | |
if attr.HasField("f"): | |
content.append(str_float(attr.f)) | |
elif attr.HasField("i"): | |
content.append(str_int(attr.i)) | |
elif attr.HasField("s"): | |
# TODO: Bit nervous about Python 2 / Python 3 determinism implications | |
content.append(repr(_sanitize_str(attr.s))) | |
elif attr.HasField("t"): | |
if len(attr.t.dims) > 0: | |
content.append("<Tensor>") | |
else: | |
# special case to print scalars | |
field = tensor_dtype_to_field(attr.t.data_type) | |
content.append(f"<Scalar Tensor {getattr(attr.t, field)}>") | |
elif attr.HasField("g"): | |
content.append(f"<graph {attr.g.name}>") | |
graphs.append(attr.g) | |
elif attr.HasField("tp"): | |
content.append(f"<Type Proto {attr.tp}>") | |
elif attr.floats: | |
content.append(str_list(str_float, attr.floats)) | |
elif attr.ints: | |
content.append(str_list(str_int, attr.ints)) | |
elif attr.strings: | |
# TODO: Bit nervous about Python 2 / Python 3 determinism implications | |
content.append(str(list(map(_sanitize_str, attr.strings)))) | |
elif attr.tensors: | |
content.append("[<Tensor>, ...]") | |
elif attr.type_protos: | |
content.append("[") | |
for i, tp in enumerate(attr.type_protos): | |
comma = "," if i != len(attr.type_protos) - 1 else "" | |
content.append(f"<Type Proto {tp}>{comma}") | |
content.append("]") | |
elif attr.graphs: | |
content.append("[") | |
for i, g in enumerate(attr.graphs): | |
comma = "," if i != len(attr.graphs) - 1 else "" | |
content.append(f"<graph {g.name}>{comma}") | |
content.append("]") | |
graphs.extend(attr.graphs) | |
else: | |
content.append("<Unknown>") | |
if subgraphs: | |
return " ".join(content), graphs | |
return " ".join(content) | |
def printable_dim(dim: TensorShapeProto.Dimension) -> str: | |
which = dim.WhichOneof("value") | |
if which is None: | |
return "?" | |
return str(getattr(dim, which)) | |
def printable_type(t: TypeProto) -> str: | |
if t.WhichOneof("value") == "tensor_type": | |
s = TensorProto.DataType.Name(t.tensor_type.elem_type) | |
if t.tensor_type.HasField("shape"): | |
if len(t.tensor_type.shape.dim): | |
s += str(", " + "x".join(map(printable_dim, t.tensor_type.shape.dim))) | |
else: | |
s += ", scalar" | |
return s # type: ignore[no-any-return] | |
if t.WhichOneof("value") is None: | |
return "" | |
return f"Unknown type {t.WhichOneof('value')}" | |
def printable_value_info(v: ValueInfoProto) -> str: | |
s = f"%{v.name}" | |
if v.type: | |
s = f"{s}[{printable_type(v.type)}]" | |
return s | |
def printable_tensor_proto(t: TensorProto) -> str: | |
s = f"%{t.name}[" | |
s += TensorProto.DataType.Name(t.data_type) | |
if t.dims is not None: | |
if len(t.dims): | |
s += str(", " + "x".join(map(str, t.dims))) | |
else: | |
s += ", scalar" | |
s += "]" | |
return s | |
def printable_node( | |
node: NodeProto, prefix: str = "", subgraphs: bool = False | |
) -> Union[str, Tuple[str, List[GraphProto]]]: | |
content = [] | |
if len(node.output): | |
content.append(", ".join([f"%{name}" for name in node.output])) | |
content.append("=") | |
# To deal with nested graphs | |
graphs: List[GraphProto] = [] | |
printed_attrs = [] | |
for attr in node.attribute: | |
if subgraphs: | |
printed_attr_subgraphs = printable_attribute(attr, subgraphs) | |
if not isinstance(printed_attr_subgraphs[1], list): | |
raise TypeError( | |
f"printed_attr_subgraphs[1] must be an instance of {list}." | |
) | |
graphs.extend(printed_attr_subgraphs[1]) | |
printed_attrs.append(printed_attr_subgraphs[0]) | |
else: | |
printed = printable_attribute(attr) | |
if not isinstance(printed, str): | |
raise TypeError(f"printed must be an instance of {str}.") | |
printed_attrs.append(printed) | |
printed_attributes = ", ".join(sorted(printed_attrs)) | |
printed_inputs = ", ".join([f"%{name}" for name in node.input]) | |
if node.attribute: | |
content.append(f"{node.op_type}[{printed_attributes}]({printed_inputs})") | |
else: | |
content.append(f"{node.op_type}({printed_inputs})") | |
if subgraphs: | |
return prefix + " ".join(content), graphs | |
return prefix + " ".join(content) | |
def printable_graph(graph: GraphProto, prefix: str = "") -> str: | |
"""Display a GraphProto as a string. | |
Args: | |
graph (GraphProto): the graph to display | |
prefix (string): prefix of every line | |
Returns: | |
string | |
""" | |
content = [] | |
indent = prefix + " " | |
# header | |
header = ["graph", graph.name] | |
initializers = {t.name for t in graph.initializer} | |
if len(graph.input): | |
header.append("(") | |
in_strs = [] # required inputs | |
in_with_init_strs = ( | |
[] | |
) # optional inputs with initializer providing default value | |
for inp in graph.input: | |
if inp.name not in initializers: | |
in_strs.append(printable_value_info(inp)) | |
else: | |
in_with_init_strs.append(printable_value_info(inp)) | |
if in_strs: | |
content.append(prefix + " ".join(header)) | |
header = [] | |
for line in in_strs: | |
content.append(prefix + " " + line) | |
header.append(")") | |
if in_with_init_strs: | |
header.append("optional inputs with matching initializers (") | |
content.append(prefix + " ".join(header)) | |
header = [] | |
for line in in_with_init_strs: | |
content.append(prefix + " " + line) | |
header.append(")") | |
# from IR 4 onwards an initializer is not required to have a matching graph input | |
# so output the name, type and shape of those as well | |
if len(in_with_init_strs) < len(initializers): | |
graph_inputs = {i.name for i in graph.input} | |
init_strs = [ | |
printable_tensor_proto(i) | |
for i in graph.initializer | |
if i.name not in graph_inputs | |
] | |
header.append("initializers (") | |
content.append(prefix + " ".join(header)) | |
header = [] | |
for line in init_strs: | |
content.append(prefix + " " + line) | |
header.append(")") | |
header.append("{") | |
content.append(prefix + " ".join(header)) | |
graphs: List[GraphProto] = [] | |
# body | |
for node in graph.node: | |
contents_subgraphs = printable_node(node, indent, subgraphs=True) | |
if not isinstance(contents_subgraphs[1], list): | |
raise TypeError(f"contents_subgraphs[1] must be an instance of {list}.") | |
content.append(contents_subgraphs[0]) | |
graphs.extend(contents_subgraphs[1]) | |
# tail | |
tail = ["return"] | |
if len(graph.output): | |
tail.append(", ".join([f"%{out.name}" for out in graph.output])) | |
content.append(indent + " ".join(tail)) | |
# closing bracket | |
content.append(prefix + "}") | |
for g in graphs: | |
content.append("\n" + printable_graph(g)) | |
return "\n".join(content) | |
def strip_doc_string(proto: google.protobuf.message.Message) -> None: | |
"""Empties `doc_string` field on any nested protobuf messages""" | |
if not isinstance(proto, google.protobuf.message.Message): | |
raise TypeError( | |
f"proto must be an instance of {google.protobuf.message.Message}." | |
) | |
for descriptor in proto.DESCRIPTOR.fields: | |
if descriptor.name == "doc_string": | |
proto.ClearField(descriptor.name) | |
elif descriptor.type == descriptor.TYPE_MESSAGE: | |
if descriptor.label == descriptor.LABEL_REPEATED: | |
for x in getattr(proto, descriptor.name): | |
strip_doc_string(x) | |
elif proto.HasField(descriptor.name): | |
strip_doc_string(getattr(proto, descriptor.name)) | |
def make_training_info( | |
algorithm: GraphProto, | |
algorithm_bindings: AssignmentBindingType, | |
initialization: Optional[GraphProto], | |
initialization_bindings: Optional[AssignmentBindingType], | |
) -> TrainingInfoProto: | |
training_info = TrainingInfoProto() | |
training_info.algorithm.CopyFrom(algorithm) | |
for k, v in algorithm_bindings: | |
binding = training_info.update_binding.add() | |
binding.key = k | |
binding.value = v | |
if initialization: | |
training_info.initialization.CopyFrom(initialization) | |
if initialization_bindings: | |
for k, v in initialization_bindings: | |
binding = training_info.initialization_binding.add() | |
binding.key = k | |
binding.value = v | |
return training_info | |
# Following functions are used for mapping | |
def tensor_dtype_to_np_dtype(tensor_dtype: int) -> np.dtype: | |
"""Convert a TensorProto's data_type to corresponding numpy dtype. It can be used while making tensor. | |
Args: | |
tensor_dtype: TensorProto's data_type | |
Returns: | |
numpy's data_type | |
""" | |
return mapping.TENSOR_TYPE_MAP[tensor_dtype].np_dtype | |
def tensor_dtype_to_storage_tensor_dtype(tensor_dtype: int) -> int: | |
"""Convert a TensorProto's data_type to corresponding data_type for storage. | |
Args: | |
tensor_dtype: TensorProto's data_type | |
Returns: | |
data_type for storage | |
""" | |
return mapping.TENSOR_TYPE_MAP[tensor_dtype].storage_dtype | |
def tensor_dtype_to_string(tensor_dtype: int) -> str: | |
"""Get the name of given TensorProto's data_type. | |
Args: | |
tensor_dtype: TensorProto's data_type | |
Returns: | |
the name of data_type | |
""" | |
return mapping.TENSOR_TYPE_MAP[tensor_dtype].name | |
def tensor_dtype_to_field(tensor_dtype: int) -> str: | |
"""Convert a TensorProto's data_type to corresponding field name for storage. It can be used while making tensors. | |
Args: | |
tensor_dtype: TensorProto's data_type | |
Returns: | |
field name | |
""" | |
return mapping._STORAGE_TENSOR_TYPE_TO_FIELD[ | |
mapping.TENSOR_TYPE_MAP[tensor_dtype].storage_dtype | |
] | |
def np_dtype_to_tensor_dtype(np_dtype: np.dtype) -> int: | |
"""Convert a numpy's dtype to corresponding tensor type. It can be used while converting numpy arrays to tensors. | |
Args: | |
np_dtype: numpy's data_type | |
Returns: | |
TensorsProto's data_type | |
""" | |
return cast( | |
int, | |
mapping._NP_TYPE_TO_TENSOR_TYPE[np_dtype], | |
) | |
def get_all_tensor_dtypes() -> KeysView[int]: | |
"""Get all tensor types from TensorProto. | |
Returns: | |
all tensor types from TensorProto | |
""" | |
return mapping.TENSOR_TYPE_MAP.keys() | |
_ATTRIBUTE_TYPE_TO_STR = {k: v for v, k in AttributeProto.AttributeType.items()} | |
def _attr_type_to_str(attr_type: int) -> str: | |
"""Convert AttributeProto type to string. | |
Args: | |
attr_type: AttributeProto type. | |
Returns: | |
String representing the supplied attr_type. | |
""" | |
if attr_type in AttributeProto.AttributeType.values(): | |
return _ATTRIBUTE_TYPE_TO_STR[attr_type] # type: ignore[no-any-return] | |
return AttributeProto.AttributeType.keys()[0] # type: ignore[no-any-return] | |