Spaces:
Runtime error
Runtime error
from collections import defaultdict, OrderedDict | |
from dataclasses import dataclass | |
from typing import List, Dict, Union | |
import numpy as np | |
from openvino.runtime import Model, Node | |
from openvino.runtime.op import Parameter, Constant | |
import openvino.runtime.opset12 as opset | |
from openvino.runtime.utils.types import get_element_type | |
import openvino as ov | |
from tqdm.auto import tqdm | |
OPERATION_TYPE_MAP = {"MatMul": opset.matmul, "Convolution": opset.convolution} | |
ORIGINAL_PRECISION_RT_INFO_NAME = "precise_0" | |
class TrackedNodeInfo: | |
""" | |
Data associated with a node tracked for upcasting | |
""" | |
node: Node # Target node to track | |
snr: float = None # SNR of the target node | |
input_nodes: List[Node] = None # Input nodes of the target node | |
result_node: Node = None # Result node of the target node | |
input_result_nodes: Dict[Node, Node] = None # Result nodes of non-const inputs of the target node | |
node_value_full_precision: np.ndarray = None # Result of the node in full precision | |
node_value_half_precision: np.ndarray = None # Result of the node in half precision | |
input_values_full_precision: np.ndarray = None # Results of the target node inputs in full precision | |
def partially_upcast_nodes_to_fp32( | |
orig_model: Model, | |
example_input: Union[List, Dict], | |
half_type: str = "f16", | |
batch_size: int = 50, | |
operation_types: List[str] = None, | |
upcast_ratio: float = 0.1, | |
verbose: bool = False, | |
) -> Model: | |
""" | |
Transform a model to upcast some nodes to be executed in full precision instead of half precision. These nodes are | |
marked with runtime info flag. | |
Nodes are selected based on Signal-to-Noise Ratio (SNR) metric: upcast_ratio fraction of tracked nodes with the | |
lowest SNR are marked for full precision execution. | |
Note: Input model should have fp16 weights (i.e. saved with compress_to_fp16=True) in order to conserve | |
calibration memory. | |
:param orig_model: Model to process | |
:param example_input: Example input for model inference | |
:param half_type: Either "f16" or "bf16" | |
:param batch_size: Number of nodes to process together during a single model inference. The lower the value is, | |
the less memory footprint is, but the larger is the processing time. The value of -1 is used to disable | |
batching. | |
:param operation_types: Types of operations to consider. If None, MatMuls and Convolutions are considered. | |
:param upcast_ratio: Fraction of nodes to upcast (with the lowest SNR). 0 - do not upcast anything, 1 - upcast every | |
operation of the given types. | |
:param verbose: If True, prints progress output. | |
:return: Upcasted OV model with some nodes marked for full precision execution. | |
""" | |
if half_type not in ("f16", "bf16"): | |
raise ValueError(f"Half type must be either 'f16' or 'bf16'. Got {half_type}.") | |
if half_type == "bf16": | |
print("Warning! Calibration currently does not provide any improvement for bf16 type.") | |
if operation_types is None: | |
operation_types = ["MatMul", "Convolution"] | |
for op_type in operation_types: | |
if op_type not in OPERATION_TYPE_MAP: | |
raise ValueError(f"Operation type must be one of the following {list(OPERATION_TYPE_MAP.keys())}. " f"Got {op_type}.") | |
if verbose: | |
print(f"The following operation types will be considered: {operation_types}") | |
device = "GPU" if half_type == "f16" else "CPU" | |
nodes_to_track_names = get_nodes_to_track(orig_model, operation_types) | |
if len(nodes_to_track_names) == 0: | |
if verbose: | |
print("Warning. Not found any operations of the given type(s).") | |
return orig_model.clone() | |
node_names_and_snrs = [] | |
batch_size = len(nodes_to_track_names) if batch_size == -1 or batch_size > len(nodes_to_track_names) else batch_size | |
if verbose: | |
print("Started upcasting") | |
for i in tqdm( | |
range(0, len(nodes_to_track_names), batch_size), | |
desc="Processing batches", | |
disable=not verbose, | |
): | |
if upcast_ratio == 0.0 or upcast_ratio == 1.0: | |
continue | |
model = orig_model.clone() | |
name_to_node_map = {op.get_friendly_name(): op for op in model.get_ops()} | |
nodes_to_track_batch = [TrackedNodeInfo(name_to_node_map[node_name]) for node_name in nodes_to_track_names[i : i + batch_size]] | |
# Add outputs for non-constant inputs of tracked nodes | |
insert_outputs_for_tracked_ops(model, nodes_to_track_batch) | |
# Infer model to collect tracked operation results and results of their inputs in full precision | |
infer_full_net(nodes_to_track_batch, model, example_input) | |
# Infer nodes in half precision one by one using full precision inputs, collect half precision results | |
infer_nodes(nodes_to_track_batch, device, half_type) | |
# Compute operation SNR based on full precision and half precision results | |
for node_info in nodes_to_track_batch: | |
try: | |
snr = compute_snr( | |
node_info.node_value_full_precision, | |
node_info.node_value_half_precision, | |
) | |
except RuntimeError as e: | |
# TODO: find the reason behind this | |
if node_info.node.get_type_name() in [ | |
"Add", | |
"Concat", | |
] and "Shape mismatch" in str(e): | |
print( | |
"Warning.", | |
str(e), | |
node_info.node.get_friendly_name(), | |
node_info.node.get_type_name(), | |
[(inp_node.get_friendly_name(), inp_node.get_type_name()) for inp_node in node_info.input_nodes], | |
) | |
snr = np.finfo(np.float32).max | |
else: | |
raise e | |
node_names_and_snrs.append((node_info.node.get_friendly_name(), snr)) | |
if upcast_ratio != 0.0 and upcast_ratio != 1.0: | |
node_names_and_snrs = sorted(node_names_and_snrs, key=lambda it: it[1]) | |
node_names, node_snrs = tuple(zip(*node_names_and_snrs)) | |
n_nodes = len(node_names) | |
nodes_to_upcast_cnt = int(np.ceil(n_nodes * upcast_ratio)) | |
node_to_upcast_names = node_names[:nodes_to_upcast_cnt] | |
if verbose: | |
snr_quantile = node_snrs[nodes_to_upcast_cnt - 1] | |
print(f"Upcasted {nodes_to_upcast_cnt}/{n_nodes} nodes with SNR less than {snr_quantile:.2f}.") | |
for node_name, node_snr in node_names_and_snrs[:nodes_to_upcast_cnt]: | |
print(node_name, node_snr) | |
elif upcast_ratio == 0.0: | |
if verbose: | |
print("Skipping algorithm because upcast ratio equals 0.0. Nothing to upcast.") | |
node_to_upcast_names = [] | |
else: | |
if verbose: | |
print("Skipping algorithm because upcast ratio equals 1.0. Upcasting all nodes of the given type(s).") | |
node_to_upcast_names = nodes_to_track_names | |
new_model = orig_model.clone() | |
mark_nodes_to_upcast_to_fp32(new_model, node_to_upcast_names) | |
return new_model | |
def get_nodes_to_track(model: Model, operation_types: List[str]) -> List: | |
nodes_to_track = [] | |
for i, op in enumerate(model.get_ordered_ops()): | |
if op.get_type_name() in operation_types and all( | |
map( | |
lambda input: input.get_node().get_type_name() != "Result", | |
op.output(0).get_target_inputs(), | |
) | |
): | |
nodes_to_track.append(op.get_friendly_name()) | |
return nodes_to_track | |
def insert_outputs_for_tracked_ops(model: Model, nodes_to_track: List[TrackedNodeInfo]) -> None: | |
node_to_output_map = OrderedDict() | |
node_to_node_info_map = defaultdict(list) | |
for node_info in nodes_to_track: | |
node = node_info.node | |
node_to_node_info_map[node].append((node_info, "parent")) # add as a parent node | |
if node not in node_to_output_map: | |
node_to_output_map[node] = node.output(0) | |
node_info.input_nodes = [] | |
for inp_value in node.input_values(): | |
child_node = inp_value.get_node() | |
node_info.input_nodes.append(child_node) | |
# Do not add outputs for constant nodes | |
if child_node.get_type_name() != "Constant" and not is_constant_path(child_node): | |
node_to_node_info_map[child_node].append((node_info, "child")) # add as a child node | |
if child_node not in node_to_output_map: | |
node_to_output_map[child_node] = child_node.output(0) | |
outputs = model.add_outputs(list(node_to_output_map.values())) | |
for output, node in zip(outputs, node_to_output_map.keys()): | |
# Value matching will be done later based on result node friendly names | |
result_node = output.node | |
for node_info, parent_label in node_to_node_info_map[node]: | |
is_parent = parent_label == "parent" | |
if is_parent: | |
node_info.result_node = result_node | |
else: | |
if node_info.input_result_nodes is None: | |
node_info.input_result_nodes = {} | |
node_info.input_result_nodes[node] = result_node | |
def get_const_value_from_ovmodel(node: Union[Constant, Node]) -> np.ndarray: | |
if node.get_type_name() == "Constant": | |
assert node.get_element_type() not in [ | |
ov.Type.f16, | |
ov.Type.bf16, | |
], f"{node.get_friendly_name()}, {node.get_element_type()}" | |
return node.get_data() | |
elif is_constant_path(node): | |
# If model is compressed and constant values flow through decompression convert | |
const_node = node.input_value(0).get_node() | |
assert const_node.get_type_name() == "Constant" | |
assert const_node.get_element_type().is_real(), const_node.get_element_type() | |
return node.input_value(0).get_node().get_data() # return f16 weight | |
else: | |
raise Exception(f"Cannot get const values from ov.Model for {node.get_friendly_name()} with type {node.get_type_name()}") | |
def is_constant_path(node: Node) -> bool: | |
if node.get_type_name() != "Convert": | |
return False | |
if len(node.get_rt_info()["is_decompression_0"].aslist()) > 0: | |
return True | |
if node.input_value(0).get_node().get_type_name() == "Constant": | |
return True | |
return False | |
def infer_full_net(nodes_to_track: List[TrackedNodeInfo], orig_model: Model, example_inputs: List) -> None: | |
core = ov.Core() | |
exec_net = core.compile_model(orig_model, "CPU", config={"INFERENCE_PRECISION_HINT": "f32"}) | |
request = exec_net.create_infer_request() | |
results = request.infer(example_inputs, share_inputs=True, share_outputs=True) | |
friendly_name_to_result_map = {} | |
for i, (key, val) in enumerate(results.items()): | |
result_node = key.node | |
friendly_name_to_result_map[result_node.get_friendly_name()] = val | |
for node_info in nodes_to_track: | |
node_info.node_value_full_precision = friendly_name_to_result_map[node_info.result_node.get_friendly_name()] | |
node_info.input_values_full_precision = [] | |
for input_node in node_info.input_nodes: | |
if input_node.get_type_name() == "Constant" or is_constant_path(input_node): | |
# If input is constant, retrieve its value from model | |
input_value = get_const_value_from_ovmodel(input_node) | |
else: | |
# If input is not constant, retrieve its input from inference results | |
input_value = friendly_name_to_result_map[node_info.input_result_nodes[input_node].get_friendly_name()] | |
node_info.input_values_full_precision.append(input_value) | |
def infer_nodes(nodes_to_track: List[TrackedNodeInfo], device: str, precision: str) -> None: | |
for node_info in nodes_to_track: | |
infer_tracked_op(node_info, device, precision) | |
def infer_tracked_op(node_info: TrackedNodeInfo, device: str, precision: str) -> None: | |
parameters = [] | |
inputs = [] | |
input_values = node_info.input_values_full_precision | |
for input_value in input_values: | |
parameter = Parameter(get_element_type(input_value.dtype), ov.PartialShape(input_value.shape)) | |
if input_value.dtype == np.float16: | |
# Convert f16 weight to f32 | |
convert_node = opset.convert(parameter, "f32") | |
inputs.append(convert_node) | |
else: | |
inputs.append(parameter) | |
parameters.append(parameter) | |
node = node_info.node | |
try: | |
call_attributes = node.get_attributes() | |
# Below are some op workarounds | |
if node.get_type_name() == "Divide" and "m_pythondiv" in call_attributes: | |
del call_attributes["m_pythondiv"] | |
if node.get_type_name() == "Broadcast" and "mode" in call_attributes: | |
call_attributes["broadcast_spec"] = call_attributes["mode"] | |
del call_attributes["mode"] | |
if node.get_type_name() == "Concat": | |
new_op = OPERATION_TYPE_MAP[node.get_type_name()](inputs, **call_attributes) | |
else: | |
new_op = OPERATION_TYPE_MAP[node.get_type_name()](*inputs, **call_attributes) | |
ov_model = ov.Model([new_op], parameters=parameters) | |
exec_net = ov.Core().compile_model(ov_model, device, config={"INFERENCE_PRECISION_HINT": precision}) | |
request = exec_net.create_infer_request() | |
result = request.infer(input_values, share_inputs=True, share_outputs=True) | |
except Exception as e: | |
print( | |
"Operation inference error", | |
node.get_type_name(), | |
node.get_friendly_name(), | |
inputs, | |
node.get_attributes(), | |
) | |
raise e | |
node_info.node_value_half_precision = result[0] | |
assert len(result) == 1 | |
def is_model_partially_upcasted(model) -> bool: | |
for node in model.get_ordered_ops(): | |
if node.get_type_name() not in OPERATION_TYPE_MAP.keys(): | |
continue | |
if ORIGINAL_PRECISION_RT_INFO_NAME in node.get_rt_info().keys(): | |
return True | |
return False | |
def mark_nodes_to_upcast_to_fp32(model: ov.Model, nodes_with_errors: List[str]) -> None: | |
nodes_to_mark = set(nodes_with_errors) | |
for node in model.get_ordered_ops(): | |
if node.get_friendly_name() in nodes_to_mark: | |
node.get_rt_info()[ORIGINAL_PRECISION_RT_INFO_NAME] = "" | |
nodes_to_mark.remove(node.get_friendly_name()) | |
assert len(nodes_to_mark) == 0, nodes_to_mark | |
def compute_snr(x, y): | |
# x -- original value (full precision), y -- value with noise (half precision) | |
x, y = x.astype(np.float32), y.astype(np.float32) | |
max_value = np.finfo(np.float32).max | |
if np.prod(x.shape) != np.prod(y.shape): | |
raise RuntimeError(f"Shape mismatch: {x.shape}, {y.shape}.") | |
x = np.nan_to_num(x, posinf=max_value) | |
y = np.nan_to_num(y, posinf=max_value) | |
Ps = np.linalg.norm(x) | |
Pn = np.nan_to_num(np.linalg.norm(x - y), posinf=max_value) | |
if Ps == Pn == 0.0: | |
return max_value | |
snr = np.nan_to_num(20 * np.log10(Ps / Pn), posinf=max_value) | |
return snr | |