File size: 26,885 Bytes
dc2106c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0


from typing import Dict, List, MutableMapping, Optional, Set, Tuple

from onnx import GraphProto, ModelProto, TensorProto, checker, helper, utils


def check_overlapping_names(

    g1: GraphProto, g2: GraphProto, io_map: Optional[List[Tuple[str, str]]] = None

) -> List[Tuple[str, List[str]]]:
    """Checks whether there are name collisions between two graphs



    Returns a list of tuples where the first element represents the member containing overlapping names

    (One of: "node", "edge", "value_info", "initializer", "sparse_initializer"), and the

    second element contains a list of names that appear in both graphs on that category.



    Optionally, it takes an io_map, representing the output/inputs to be connected. It provided, overlapping

    present in the io_map argument will be ignored.

    """
    if type(g1) is not GraphProto:
        raise ValueError("g1 argument is not an ONNX graph")
    if type(g2) is not GraphProto:
        raise ValueError("g2 argument is not an ONNX graph")

    def _overlapping(c1: List[str], c2: List[str]) -> List[str]:
        return list(set(c1) & set(c2))

    def _edge_names(graph: GraphProto, exclude: Optional[Set[str]] = None) -> List[str]:
        if exclude is None:
            exclude = set()
        edges = []
        for n in graph.node:
            for i in n.input:
                if i != "" and i not in exclude:
                    edges.append(i)
            for o in n.output:
                if o != "" and o not in exclude:
                    edges.append(o)
        return edges

    result = []

    if not io_map:
        io_map = []
    io_map_inputs = {elem[1] for elem in io_map}

    # Edges already cover input/output
    overlap = _overlapping(_edge_names(g1), _edge_names(g2, exclude=io_map_inputs))
    if overlap:
        result.append(("edge", overlap))

    overlap = _overlapping(
        [e.name for e in g1.value_info], [e.name for e in g2.value_info]
    )
    if overlap:
        result.append(("value_info", overlap))

    overlap = _overlapping(
        [e.name for e in g1.initializer], [e.name for e in g2.initializer]
    )
    if overlap:
        result.append(("initializer", overlap))

    overlap = _overlapping(
        [e.values.name for e in g1.sparse_initializer],
        [e.values.name for e in g2.sparse_initializer],
    ) + _overlapping(
        [e.indices.name for e in g1.sparse_initializer],
        [e.indices.name for e in g2.sparse_initializer],
    )
    if overlap:
        result.append(("sparse_initializer", overlap))

    return result


