File size: 15,484 Bytes
db5855f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
from collections import defaultdict, OrderedDict
from dataclasses import dataclass
from typing import List, Dict, Union

import numpy as np
from openvino.runtime import Model, Node
from openvino.runtime.op import Parameter, Constant
import openvino.runtime.opset12 as opset
from openvino.runtime.utils.types import get_element_type

import openvino as ov
from tqdm.auto import tqdm

OPERATION_TYPE_MAP = {"MatMul": opset.matmul, "Convolution": opset.convolution}

ORIGINAL_PRECISION_RT_INFO_NAME = "precise_0"


@dataclass
class TrackedNodeInfo:
    """

    Data associated with a node tracked for upcasting

    """

    node: Node  # Target node to track
    snr: float = None  # SNR of the target node
    input_nodes: List[Node] = None  # Input nodes of the target node
    result_node: Node = None  # Result node of the target node
    input_result_nodes: Dict[Node, Node] = None  # Result nodes of non-const inputs of the target node
    node_value_full_precision: np.ndarray = None  # Result of the node in full precision
    node_value_half_precision: np.ndarray = None  # Result of the node in half precision
    input_values_full_precision: np.ndarray = None  # Results of the target node inputs in full precision


def partially_upcast_nodes_to_fp32(

    orig_model: Model,

    example_input: Union[List, Dict],

    half_type: str = "f16",

    batch_size: int = 50,

    operation_types: List[str] = None,

    upcast_ratio: float = 0.1,

    verbose: bool = False,

) -> Model:
    """

    Transform a model to upcast some nodes to be executed in full precision instead of half precision. These nodes are

    marked with runtime info flag.

    Nodes are selected based on Signal-to-Noise Ratio (SNR) metric: upcast_ratio fraction of tracked nodes with the

    lowest SNR are marked for full precision execution.



    Note: Input model should have fp16 weights (i.e. saved with compress_to_fp16=True) in order to conserve

    calibration memory.



    :param orig_model: Model to process

    :param example_input: Example input for model inference

    :param half_type: Either "f16" or "bf16"

    :param batch_size: Number of nodes to process together during a single model inference. The lower the value is,

        the less memory footprint is, but the larger is the processing time. The value of -1 is used to disable

        batching.

    :param operation_types: Types of operations to consider. If None, MatMuls and Convolutions are considered.

    :param upcast_ratio: Fraction of nodes to upcast (with the lowest SNR). 0 - do not upcast anything, 1 - upcast every

        operation of the given types.

    :param verbose: If True, prints progress output.

    :return: Upcasted OV model with some nodes marked for full precision execution.

    """
    if half_type not in ("f16", "bf16"):
        raise ValueError(f"Half type must be either 'f16' or 'bf16'. Got {half_type}.")
    if half_type == "bf16":
        print("Warning! Calibration currently does not provide any improvement for bf16 type.")
    if operation_types is None:
        operation_types = ["MatMul", "Convolution"]
    for op_type in operation_types:
        if op_type not in OPERATION_TYPE_MAP:
            raise ValueError(f"Operation type must be one of the following {list(OPERATION_TYPE_MAP.keys())}. " f"Got {op_type}.")
    if verbose:
        print(f"The following operation types will be considered: {operation_types}")

    device = "GPU" if half_type == "f16" else "CPU"

    nodes_to_track_names = get_nodes_to_track(orig_model, operation_types)
    if len(nodes_to_track_names) == 0:
        if verbose:
            print("Warning. Not found any operations of the given type(s).")
        return orig_model.clone()

    node_names_and_snrs = []
    batch_size = len(nodes_to_track_names) if batch_size == -1 or batch_size > len(nodes_to_track_names) else batch_size
    if verbose:
        print("Started upcasting")
    for i in tqdm(
        range(0, len(nodes_to_track_names), batch_size),
        desc="Processing batches",
        disable=not verbose,
    ):
        if upcast_ratio == 0.0 or upcast_ratio == 1.0:
            continue
        model = orig_model.clone()
        name_to_node_map = {op.get_friendly_name(): op for op in model.get_ops()}
        nodes_to_track_batch = [TrackedNodeInfo(name_to_node_map[node_name]) for node_name in nodes_to_track_names[i : i + batch_size]]

        # Add outputs for non-constant inputs of tracked nodes
        insert_outputs_for_tracked_ops(model, nodes_to_track_batch)
        # Infer model to collect tracked operation results and results of their inputs in full precision
        infer_full_net(nodes_to_track_batch, model, example_input)
        # Infer nodes in half precision one by one using full precision inputs, collect half precision results
        infer_nodes(nodes_to_track_batch, device, half_type)

        # Compute operation SNR based on full precision and half precision results
        for node_info in nodes_to_track_batch:
            try:
                snr = compute_snr(
                    node_info.node_value_full_precision,
                    node_info.node_value_half_precision,
                )
            except RuntimeError as e:
                # TODO: find the reason behind this
                if node_info.node.get_type_name() in [
                    "Add",
                    "Concat",
                ] and "Shape mismatch" in str(e):
                    print(
                        "Warning.",
                        str(e),
                        node_info.node.get_friendly_name(),
                        node_info.node.get_type_name(),
                        [(inp_node.get_friendly_name(), inp_node.get_type_name()) for inp_node in node_info.input_nodes],
                    )
                    snr = np.finfo(np.float32).max
                else:
                    raise e
            node_names_and_snrs.append((node_info.node.get_friendly_name(), snr))

    if upcast_ratio != 0.0 and upcast_ratio != 1.0:
        node_names_and_snrs = sorted(node_names_and_snrs, key=lambda it: it[1])
        node_names, node_snrs = tuple(zip(*node_names_and_snrs))

        n_nodes = len(node_names)
        nodes_to_upcast_cnt = int(np.ceil(n_nodes * upcast_ratio))
        node_to_upcast_names = node_names[:nodes_to_upcast_cnt]

        if verbose:
            snr_quantile = node_snrs[nodes_to_upcast_cnt - 1]
            print(f"Upcasted {nodes_to_upcast_cnt}/{n_nodes} nodes with SNR less than {snr_quantile:.2f}.")
            for node_name, node_snr in node_names_and_snrs[:nodes_to_upcast_cnt]:
                print(node_name, node_snr)
    elif upcast_ratio == 0.0:
        if verbose:
            print("Skipping algorithm because upcast ratio equals 0.0. Nothing to upcast.")
        node_to_upcast_names = []
    else:
        if verbose:
            print("Skipping algorithm because upcast ratio equals 1.0. Upcasting all nodes of the given type(s).")
        node_to_upcast_names = nodes_to_track_names

    new_model = orig_model.clone()
    mark_nodes_to_upcast_to_fp32(new_model, node_to_upcast_names)
    return new_model


