File size: 69,361 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
from collections import defaultdict
from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name
import torch.utils._pytree as pytree
from . import _pytree as fx_pytree
from ._compatibility import compatibility

import contextlib
from typing import TYPE_CHECKING, Callable, Any, List, Dict, NamedTuple, Optional, Tuple, Set, FrozenSet, Type
from dataclasses import dataclass
from contextlib import contextmanager
import copy
import enum
import torch
import keyword
import re
import builtins
import math
import warnings
import inspect

__all__ = ["PythonCode", "CodeGen", "Graph"]

if TYPE_CHECKING:
    from .graph_module import GraphModule  # noqa: F401
    from ._symbolic_trace import Tracer   # noqa: F401


# Mapping of builtins to their `typing` equivalent.
_origin_type_map = {
    list: List,
    dict: Dict,
    set: Set,
    frozenset: FrozenSet,
    tuple: Tuple,
}


# Signature for functions thattransforms the body (`list[str]`) of the
# generated code
TransformCodeFunc = Callable[[List[str]], List[str]]


class _CustomBuiltin(NamedTuple):
    """Additional objs that we add to every graph's globals.



    The repr() for some standard library objects is not valid Python code without

    an import. For common objects of this sort, we bundle them in the globals of

    every FX graph.

    """
    # How to import this object from the standard library.
    import_str: str
    # The actual object, produced from that import string.
    obj: Any

_custom_builtins: Dict[str, _CustomBuiltin] = {}


def _register_custom_builtin(name: str, import_str: str, obj: Any):
    _custom_builtins[name] = _CustomBuiltin(import_str, obj)


_register_custom_builtin('inf', 'from math import inf', math.inf)
_register_custom_builtin('nan', 'from math import nan', math.nan)
_register_custom_builtin('NoneType', 'NoneType = type(None)', type(None))
_register_custom_builtin('torch', 'import torch', torch)
_register_custom_builtin('device', 'from torch import device', torch.device)
_register_custom_builtin('fx_pytree', 'import torch.fx._pytree as fx_pytree', fx_pytree)
_register_custom_builtin('pytree', 'import torch.utils._pytree as pytree', pytree)


def _is_magic(x: str) -> bool:
    return x.startswith('__') and x.endswith('__')


def _snake_case(s: str) -> str:
    """

    Transforms the given string ``s`` to a Python-style variable name



    Examples:

        ``mod.snake_case`` -> ``mod.snake_case``

        ``mod.pascalCase``-> ``mod.pascal_case``

        ``mod.ALL_CAPS`` -> ``mod.all_caps``

    """
    chars = []
    prev_lower = False
    for c in s:
        if prev_lower and c.isupper():
            chars.append('_')
        chars.append(c.lower())
        prev_lower = c.islower()
    return ''.join(chars)


def _is_from_torch(obj: Any) -> bool:
    module_name = getattr(obj, '__module__', None)
    if module_name is not None:
        base_module = module_name.partition('.')[0]
        return (
            base_module == 'torch' and
            not module_name.startswith("torch._dynamo.") and
            not module_name.startswith("torch._inductor.")
        )

    name = getattr(obj, '__name__', None)
    # exclude torch because torch.torch.torch.torch works. idk mang
    if name is not None and name != 'torch':
        for guess in [torch, torch.nn.functional]:
            if getattr(guess, name, None) is obj:
                return True

    return False


class _Namespace:
    """A context for associating names uniquely with objects.



    The following invariants are enforced:

    - Each object gets a single name.

    - Each name is unique within a given namespace.

    - Names generated do not shadow builtins, unless the object is indeed that builtin.

    """
    def __init__(self):
        self._obj_to_name: Dict[Any, str] = {}
        self._unassociated_names = set()
        self._used_names: Set[str] = set()
        self._base_count: Dict[str, int] = defaultdict(int)

        self._illegal_char_regex = re.compile('[^0-9a-zA-Z_]+')
        self._name_suffix_regex = re.compile(r"(.*)_(\d+)$")

    def create_name(self, candidate: str, obj: Optional[Any]) -> str:
        """Create a unique name.



        Arguments:

            candidate: used as the basis for the unique name, relevant to the user.

            obj: If not None, an object that will be associated with the unique name.

        """
        if obj is not None and obj in self._obj_to_name:
            return self._obj_to_name[obj]

        # delete all characters that are illegal in a Python identifier
        candidate = self._illegal_char_regex.sub('_', candidate)

        if not candidate:
            candidate = '_unnamed'

        if candidate[0].isdigit():
            candidate = f'_{candidate}'

        match = self._name_suffix_regex.match(candidate)
        if match is None:
            base = candidate
            num = None
        else:
            base, num_str = match.group(1, 2)
            num = int(num_str)

        candidate = base if num is None else f'{base}_{num}'
        if not num:
            num = self._base_count[base]

        while candidate in self._used_names or self._is_illegal_name(candidate, obj):
            num += 1
            candidate = f'{base}_{num}'

        self._used_names.add(candidate)
        self._base_count[base] = num
        if obj is None:
            self._unassociated_names.add(candidate)
        else:
            self._obj_to_name[obj] = candidate
        return candidate

    def associate_name_with_obj(self, name: str, obj: Any):
        """Associate a unique name with an object.



        Neither `name` nor `obj` should be associated already.

        """
        assert obj not in self._obj_to_name
        assert name in self._unassociated_names
        self._obj_to_name[obj] = name
        self._unassociated_names.remove(name)

    def _is_illegal_name(self, name: str, obj: Any) -> bool:
        # 1. keywords are never allowed as names.
        if name in keyword.kwlist:
            return True

        # 2. Can't shadow a builtin name, unless you *are* that builtin.
        if name in builtins.__dict__:
            return obj is not builtins.__dict__[name]

        # 3. Can't shadow our custom builtins either
        if name in _custom_builtins:
            return obj is not _custom_builtins[name].obj

        return False

    def _rename_object(self, obj: Any, name: str):
        assert obj in self._obj_to_name
        self._obj_to_name[obj] = name
        self._used_names.add(name)

dtype_abbrs = {
    torch.bfloat16: 'bf16',
    torch.float64: 'f64',
    torch.float32: 'f32',
    torch.float16: 'f16',
    torch.float8_e4m3fn: 'f8e4m3fn',
    torch.float8_e5m2: 'f8e5m2',
    torch.float8_e4m3fnuz: 'f8e4m3fnuz',
    torch.float8_e5m2fnuz: 'f8e5m2fnuz',
    torch.complex32: 'c32',
    torch.complex64: 'c64',
    torch.complex128: 'c128',
    torch.int8: 'i8',
    torch.int16: 'i16',
    torch.int32: 'i32',
    torch.int64: 'i64',
    torch.bool: 'b8',
    torch.uint8: 'u8',
    torch.uint32: 'u32',
    torch.uint64: 'u64',
}