def merge_graphs(

    g1: GraphProto,

    g2: GraphProto,

    io_map: List[Tuple[str, str]],

    inputs: Optional[List[str]] = None,

    outputs: Optional[List[str]] = None,

    prefix1: Optional[str] = None,

    prefix2: Optional[str] = None,

    name: Optional[str] = None,

    doc_string: Optional[str] = None,

) -> GraphProto:
    """Combines two ONNX graphs into a single one.



    The combined graph is defined by connecting the specified set of outputs/inputs. Those inputs/outputs

    not specified in the io_map argument will remain as inputs/outputs of the combined graph.



    Arguments:

        g1 (GraphProto): First graph

        g2 (GraphProto): Second graph

        io_map (list of pairs of string): The pairs of names [(out0, in0), (out1, in1), ...]

                                          representing outputs of the first graph and inputs of the second

                                          to be connected

        inputs (list of string): Optional list of inputs to be included in the combined graph

                                 By default, all inputs not present in the ``io_map`` argument will be

                                 included in the combined model

        outputs (list of string): Optional list of outputs to be included in the combined graph

                                  By default, all outputs not present in the ``io_map`` argument will be

                                  included in the combined model

        prefix1 (string): Optional prefix to be added to all names in g1

        prefix2 (string): Optional prefix to be added to all names in g2

        name (string): Optional name for the combined graph

                       By default, the name is g1.name and g2.name concatenated with an undescore delimiter

        doc_string (string): Optional docstring for the combined graph

                             If not provided, a default docstring with the concatenation of g1 and g2 docstrings is used



    Returns:

        GraphProto

    """
    if type(g1) is not GraphProto:
        raise ValueError("g1 argument is not an ONNX graph")
    if type(g2) is not GraphProto:
        raise ValueError("g2 argument is not an ONNX graph")

    # Prefixing names in the graph if requested, adjusting io_map accordingly
    if prefix1 or prefix2:
        if prefix1:
            g1_copy = GraphProto()
            g1_copy.CopyFrom(g1)
            g1 = g1_copy
            g1 = add_prefix_graph(g1, prefix=prefix1)
        if prefix2:
            g2_copy = GraphProto()
            g2_copy.CopyFrom(g2)
            g2 = g2_copy
            g2 = add_prefix_graph(g2, prefix=prefix2)
        io_map = [
            (
                prefix1 + io[0] if prefix1 else io[0],
                prefix2 + io[1] if prefix2 else io[1],
            )
            for io in io_map
        ]

    io_map_g1_outs = {io[0] for io in io_map}
    io_map_g2_ins = {io[1] for io in io_map}
    reversed_io_map = {in_name: out_name for out_name, in_name in io_map}
    g1_outs = {o.name for o in g1.output}
    g2_ins = {i.name for i in g2.input}

    # If necessary extract subgraphs
    if inputs or outputs:
        if not inputs:
            g1_inputs = [i.name for i in g1.input]
            g2_inputs = [i.name for i in g2.input]
        else:
            input_set = set(inputs)
            g1_inputs = [i.name for i in g1.input if i.name in input_set]
            g2_inputs = [
                i.name
                for i in g2.input
                if i.name in input_set or i.name in io_map_g2_ins
            ]

        if not outputs:
            g1_outputs = [o.name for o in g1.output]
            g2_outputs = [o.name for o in g2.output]
        else:
            output_set = set(outputs)
            g1_outputs = [
                o.name
                for o in g1.output
                if o.name in output_set or o.name in io_map_g1_outs
            ]
            g2_outputs = [o.name for o in g2.output if o.name in output_set]

        if len(g1_inputs) < len(g1.input) or len(g1_outputs) < len(g1.output):
            e1 = utils.Extractor(helper.make_model(g1))
            g1 = e1.extract_model(g1_inputs, g1_outputs).graph

        if len(g2_inputs) < len(g2.input) or len(g2_outputs) < len(g2.output):
            e2 = utils.Extractor(helper.make_model(g2))
            g2 = e2.extract_model(g2_inputs, g2_outputs).graph

    # Check that input/output names specified in the io_map argument are valid input/output names
    for g1_out_name, g2_in_name in io_map:
        if g1_out_name not in g1_outs:
            raise ValueError(f"Output {g1_out_name} is not present in g1")
        if g2_in_name not in g2_ins:
            raise ValueError(f"Input {g2_in_name} is not present in g2")

    # Check for name collision
    overlapping_names = check_overlapping_names(g1, g2, io_map)
    if len(overlapping_names) > 0:
        category, names = overlapping_names[0]
        raise ValueError(
            "Cant merge two graphs with overlapping names. "
            f"Found repeated {category} names: "
            + ", ".join(names)
            + "\n"
            + "Consider using ``onnx.compose.add_prefix`` to add a prefix to names in one of the graphs."
        )

    g = GraphProto()

    g.node.extend(g1.node)
    g2_nodes_begin = len(g.node)
    g.node.extend(g2.node)
    g2_nodes_end = len(g.node)

    # Connecting outputs of the first graph with the inputs of the second
    for node_idx in range(g2_nodes_begin, g2_nodes_end):
        node = g.node[node_idx]
        for index, name_ in enumerate(node.input):
            if name_ in reversed_io_map:
                node.input[index] = reversed_io_map[name_]

    if inputs:
        input_set = set(inputs)
        g.input.extend([i for i in g1.input if i.name in input_set])
        g.input.extend([i for i in g2.input if i.name in input_set])
    else:
        g.input.extend(g1.input)
        g.input.extend([i for i in g2.input if i.name not in io_map_g2_ins])

    if outputs:
        output_set = set(outputs)
        g.output.extend([o for o in g1.output if o.name in output_set])
        g.output.extend([o for o in g2.output if o.name in output_set])
    else:
        g.output.extend([o for o in g1.output if o.name not in io_map_g1_outs])
        g.output.extend(g2.output)

    g.initializer.extend(g1.initializer)
    g.initializer.extend(
        [init for init in g2.initializer if init.name not in io_map_g2_ins]
    )

    g.sparse_initializer.extend(g1.sparse_initializer)
    g.sparse_initializer.extend(
        [
            init
            for init in g2.sparse_initializer
            if init.values.name not in io_map_g2_ins
        ]
    )

    g.value_info.extend(g1.value_info)
    g.value_info.extend([vi for vi in g2.value_info if vi.name not in io_map_g2_ins])

    g.name = name if name is not None else "_".join([g1.name, g2.name])

    if doc_string is None:
        doc_string = (
            f"Graph combining {g1.name} and {g2.name}\n"
            + g1.name
            + "\n\n"
            + g1.doc_string
            + "\n\n"
            + g2.name
            + "\n\n"
            + g2.doc_string
        )
    g.doc_string = doc_string

    return g


