File size: 21,679 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
import operator
import types

import torch
from torch._export import capture_pre_autograd_graph
from torch.fx import (
    GraphModule,
    Node,
)
from torch.nn.utils.fusion import fuse_conv_bn_weights
from typing import Any, Callable, Dict, Optional, Tuple, List, Union
from torch.utils._pytree import LeafSpec
from torch.export.unflatten import _AttrKind, _assign_attr

# Makes sure that quantized_decomposed ops are registered
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib  # noqa: F401

from torch.ao.quantization.quantizer import QuantizationAnnotation


__all__ = [
    "fold_bn_weights_into_conv_node",
    "get_aten_graph_module",
    "remove_tensor_overload_for_qdq_ops",
]

_QUANTIZE_OPS = [
    torch.ops.quantized_decomposed.quantize_per_tensor.default,
    torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
    torch.ops.quantized_decomposed.quantize_per_channel.default,
]


_DEQUANTIZE_OPS = [
    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
    torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
    torch.ops.quantized_decomposed.dequantize_per_channel.default,
]

# Example inputs for conv-bn1d patterns
_conv1d_bn_example_inputs = (
    torch.randn(1, 1, 3),  # x
    torch.randn(1, 1, 1),  # conv_weight
    torch.randn(1),        # conv_bias
    torch.randn(1),        # bn_weight
    torch.randn(1),        # bn_bias
    torch.randn(1),        # bn_running_mean
    torch.randn(1),        # bn_running_var
)

# Example inputs for conv-bn2d patterns
_conv2d_bn_example_inputs = (
    torch.randn(1, 1, 3, 3),  # x
    torch.randn(1, 1, 1, 1),  # conv_weight
    torch.randn(1),           # conv_bias
    torch.randn(1),           # bn_weight
    torch.randn(1),           # bn_bias
    torch.randn(1),           # bn_running_mean
    torch.randn(1),           # bn_running_var
)

def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool:
    """

    Assuming dest is one of the ops inserted by quant workflow, this function

    finds if source and dest are connected. Assumption is that only quant workflow

    inserted ops exist between source and dest

    """
    quant_workflow_ops = _QUANTIZE_OPS + _DEQUANTIZE_OPS
    quant_workflow_ops.append(torch.ops.quantized_decomposed.choose_qparams.tensor)
    while dest.target in quant_workflow_ops:
        if not isinstance(dest.args[0], torch.fx.Node):
            raise ValueError(f"expected arg[0] of quant workflow ops to be a node but found {dest.args[0]}")
        dest = dest.args[0]
    return (dest == source)


def _find_q_dq_node_for_user(

    produer: torch.fx.Node, user: torch.fx.Node

) -> Tuple[Any, Any]:
    """

    Find q, dq pair corresponding to [producer -> q -> dq -> user]

    Utils works by finding dq arg of user and ensuring it is connected to

    producer

    """
    dq_node = None
    for n in user.args:
        if isinstance(n, torch.fx.Node) and n.op == "call_function" and n.target in _DEQUANTIZE_OPS:
            if _is_connected(produer, n):
                dq_node = n
                break
    if dq_node is None:
        for n in user.kwargs:
            if isinstance(n, torch.fx.Node) and n.op == "call_function" and n.target in _DEQUANTIZE_OPS:
                if _is_connected(produer, n):
                    dq_node = n
                    break
    if dq_node is None:
        return (None, None)

    q_node = None
    if dq_node.args[0].op == "call_function" and dq_node.args[0].target in _QUANTIZE_OPS:
        q_node = dq_node.args[0]
    return (q_node, dq_node)



def _is_sym_size_node(node: Node):
    return (
        node.op == "call_function"
        and node.target == torch.ops.aten.sym_size.default
        or node.target == torch.ops.aten.sym_numel.default
        or node.target == torch.ops.aten.sym_numel
        or node.target == torch.ops.aten.sym_size
    )