def get_nodes_to_track(model: Model, operation_types: List[str]) -> List:
    nodes_to_track = []
    for i, op in enumerate(model.get_ordered_ops()):
        if op.get_type_name() in operation_types and all(
            map(
                lambda input: input.get_node().get_type_name() != "Result",
                op.output(0).get_target_inputs(),
            )
        ):
            nodes_to_track.append(op.get_friendly_name())
    return nodes_to_track


def insert_outputs_for_tracked_ops(model: Model, nodes_to_track: List[TrackedNodeInfo]) -> None:
    node_to_output_map = OrderedDict()
    node_to_node_info_map = defaultdict(list)
    for node_info in nodes_to_track:
        node = node_info.node
        node_to_node_info_map[node].append((node_info, "parent"))  # add as a parent node
        if node not in node_to_output_map:
            node_to_output_map[node] = node.output(0)
        node_info.input_nodes = []
        for inp_value in node.input_values():
            child_node = inp_value.get_node()
            node_info.input_nodes.append(child_node)
            # Do not add outputs for constant nodes
            if child_node.get_type_name() != "Constant" and not is_constant_path(child_node):
                node_to_node_info_map[child_node].append((node_info, "child"))  # add as a child node
                if child_node not in node_to_output_map:
                    node_to_output_map[child_node] = child_node.output(0)

    outputs = model.add_outputs(list(node_to_output_map.values()))
    for output, node in zip(outputs, node_to_output_map.keys()):
        # Value matching will be done later based on result node friendly names
        result_node = output.node
        for node_info, parent_label in node_to_node_info_map[node]:
            is_parent = parent_label == "parent"
            if is_parent:
                node_info.result_node = result_node
            else:
                if node_info.input_result_nodes is None:
                    node_info.input_result_nodes = {}
                node_info.input_result_nodes[node] = result_node


def get_const_value_from_ovmodel(node: Union[Constant, Node]) -> np.ndarray:
    if node.get_type_name() == "Constant":
        assert node.get_element_type() not in [
            ov.Type.f16,
            ov.Type.bf16,
        ], f"{node.get_friendly_name()}, {node.get_element_type()}"
        return node.get_data()
    elif is_constant_path(node):
        # If model is compressed and constant values flow through decompression convert
        const_node = node.input_value(0).get_node()
        assert const_node.get_type_name() == "Constant"
        assert const_node.get_element_type().is_real(), const_node.get_element_type()
        return node.input_value(0).get_node().get_data()  # return f16 weight
    else:
        raise Exception(f"Cannot get const values from ov.Model for {node.get_friendly_name()} with type {node.get_type_name()}")


def is_constant_path(node: Node) -> bool:
    if node.get_type_name() != "Convert":
        return False
    if len(node.get_rt_info()["is_decompression_0"].aslist()) > 0:
        return True
    if node.input_value(0).get_node().get_type_name() == "Constant":
        return True
    return False