@compatibility(is_backward_compatible=True)
@dataclass
class PythonCode:
    """

    Represents all the information necessary to exec or save a graph as Python code.

    """
    # Python source code for the forward function definition.
    src: str
    # Values in global scope during execution of `src_def`.
    globals: Dict[str, Any]
    # Optional mapping from the forward function's line number to
    # node index.
    _lineno_map: Optional[Dict[int, Optional[int]]]


def _format_target(base: str, target: str) -> str:
    elems = target.split('.')
    r = base
    for e in elems:
        if not e.isidentifier():
            r = f'getattr({r}, "{e}")'
        else:
            r = f'{r}.{e}'
    return r

class _InsertPoint:
    def __init__(self, graph, new_insert):
        self.graph = graph
        self.orig_insert, graph._insert = graph._insert, new_insert

    def __enter__(self):
        pass

    def __exit__(self, type, value, tb):
        self.graph._insert = self.orig_insert

class _node_list:
    def __init__(self, graph: 'Graph', direction: str = '_next'):
        assert direction in ['_next', '_prev']
        self.graph = graph
        self.direction = direction

    def __len__(self):
        return self.graph._len

    def __iter__(self):
        root = self.graph._root
        if self.direction == "_next":
            cur = root._next
            while cur is not root:
                if not cur._erased:
                    yield cur
                cur = cur._next
        else:
            assert self.direction == "_prev"
            cur = root._prev
            while cur is not root:
                if not cur._erased:
                    yield cur
                cur = cur._prev

    def __reversed__(self):
        return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev')

class _PyTreeInfo(NamedTuple):
    """

    Contains extra info stored when we're using Pytrees

    """
    orig_args: List[str]
    in_spec: pytree.TreeSpec
    out_spec: Optional[pytree.TreeSpec]

@dataclass(frozen=True)
class _ParsedStackTrace:
    """

    Represents the top-most frame of a parsed stack trace

    """
    file: str
    lineno: str
    name: str
    code: str

# get File:lineno code from stack_trace
def _parse_stack_trace(stack_trace: str):
    if stack_trace is None:
        return None
    pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$")
    lines = stack_trace.strip().split('\n')
    # stacktrace should have innermost frame last, so we
    # iterate backwards to find the first line that starts
    # with 'File '
    summary_str = ""
    for idx in range(len(lines) - 2, -1, -1):
        line = lines[idx].strip()
        matches = pattern.match(line)
        if matches:
            file = matches.group(1)
            lineno = matches.group(2)
            name = matches.group(3)
            # next line should be the code
            code = lines[idx + 1].strip()
            return _ParsedStackTrace(file, lineno, name, code)
    return None