def _filter_sym_size_users(node: torch.fx.Node) -> List[torch.fx.Node]:
    node_users = list(filter((lambda x: (_is_sym_size_node(x) is False)), node.users))
    return node_users


def _is_valid_annotation(annotation: QuantizationAnnotation) -> bool:
    if annotation is None:
        return False
    input_qspec_map = annotation.input_qspec_map
    output_qspec = annotation.output_qspec
    if len(input_qspec_map) == 0 and output_qspec is None:
        return False
    return True


def _get_tensor_constant_from_node(node, m):
    if node is None:
        return None
    assert node.op == "get_attr"
    target_atoms = node.target.split('.')
    attr_itr = m
    for i, atom in enumerate(target_atoms):
        if not hasattr(attr_itr, atom):
            raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
        attr_itr = getattr(attr_itr, atom)
    return attr_itr

def _get_all_arguments(orig_args, orig_kwargs, args_schema):
    all_args = []
    for i, schema in enumerate(args_schema):
        if schema.name in orig_kwargs:
            all_args.append(orig_kwargs[schema.name])
        elif not schema.kwarg_only and i < len(orig_args):
            all_args.append(orig_args[i])
        else:
            all_args.append(schema.default_value)
    return all_args

def _is_supported_batch_norm_for_training(node: Node):
    """

    Return True if the given node refers to an aten batch norm op QAT supports.

    """
    supported_ops = [
        torch.ops.aten._native_batch_norm_legit.default,
        # Note: we won't need this op anymore after batch norm consolidation
        # For now, we need to continue to support it because it gives better
        # training numerics than `_native_batch_norm_legit`
        torch.ops.aten.cudnn_batch_norm.default,
        torch.ops.aten.miopen_batch_norm.default,
    ]
    return node.target in supported_ops

# TODO: rename this to _is_conv_node
def _is_conv(n: Node):
    """

    Return whether the node refers to an aten conv op.

    """
    return n.op == "call_function" and n.target in [
        torch.ops.aten.conv1d.default,
        torch.ops.aten.conv2d.default,
    ]

# TODO: rename this to _is_conv_transpose_node
def _is_conv_transpose(n: Node):
    """

    Return whether the node refers to an aten conv_transpose op.

    """
    return n.op == "call_function" and n.target in [
        torch.ops.aten.conv_transpose1d,
        torch.ops.aten.conv_transpose2d,
    ]

def _is_bn_node(n: Node):
    return _is_supported_batch_norm_for_training(n) or n.target == torch.ops.aten._native_batch_norm_legit_no_training.default