def merge_models(

    m1: ModelProto,

    m2: ModelProto,

    io_map: List[Tuple[str, str]],

    inputs: Optional[List[str]] = None,

    outputs: Optional[List[str]] = None,

    prefix1: Optional[str] = None,

    prefix2: Optional[str] = None,

    name: Optional[str] = None,

    doc_string: Optional[str] = None,

    producer_name: Optional[str] = "onnx.compose.merge_models",

    producer_version: Optional[str] = "1.0",

    domain: Optional[str] = "",

    model_version: Optional[int] = 1,

) -> ModelProto:
    """Combines two ONNX models into a single one.



    The combined model is defined by connecting the specified set of outputs/inputs.

    Those inputs/outputs not specified in the io_map argument will remain as

    inputs/outputs of the combined model.



    Both models should have the same IR version, and same operator sets imported.



    Arguments:

        m1 (ModelProto): First model

        m2 (ModelProto): Second model

        io_map (list of pairs of string): The pairs of names [(out0, in0), (out1, in1), ...]

                                          representing outputs of the first graph and inputs of the second

                                          to be connected

        inputs (list of string): Optional list of inputs to be included in the combined graph

                                 By default, all inputs not present in the ``io_map`` argument will be

                                 included in the combined model

        outputs (list of string): Optional list of outputs to be included in the combined graph

                                  By default, all outputs not present in the ``io_map`` argument will be

                                  included in the combined model

        prefix1 (string): Optional prefix to be added to all names in m1

        prefix2 (string): Optional prefix to be added to all names in m2

        name (string): Optional name for the combined graph

                       By default, the name is g1.name and g2.name concatenated with an undescore delimiter

        doc_string (string): Optional docstring for the combined graph

                             If not provided, a default docstring with the concatenation of g1 and g2 docstrings is used

        producer_name (string): Optional producer name for the combined model. Default: 'onnx.compose'

        producer_version (string): Optional producer version for the combined model. Default: "1.0"

        domain (string): Optional domain of the combined model. Default: ""

        model_version (int): Optional version of the graph encoded. Default: 1



    Returns:

        ModelProto

    """
    if type(m1) is not ModelProto:
        raise ValueError("m1 argument is not an ONNX model")
    if type(m2) is not ModelProto:
        raise ValueError("m2 argument is not an ONNX model")

    if m1.ir_version != m2.ir_version:
        raise ValueError(
            f"IR version mismatch {m1.ir_version} != {m2.ir_version}."
            " Both models should have the same IR version"
        )
    ir_version = m1.ir_version

    opset_import_map: MutableMapping[str, int] = {}
    opset_imports = list(m1.opset_import) + list(m2.opset_import)

    for entry in opset_imports:
        if entry.domain in opset_import_map:
            found_version = opset_import_map[entry.domain]
            if entry.version != found_version:
                raise ValueError(
                    "Can't merge two models with different operator set ids for a given domain. "
                    f"Got: {m1.opset_import} and {m2.opset_import}"
                )
        else:
            opset_import_map[entry.domain] = entry.version

    # Prefixing names in the graph if requested, adjusting io_map accordingly
    if prefix1 or prefix2:
        if prefix1:
            m1_copy = ModelProto()
            m1_copy.CopyFrom(m1)
            m1 = m1_copy
            m1 = add_prefix(m1, prefix=prefix1)
        if prefix2:
            m2_copy = ModelProto()
            m2_copy.CopyFrom(m2)
            m2 = m2_copy
            m2 = add_prefix(m2, prefix=prefix2)
        io_map = [
            (
                prefix1 + io[0] if prefix1 else io[0],
                prefix2 + io[1] if prefix2 else io[1],
            )
            for io in io_map
        ]

    graph = merge_graphs(
        m1.graph,
        m2.graph,
        io_map,
        inputs=inputs,
        outputs=outputs,
        name=name,
        doc_string=doc_string,
    )
    model = helper.make_model(
        graph,
        producer_name=producer_name,
        producer_version=producer_version,
        domain=domain,
        model_version=model_version,
        opset_imports=opset_imports,
        ir_version=ir_version,
    )

    # Merging model metadata props
    model_props = {}
    for meta_entry in m1.metadata_props:
        model_props[meta_entry.key] = meta_entry.value
    for meta_entry in m2.metadata_props:
        if meta_entry.key in model_props:
            value = model_props[meta_entry.key]
            if value != meta_entry.value:
                raise ValueError(
                    "Can't merge models with different values for the same model metadata property."
                    f" Found: property = {meta_entry.key}, with values {value} and {meta_entry.value}."
                )
        else:
            model_props[meta_entry.key] = meta_entry.value
    helper.set_model_props(model, model_props)

    # Merging functions
    function_overlap = list(
        {f.name for f in m1.functions} & {f.name for f in m2.functions}
    )
    if function_overlap:
        raise ValueError(
            "Can't merge models with overlapping local function names."
            " Found in both graphs: " + ", ".join(function_overlap)
        )
    model.functions.MergeFrom(m1.functions)
    model.functions.MergeFrom(m2.functions)

    checker.check_model(model)
    return model