@compatibility(is_backward_compatible=False)
class CodeGen:
    def __init__(self):
        self._body_transformer: Optional[TransformCodeFunc] = None
        self._func_name: str = "forward"

    def gen_fn_def(self, free_vars: List[str], maybe_return_annotation: str) -> str:
        """

        Given the free variables and a return annotation, generates the beginning of the FX function.

        By default, `gen_fn_def(['a', 'b'], '') == 'def {self._func_name}(a, b):'`

        """
        # If the original function didn't have self as its first argument, we
        # would have added it.
        if len(free_vars) == 0 or free_vars[0] != 'self':
            free_vars.insert(0, 'self')
        return f"def {self._func_name}({', '.join(free_vars)}){maybe_return_annotation}:"

    def generate_output(self, output_args: Argument) -> str:
        """

        Given the output arguments, generates the return statement of the FX function.

        Note: The returned statement should not be indented.

        """
        return f'return {repr(output_args)}'

    def process_inputs(self, *args: Any) -> Any:
        """

        Transforms the inputs so that the graph can take them as arguments, as

        non-default codegen may result in the inputs to the function being

        different from the inputs to the graph.



        If the graph was directly runnable, this invariant should hold true

        `f.graph.process_outputs(f.graph(*f.graph.process_inputs(*inputs))) == f(*inputs)`

        """
        return args

    def process_outputs(self, outputs: Any) -> Any:
        """

        Transforms the outputs of the graph to be identical to the codegen.



        See ``process_inputs`` for more details.

        """
        return outputs

    def additional_globals(self) -> List[Tuple[str, Any]]:
        """

        If your codegen uses extra global values, add tuples of (identifier,reference to the value) here.

        For example, return ['List', typing.List] if you need ``List`` in the global context.

        """
        return []

    def _gen_python_code(

        self, nodes, root_module: str, namespace: _Namespace, *, verbose: bool = False,

    ) -> PythonCode:
        free_vars: List[str] = []
        body: List[str] = []
        globals_: Dict[str, Any] = {}
        wrapped_fns: Dict[str, None] = {}

        # Wrap string in list to pass by reference
        maybe_return_annotation : List[str] = ['']

        def add_global(name_hint: str, obj: Any):
            """Add an obj to be tracked as a global.



            We call this for names that reference objects external to the

            Graph, like functions or types.



            Returns: the global name that should be used to reference 'obj' in generated source.

            """
            if _is_from_torch(obj) and obj != torch.device:  # to support registering torch.device
                # HACK: workaround for how torch custom ops are registered. We
                # can't import them like normal modules so they must retain their
                # fully qualified name.
                return _get_qualified_name(obj)

            # normalize the name hint to get a proper identifier
            global_name = namespace.create_name(name_hint, obj)

            if global_name in globals_:
                assert globals_[global_name] is obj
                return global_name
            globals_[global_name] = obj
            return global_name

        # Pre-fill the globals table with registered builtins.
        for name, (_, obj) in _custom_builtins.items():
            add_global(name, obj)

        def type_repr(o : Any):
            if o == ():
                # Empty tuple is used for empty tuple type annotation Tuple[()]
                return '()'

            typename = _type_repr(o)

            if hasattr(o, '__origin__'):
                # This is a generic type, e.g. typing.List[torch.Tensor]
                origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
                origin_typename = add_global(_type_repr(origin_type), origin_type)

                if hasattr(o, '__args__'):
                    # Assign global names for each of the inner type variables.
                    args = [type_repr(arg) for arg in o.__args__]

                    if len(args) == 0:
                        # Bare type, such as `typing.Tuple` with no subscript
                        # This code-path used in Python < 3.9
                        return origin_typename

                    return f'{origin_typename}[{",".join(args)}]'
                else:
                    # Bare type, such as `typing.Tuple` with no subscript
                    # This code-path used in Python 3.9+
                    return origin_typename

            # Common case: this is a regular module name like 'foo.bar.baz'
            return add_global(typename, o)

        def _get_repr(arg: Any) -> str:
            # Handle NamedTuples (if it has `_fields`) via add_global.
            if isinstance(arg, tuple) and hasattr(arg, '_fields'):
                qualified_name = _get_qualified_name(type(arg))
                global_name = add_global(qualified_name, type(arg))
                return f"{global_name}{repr(tuple(arg))}"
            elif isinstance(arg, torch._ops.OpOverload):
                qualified_name = _get_qualified_name(arg)
                global_name = add_global(qualified_name, arg)
                return f"{global_name}"
            elif isinstance(arg, enum.Enum):
                cls = arg.__class__
                clsname = add_global(cls.__name__, cls)
                return f"{clsname}.{arg.name}"
            return repr(arg)

        def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
            args_s = ', '.join(_get_repr(a) for a in args)
            kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
            if args_s and kwargs_s:
                return f'{args_s}, {kwargs_s}'
            return args_s or kwargs_s

        # Run through reverse nodes and record the first instance of a use
        # of a given node. This represents the *last* use of the node in the
        # execution order of the program, which we will use to free unused
        # values
        node_to_last_use : Dict[Node, Node] = {}
        user_to_last_uses : Dict[Node, List[Node]] = {}

        def register_last_uses(n : Node, user : Node):
            if n not in node_to_last_use:
                node_to_last_use[n] = user
                user_to_last_uses.setdefault(user, []).append(n)

        for node in reversed(nodes):
            map_arg(node.args, lambda n: register_last_uses(n, node))
            map_arg(node.kwargs, lambda n: register_last_uses(n, node))

        def delete_unused_values(user : Node):
            """

            Delete values after their last use. This ensures that values that are

            not used in the remainder of the code are freed and the memory usage

            of the code is optimal.

            """
            if user.op == 'placeholder':
                return
            if user.op == 'output':
                body.append('\n')
                return
            nodes_to_delete = user_to_last_uses.get(user, [])
            if len(nodes_to_delete):
                to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
                body.append(f';  {to_delete_str}\n')
            else:
                body.append('\n')

        prev_stacktrace = None

        def append_stacktrace_summary(node : Node):
            """

            Append a summary of the stacktrace to the generated code. This is

            useful for debugging.

            """
            nonlocal prev_stacktrace

            if node.op not in {'placeholder', 'output'}:
                if node.stack_trace:
                    if node.stack_trace != prev_stacktrace:
                        prev_stacktrace = node.stack_trace
                        summary_str = ""

                        parsed_stack_trace = _parse_stack_trace(node.stack_trace)

                        if parsed_stack_trace is not None:
                            lineno = parsed_stack_trace.lineno
                            code = parsed_stack_trace.code
                            name = parsed_stack_trace.name
                            summary_str = f'File: {parsed_stack_trace.file}:{lineno} in {name}, code: {code}'

                        body.append(f'\n# {summary_str}\n')
                elif prev_stacktrace != "":
                    prev_stacktrace = ""
                    body.append('\n# No stacktrace found for following nodes\n')

        def stringify_shape(shape : torch.Size) -> str:
            return f"[{', '.join(str(x) for x in shape)}]"

        def emit_node(node : Node):
            maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'

            if verbose:
                # override annotation with more detailed information
                from torch._subclasses.fake_tensor import FakeTensor
                from torch.fx.experimental.proxy_tensor import py_sym_types
                from torch.fx.passes.shape_prop import TensorMetadata

                meta_val = node.meta.get('val', node.meta.get('tensor_meta', None))

                # use string as annotation, to make it valid python code
                if isinstance(meta_val, FakeTensor):
                    maybe_type_annotation = f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}"'
                elif isinstance(meta_val, py_sym_types):
                    maybe_type_annotation = f': "Sym({meta_val})"'
                elif isinstance(meta_val, TensorMetadata):
                    maybe_type_annotation = f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}"'

            if node.op == 'placeholder':
                assert isinstance(node.target, str)
                maybe_default_arg = '' if not node.args else f' = {_get_repr(node.args[0])}'
                free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
                raw_name = node.target.replace('*', '')
                if raw_name != repr(node):
                    body.append(f'{repr(node)} = {raw_name}\n')
                return
            elif node.op == 'call_method':
                assert isinstance(node.target, str)
                body.append(
                    f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.target)}'
                    f'({_format_args(node.args[1:], node.kwargs)})')
                return
            elif node.op == 'call_function':
                assert callable(node.target)
                # pretty print operators
                if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in magic_methods:
                    assert isinstance(node.args, tuple)
                    body.append(f'{repr(node)}{maybe_type_annotation} = '
                                f'{magic_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}')
                    return

                # pretty print inplace operators; required for jit.script to work properly
                # not currently supported in normal FX graphs, but generated by torchdynamo
                if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in inplace_methods:
                    body.append(f'{inplace_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))};  '
                                f'{repr(node)}{maybe_type_annotation} = {_get_repr(node.args[0])}')
                    return

                qualified_name = _get_qualified_name(node.target)
                global_name = add_global(qualified_name, node.target)
                # special case for getattr: node.args could be 2-argument or 3-argument
                # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
                if global_name == 'getattr' and \
                   isinstance(node.args, tuple) and \
                   isinstance(node.args[1], str) and \
                   node.args[1].isidentifier() and \
                   len(node.args) == 2:
                    body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.args[1])}')
                    return
                body.append(f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
                if node.meta.get('is_wrapped', False):
                    wrapped_fns.setdefault(global_name)
                return
            elif node.op == 'call_module':
                assert isinstance(node.target, str)
                body.append(f'{repr(node)}{maybe_type_annotation} = '
                            f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
                return
            elif node.op == 'get_attr':
                assert isinstance(node.target, str)
                body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
                return
            elif node.op == 'output':
                if node.type is not None:
                    maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
                body.append(self.generate_output(node.args[0]))
                return
            raise NotImplementedError(f'node: {node.op} {node.target}')

        for i, node in enumerate(nodes):
            # NOTE: emit_node does not emit a string with newline. It depends
            # on delete_unused_values to append one
            if verbose:
                append_stacktrace_summary(node)
            # emit a counter comment to keep track of
            # node index, which will be deleted later
            # after going through _body_transformer
            body.append(f"# COUNTER: {i}\n")
            emit_node(node)
            delete_unused_values(node)

        if len(body) == 0:
            # If the Graph has no non-placeholder nodes, no lines for the body
            # have been emitted. To continue to have valid Python code, emit a
            # single pass statement
            body.append('pass\n')



        if len(wrapped_fns) > 0:
            wrap_name = add_global('wrap', torch.fx.wrap)
            wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
        else:
            wrap_stmts = ''

        if self._body_transformer:
            body = self._body_transformer(body)

        for name, value in self.additional_globals():
            add_global(name, value)

        prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])

        # remove counter and generate lineno to node index mapping
        lineno_map: Dict[int, Optional[int]] = {}
        prologue_len = prologue.count('\n') + 1
        new_lines: List[str] = []
        cur_idx = None
        for line in ''.join(body).split('\n'):
            counter = re.search(r"# COUNTER: (\d+)", line)
            if counter and counter.group(1) is not None:
                cur_idx = int(counter.group(1))
            else:
                lineno_map[len(new_lines) + prologue_len] = cur_idx
                new_lines.append(line)

        code = "\n".join(new_lines).lstrip('\n')
        code = '\n'.join('    ' + line for line in code.split('\n'))

        fn_code = f"""

{wrap_stmts}



{prologue}

{code}"""
        return PythonCode(fn_code, globals_, _lineno_map=lineno_map)