def fold_bn_weights_into_conv_node(

    conv_node: Node,

    conv_weight_node: Node,

    conv_bias_node: Optional[Node],

    bn_node: Node,

    m: GraphModule

) -> None:
    # conv args: input, weight, bias, stride, padding, dilation, ...
    conv_w = _get_tensor_constant_from_node(conv_weight_node, m)
    conv_b = _get_tensor_constant_from_node(conv_bias_node, m)
    transpose = _is_conv_transpose(conv_node)

    # eval bn args: input, weight, bias, running mean, running var, momentum, eps
    # train bn args: input, weight, bias, running mean, running var, training, momentum, eps
    bn_args_schema = bn_node.target._schema.arguments  # type: ignore[union-attr]
    bn_args = _get_all_arguments(bn_node.args, bn_node.kwargs, bn_args_schema)
    bn_w = _get_tensor_constant_from_node(bn_args[1], m)
    bn_b = _get_tensor_constant_from_node(bn_args[2], m)
    bn_rm = _get_tensor_constant_from_node(bn_args[3], m)
    bn_rv = _get_tensor_constant_from_node(bn_args[4], m)
    if bn_node.target == torch.ops.aten._native_batch_norm_legit_no_training.default:
        eps_arg_index = 6
    elif _is_supported_batch_norm_for_training(bn_node):
        eps_arg_index = 7
    else:
        raise ValueError("BN node target is unexpected ", bn_node.target)
    bn_eps = bn_args[eps_arg_index]

    fused_weight, fused_bias = fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose)

    # update the weight and bias for conv
    conv_args = list(conv_node.args)
    # filling in the default bias argument
    if len(conv_args) == 2:
        conv_args.append(None)

    # calling data since the fused_weight and fused_bias are nn.Parameter
    weight_attr_name = conv_weight_node.target
    assert isinstance(weight_attr_name, str)
    _assign_attr(fused_weight, m, weight_attr_name, _AttrKind.PARAMETER)
    if conv_bias_node is not None:
        bias_attr_name = conv_bias_node.target
        _assign_attr(fused_bias, m, str(bias_attr_name), _AttrKind.PARAMETER)
    else:
        bias_attr_name = weight_attr_name + "_bias"
        _assign_attr(fused_bias, m, bias_attr_name, _AttrKind.PARAMETER)
        with m.graph.inserting_before(conv_node):
            get_bias_node = m.graph.get_attr(bias_attr_name)
        # NOTE: here we assume the bias of conv is not quantized!
        conv_args[2] = get_bias_node
    conv_node.args = tuple(conv_args)

    # native_batch_norm has 3 outputs, we expect getitem calls on the output
    # and we want to replace the uses of getitem 0 with the output of conv
    #
    # Before:
    # conv -> bn - (first output) -> users1
    #          \ - (second output) -> users2
    #          \ - (third output) -> users3
    # After:
    # conv -> (first output) -> users1
    #       bn -
    #          \ - (second output) -> users2
    #          \ - (third output) -> users3
    # if users2 and users3 are empty then bn will be removed through dead code elimination

    for user in bn_node.users:
        if user.op != "call_function" or user.target != operator.getitem or user.args[1] != 0:
            continue
        user.replace_all_uses_with(conv_node)

# fuse conv bn weights, inplace modification of the graph_module and graph
def _fuse_conv_bn_(m: GraphModule) -> None:
    has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
    if not has_bn:
        return
    for n in m.graph.nodes:
        if n.op != "call_function" or n.target != torch.ops.aten._native_batch_norm_legit_no_training.default:
            continue
        bn_node = n
        n = bn_node.args[0]
        if not _is_conv(n):
            continue
        conv_node = n
        conv_weight_node = conv_node.args[1]
        conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None
        fold_bn_weights_into_conv_node(conv_node, conv_weight_node, conv_bias_node, bn_node, m)

    m.graph.eliminate_dead_code()
    m.recompile()

def _get_node_name_to_scope(model: GraphModule) -> Dict[str, Tuple[str, type]]:
    # TODO: move this information to fx node itself
    node_name_to_scope: Dict[str, Tuple[str, type]] = {}
    for n in model.graph.nodes:
        nn_module_stack = n.meta.get("nn_module_stack", None)
        current_scope = ("", type(None))
        if nn_module_stack:
            bt = list(nn_module_stack.values())[-1]
            current_scope = (bt[0].split(".")[-1], bt[1])
        node_name_to_scope[n.name] = current_scope
    return node_name_to_scope

def get_aten_graph_module(

    pattern: Callable,

    example_inputs: Tuple[Any, ...],

    is_cuda: bool = False,

    **kwargs,

) -> GraphModule:
    """

    Convert the pattern to an FX graph with decomposed aten ops.

    """
    if is_cuda:
        example_inputs = tuple([x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs])
    aten_pattern = capture_pre_autograd_graph(
        pattern,
        example_inputs,
        kwargs,
    )
    aten_pattern.graph.eliminate_dead_code()
    aten_pattern.recompile()
    return aten_pattern

