Spaces:
Sleeping
Sleeping
# Copyright (c) ONNX Project Contributors | |
# SPDX-License-Identifier: Apache-2.0 | |
from typing import List, Optional, Union | |
import numpy as np | |
from onnx import ( | |
AttributeProto, | |
FunctionProto, | |
GraphProto, | |
ModelProto, | |
NodeProto, | |
SparseTensorProto, | |
TensorProto, | |
) | |
from onnx.helper import ( | |
make_attribute, | |
make_function, | |
make_graph, | |
make_model, | |
make_node, | |
make_tensor, | |
make_tensor_value_info, | |
set_model_props, | |
tensor_dtype_to_np_dtype, | |
) | |
from onnx.numpy_helper import from_array | |
def _replace_constant( | |
node: NodeProto, threshold: int, value_constant_of_shape: float | |
) -> List[NodeProto]: | |
"""Replaces a Constant node with a large tensor (with more than threshold elements) by a sequence of nodes that produces a dummy constant of same shape as original tensor.""" | |
if node.op_type != "Constant": | |
raise TypeError(f"Node type must be 'Constant' not {node.op_type!r}.") | |
for att in node.attribute: | |
if att.name == "sparse_value": | |
raise NotImplementedError( | |
f"This feature is not yet implemented for a sparse constant " | |
f"(node name={node.name!r})." | |
) | |
if att.name == "value": | |
value = att.t | |
new_name = f"{value.name}__SHAPE" | |
dims = value.dims | |
size = np.prod(dims) | |
if size <= threshold: | |
return [node] | |
init = from_array(np.array(list(dims), dtype=np.int64), name=new_name) | |
dtype = tensor_dtype_to_np_dtype(value.data_type) | |
node_shape = make_node( | |
"Constant", | |
[], | |
[new_name], | |
value=init, | |
) | |
new_node = make_node( | |
"ConstantOfShape", | |
[new_name], | |
node.output, | |
value=from_array(np.array([value_constant_of_shape], dtype=dtype)), | |
) | |
return [node_shape, new_node] | |
raise NotImplementedError( | |
f"Replacement of constant with attribute {att.name!r}" | |
) | |
return [node] | |
def _replace_constant_of_shape_with_range( | |
onx: Union[GraphProto, FunctionProto] | |
) -> Union[GraphProto, FunctionProto]: | |
"""Replaces all *ConstantOfShape* by node *Range* to avoid constant tensors. | |
The function is not recursive. The recursivity is done by | |
*replace_initializer_by_constant_of_shape*. | |
""" | |
if isinstance(onx, GraphProto): | |
nodes = list(onx.node) | |
elif isinstance(onx, FunctionProto): | |
nodes = list(onx.node) | |
else: | |
raise TypeError(f"Not implemented for type {type(onx)}.") | |
existing_names = set() | |
for node in nodes: | |
existing_names |= set(node.input) | |
existing_names |= set(node.output) | |
def _find_name(prefix): | |
if prefix not in existing_names: | |
existing_names.add(prefix) | |
return prefix | |
i = 2 | |
while True: | |
name = f"{prefix}_{i}" | |
if name not in existing_names: | |
existing_names.add(name) | |
return name | |
i += 1 | |
# The function should never go through that line. | |
raise RuntimeError("The function should never go through that line.") | |
cst0 = make_node("Constant", [], [_find_name("zero")], value_int=0) | |
cst1 = make_node("Constant", [], [_find_name("one")], value_int=1) | |
update = {} | |
for inode, node in enumerate(nodes): | |
if node.op_type != "ConstantOfShape": | |
continue | |
shape = node.input[0] | |
n = make_node("ReduceProd", [shape], [_find_name(f"{shape}_N")]) | |
a = make_node( | |
"Range", | |
[cst0.output[0], n.output[0], cst1.output[0]], | |
[_find_name(f"{shape}_RANGE")], | |
) | |
if len(node.attribute) == 1: | |
to = node.attribute[0].t.data_type | |
else: | |
to = TensorProto.FLOAT | |
ac = make_node("Cast", [a.output[0]], [_find_name(f"{shape}_RANGEf")], to=to) | |
cl = make_node("Cast", [n.output[0]], [_find_name(f"{shape}_Nf")], to=to) | |
d = make_node( | |
"Div", [ac.output[0], cl.output[0]], [_find_name(f"{shape}_FLAT")] | |
) | |
resh = make_node("Reshape", [d.output[0], shape], node.output) | |
update[inode] = [n, a, ac, cl, d, resh] | |
for inode, up in sorted(update.items(), reverse=True): | |
nodes[inode : inode + 1] = up | |
nodes.insert(0, cst0) | |
nodes.insert(1, cst1) | |
if isinstance(onx, GraphProto): | |
graph = make_graph( | |
nodes, | |
onx.name, | |
onx.input, | |
onx.output, | |
initializer=onx.initializer, | |
sparse_initializer=onx.sparse_initializer, | |
) | |
return graph | |
if isinstance(onx, FunctionProto): | |
new_onx = make_function( | |
onx.domain, | |
onx.name, | |
onx.input, | |
onx.output, | |
nodes, | |
opset_imports=onx.opset_import, | |
) | |
return new_onx | |
raise TypeError(f"Not implemented for type {type(onx)}.") | |
def _replace_constant_of_shape_value( | |
onx: Union[GraphProto, FunctionProto], value_constant_of_shape: float | |
) -> Union[GraphProto, FunctionProto]: | |
"""Replaces all fill value of all nodes *ConstantOfShape*.""" | |
if isinstance(onx, GraphProto): | |
nodes = list(onx.node) | |
elif isinstance(onx, FunctionProto): | |
nodes = list(onx.node) | |
else: | |
raise TypeError(f"Not implemented for type {type(onx)}.") | |
existing_names = set() | |
for node in nodes: | |
existing_names |= set(node.input) | |
existing_names |= set(node.output) | |
update = {} | |
for inode, node in enumerate(nodes): | |
if node.op_type != "ConstantOfShape": | |
continue | |
tensor = node.attribute[0].t | |
new_tensor = make_tensor( | |
tensor.name, tensor.data_type, [1], [value_constant_of_shape] | |
) | |
new_node = make_node("ConstantOfShape", node.input, node.output) | |
att = make_attribute(node.attribute[0].name, value=new_tensor) | |
new_node.attribute.append(att) | |
update[inode] = new_node | |
for inode, up in update.items(): | |
nodes[inode] = up | |
if isinstance(onx, GraphProto): | |
graph = make_graph( | |
nodes, | |
onx.name, | |
onx.input, | |
onx.output, | |
initializer=onx.initializer, | |
sparse_initializer=onx.sparse_initializer, | |
) | |
return graph | |
if isinstance(onx, FunctionProto): | |
new_onx = make_function( | |
onx.domain, | |
onx.name, | |
onx.input, | |
onx.output, | |
nodes, | |
opset_imports=onx.opset_import, | |
) | |
return new_onx | |
raise TypeError(f"Not implemented for type {type(onx)}.") | |
def replace_initializer_by_constant_of_shape( # noqa: PLR0911 | |
onx: Union[FunctionProto, GraphProto, ModelProto], | |
threshold: int = 128, | |
ir_version: Optional[int] = None, | |
use_range: bool = False, | |
value_constant_of_shape: float = 0.5, | |
): | |
"""Replace initializers or constant node by nodes *ConstantOfShape* to reduce the size. | |
This reduce the cost to write a unit test about a specific graph structure. | |
Args: | |
onx: ModelProto | |
threshold: every initializer under this threshold is not | |
impacted | |
ir_version: initializer must be specified as input for | |
`ir_version <= 3`, this must be specified if onx is | |
:class:`FunctionProto` or :class:`GraphProto` | |
use_range: if uses operator *Range* instead of *ConstantOfShape* | |
to avoid constant tensors | |
value_constant_of_shape: value to use as a value for all nodes | |
*ConstantOfShape*, a high value may produce nan or inf | |
predictions | |
Returns: | |
onx, modified ModelProto | |
The function is designed so that the function can be reapplied on a modified model | |
and either replace *ConstantOfShape* with *Range* operators, either replace the fill value | |
for every *ConstantOfShape*. | |
""" | |
if isinstance(onx, FunctionProto): | |
modified = False | |
new_nodes: List[NodeProto] = [] | |
for node in onx.node: | |
if node.op_type == "Constant": | |
cst_nodes = _replace_constant(node, threshold, value_constant_of_shape) | |
if len(cst_nodes) == 2: # noqa: PLR2004 | |
modified = True | |
new_nodes.extend(cst_nodes) | |
continue | |
new_nodes.append(node) | |
if modified: | |
new_onx = make_function( | |
onx.domain, | |
onx.name, | |
onx.input, | |
onx.output, | |
new_nodes, | |
opset_imports=onx.opset_import, | |
) | |
if use_range: | |
return _replace_constant_of_shape_with_range(new_onx) | |
if value_constant_of_shape != 1: | |
return _replace_constant_of_shape_value( | |
new_onx, value_constant_of_shape | |
) | |
return new_onx | |
if use_range: | |
return _replace_constant_of_shape_with_range(onx) | |
if value_constant_of_shape != 1: | |
return _replace_constant_of_shape_value(onx, value_constant_of_shape) | |
return onx | |
if isinstance(onx, ModelProto): | |
new_graph = replace_initializer_by_constant_of_shape( | |
onx.graph, | |
ir_version=ir_version or onx.ir_version, | |
threshold=threshold, | |
use_range=use_range, | |
value_constant_of_shape=value_constant_of_shape, | |
) | |
new_functions = [ | |
replace_initializer_by_constant_of_shape( | |
f, | |
threshold=threshold, | |
ir_version=ir_version or onx.ir_version, | |
use_range=use_range, | |
value_constant_of_shape=value_constant_of_shape, | |
) | |
for f in onx.functions | |
] | |
model = make_model( | |
new_graph, | |
functions=new_functions, | |
producer_name=onx.producer_name, | |
producer_version=onx.producer_version, | |
ir_version=ir_version or onx.ir_version, | |
doc_string=onx.doc_string, | |
domain=onx.domain, | |
model_version=onx.model_version, | |
) | |
if len(onx.metadata_props) > 0: # pragma: no cover | |
values = {p.key: p.value for p in onx.metadata_props} | |
set_model_props(model, values) | |
del model.opset_import[:] | |
for oimp in onx.opset_import: | |
op_set = model.opset_import.add() | |
if oimp.domain == "" and oimp.version < 11 and use_range: # noqa: PLR2004 | |
raise RuntimeError( | |
f"Range was introduced in opset 11 but opset is {oimp.version}." | |
) | |
if oimp.domain == "" and oimp.version < 9: # noqa: PLR2004 | |
raise RuntimeError( | |
f"ConstantOfShape was introduced in " | |
f"opset 9 but opset is {oimp.version}." | |
) | |
op_set.domain = oimp.domain | |
op_set.version = oimp.version | |
return model | |
if not isinstance(onx, GraphProto): | |
raise TypeError(f"onx should be a GraphProto at this stage not {type(onx)}.") | |
n_modifications = 0 | |
new_nodes = [] | |
removed = set() | |
additional_inputs = [] | |
new_inits: List[TensorProto] = [] | |
for init in onx.initializer: | |
dims = tuple(init.dims) | |
size = np.prod(dims) | |
if size <= threshold: | |
new_inits.append(init) | |
continue | |
n_modifications += 1 | |
new_name = f"{init.name}__SHAPE" | |
new_inits.append( | |
from_array(np.array(list(dims), dtype=np.int64), name=new_name) | |
) | |
dtype = tensor_dtype_to_np_dtype(init.data_type) | |
node = make_node( | |
"ConstantOfShape", | |
[new_name], | |
[init.name], | |
value=from_array(np.array([0.5], dtype=dtype)), | |
) | |
new_nodes.append(node) | |
removed.add(init.name) | |
if ir_version is not None and ir_version <= 3: # noqa: PLR2004 | |
additional_inputs.append( | |
make_tensor_value_info(new_name, TensorProto.INT64, [len(dims)]) | |
) | |
new_sparse_inits: List[SparseTensorProto] = [] | |
for sp_init in onx.sparse_initializer: | |
dims = tuple(sp_init.dims) | |
size = np.prod(dims) | |
if size <= threshold: | |
new_sparse_inits.append(sp_init) | |
continue | |
raise NotImplementedError( | |
f"This feature is not yet implemented for a sparse initializer " | |
f"(indices.name={sp_init.indices.name!r}, " | |
f"values.name={sp_init.values.name!r})." | |
) | |
for node in onx.node: | |
if node.op_type == "Constant": | |
shape_nodes = _replace_constant(node, threshold, value_constant_of_shape) | |
if len(shape_nodes) == 2: # noqa: PLR2004 | |
n_modifications += 1 | |
new_nodes.extend(shape_nodes) | |
continue | |
modified = False | |
atts = [] | |
for att in node.attribute: | |
if ( | |
att.type == AttributeProto.GRAPH | |
and hasattr(att, "g") | |
and att.g is not None | |
): | |
g = replace_initializer_by_constant_of_shape( | |
att.g, | |
threshold=threshold, | |
ir_version=ir_version, | |
use_range=use_range, | |
value_constant_of_shape=value_constant_of_shape, | |
) | |
if id(g) != id(att.g): | |
modified = True | |
att = make_attribute(att.name, g) # noqa: PLW2901 | |
atts.append(att) | |
if modified: | |
new_node = make_node(node.op_type, node.input, node.output) | |
new_node.attribute.extend(atts) | |
new_nodes.append(new_node) | |
n_modifications += 1 | |
else: | |
new_nodes.append(node) | |
if n_modifications > 0: | |
graph = make_graph( | |
new_nodes, | |
onx.name, | |
[i for i in onx.input if i.name not in removed] + additional_inputs, | |
onx.output, | |
initializer=new_inits, | |
sparse_initializer=new_sparse_inits, | |
) | |
if use_range: | |
return _replace_constant_of_shape_with_range(graph) | |
if value_constant_of_shape != 1: | |
return _replace_constant_of_shape_value(graph, value_constant_of_shape) | |
return graph | |
if use_range: | |
return _replace_constant_of_shape_with_range(onx) | |
if value_constant_of_shape != 1: | |
return _replace_constant_of_shape_value(onx, value_constant_of_shape) | |
return onx | |