# Ideally, we'd like to refactor all of the pytree logic into this codegen
# class. Unfortunately, there are 3 areas we currently need extra logic in FX.
# 1. In the initial symbolic trace, the pytree logic is tied up with `concrete_args`.
# 2. In the FX graph, we need to access 2 attributes - in_spec and out_spec.
#    Since we can't access .graph within the FX forward, we need to copy the attribute to the module.
# 3. We currently can't register the pytree imports with `add_global` - not sure why.
class _PyTreeCodeGen(CodeGen):
    def __init__(self, pytree_info: _PyTreeInfo):
        super().__init__()
        self.pytree_info: _PyTreeInfo = pytree_info

    def process_inputs(self, *inputs: Any) -> Any:
        flat_args = pytree.arg_tree_leaves(*inputs)
        return flat_args

    def process_outputs(self, out: Any) -> Any:
        if self.pytree_info is None or self.pytree_info.out_spec is None:
            return out
        if not isinstance(out, (list, tuple)):
            out = [out]
        assert self.pytree_info.out_spec is not None
        return pytree.tree_unflatten(out, self.pytree_info.out_spec)

    def gen_fn_def(self, free_vars, maybe_return_annotation):
        # Given a user function/model:
        #   myargs = [myargs0, myargs1]
        #   mykwargs = {'mykwargs0': ..., 'mykwargs1': ...}
        #   def forward(self, mypos, *myargs, mykey=None, **mykwargs):
        #
        # The generated code flattens all keywords into positional arguments for `forward()`
        #   e.g forward(self, mypos, myargs0, myargs1, mykey, mykwargs0, mykwargs1):
        #
        # Within `forward`, `tree_flatten_spec``still parses args and kwargs separately
        #   e.g. tree_flatten_spec(([mypos, myargs0, myargs1],
        #                           {'mykey':mykey, 'mykwargs0':mykwargs0, 'mykwargs1':mykwargs1}),
        #                          self._in_spec)
        #
        # If the user function/model does not have keywords, the dict is suppressed from tree_flatten_spec
        #   e.g. tree_flatten_spec([mypos, myargs0, myargs1]), self._in_spec)
        if self.pytree_info is None:
            return super().gen_fn_def(free_vars, maybe_return_annotation)

        fn_args = self.pytree_info.orig_args
        has_orig_self = (fn_args[0] == 'self') if len(fn_args) > 0 else False
        if has_orig_self:
            free_vars.insert(0, 'self')
        fn_definition = super().gen_fn_def(fn_args[:], maybe_return_annotation)

        if len(free_vars) > 0:  # pytree has placeholders in it
            # when kwargs is present, in_spec is tuple(args, kwargs)
            has_args_kwargs_tuple = self.pytree_info.in_spec.type == tuple and \
                self.pytree_info.in_spec.num_children == 2 and \
                self.pytree_info.in_spec.children_specs[0].type == tuple and \
                self.pytree_info.in_spec.children_specs[1].type == dict
            fn_kwargs = '{}'
            fn_signature = f"[{', '.join(fn_args)}], self._in_spec"
            if has_args_kwargs_tuple:
                count_args = self.pytree_info.in_spec.children_specs[0].num_children
                fn_args = self.pytree_info.orig_args[:count_args]
                fn_kwargs = '{' + ', '.join(f"'{k}':{v}" for k, v in zip(
                                  self.pytree_info.in_spec.children_specs[1].context,
                                  self.pytree_info.orig_args[count_args:])) + '}'
                fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec"

            # in Python, `var1: annotation1, var2: annotation2 = function_call()` is invalid.
            # we need to split it to two lines:
            # one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon)
            # one for code: `var1, var2, = function_call()`
            without_annotation = [x.split(":")[0] for x in free_vars]
            has_annotation = [x + "; " for x in free_vars if ":" in x]
            if len(has_annotation) > 0:
                fn_definition += "\n    " + "".join(has_annotation) + "\n"
            fn_definition += f"""

    {', '.join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})"""
        return fn_definition

    def generate_output(self, output_args):
        if self.pytree_info and self.pytree_info.out_spec:
            return f'return pytree.tree_unflatten({repr(output_args)}, self._out_spec)'
        else:
            return super().generate_output(output_args)