def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
    """ Remove .tensor overload for quantize/dequantize ops so that we can

    use the match_pattern that we get from torchdynamo export to match the output of convert_pt2e

    """
    _MAP = {
        torch.ops.quantized_decomposed.quantize_per_tensor.default: torch.ops.quantized_decomposed.quantize_per_tensor,
        torch.ops.quantized_decomposed.dequantize_per_tensor.default: torch.ops.quantized_decomposed.dequantize_per_tensor,
        torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor,
        torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor,
        torch.ops.quantized_decomposed.quantize_per_tensor.tensor2: torch.ops.quantized_decomposed.quantize_per_tensor,
        torch.ops.quantized_decomposed.dequantize_per_tensor.tensor2: torch.ops.quantized_decomposed.dequantize_per_tensor,
        torch.ops.quantized_decomposed.quantize_per_channel.default: torch.ops.quantized_decomposed.quantize_per_channel,
        torch.ops.quantized_decomposed.dequantize_per_channel.default: torch.ops.quantized_decomposed.dequantize_per_channel,
        torch.ops.aten.clamp.Tensor: torch.ops.aten.clamp,
    }
    for n in match_pattern.graph.nodes:
        if n.op != "call_function":
            continue
        if n.target in _MAP:
            n.target = _MAP[n.target]

def _is_literal(arg):
    if isinstance(arg, (int, float)):
        return True
    if isinstance(arg, (tuple, list)):
        return all(map(_is_literal, arg))
    return False

def _replace_literals_with_new_placeholders(

    gm: torch.fx.GraphModule,

    merge_dup: bool = False,

    exclude_literals: Optional[List[Any]] = None

):
    """Replace the literals in the graph with placeholder nodes that's created on the fly while we

    traverse the graph, so that the literal arguments in the graph can be matched and replaced



    To use this, the pattern and replacement graph should have the exact same number of literal args

    and they should be used in the exact same order in the pattern and replacement graph.



    If the literal arguments are not used in the same order in pattern and replacement graph, please

    use `_replace_literals_with_existing_placeholders` instead



    Args:

        `gm`: input GraphModule that we'll transform

        `merge_dup`: boolean flag to indicate that if the same literal appears multiple times in

         the graph, whether they should correspond to the same placeholder or not

        `exclude_literals`: a list of literals that will not be replaced with placeholders



    Example:



    # 1. Original Graph

    def pattern(self, x):

        return x + 3



    def replacement(self, x):

        return x - 3



    example_inputs = (torch.randn(1, 3, 3, 3),)

    pattern_gm = get_aten_graph_module(pattern, example_inputs)

    replacement_gm = get_aten_graph_module(pattern, example_inptus)



    # 2. Before calling replace literals we'll see the following graph:

    def pattern(self, x):

        return x + 3



    def replacement(self, x):

        return x - 3



    pattern_gm = _replace_literals_with_new_placeholders(pattern_gm)

    replacement_gm = _replace_literals_with_new_placeholders(replacement_gm)



    # 3. After replacing literals with new placeholder nodes



    def pattern(self, x, new_ph):

        return x + new_ph



    def pattern(self, x, new_ph):

        return x - new_ph



    """
    last_ph = None
    cnt = 0
    literal_to_ph: Dict[Union[float, bool, int, torch.dtype], Node] = {}
    if exclude_literals is None:
        exclude_literals = []

    in_spec = gm._in_spec
    args_spec = in_spec.children_specs[0]
    for node in gm.graph.nodes:
        if node.op == "placeholder":
            last_ph = node
            cnt += 1
            continue
        with gm.graph.inserting_after(last_ph):
            new_args = []
            for arg in node.args:
                if _is_literal(arg) and arg not in exclude_literals:
                    if merge_dup and arg in literal_to_ph:
                        new_args.append(literal_to_ph[arg])
                    else:
                        ph_node = gm.graph.placeholder("arg" + str(cnt))
                        new_args.append(ph_node)
                        args_spec.children_specs.append(LeafSpec())
                        cnt += 1
                        if merge_dup:
                            literal_to_ph[arg] = ph_node
                else:
                    new_args.append(arg)
            new_args = tuple(new_args)

        node.args = new_args

    # Update `num_nodes`, `num_leaves`, `num_children`.
    args_spec.__post_init__()
    in_spec.__post_init__()
    return gm