def add_prefix_graph(

    graph: GraphProto,

    prefix: str,

    rename_nodes: Optional[bool] = True,

    rename_edges: Optional[bool] = True,

    rename_inputs: Optional[bool] = True,

    rename_outputs: Optional[bool] = True,

    rename_initializers: Optional[bool] = True,

    rename_value_infos: Optional[bool] = True,

    inplace: Optional[bool] = False,

    name_map: Optional[Dict[str, str]] = None,

) -> GraphProto:
    """Adds a prefix to names of elements in a graph: nodes, edges, inputs, outputs,

    initializers, sparse initializer, value infos.



    It can be used as a utility before merging graphs that have overlapping names.

    Empty names are not prefixed.



    Arguments:

        graph (GraphProto): Graph

        prefix (str): Prefix to be added to each name in the graph

        rename_nodes (bool): Whether to prefix node names

        rename_edges (bool): Whether to prefix node edge names

        rename_inputs (bool): Whether to prefix input names

        rename_outputs (bool): Whether to prefix output names

        rename_initializers (bool): Whether to prefix initializer and sparse initializer names

        rename_value_infos (bool): Whether to prefix value info names

        inplace (bool): If True, mutates the graph directly.

                        Otherwise, a copy will be created

        name_map: (Dict): shared name_map in subgraph



    Returns:

        GraphProto

    """
    if type(graph) is not GraphProto:
        raise ValueError("graph argument is not an ONNX graph")

    if not inplace:
        g = GraphProto()
        g.CopyFrom(graph)
    else:
        g = graph

    def _prefixed(prefix: str, name: str) -> str:
        return prefix + name if len(name) > 0 else name

    if name_map is None:
        name_map = {}
    if rename_edges:
        for n in g.node:
            for e in n.input:
                name_map[e] = _prefixed(prefix, e)
            for e in n.output:
                name_map[e] = _prefixed(prefix, e)

    if rename_inputs:
        for entry in g.input:
            name_map[entry.name] = _prefixed(prefix, entry.name)
    if rename_outputs:
        for entry in g.output:
            name_map[entry.name] = _prefixed(prefix, entry.name)

    if rename_nodes:
        for n in g.node:
            n.name = _prefixed(prefix, n.name)
            for attribute in n.attribute:
                if attribute.g:
                    add_prefix_graph(
                        attribute.g, prefix, inplace=True, name_map=name_map
                    )

    if rename_initializers:
        for init in g.initializer:
            name_map[init.name] = _prefixed(prefix, init.name)
        for sparse_init in g.sparse_initializer:
            name_map[sparse_init.values.name] = _prefixed(
                prefix, sparse_init.values.name
            )
            name_map[sparse_init.indices.name] = _prefixed(
                prefix, sparse_init.indices.name
            )

    if rename_value_infos:
        for entry in g.value_info:
            name_map[entry.name] = _prefixed(prefix, entry.name)

    for n in g.node:
        for i, output in enumerate(n.output):
            if n.output[i] in name_map:
                n.output[i] = name_map[output]
        for i, input_ in enumerate(n.input):
            if n.input[i] in name_map:
                n.input[i] = name_map[input_]

    for in_desc in g.input:
        if in_desc.name in name_map:
            in_desc.name = name_map[in_desc.name]
    for out_desc in g.output:
        if out_desc.name in name_map:
            out_desc.name = name_map[out_desc.name]

    for initializer in g.initializer:
        if initializer.name in name_map:
            initializer.name = name_map[initializer.name]
    for sparse_initializer in g.sparse_initializer:
        if sparse_initializer.values.name in name_map:
            sparse_initializer.values.name = name_map[sparse_initializer.values.name]
        if sparse_initializer.indices.name in name_map:
            sparse_initializer.indices.name = name_map[sparse_initializer.indices.name]

    for value_info in g.value_info:
        if value_info.name in name_map:
            value_info.name = name_map[value_info.name]

    return g