@compatibility(is_backward_compatible=True)
class Graph:
    """

    ``Graph`` is the main data structure used in the FX Intermediate Representation.

    It consists of a series of ``Node`` s, each representing callsites (or other

    syntactic constructs). The list of ``Node`` s, taken together, constitute a

    valid Python function.



    For example, the following code



    .. code-block:: python



        import torch

        import torch.fx



        class MyModule(torch.nn.Module):

            def __init__(self):

                super().__init__()

                self.param = torch.nn.Parameter(torch.rand(3, 4))

                self.linear = torch.nn.Linear(4, 5)



            def forward(self, x):

                return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)



        m = MyModule()

        gm = torch.fx.symbolic_trace(m)



    Will produce the following Graph::



        print(gm.graph)



    .. code-block:: text



        graph(x):

            %linear_weight : [num_users=1] = self.linear.weight

            %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})

            %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})

            %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})

            %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})

            %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})

            return topk_1



    For the semantics of operations represented in the ``Graph``, please see :class:`Node`.

    """

    @compatibility(is_backward_compatible=True)
    def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional[Type["Tracer"]] = None,

                 tracer_extras: Optional[Dict[str, Any]] = None):
        """

        Construct an empty Graph.

        """
        self._root : Node = Node(self, '', 'root', '', (), {})
        self._used_names : Dict[str, int] = {}  # base name -> number
        self._insert = self._root.prepend
        self._len = 0
        self._graph_namespace = _Namespace()
        self._owning_module = owning_module
        self._tracer_cls = tracer_cls
        self._tracer_extras = tracer_extras
        self._codegen = CodeGen()
        self._co_fields : Dict[str, Any] = {}

    @property
    def owning_module(self):
        return self._owning_module

    @owning_module.setter
    def owning_module(self, mod: Optional["GraphModule"]):
        self._owning_module = mod

    @property
    def nodes(self) -> _node_list:
        """

        Get the list of Nodes that constitute this Graph.



        Note that this ``Node`` list representation is a doubly-linked list. Mutations

        during iteration (e.g. delete a Node, add a Node) are safe.



        Returns:



            A doubly-linked list of Nodes. Note that ``reversed`` can be called on

            this list to switch iteration order.

        """
        return _node_list(self)

    @compatibility(is_backward_compatible=True)
    def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node], return_output_node=False) -> 'Optional[Argument]':
        """

        Copy all nodes from a given graph into ``self``.



        Args:



            g (Graph): The source graph from which to copy Nodes.



            val_map (Dict[Node, Node]): a dictionary that will be populated with a mapping

                from nodes in ``g`` to nodes in ``self``. Note that ``val_map`` can be passed

                in with values in it already to override copying of certain values.



        Returns:



            The value in ``self`` that is now equivalent to the output value in ``g``,

            if ``g`` had an ``output`` node. ``None`` otherwise.

        """
        for node in g.nodes:
            if node in val_map:
                continue
            if node.op == 'output':
                rv = map_arg(node.args[0], lambda n: val_map[n])
                return rv if not return_output_node else (rv, node)
            val_map[node] = self.node_copy(node, lambda n : val_map[n])
        return None

    def __deepcopy__(self, memo=None) -> 'Graph':
        """

        Explicitly implement __deepcopy__ to prevent excessive recursion depth

        from the default implementation. This uses graph_copy to copy the nodes

        in an iterative way, rather than recursive. It also populates the

        memoization table to prevent unnecessary copies (e.g. references to

        nodes or other parts of the Graph from a custom GraphModule implementation.

        """
        memo = memo if memo else {}
        g = Graph(tracer_cls=self._tracer_cls)
        output_vals = g.graph_copy(self, val_map=memo, return_output_node=True)
        g._codegen = copy.deepcopy(self._codegen)
        assert isinstance(output_vals, tuple)
        output_val, old_output_node = output_vals
        new_output_node = g.output(output_val, type_expr=getattr(old_output_node, 'type', None))
        new_output_node.meta = copy.copy(old_output_node.meta)
        return g

    @compatibility(is_backward_compatible=True)
    def create_node(self, op: str, target: 'Target',

                    args: Optional[Tuple['Argument', ...]] = None,

                    kwargs: Optional[Dict[str, 'Argument']] = None,

                    name: Optional[str] = None,

                    type_expr: Optional[Any] = None) -> Node:
        """

        Create a ``Node`` and add it to the ``Graph`` at the current insert-point.

        Note that the current insert-point can be set via :meth:`Graph.inserting_before`

        and :meth:`Graph.inserting_after`.



        Args:

            op (str): the opcode for this Node. One of 'call_function', 'call_method', 'get_attr',

                'call_module', 'placeholder', or 'output'. The semantics of these opcodes are

                described in the ``Graph`` docstring.



            args (Optional[Tuple[Argument, ...]]): is a tuple of arguments to this node.



            kwargs (Optional[Dict[str, Argument]]): the kwargs of this Node



            name (Optional[str]): an optional string name for the ``Node``.

                This will influence the name of the value assigned to in the

                Python generated code.



            type_expr (Optional[Any]): an optional type annotation representing the

                Python type the output of this node will have.



        Returns:



            The newly-created and inserted node.

        """
        assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output')
        args = () if args is None else args
        kwargs = {} if kwargs is None else kwargs
        assert isinstance(args, tuple), "args must be a tuple"
        assert isinstance(kwargs, dict), "kwargs must be a dict"

        candidate = name if name is not None else self._target_to_str(target)
        name = self._graph_namespace.create_name(candidate, None)
        n = Node(self, name, op, target, args, kwargs, type_expr)

        self._graph_namespace.associate_name_with_obj(name, n)

        self._insert(n)
        self._len += 1
        return n

    @compatibility(is_backward_compatible=False)
    def process_inputs(self, *args):
        """

        Processes args so that they can be passed to the FX graph.

        """
        return self._codegen.process_inputs(*args)

    @compatibility(is_backward_compatible=False)
    def process_outputs(self, out):
        return self._codegen.process_outputs(out)


    @compatibility(is_backward_compatible=True)
    def erase_node(self, to_erase : Node) -> None:
        """

        Erases a ``Node`` from the ``Graph``. Throws an exception if

        there are still users of that node in the ``Graph``.



        Args:



            to_erase (Node): The ``Node`` to erase from the ``Graph``.

        """
        if len(to_erase.users) > 0:
            raise RuntimeError(f'Tried to erase Node {to_erase} but it still had {len(to_erase.users)} '
                               f'users in the graph: {to_erase.users}!')
        if to_erase.graph != self:
            raise RuntimeError(f"Attempting to remove {to_erase} from wrong graph!")
        if to_erase._erased:
            warnings.warn(f"erase_node({to_erase}) on an already erased node")
            return

        to_erase._remove_from_list()
        to_erase._erased = True  # iterators may retain handles to erased nodes
        self._len -= 1

        # Null out this Node's argument nodes so that the Nodes referred to
        # can update their ``users`` accordingly
        new_args = map_arg(to_erase.args, lambda n: None)
        assert isinstance(new_args, tuple)
        to_erase.args = new_args
        new_kwargs = map_arg(to_erase.kwargs, lambda n: None)
        assert isinstance(new_kwargs, dict)
        to_erase.kwargs = new_kwargs

    @compatibility(is_backward_compatible=True)
    def inserting_before(self, n: Optional[Node] = None):
        """Set the point at which create_node and companion methods will insert into the graph.

        When used within a 'with' statement, this will temporary set the insert point and

        then restore it when the with statement exits::



            with g.inserting_before(n):

                ... # inserting before node n

            ... # insert point restored to what it was previously

            g.inserting_before(n) #  set the insert point permanently



        Args:



            n (Optional[Node]): The node before which to insert. If None this will insert before

                the beginning of the entire graph.



        Returns:

            A resource manager that will restore the insert point on ``__exit__``.

        """
        if n is None:
            return self.inserting_after(self._root)
        assert n.graph == self, "Node to insert before is not in graph."
        return _InsertPoint(self, n.prepend)

    @compatibility(is_backward_compatible=True)
    def inserting_after(self, n: Optional[Node] = None):
        """Set the point at which create_node and companion methods will insert into the graph.

        When used within a 'with' statement, this will temporary set the insert point and

        then restore it when the with statement exits::



            with g.inserting_after(n):

                ... # inserting after node n

            ... # insert point restored to what it was previously

            g.inserting_after(n) #  set the insert point permanently



        Args:



            n (Optional[Node]): The node before which to insert. If None this will insert after

                the beginning of the entire graph.



        Returns:

            A resource manager that will restore the insert point on ``__exit__``.

        """
        if n is None:
            return self.inserting_before(self._root)
        assert n.graph == self, "Node to insert after is not in graph."
        return _InsertPoint(self, n.append)

    @compatibility(is_backward_compatible=True)
    def placeholder(self, name: str, type_expr: Optional[Any] = None,

                    default_value : Any = inspect.Signature.empty) -> Node:
        """

        Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents

        a function input.



        Args:



            name (str): A name for the input value. This corresponds to the name

                of the positional argument to the function this ``Graph`` represents.



            type_expr (Optional[Any]): an optional type annotation representing the

                Python type the output of this node will have. This is needed in some

                cases for proper code generation (e.g. when the function is used

                subsequently in TorchScript compilation).



            default_value (Any): The default value this function argument should take

                on. NOTE: to allow for `None` as a default value, `inspect.Signature.empty`

                should be passed as this argument to specify that the parameter does _not_

                have a default value.



        .. note::

            The same insertion point and type expression rules apply for this method

            as ``Graph.create_node``.

        """
        args = () if default_value is inspect.Signature.empty else (default_value,)
        return self.create_node('placeholder', name, args=args, type_expr=type_expr)

    @compatibility(is_backward_compatible=True)
    def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node:
        """

        Insert a ``get_attr`` node into the Graph. A ``get_attr`` ``Node`` represents the

        fetch of an attribute from the ``Module`` hierarchy.



        Args:



            qualified_name (str): the fully-qualified name of the attribute to be retrieved.

                For example, if the traced Module has a submodule named ``foo``, which has a

                submodule named ``bar``, which has an attribute named ``baz``, the qualified

                name ``foo.bar.baz`` should be passed as ``qualified_name``.



            type_expr (Optional[Any]): an optional type annotation representing the

                Python type the output of this node will have.





        Returns:



            The newly-created and inserted ``get_attr`` node.



        .. note::

            The same insertion point and type expression rules apply for this method

            as ``Graph.create_node``.

        """
        def _get_attr_reference_exists(mod: torch.nn.Module, qualified_name: str) -> bool:
            module_path, _, name = qualified_name.rpartition(".")

            try:
                submod: torch.nn.Module = mod.get_submodule(module_path)
            except AttributeError:
                warnings.warn(f"Failed to fetch module {module_path}!")
                return False

            if not hasattr(submod, name):
                return False

            res = getattr(submod, name)

            if (not isinstance(res, torch.nn.Module)
                    and not isinstance(res, torch.nn.Parameter)
                    and name not in submod._buffers):
                return False

            return True

        if (self.owning_module and
                not _get_attr_reference_exists(self.owning_module, qualified_name)):
            warnings.warn("Attempted to insert a get_attr Node with no "
                          "underlying reference in the owning "
                          "GraphModule! Call "
                          "GraphModule.add_submodule to add the "
                          "necessary submodule, "
                          "GraphModule.add_parameter to add the "
                          "necessary Parameter, or "
                          "nn.Module.register_buffer to add the "
                          "necessary buffer", stacklevel=2)
        return self.create_node('get_attr', qualified_name, type_expr=type_expr)

    @compatibility(is_backward_compatible=True)
    def call_module(self,

                    module_name: str,

                    args: Optional[Tuple['Argument', ...]] = None,

                    kwargs: Optional[Dict[str, 'Argument']] = None,

                    type_expr: Optional[Any] = None) -> Node:
        """

        Insert a ``call_module`` ``Node`` into the ``Graph``. A ``call_module`` node

        represents a call to the forward() function of a ``Module`` in the ``Module``

        hierarchy.



        Args:



            module_name (str): The qualified name of the ``Module`` in the ``Module``

                hierarchy to be called. For example, if the traced ``Module`` has a

                submodule named ``foo``, which has a submodule named ``bar``, the

                qualified name ``foo.bar`` should be passed as ``module_name`` to

                call that module.



            args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed

                to the called method. Note that this should *not* include a ``self`` argument.



            kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed

                to the called method



            type_expr (Optional[Any]): an optional type annotation representing the

                Python type the output of this node will have.



        Returns:



            The newly-created and inserted ``call_module`` node.



        .. note::

            The same insertion point and type expression rules apply for this method

            as :meth:`Graph.create_node`.

        """
        if (self.owning_module and
                self.owning_module.get_submodule(module_name) is None):
            warnings.warn("Attempted to insert a call_module Node with "
                          "no underlying reference in the owning "
                          "GraphModule! Call "
                          "GraphModule.add_submodule to add the "
                          "necessary submodule")
        return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr)

    @compatibility(is_backward_compatible=True)
    def call_method(self,

                    method_name: str,

                    args: Optional[Tuple['Argument', ...]] = None,

                    kwargs: Optional[Dict[str, 'Argument']] = None,

                    type_expr: Optional[Any] = None) -> Node:
        """

        Insert a ``call_method`` ``Node`` into the ``Graph``. A ``call_method`` node

        represents a call to a given method on the 0th element of ``args``.



        Args:



            method_name (str): The name of the method to apply to the self argument.

                For example, if args[0] is a ``Node`` representing a ``Tensor``,

                then to call ``relu()`` on that ``Tensor``, pass ``relu`` to ``method_name``.



            args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed

                to the called method. Note that this *should* include a ``self`` argument.



            kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed

                to the called method



            type_expr (Optional[Any]): an optional type annotation representing the

                Python type the output of this node will have.



        Returns:



            The newly created and inserted ``call_method`` node.



        .. note::

            The same insertion point and type expression rules apply for this method

            as :meth:`Graph.create_node`.

        """
        return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr)

    @compatibility(is_backward_compatible=True)
    def call_function(self,

                      the_function: Callable[..., Any],

                      args: Optional[Tuple['Argument', ...]] = None,

                      kwargs: Optional[Dict[str, 'Argument']] = None,

                      type_expr: Optional[Any] = None) -> Node:
        """

        Insert a ``call_function`` ``Node`` into the ``Graph``. A ``call_function`` node

        represents a call to a Python callable, specified by ``the_function``.



        Args:



            the_function (Callable[..., Any]): The function to be called. Can be any PyTorch

                operator, Python function, or member of the ``builtins`` or ``operator``

                namespaces.



            args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed

                to the called function.



            kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed

                to the called function



            type_expr (Optional[Any]): an optional type annotation representing the

                Python type the output of this node will have.



        Returns:



            The newly created and inserted ``call_function`` node.



        .. note::

            The same insertion point and type expression rules apply for this method

            as :meth:`Graph.create_node`.

        """
        return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr)

    @compatibility(is_backward_compatible=True)
    def node_copy(self, node: Node, arg_transform: Callable[[Node], 'Argument'] = lambda x: x) -> Node:
        """

        Copy a node from one graph into another. ``arg_transform`` needs to transform arguments from

        the graph of node to the graph of self. Example::



            # Copying all the nodes in `g` into `new_graph`

            g : torch.fx.Graph = ...

            new_graph = torch.fx.graph()

            value_remap = {}

            for node in g.nodes:

                value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n])



        Args:



            node (Node): The node to copy into ``self``.



            arg_transform (Callable[[Node], Argument]): A function that transforms

                ``Node`` arguments in node's ``args`` and ``kwargs`` into the

                equivalent argument in ``self``. In the simplest case, this should

                retrieve a value out of a table mapping Nodes in the original

                graph to ``self``.

        """
        args = map_arg(node.args, arg_transform)
        kwargs = map_arg(node.kwargs, arg_transform)
        assert isinstance(args, tuple)
        assert isinstance(kwargs, dict)
        result_node = self.create_node(node.op, node.target, args, kwargs, node.name, node.type)
        result_node.meta = copy.copy(node.meta)
        return result_node

    @compatibility(is_backward_compatible=True)
    def output(self, result: 'Argument', type_expr: Optional[Any] = None):
        """

        Insert an ``output`` ``Node`` into the ``Graph``. An ``output`` node represents

        a ``return`` statement in Python code. ``result`` is the value that should

        be returned.



        Args:



            result (Argument): The value to be returned.



            type_expr (Optional[Any]): an optional type annotation representing the

                Python type the output of this node will have.



        .. note::



            The same insertion point and type expression rules apply for this method

            as ``Graph.create_node``.

        """
        return self.create_node(op='output', target='output', args=(result,), type_expr=type_expr)

    def _target_to_str(self, target : Target) -> str:
        if callable(target):
            op = target.__name__
        else:
            assert isinstance(target, str)
            op = target
            if _is_magic(op):
                op = op[2:-2]
        op = _snake_case(op)
        return op

    @compatibility(is_backward_compatible=True)
    def python_code(self, root_module: str, *, verbose: bool = False) -> PythonCode:
        """

        Turn this ``Graph`` into valid Python code.



        Args:



            root_module (str): The name of the root module on which to look-up

                qualified name targets. This is usually 'self'.



        Returns:



            A PythonCode object, consisting of two fields:

                src: the Python source code representing the object

                globals: a dictionary of global names in `src` -> the objects that they reference.

        """
        # NOTE: [Graph Namespaces]
        #
        # There are two types of symbols in generated Python source code:
        # locals and globals.
        #   Locals are locally defined by the output of a node in the Graph.
        #   Globals are references to external objects, like functions or types.
        #
        # When generating Python code, we need to make sure to name things
        # appropriately. In particular:
        # - All names should be unique, to avoid weird shadowing bugs.
        # - These names need to be consistent, e.g. a object should always be
        #   referenced by the same name.
        #
        # To do this, we create a new namespace just for this source. All names
        # that get printed must come from this namespace.
        #
        # Why can't we re-use node.name? Because it was generated within the
        # namespace `self._graph_namespace`. In order to provide uniqueness
        # over both locals (node.name) *and* globals, we create a completely
        # new namespace to put all identifiers in.
        namespace = _Namespace()

        # Override Node's repr to generate a valid name within our namespace.
        # Since repr() is designed to produce a valid Python expression, it
        # makes sense to re-use it. This way, it's easy to print something like
        # Tuple[Node, Node] by simply calling repr() on it. Node's __repr__ is
        # implemented cooperatively to allow this.
        def node_repr(n: Node):
            return namespace.create_name(n.name, n)

        @contextmanager
        def override_node_repr(graph: Graph):
            orig_repr_fns = {}
            for node in graph.nodes:
                orig_repr_fns[node] = node._repr_fn
                node._repr_fn = node_repr
            try:
                yield None
            finally:
                # restore the original repr functions
                for node in graph.nodes:
                    node._repr_fn = orig_repr_fns[node]

        with override_node_repr(self):
            return self._python_code(root_module, namespace, verbose=verbose)

    def _python_code(self, root_module: str, namespace: _Namespace, *, verbose: bool = False) -> PythonCode:
        return self._codegen._gen_python_code(self.nodes, root_module, namespace, verbose=verbose)


    def __str__(self) -> str:
        """

        Return a human-readable (not machine-readable) string representation

        of this Graph

        """
        placeholder_names : List[str] = []
        # This is a one-element array just so ``format_node`` can modify the closed
        # over value
        maybe_return_typename : List[str] = ['']

        node_strs = [node.format_node(placeholder_names) for node in self.nodes]
        param_str = ', '.join(placeholder_names)
        s = f'graph({param_str}){maybe_return_typename[0]}:'
        for node_str in node_strs:
            if node_str:
                s += '\n    ' + node_str
        return s

    @compatibility(is_backward_compatible=True)
    def print_tabular(self):
        """

        Prints the intermediate representation of the graph in tabular

        format. Note that this API requires the ``tabulate`` module to be

        installed.

        """
        try:
            from tabulate import tabulate
        except ImportError:
            print("`print_tabular` relies on the library `tabulate`, "
                  "which could not be found on this machine. Run `pip "
                  "install tabulate` to install the library.")
            raise

        node_specs = [[n.op, n.name, n.target, n.args, n.kwargs]
                      for n in self.nodes]
        print(tabulate(node_specs,
              headers=['opcode', 'name', 'target', 'args', 'kwargs']))

    @compatibility(is_backward_compatible=True)
    def lint(self):
        """

        Runs various checks on this Graph to make sure it is well-formed. In

        particular:

        - Checks Nodes have correct ownership (owned by this graph)

        - Checks Nodes appear in topological order

        - If this Graph has an owning GraphModule, checks that targets

        exist in that GraphModule

        """

        # Check topo order
        def check_arg(arg : Node, n : Optional[Node] = None) -> None:
            context_str = f' of Node \'{n}\' ' if n else ' '
            if arg.graph is not self:
                raise RuntimeError(f'Argument \'{arg}\'{context_str}does not belong to this Graph, '
                                   f'but was used as an argument! If you are copying nodes from another graph, make '
                                   f'sure to use ``arg_transform`` on node_copy() to remap values\n{self}')
            if arg not in seen_values:
                raise RuntimeError(f'Argument \'{arg}\'{context_str}was used before it has been '
                                   f'defined! Please check that Nodes in the graph are topologically ordered\n{self}')

        seen_names : Set[str] = set()
        seen_values : Set[Node] = set()
        for node in self.nodes:
            if node.op not in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output']:
                raise RuntimeError(f'Node {node} had unknown opcode {node.op}!')
            if node.graph is not self:
                raise RuntimeError(f'Node \'{node}\' does not belong to this Graph!')
            map_arg(node.args, lambda arg: check_arg(arg, node))
            map_arg(node.kwargs, lambda arg: check_arg(arg, node))
            seen_values.add(node)

            if node.name in seen_names:
                raise RuntimeError(f'Node redefined name {node.name}!')
            seen_names.add(node.name)

        # Check targets are legit
        if self.owning_module:
            for node in self.nodes:
                if node.op == 'call_function':
                    if not callable(node.target):
                        raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but '
                                         'a Callable is expected')
                else:
                    if not isinstance(node.target, str):
                        raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but '
                                         'a str is expected')
                if node.op in ['get_attr', 'call_module']:
                    target_atoms = node.target.split('.')
                    m_itr = self.owning_module
                    for i, atom in enumerate(target_atoms):
                        new_m_itr = getattr(m_itr, atom, None)
                        seen_qualname = '.'.join(target_atoms[:i])
                        if new_m_itr is None:
                            raise RuntimeError(f'Node {node} target {node.target} references nonexistent attribute '
                                               f'{atom} of {seen_qualname}')
                        if (node.op == "call_module"
                                and not isinstance(new_m_itr, torch.nn.Module)):
                            raise RuntimeError(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
                                               'not reference an nn.Module')
                        elif (node.op == "get_attr"
                              and not isinstance(new_m_itr, torch.nn.Module)
                              and not isinstance(new_m_itr, torch.nn.Parameter)
                              and atom not in m_itr._buffers):
                            warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
                                          'not reference an nn.Module, nn.Parameter, or buffer, which is '
                                          'what \'get_attr\' Nodes typically target')
                        else:
                            m_itr = new_m_itr

    @compatibility(is_backward_compatible=True)
    def eliminate_dead_code(self):
        """

        Remove all dead code from the graph, based on each node's number of

        users, and whether the nodes have any side effects. The graph must be

        topologically sorted before calling.



        Returns:

          bool: Whether the graph was changed as a result of the pass.



        Example:



        Before dead code is eliminated, `a` from `a = x + 1` below has no users

        and thus can be eliminated from the graph without having an effect.



        .. code-block:: python



            def forward(self, x):

                a = x + 1

                return x + self.attr_1



        After dead code is eliminated, `a = x + 1` has been removed, and the rest

        of `forward` remains.



        .. code-block:: python



            def forward(self, x):

                return x + self.attr_1



        .. warning::



            Dead code elimination has some heuristics to avoid removing

            side-effectful nodes (see Node.is_impure) but in general coverage

            is very bad, so you should assume that this method is not sound

            to call unless you know that your FX graph consists entirely

            of functional operations.

        """
        # Lint the graph first to make sure its topologically sorted, otherwise
        # DCE below will not behave as expected.
        self.lint()

        # Reverse iterate so that when we remove a node, any nodes used as an
        # input to that node have an updated user count that no longer reflects
        # the removed node.
        changed = False
        for node in reversed(self.nodes):
            if not node.is_impure() and len(node.users) == 0:
                self.erase_node(node)
                changed = True

        return changed

    @compatibility(is_backward_compatible=False)
    def set_codegen(self, codegen: CodeGen):
        self._codegen = codegen

    @compatibility(is_backward_compatible=False)
    def on_generate_code(

        self,

        make_transformer: Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]

    ):
        """Register a transformer function when python code is generated



        Args:

            make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]):

                a function that returns a code transformer to be registered.

                This function is called by `on_generate_code` to obtain the

                code transformer.



                This function is also given as its input the currently

                registered code transformer (or None if nothing is registered),

                in case it is not desirable to overwrite it. This is useful to

                chain code transformers together.



        Returns:

            a context manager that when used in a `with` statement, to automatically

            restore the previously registered code transformer.



        Example:



        .. code-block:: python





            gm: fx.GraphModule = ...



            # This is a code transformer we want to register. This code

            # transformer prepends a pdb import and trace statement at the very

            # beginning of the generated torch.fx code to allow for manual

            # debugging with the PDB library.

            def insert_pdb(body):

                return ["import pdb; pdb.set_trace()\\n", *body]



            # Registers `insert_pdb`, and overwrites the current registered

            # code transformer (given by `_` to the lambda):

            gm.graph.on_generate_code(

                lambda _: insert_pdb

            )



            # Or alternatively, registers a code transformer which first

            # runs `body` through existing registered transformer, then

            # through `insert_pdb`:

            gm.graph.on_generate_code(

                lambda current_trans: (

                    lambda body: insert_pdb(

                        current_trans(body) if current_trans

                        else body

                    )

                )

            )



            gm.recompile()

            gm(*inputs)  # drops into pdb





        This function can also be used as a context manager, with the benefit to

        automatically restores the previously registered code transformer:



        .. code-block:: python



            # ... continue from previous example



            with gm.graph.on_generate_code(lambda _: insert_pdb):

                # do more stuff with `gm`...

                gm.recompile()

                gm(*inputs)  # drops into pdb



            # now previous code transformer is restored (but `gm`'s code with pdb

            # remains - that means you can run `gm` with pdb here too, until you

            # run next `recompile()`).

        """
        on_gen_code_old = self._codegen._body_transformer
        self._codegen._body_transformer = make_transformer(on_gen_code_old)

        @contextlib.contextmanager
        def on_generate_code_context_manager():
            try:
                yield
            finally:
                self._codegen._body_transformer = on_gen_code_old

        return on_generate_code_context_manager()