def _replace_literals_with_existing_placeholders(

    gm: torch.fx.GraphModule,

    exclude_literals: Optional[List[Any]] = None,

    literal_to_ph_idx: Optional[Dict[Union[float, int, bool, torch.dtype], int]] = None

):
    """Replace the literals in the graph with **existing** placeholder nodes, so that the literal arguments

    in the graph can be matched and replaced



    To use this, all literal args in the graph should be unique and each of them should correspond

    to exactly one placeholder node



    # 1. Original Graph

    def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):

        return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max)



    def replacement(x_i8, scale, zero_point, quant_min, quant_max):

        x_i8 = torch.clamp(x_i8, quant_min, quant_max)

        return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)



    example_inputs = (

        torch.randn(1, 3, 3, 3),

        1.0,

        0,

        -128,

        127,

    )

    pattern_gm = get_aten_graph_module(pattern, example_inputs)

    replacement_gm = get_aten_graph_module(pattern, example_inptus)



    # 2. Before calling replace literals we'll see the following graph:

    def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):

        # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values

        return torch.dequantize_per_tensor(x_i8, 1.0, 0, -128, 127)



    def replacement(x_i8, scale, zero_point, quant_min, quant_max):

        # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values

        x_i8 = torch.clamp(x_i8, -128, 127)

        return ((x_i8.to(torch.float32) - 0) * 1.0).to(dtype=torch.float32)



    # Note that literal args appear in different order in pattern and replacement graph, so

    # we can't use _replace_literals_with_new_placeholders



    literal_to_ph_idx = {1.0: 1, 0: 2, -128: 3, 127: 4}

    pattern_gm = _replace_literals_with_existing_placeholders(pattern_gm, literal_to_ph_idx)

    replacement_gm = _replace_literals_with_existing_placeholders(replacement_gm, literal_to_ph_idx)



    # 3. After replacing literals with existing placeholder nodes



    def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):

        # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values

        return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max)



    def replacement(x_i8, scale, zero_point, quant_min, quant_max):

        # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values

        x_i8 = torch.clamp(x_i8, quant_min, quant_max)

        return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)

    """
    if exclude_literals is None:
        exclude_literals = []

    if literal_to_ph_idx is None:
        literal_to_ph_idx = {}

    phs = [node for node in gm.graph.nodes if node.op == "placeholder"]

    for node in gm.graph.nodes:
        if node.op != "call_function":
            continue
        new_args = []
        for arg in node.args:
            if _is_literal(arg) and arg not in exclude_literals and arg in literal_to_ph_idx:
                ph_idx = literal_to_ph_idx[arg]
                ph_node = phs[ph_idx]
                new_args.append(ph_node)
            else:
                new_args.append(arg)
        new_args = tuple(new_args)
        node.args = new_args
    return gm

# TODO: Handle this in export itself and don't wrap the model in another GraphModule
# in prepare and convert
def _disallow_eval_train(model: GraphModule):
    """

    Disallow calling `model.train()` or `model.eval()` on the given GraphModule.

    This is useful for exported models, where these methods don't actually behave as expected.

    """
    error_message = \
        """

        Calling train() or eval() is not supported for exported models.

        Please call `torch.ao.quantization.move_exported_model_to_train(model)` (or eval) instead.



        If you cannot replace the calls to `model.train()` and `model.eval()`, you may override

        the behavior for these methods by calling `torch.ao.quantization.allow_exported_model_train_eval(model)`,

        which does the above automatically for you. Note that this has limited effect on switching

        behavior between train and eval modes, and should be used only for special ops such as dropout

        and batchnorm.

        """

    def _train(self, mode: bool = True):
        raise NotImplementedError(error_message)

    def _eval(self, mode: bool = True):
        raise NotImplementedError(error_message)

    model.train = types.MethodType(_train, model)  # type: ignore[method-assign]
    model.eval = types.MethodType(_eval, model)  # type: ignore[method-assign]
    return model