def add_prefix(

    model: ModelProto,

    prefix: str,

    rename_nodes: Optional[bool] = True,

    rename_edges: Optional[bool] = True,

    rename_inputs: Optional[bool] = True,

    rename_outputs: Optional[bool] = True,

    rename_initializers: Optional[bool] = True,

    rename_value_infos: Optional[bool] = True,

    rename_functions: Optional[bool] = True,

    inplace: Optional[bool] = False,

) -> ModelProto:
    """Adds a prefix to names of elements in a graph: nodes, edges, inputs, outputs,

    initializers, sparse initializer, value infos, and local functions.



    It can be used as a utility before merging graphs that have overlapping names.

    Empty names are not _prefixed.



    Arguments:

        model (ModelProto): Model

        prefix (str): Prefix to be added to each name in the graph

        rename_nodes (bool): Whether to prefix node names

        rename_edges (bool): Whether to prefix node edge names

        rename_inputs (bool): Whether to prefix input names

        rename_outputs (bool): Whether to prefix output names

        rename_initializers (bool): Whether to prefix initializer and sparse initializer names

        rename_value_infos (bool): Whether to prefix value info nanes

        rename_functions (bool): Whether to prefix local function names

        inplace (bool): If True, mutates the model directly.

                        Otherwise, a copy will be created



    Returns:

        ModelProto

    """
    if type(model) is not ModelProto:
        raise ValueError("model argument is not an ONNX model")

    if not inplace:
        m = ModelProto()
        m.CopyFrom(model)
        model = m

    add_prefix_graph(
        model.graph,
        prefix,
        rename_nodes=rename_nodes,
        rename_edges=rename_edges,
        rename_inputs=rename_inputs,
        rename_outputs=rename_outputs,
        rename_initializers=rename_initializers,
        rename_value_infos=rename_value_infos,
        inplace=True,  # No need to create a copy, since it's a new model
    )

    if rename_functions:
        f_name_map = {}
        for f in model.functions:
            new_f_name = prefix + f.name
            f_name_map[f.name] = new_f_name
            f.name = new_f_name
        # Adjust references to local functions in other local function
        # definitions
        for f in model.functions:
            for n in f.node:
                if n.op_type in f_name_map:
                    n.op_type = f_name_map[n.op_type]
        # Adjust references to local functions in the graph
        for n in model.graph.node:
            if n.op_type in f_name_map:
                n.op_type = f_name_map[n.op_type]

    return model