reflectable_magic_methods = {
    'add': '{} + {}',
    'sub': '{} - {}',
    'mul': '{} * {}',
    'floordiv': '{} // {}',
    'truediv': '{} / {}',
    'div': '{} / {}',
    'mod': '{} % {}',
    'pow': '{} ** {}',
    'lshift': '{} << {}',
    'rshift': '{} >> {}',
    'and_': '{} & {}',
    'or_': '{} | {}',
    'xor': '{} ^ {}',
    'getitem': '{}[{}]',
    'matmul': '{} @ {}',
}

magic_methods = dict({
    'eq': '{} == {}',
    'ne': '{} != {}',
    'lt': '{} < {}',
    'gt': '{} > {}',
    'le': '{} <= {}',
    'ge': '{} >= {}',
    'pos': '+{}',
    'neg': '-{}',
    'invert': '~{}'}, **reflectable_magic_methods)

inplace_methods = {
    'iadd': '{} += {}',
    'iand': '{} &= {}',
    'ifloordiv': '{} //= {}',
    'ilshift': '{} <<= {}',
    'imod': '{} %= {}',
    'imul': '{} *= {}',
    'imatmul': '{} @= {}',
    'ior': '{} |= {}',
    'ipow': '{} **= {}',
    'irshift': '{} >>= {}',
    'isub': '{} -= {}',
    'itruediv': '{} /= {}',
    'ixor': '{} ^= {}',
    'setitem': '{}[{}] = {}',
}