openvino_notebooks / utils /model_upcast_utils.py
malvika2003's picture
Upload folder using huggingface_hub
db5855f verified
from collections import defaultdict, OrderedDict
from dataclasses import dataclass
from typing import List, Dict, Union
import numpy as np
from openvino.runtime import Model, Node
from openvino.runtime.op import Parameter, Constant
import openvino.runtime.opset12 as opset
from openvino.runtime.utils.types import get_element_type
import openvino as ov
from tqdm.auto import tqdm
OPERATION_TYPE_MAP = {"MatMul": opset.matmul, "Convolution": opset.convolution}
ORIGINAL_PRECISION_RT_INFO_NAME = "precise_0"
@dataclass
class TrackedNodeInfo:
"""
Data associated with a node tracked for upcasting
"""
node: Node # Target node to track
snr: float = None # SNR of the target node
input_nodes: List[Node] = None # Input nodes of the target node
result_node: Node = None # Result node of the target node
input_result_nodes: Dict[Node, Node] = None # Result nodes of non-const inputs of the target node
node_value_full_precision: np.ndarray = None # Result of the node in full precision
node_value_half_precision: np.ndarray = None # Result of the node in half precision
input_values_full_precision: np.ndarray = None # Results of the target node inputs in full precision
def partially_upcast_nodes_to_fp32(
orig_model: Model,
example_input: Union[List, Dict],
half_type: str = "f16",
batch_size: int = 50,
operation_types: List[str] = None,
upcast_ratio: float = 0.1,
verbose: bool = False,
) -> Model:
"""
Transform a model to upcast some nodes to be executed in full precision instead of half precision. These nodes are
marked with runtime info flag.
Nodes are selected based on Signal-to-Noise Ratio (SNR) metric: upcast_ratio fraction of tracked nodes with the
lowest SNR are marked for full precision execution.
Note: Input model should have fp16 weights (i.e. saved with compress_to_fp16=True) in order to conserve
calibration memory.
:param orig_model: Model to process
:param example_input: Example input for model inference
:param half_type: Either "f16" or "bf16"
:param batch_size: Number of nodes to process together during a single model inference. The lower the value is,
the less memory footprint is, but the larger is the processing time. The value of -1 is used to disable
batching.
:param operation_types: Types of operations to consider. If None, MatMuls and Convolutions are considered.
:param upcast_ratio: Fraction of nodes to upcast (with the lowest SNR). 0 - do not upcast anything, 1 - upcast every
operation of the given types.
:param verbose: If True, prints progress output.
:return: Upcasted OV model with some nodes marked for full precision execution.
"""
if half_type not in ("f16", "bf16"):
raise ValueError(f"Half type must be either 'f16' or 'bf16'. Got {half_type}.")
if half_type == "bf16":
print("Warning! Calibration currently does not provide any improvement for bf16 type.")
if operation_types is None:
operation_types = ["MatMul", "Convolution"]
for op_type in operation_types:
if op_type not in OPERATION_TYPE_MAP:
raise ValueError(f"Operation type must be one of the following {list(OPERATION_TYPE_MAP.keys())}. " f"Got {op_type}.")
if verbose:
print(f"The following operation types will be considered: {operation_types}")
device = "GPU" if half_type == "f16" else "CPU"
nodes_to_track_names = get_nodes_to_track(orig_model, operation_types)
if len(nodes_to_track_names) == 0:
if verbose:
print("Warning. Not found any operations of the given type(s).")
return orig_model.clone()
node_names_and_snrs = []
batch_size = len(nodes_to_track_names) if batch_size == -1 or batch_size > len(nodes_to_track_names) else batch_size
if verbose:
print("Started upcasting")
for i in tqdm(
range(0, len(nodes_to_track_names), batch_size),
desc="Processing batches",
disable=not verbose,
):
if upcast_ratio == 0.0 or upcast_ratio == 1.0:
continue
model = orig_model.clone()
name_to_node_map = {op.get_friendly_name(): op for op in model.get_ops()}
nodes_to_track_batch = [TrackedNodeInfo(name_to_node_map[node_name]) for node_name in nodes_to_track_names[i : i + batch_size]]
# Add outputs for non-constant inputs of tracked nodes
insert_outputs_for_tracked_ops(model, nodes_to_track_batch)
# Infer model to collect tracked operation results and results of their inputs in full precision
infer_full_net(nodes_to_track_batch, model, example_input)
# Infer nodes in half precision one by one using full precision inputs, collect half precision results
infer_nodes(nodes_to_track_batch, device, half_type)
# Compute operation SNR based on full precision and half precision results
for node_info in nodes_to_track_batch:
try:
snr = compute_snr(
node_info.node_value_full_precision,
node_info.node_value_half_precision,
)
except RuntimeError as e:
# TODO: find the reason behind this
if node_info.node.get_type_name() in [
"Add",
"Concat",
] and "Shape mismatch" in str(e):
print(
"Warning.",
str(e),
node_info.node.get_friendly_name(),
node_info.node.get_type_name(),
[(inp_node.get_friendly_name(), inp_node.get_type_name()) for inp_node in node_info.input_nodes],
)
snr = np.finfo(np.float32).max
else:
raise e
node_names_and_snrs.append((node_info.node.get_friendly_name(), snr))
if upcast_ratio != 0.0 and upcast_ratio != 1.0:
node_names_and_snrs = sorted(node_names_and_snrs, key=lambda it: it[1])
node_names, node_snrs = tuple(zip(*node_names_and_snrs))
n_nodes = len(node_names)
nodes_to_upcast_cnt = int(np.ceil(n_nodes * upcast_ratio))
node_to_upcast_names = node_names[:nodes_to_upcast_cnt]
if verbose:
snr_quantile = node_snrs[nodes_to_upcast_cnt - 1]
print(f"Upcasted {nodes_to_upcast_cnt}/{n_nodes} nodes with SNR less than {snr_quantile:.2f}.")
for node_name, node_snr in node_names_and_snrs[:nodes_to_upcast_cnt]:
print(node_name, node_snr)
elif upcast_ratio == 0.0:
if verbose:
print("Skipping algorithm because upcast ratio equals 0.0. Nothing to upcast.")
node_to_upcast_names = []
else:
if verbose:
print("Skipping algorithm because upcast ratio equals 1.0. Upcasting all nodes of the given type(s).")
node_to_upcast_names = nodes_to_track_names
new_model = orig_model.clone()
mark_nodes_to_upcast_to_fp32(new_model, node_to_upcast_names)
return new_model
def get_nodes_to_track(model: Model, operation_types: List[str]) -> List:
nodes_to_track = []
for i, op in enumerate(model.get_ordered_ops()):
if op.get_type_name() in operation_types and all(
map(
lambda input: input.get_node().get_type_name() != "Result",
op.output(0).get_target_inputs(),
)
):
nodes_to_track.append(op.get_friendly_name())
return nodes_to_track
def insert_outputs_for_tracked_ops(model: Model, nodes_to_track: List[TrackedNodeInfo]) -> None:
node_to_output_map = OrderedDict()
node_to_node_info_map = defaultdict(list)
for node_info in nodes_to_track:
node = node_info.node
node_to_node_info_map[node].append((node_info, "parent")) # add as a parent node
if node not in node_to_output_map:
node_to_output_map[node] = node.output(0)
node_info.input_nodes = []
for inp_value in node.input_values():
child_node = inp_value.get_node()
node_info.input_nodes.append(child_node)
# Do not add outputs for constant nodes
if child_node.get_type_name() != "Constant" and not is_constant_path(child_node):
node_to_node_info_map[child_node].append((node_info, "child")) # add as a child node
if child_node not in node_to_output_map:
node_to_output_map[child_node] = child_node.output(0)
outputs = model.add_outputs(list(node_to_output_map.values()))
for output, node in zip(outputs, node_to_output_map.keys()):
# Value matching will be done later based on result node friendly names
result_node = output.node
for node_info, parent_label in node_to_node_info_map[node]:
is_parent = parent_label == "parent"
if is_parent:
node_info.result_node = result_node
else:
if node_info.input_result_nodes is None:
node_info.input_result_nodes = {}
node_info.input_result_nodes[node] = result_node
def get_const_value_from_ovmodel(node: Union[Constant, Node]) -> np.ndarray:
if node.get_type_name() == "Constant":
assert node.get_element_type() not in [
ov.Type.f16,
ov.Type.bf16,
], f"{node.get_friendly_name()}, {node.get_element_type()}"
return node.get_data()
elif is_constant_path(node):
# If model is compressed and constant values flow through decompression convert
const_node = node.input_value(0).get_node()
assert const_node.get_type_name() == "Constant"
assert const_node.get_element_type().is_real(), const_node.get_element_type()
return node.input_value(0).get_node().get_data() # return f16 weight
else:
raise Exception(f"Cannot get const values from ov.Model for {node.get_friendly_name()} with type {node.get_type_name()}")
def is_constant_path(node: Node) -> bool:
if node.get_type_name() != "Convert":
return False
if len(node.get_rt_info()["is_decompression_0"].aslist()) > 0:
return True
if node.input_value(0).get_node().get_type_name() == "Constant":
return True
return False
def infer_full_net(nodes_to_track: List[TrackedNodeInfo], orig_model: Model, example_inputs: List) -> None:
core = ov.Core()
exec_net = core.compile_model(orig_model, "CPU", config={"INFERENCE_PRECISION_HINT": "f32"})
request = exec_net.create_infer_request()
results = request.infer(example_inputs, share_inputs=True, share_outputs=True)
friendly_name_to_result_map = {}
for i, (key, val) in enumerate(results.items()):
result_node = key.node
friendly_name_to_result_map[result_node.get_friendly_name()] = val
for node_info in nodes_to_track:
node_info.node_value_full_precision = friendly_name_to_result_map[node_info.result_node.get_friendly_name()]
node_info.input_values_full_precision = []
for input_node in node_info.input_nodes:
if input_node.get_type_name() == "Constant" or is_constant_path(input_node):
# If input is constant, retrieve its value from model
input_value = get_const_value_from_ovmodel(input_node)
else:
# If input is not constant, retrieve its input from inference results
input_value = friendly_name_to_result_map[node_info.input_result_nodes[input_node].get_friendly_name()]
node_info.input_values_full_precision.append(input_value)
def infer_nodes(nodes_to_track: List[TrackedNodeInfo], device: str, precision: str) -> None:
for node_info in nodes_to_track:
infer_tracked_op(node_info, device, precision)
def infer_tracked_op(node_info: TrackedNodeInfo, device: str, precision: str) -> None:
parameters = []
inputs = []
input_values = node_info.input_values_full_precision
for input_value in input_values:
parameter = Parameter(get_element_type(input_value.dtype), ov.PartialShape(input_value.shape))
if input_value.dtype == np.float16:
# Convert f16 weight to f32
convert_node = opset.convert(parameter, "f32")
inputs.append(convert_node)
else:
inputs.append(parameter)
parameters.append(parameter)
node = node_info.node
try:
call_attributes = node.get_attributes()
# Below are some op workarounds
if node.get_type_name() == "Divide" and "m_pythondiv" in call_attributes:
del call_attributes["m_pythondiv"]
if node.get_type_name() == "Broadcast" and "mode" in call_attributes:
call_attributes["broadcast_spec"] = call_attributes["mode"]
del call_attributes["mode"]
if node.get_type_name() == "Concat":
new_op = OPERATION_TYPE_MAP[node.get_type_name()](inputs, **call_attributes)
else:
new_op = OPERATION_TYPE_MAP[node.get_type_name()](*inputs, **call_attributes)
ov_model = ov.Model([new_op], parameters=parameters)
exec_net = ov.Core().compile_model(ov_model, device, config={"INFERENCE_PRECISION_HINT": precision})
request = exec_net.create_infer_request()
result = request.infer(input_values, share_inputs=True, share_outputs=True)
except Exception as e:
print(
"Operation inference error",
node.get_type_name(),
node.get_friendly_name(),
inputs,
node.get_attributes(),
)
raise e
node_info.node_value_half_precision = result[0]
assert len(result) == 1
def is_model_partially_upcasted(model) -> bool:
for node in model.get_ordered_ops():
if node.get_type_name() not in OPERATION_TYPE_MAP.keys():
continue
if ORIGINAL_PRECISION_RT_INFO_NAME in node.get_rt_info().keys():
return True
return False
def mark_nodes_to_upcast_to_fp32(model: ov.Model, nodes_with_errors: List[str]) -> None:
nodes_to_mark = set(nodes_with_errors)
for node in model.get_ordered_ops():
if node.get_friendly_name() in nodes_to_mark:
node.get_rt_info()[ORIGINAL_PRECISION_RT_INFO_NAME] = ""
nodes_to_mark.remove(node.get_friendly_name())
assert len(nodes_to_mark) == 0, nodes_to_mark
def compute_snr(x, y):
# x -- original value (full precision), y -- value with noise (half precision)
x, y = x.astype(np.float32), y.astype(np.float32)
max_value = np.finfo(np.float32).max
if np.prod(x.shape) != np.prod(y.shape):
raise RuntimeError(f"Shape mismatch: {x.shape}, {y.shape}.")
x = np.nan_to_num(x, posinf=max_value)
y = np.nan_to_num(y, posinf=max_value)
Ps = np.linalg.norm(x)
Pn = np.nan_to_num(np.linalg.norm(x - y), posinf=max_value)
if Ps == Pn == 0.0:
return max_value
snr = np.nan_to_num(20 * np.log10(Ps / Pn), posinf=max_value)
return snr