def infer_full_net(nodes_to_track: List[TrackedNodeInfo], orig_model: Model, example_inputs: List) -> None:
    core = ov.Core()
    exec_net = core.compile_model(orig_model, "CPU", config={"INFERENCE_PRECISION_HINT": "f32"})
    request = exec_net.create_infer_request()
    results = request.infer(example_inputs, share_inputs=True, share_outputs=True)

    friendly_name_to_result_map = {}
    for i, (key, val) in enumerate(results.items()):
        result_node = key.node
        friendly_name_to_result_map[result_node.get_friendly_name()] = val

    for node_info in nodes_to_track:
        node_info.node_value_full_precision = friendly_name_to_result_map[node_info.result_node.get_friendly_name()]
        node_info.input_values_full_precision = []
        for input_node in node_info.input_nodes:
            if input_node.get_type_name() == "Constant" or is_constant_path(input_node):
                # If input is constant, retrieve its value from model
                input_value = get_const_value_from_ovmodel(input_node)
            else:
                # If input is not constant, retrieve its input from inference results
                input_value = friendly_name_to_result_map[node_info.input_result_nodes[input_node].get_friendly_name()]
            node_info.input_values_full_precision.append(input_value)


def infer_nodes(nodes_to_track: List[TrackedNodeInfo], device: str, precision: str) -> None:
    for node_info in nodes_to_track:
        infer_tracked_op(node_info, device, precision)


def infer_tracked_op(node_info: TrackedNodeInfo, device: str, precision: str) -> None:
    parameters = []
    inputs = []
    input_values = node_info.input_values_full_precision
    for input_value in input_values:
        parameter = Parameter(get_element_type(input_value.dtype), ov.PartialShape(input_value.shape))
        if input_value.dtype == np.float16:
            # Convert f16 weight to f32
            convert_node = opset.convert(parameter, "f32")
            inputs.append(convert_node)
        else:
            inputs.append(parameter)
        parameters.append(parameter)

    node = node_info.node
    try:
        call_attributes = node.get_attributes()
        # Below are some op workarounds
        if node.get_type_name() == "Divide" and "m_pythondiv" in call_attributes:
            del call_attributes["m_pythondiv"]
        if node.get_type_name() == "Broadcast" and "mode" in call_attributes:
            call_attributes["broadcast_spec"] = call_attributes["mode"]
            del call_attributes["mode"]
        if node.get_type_name() == "Concat":
            new_op = OPERATION_TYPE_MAP[node.get_type_name()](inputs, **call_attributes)
        else:
            new_op = OPERATION_TYPE_MAP[node.get_type_name()](*inputs, **call_attributes)

        ov_model = ov.Model([new_op], parameters=parameters)
        exec_net = ov.Core().compile_model(ov_model, device, config={"INFERENCE_PRECISION_HINT": precision})
        request = exec_net.create_infer_request()
        result = request.infer(input_values, share_inputs=True, share_outputs=True)
    except Exception as e:
        print(
            "Operation inference error",
            node.get_type_name(),
            node.get_friendly_name(),
            inputs,
            node.get_attributes(),
        )
        raise e

    node_info.node_value_half_precision = result[0]
    assert len(result) == 1


def is_model_partially_upcasted(model) -> bool:
    for node in model.get_ordered_ops():
        if node.get_type_name() not in OPERATION_TYPE_MAP.keys():
            continue
        if ORIGINAL_PRECISION_RT_INFO_NAME in node.get_rt_info().keys():
            return True
    return False


def mark_nodes_to_upcast_to_fp32(model: ov.Model, nodes_with_errors: List[str]) -> None:
    nodes_to_mark = set(nodes_with_errors)
    for node in model.get_ordered_ops():
        if node.get_friendly_name() in nodes_to_mark:
            node.get_rt_info()[ORIGINAL_PRECISION_RT_INFO_NAME] = ""
            nodes_to_mark.remove(node.get_friendly_name())
    assert len(nodes_to_mark) == 0, nodes_to_mark


def compute_snr(x, y):
    # x -- original value (full precision), y -- value with noise (half precision)

    x, y = x.astype(np.float32), y.astype(np.float32)
    max_value = np.finfo(np.float32).max

    if np.prod(x.shape) != np.prod(y.shape):
        raise RuntimeError(f"Shape mismatch: {x.shape}, {y.shape}.")

    x = np.nan_to_num(x, posinf=max_value)
    y = np.nan_to_num(y, posinf=max_value)

    Ps = np.linalg.norm(x)
    Pn = np.nan_to_num(np.linalg.norm(x - y), posinf=max_value)

    if Ps == Pn == 0.0:
        return max_value

    snr = np.nan_to_num(20 * np.log10(Ps / Pn), posinf=max_value)

    return snr