def expand_out_dim_graph(

    graph: GraphProto,

    dim_idx: int,

    inplace: Optional[bool] = False,

) -> GraphProto:
    """Inserts an extra dimension with extent 1 to each output in the graph.



    Inserts an Unsqueeze node for each output. It can be used as a utility before merging graphs,

    for example when the second one expects a batch dimension.



    Arguments:

        graph (GraphProto): Graph

        dim_idx (int): Index of the dimension to be inserted.

                       A negative value means counting dimensions from the back.

        inplace (bool): If True, mutates the model directly.

                        Otherwise, a copy will be created



    Returns:

        GraphProto

    """
    if type(graph) is not GraphProto:
        raise ValueError("graph argument is not an ONNX graph")

    if not inplace:
        g = GraphProto()
        g.CopyFrom(graph)
    else:
        g = graph

    orig_out_names = [output.name for output in g.output]

    for n in g.node:
        for i, out in enumerate(n.output):
            if out in orig_out_names:
                n.output[i] = out + f"_collapsed_dim_{dim_idx}"
        for i, inp in enumerate(n.input):
            if inp in orig_out_names:
                n.input[i] = inp + f"_collapsed_dim_{dim_idx}"

    expand_dim_k = g.name + "_expand_out_dim_idx"
    g.node.append(
        helper.make_node(
            "Constant",
            inputs=[],
            outputs=[expand_dim_k],
            name=f"{expand_dim_k}-constant",
            value=helper.make_tensor(
                name=f"{expand_dim_k}-value",
                data_type=TensorProto.INT64,
                dims=[
                    1,
                ],
                vals=[
                    dim_idx,
                ],
            ),
        )
    )

    for _ in range(len(g.output)):
        o = g.output.pop(0)
        prev_output = o.name + f"_collapsed_dim_{dim_idx}"
        g.node.append(
            helper.make_node(
                "Unsqueeze",
                inputs=[prev_output, expand_dim_k],
                outputs=[o.name],
                name=f"unsqueeze-{o.name}",
            )
        )
        new_shape = [d.dim_value for d in o.type.tensor_type.shape.dim]
        new_shape.insert(dim_idx, 1)
        g.output.append(
            helper.make_tensor_value_info(
                o.name, o.type.tensor_type.elem_type, new_shape
            )
        )
    return g


def expand_out_dim(

    model: ModelProto,

    dim_idx: int,

    inplace: Optional[bool] = False,

) -> ModelProto:
    """Inserts an extra dimension with extent 1 to each output in the graph.



    Inserts an Unsqueeze node for each output. It can be used as a utility before merging graphs,

    for example when the second one expects a batch dimension.



    Arguments:

        model (ModelProto): Model

        dim_idx (int): Index of the dimension to be inserted.

                       A negative value means counting dimensions from the back.

        inplace (bool): If True, mutates the model directly.

                        Otherwise, a copy will be created



    Returns:

        ModelProto

    """
    if type(model) is not ModelProto:
        raise ValueError("model argument is not an ONNX model")

    if not inplace:
        m = ModelProto()
        m.CopyFrom(model)
        model = m

    expand_out_dim_graph(
        model.graph,
        dim_idx,
        inplace=True,  # No need to create a copy, since it's a new model
    )
    return model