File size: 59,206 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
r"""Pruning methods."""
import numbers
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Tuple

import torch


class BasePruningMethod(ABC):
    r"""Abstract base class for creation of new pruning techniques.



    Provides a skeleton for customization requiring the overriding of methods

    such as :meth:`compute_mask` and :meth:`apply`.

    """

    _tensor_name: str

    def __call__(self, module, inputs):
        r"""Multiply the mask into original tensor and store the result.



        Multiplies the mask (stored in ``module[name + '_mask']``)

        into the original tensor (stored in ``module[name + '_orig']``)

        and stores the result into ``module[name]`` by using :meth:`apply_mask`.



        Args:

            module (nn.Module): module containing the tensor to prune

            inputs: not used.

        """
        setattr(module, self._tensor_name, self.apply_mask(module))

    @abstractmethod
    def compute_mask(self, t, default_mask):
        r"""Compute and returns a mask for the input tensor ``t``.



        Starting from a base ``default_mask`` (which should be a mask of ones

        if the tensor has not been pruned yet), generate a random mask to

        apply on top of the ``default_mask`` according to the specific pruning

        method recipe.



        Args:

            t (torch.Tensor): tensor representing the importance scores of the

            parameter to prune.

            default_mask (torch.Tensor): Base mask from previous pruning

            iterations, that need to be respected after the new mask is

            applied. Same dims as ``t``.



        Returns:

            mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t``

        """
        pass

    def apply_mask(self, module):
        r"""Simply handles the multiplication between the parameter being pruned and the generated mask.



        Fetches the mask and the original tensor from the module

        and returns the pruned version of the tensor.



        Args:

            module (nn.Module): module containing the tensor to prune



        Returns:

            pruned_tensor (torch.Tensor): pruned version of the input tensor

        """
        # to carry out the multiplication, the mask needs to have been computed,
        # so the pruning method must know what tensor it's operating on
        assert self._tensor_name is not None, f"Module {module} has to be pruned"  # this gets set in apply()
        mask = getattr(module, self._tensor_name + "_mask")
        orig = getattr(module, self._tensor_name + "_orig")
        pruned_tensor = mask.to(dtype=orig.dtype) * orig
        return pruned_tensor

    @classmethod
    def apply(cls, module, name, *args, importance_scores=None, **kwargs):
        r"""Add pruning on the fly and reparametrization of a tensor.



        Adds the forward pre-hook that enables pruning on the fly and

        the reparametrization of a tensor in terms of the original tensor

        and the pruning mask.



        Args:

            module (nn.Module): module containing the tensor to prune

            name (str): parameter name within ``module`` on which pruning

                will act.

            args: arguments passed on to a subclass of

                :class:`BasePruningMethod`

            importance_scores (torch.Tensor): tensor of importance scores (of

                same shape as module parameter) used to compute mask for pruning.

                The values in this tensor indicate the importance of the

                corresponding elements in the parameter being pruned.

                If unspecified or None, the parameter will be used in its place.

            kwargs: keyword arguments passed on to a subclass of a

                :class:`BasePruningMethod`

        """

        def _get_composite_method(cls, module, name, *args, **kwargs):
            # Check if a pruning method has already been applied to
            # `module[name]`. If so, store that in `old_method`.
            old_method = None
            found = 0
            # there should technically be only 1 hook with hook.name == name
            # assert this using `found`
            hooks_to_remove = []
            for k, hook in module._forward_pre_hooks.items():
                # if it exists, take existing thing, remove hook, then
                # go through normal thing
                if isinstance(hook, BasePruningMethod) and hook._tensor_name == name:
                    old_method = hook
                    hooks_to_remove.append(k)
                    found += 1
            assert (
                found <= 1
            ), f"Avoid adding multiple pruning hooks to the\

                same tensor {name} of module {module}. Use a PruningContainer."

            for k in hooks_to_remove:
                del module._forward_pre_hooks[k]

            # Apply the new pruning method, either from scratch or on top of
            # the previous one.
            method = cls(*args, **kwargs)  # new pruning
            # Have the pruning method remember what tensor it's been applied to
            method._tensor_name = name

            # combine `methods` with `old_method`, if `old_method` exists
            if old_method is not None:  # meaning that there was a hook
                # if the hook is already a pruning container, just add the
                # new pruning method to the container
                if isinstance(old_method, PruningContainer):
                    old_method.add_pruning_method(method)
                    method = old_method  # rename old_method --> method

                # if the hook is simply a single pruning method, create a
                # container, add the old pruning method and the new one
                elif isinstance(old_method, BasePruningMethod):
                    container = PruningContainer(old_method)
                    # Have the pruning method remember the name of its tensor
                    # setattr(container, '_tensor_name', name)
                    container.add_pruning_method(method)
                    method = container  # rename container --> method
            return method

        method = _get_composite_method(cls, module, name, *args, **kwargs)
        # at this point we have no forward_pre_hooks but we could have an
        # active reparametrization of the tensor if another pruning method
        # had been applied (in which case `method` would be a PruningContainer
        # and not a simple pruning method).

        # Pruning is to be applied to the module's tensor named `name`,
        # starting from the state it is found in prior to this iteration of
        # pruning. The pruning mask is calculated based on importances scores.

        orig = getattr(module, name)
        if importance_scores is not None:
            assert (
                importance_scores.shape == orig.shape
            ), f"importance_scores should have the same shape as parameter                 {name} of {module}"
        else:
            importance_scores = orig

        # If this is the first time pruning is applied, take care of moving
        # the original tensor to a new parameter called name + '_orig' and
        # and deleting the original parameter
        if not isinstance(method, PruningContainer):
            # copy `module[name]` to `module[name + '_orig']`
            module.register_parameter(name + "_orig", orig)
            # temporarily delete `module[name]`
            del module._parameters[name]
            default_mask = torch.ones_like(orig)  # temp
        # If this is not the first time pruning is applied, all of the above
        # has been done before in a previous pruning iteration, so we're good
        # to go
        else:
            default_mask = (
                getattr(module, name + "_mask")
                .detach()
                .clone(memory_format=torch.contiguous_format)
            )

        # Use try/except because if anything goes wrong with the mask
        # computation etc., you'd want to roll back.
        try:
            # get the final mask, computed according to the specific method
            mask = method.compute_mask(importance_scores, default_mask=default_mask)
            # reparameterize by saving mask to `module[name + '_mask']`...
            module.register_buffer(name + "_mask", mask)
            # ... and the new pruned tensor to `module[name]`
            setattr(module, name, method.apply_mask(module))
            # associate the pruning method to the module via a hook to
            # compute the function before every forward() (compile by run)
            module.register_forward_pre_hook(method)

        except Exception as e:
            if not isinstance(method, PruningContainer):
                orig = getattr(module, name + "_orig")
                module.register_parameter(name, orig)
                del module._parameters[name + "_orig"]
            raise e

        return method

    def prune(self, t, default_mask=None, importance_scores=None):
        r"""Compute and returns a pruned version of input tensor ``t``.



        According to the pruning rule specified in :meth:`compute_mask`.



        Args:

            t (torch.Tensor): tensor to prune (of same dimensions as

                ``default_mask``).

            importance_scores (torch.Tensor): tensor of importance scores (of

                same shape as ``t``) used to compute mask for pruning ``t``.

                The values in this tensor indicate the importance of the

                corresponding elements in the ``t`` that is being pruned.

                If unspecified or None, the tensor ``t`` will be used in its place.

            default_mask (torch.Tensor, optional): mask from previous pruning

                iteration, if any. To be considered when determining what

                portion of the tensor that pruning should act on. If None,

                default to a mask of ones.



        Returns:

            pruned version of tensor ``t``.

        """
        if importance_scores is not None:
            assert (
                importance_scores.shape == t.shape
            ), "importance_scores should have the same shape as tensor t"
        else:
            importance_scores = t
        default_mask = default_mask if default_mask is not None else torch.ones_like(t)
        return t * self.compute_mask(importance_scores, default_mask=default_mask)

    def remove(self, module):
        r"""Remove the pruning reparameterization from a module.



        The pruned parameter named ``name`` remains permanently pruned,

        and the parameter named ``name+'_orig'`` is removed from the parameter list.

        Similarly, the buffer named ``name+'_mask'`` is removed from the buffers.



        Note:

            Pruning itself is NOT undone or reversed!

        """
        # before removing pruning from a tensor, it has to have been applied
        assert (
            self._tensor_name is not None
        ), f"Module {module} has to be pruned            before pruning can be removed"  # this gets set in apply()

        # to update module[name] to latest trained weights
        weight = self.apply_mask(module)  # masked weights

        # delete and reset
        if hasattr(module, self._tensor_name):
            delattr(module, self._tensor_name)
        orig = module._parameters[self._tensor_name + "_orig"]
        orig.data = weight.data
        del module._parameters[self._tensor_name + "_orig"]
        del module._buffers[self._tensor_name + "_mask"]
        setattr(module, self._tensor_name, orig)


class PruningContainer(BasePruningMethod):
    """Container holding a sequence of pruning methods for iterative pruning.



    Keeps track of the order in which pruning methods are applied and handles

    combining successive pruning calls.



    Accepts as argument an instance of a BasePruningMethod or an iterable of

    them.

    """

    def __init__(self, *args):
        self._pruning_methods: Tuple[BasePruningMethod, ...] = tuple()
        if not isinstance(args, Iterable):  # only 1 item
            self._tensor_name = args._tensor_name
            self.add_pruning_method(args)
        elif len(args) == 1:  # only 1 item in a tuple
            self._tensor_name = args[0]._tensor_name
            self.add_pruning_method(args[0])
        else:  # manual construction from list or other iterable (or no args)
            for method in args:
                self.add_pruning_method(method)

    def add_pruning_method(self, method):
        r"""Add a child pruning ``method`` to the container.



        Args:

            method (subclass of BasePruningMethod): child pruning method

                to be added to the container.

        """
        # check that we're adding a pruning method to the container
        if not isinstance(method, BasePruningMethod) and method is not None:
            raise TypeError(
                f"{type(method)} is not a BasePruningMethod subclass"
            )
        elif method is not None and self._tensor_name != method._tensor_name:
            raise ValueError(
                "Can only add pruning methods acting on "
                f"the parameter named '{self._tensor_name}' to PruningContainer {self}."
                + f" Found '{method._tensor_name}'"
            )
        # if all checks passed, add to _pruning_methods tuple
        self._pruning_methods += (method,)  # type: ignore[operator]

    def __len__(self):
        return len(self._pruning_methods)

    def __iter__(self):
        return iter(self._pruning_methods)

    def __getitem__(self, idx):
        return self._pruning_methods[idx]

    def compute_mask(self, t, default_mask):
        r"""Apply the latest ``method`` by computing the new partial masks and returning its combination with the ``default_mask``.



        The new partial mask should be computed on the entries or channels

        that were not zeroed out by the ``default_mask``.

        Which portions of the tensor ``t`` the new mask will be calculated from

        depends on the ``PRUNING_TYPE`` (handled by the type handler):



        * for 'unstructured', the mask will be computed from the raveled

          list of nonmasked entries;



        * for 'structured', the mask will be computed from the nonmasked

          channels in the tensor;



        * for 'global', the mask will be computed across all entries.



        Args:

            t (torch.Tensor): tensor representing the parameter to prune

                (of same dimensions as ``default_mask``).

            default_mask (torch.Tensor): mask from previous pruning iteration.



        Returns:

            mask (torch.Tensor): new mask that combines the effects

            of the ``default_mask`` and the new mask from the current

            pruning ``method`` (of same dimensions as ``default_mask`` and

            ``t``).

        """

        def _combine_masks(method, t, mask):
            r"""Combine the masks from all pruning methods and returns a new mask.



            Args:

                method (a BasePruningMethod subclass): pruning method

                    currently being applied.

                t (torch.Tensor): tensor representing the parameter to prune

                    (of same dimensions as mask).

                mask (torch.Tensor): mask from previous pruning iteration



            Returns:

                new_mask (torch.Tensor): new mask that combines the effects

                    of the old mask and the new mask from the current

                    pruning method (of same dimensions as mask and t).

            """
            new_mask = mask  # start off from existing mask
            new_mask = new_mask.to(dtype=t.dtype)

            # compute a slice of t onto which the new pruning method will operate
            if method.PRUNING_TYPE == "unstructured":
                # prune entries of t where the mask is 1
                slc = mask == 1

            # for struct pruning, exclude channels that have already been
            # entirely pruned
            elif method.PRUNING_TYPE == "structured":
                if not hasattr(method, "dim"):
                    raise AttributeError(
                        "Pruning methods of PRUNING_TYPE "
                        '"structured" need to have the attribute `dim` defined.'
                    )

                # find the channels to keep by removing the ones that have been
                # zeroed out already (i.e. where sum(entries) == 0)
                n_dims = t.dim()  # "is this a 2D tensor? 3D? ..."
                dim = method.dim
                # convert negative indexing
                if dim < 0:
                    dim = n_dims + dim
                # if dim is still negative after subtracting it from n_dims
                if dim < 0:
                    raise IndexError(
                        f"Index is out of bounds for tensor with dimensions {n_dims}"
                    )
                # find channels along dim = dim that aren't already tots 0ed out
                keep_channel = mask.sum(dim=[d for d in range(n_dims) if d != dim]) != 0
                # create slice to identify what to prune
                slc = [slice(None)] * n_dims
                slc[dim] = keep_channel

            elif method.PRUNING_TYPE == "global":
                n_dims = len(t.shape)  # "is this a 2D tensor? 3D? ..."
                slc = [slice(None)] * n_dims

            else:
                raise ValueError(
                    f"Unrecognized PRUNING_TYPE {method.PRUNING_TYPE}"
                )

            # compute the new mask on the unpruned slice of the tensor t
            partial_mask = method.compute_mask(t[slc], default_mask=mask[slc])
            new_mask[slc] = partial_mask.to(dtype=new_mask.dtype)

            return new_mask

        method = self._pruning_methods[-1]
        mask = _combine_masks(method, t, default_mask)
        return mask


class Identity(BasePruningMethod):
    r"""Utility pruning method that does not prune any units but generates the pruning parametrization with a mask of ones."""

    PRUNING_TYPE = "unstructured"

    def compute_mask(self, t, default_mask):
        mask = default_mask
        return mask

    @classmethod
    def apply(cls, module, name):
        r"""Add pruning on the fly and reparametrization of a tensor.



        Adds the forward pre-hook that enables pruning on the fly and

        the reparametrization of a tensor in terms of the original tensor

        and the pruning mask.



        Args:

            module (nn.Module): module containing the tensor to prune

            name (str): parameter name within ``module`` on which pruning

                will act.

        """
        return super().apply(module, name)


class RandomUnstructured(BasePruningMethod):
    r"""Prune (currently unpruned) units in a tensor at random.



    Args:

        name (str): parameter name within ``module`` on which pruning

            will act.

        amount (int or float): quantity of parameters to prune.

            If ``float``, should be between 0.0 and 1.0 and represent the

            fraction of parameters to prune. If ``int``, it represents the

            absolute number of parameters to prune.

    """

    PRUNING_TYPE = "unstructured"

    def __init__(self, amount):
        # Check range of validity of pruning amount
        _validate_pruning_amount_init(amount)
        self.amount = amount

    def compute_mask(self, t, default_mask):
        # Check that the amount of units to prune is not > than the number of
        # parameters in t
        tensor_size = t.nelement()
        # Compute number of units to prune: amount if int,
        # else amount * tensor_size
        nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
        # This should raise an error if the number of units to prune is larger
        # than the number of units in the tensor
        _validate_pruning_amount(nparams_toprune, tensor_size)

        mask = default_mask.clone(memory_format=torch.contiguous_format)

        if nparams_toprune != 0:  # k=0 not supported by torch.kthvalue
            prob = torch.rand_like(t)
            topk = torch.topk(prob.view(-1), k=nparams_toprune)
            mask.view(-1)[topk.indices] = 0

        return mask

    @classmethod
    def apply(cls, module, name, amount):
        r"""Add pruning on the fly and reparametrization of a tensor.



        Adds the forward pre-hook that enables pruning on the fly and

        the reparametrization of a tensor in terms of the original tensor

        and the pruning mask.



        Args:

            module (nn.Module): module containing the tensor to prune

            name (str): parameter name within ``module`` on which pruning

                will act.

            amount (int or float): quantity of parameters to prune.

                If ``float``, should be between 0.0 and 1.0 and represent the

                fraction of parameters to prune. If ``int``, it represents the

                absolute number of parameters to prune.

        """
        return super().apply(module, name, amount=amount)


class L1Unstructured(BasePruningMethod):
    r"""Prune (currently unpruned) units in a tensor by zeroing out the ones with the lowest L1-norm.



    Args:

        amount (int or float): quantity of parameters to prune.

            If ``float``, should be between 0.0 and 1.0 and represent the

            fraction of parameters to prune. If ``int``, it represents the

            absolute number of parameters to prune.

    """

    PRUNING_TYPE = "unstructured"

    def __init__(self, amount):
        # Check range of validity of pruning amount
        _validate_pruning_amount_init(amount)
        self.amount = amount

    def compute_mask(self, t, default_mask):
        # Check that the amount of units to prune is not > than the number of
        # parameters in t
        tensor_size = t.nelement()
        # Compute number of units to prune: amount if int,
        # else amount * tensor_size
        nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
        # This should raise an error if the number of units to prune is larger
        # than the number of units in the tensor
        _validate_pruning_amount(nparams_toprune, tensor_size)

        mask = default_mask.clone(memory_format=torch.contiguous_format)

        if nparams_toprune != 0:  # k=0 not supported by torch.kthvalue
            # largest=True --> top k; largest=False --> bottom k
            # Prune the smallest k
            topk = torch.topk(torch.abs(t).view(-1), k=nparams_toprune, largest=False)
            # topk will have .indices and .values
            mask.view(-1)[topk.indices] = 0

        return mask

    @classmethod
    def apply(cls, module, name, amount, importance_scores=None):
        r"""Add pruning on the fly and reparametrization of a tensor.



        Adds the forward pre-hook that enables pruning on the fly and

        the reparametrization of a tensor in terms of the original tensor

        and the pruning mask.



        Args:

            module (nn.Module): module containing the tensor to prune

            name (str): parameter name within ``module`` on which pruning

                will act.

            amount (int or float): quantity of parameters to prune.

                If ``float``, should be between 0.0 and 1.0 and represent the

                fraction of parameters to prune. If ``int``, it represents the

                absolute number of parameters to prune.

            importance_scores (torch.Tensor): tensor of importance scores (of same

                shape as module parameter) used to compute mask for pruning.

                The values in this tensor indicate the importance of the corresponding

                elements in the parameter being pruned.

                If unspecified or None, the module parameter will be used in its place.

        """
        return super().apply(
            module, name, amount=amount, importance_scores=importance_scores
        )


class RandomStructured(BasePruningMethod):
    r"""Prune entire (currently unpruned) channels in a tensor at random.



    Args:

        amount (int or float): quantity of parameters to prune.

            If ``float``, should be between 0.0 and 1.0 and represent the

            fraction of parameters to prune. If ``int``, it represents the

            absolute number of parameters to prune.

        dim (int, optional): index of the dim along which we define

            channels to prune. Default: -1.

    """

    PRUNING_TYPE = "structured"

    def __init__(self, amount, dim=-1):
        # Check range of validity of amount
        _validate_pruning_amount_init(amount)
        self.amount = amount
        self.dim = dim

    def compute_mask(self, t, default_mask):
        r"""Compute and returns a mask for the input tensor ``t``.



        Starting from a base ``default_mask`` (which should be a mask of ones

        if the tensor has not been pruned yet), generate a random mask to

        apply on top of the ``default_mask`` by randomly zeroing out channels

        along the specified dim of the tensor.



        Args:

            t (torch.Tensor): tensor representing the parameter to prune

            default_mask (torch.Tensor): Base mask from previous pruning

                iterations, that need to be respected after the new mask is

                applied. Same dims as ``t``.



        Returns:

            mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t``



        Raises:

            IndexError: if ``self.dim >= len(t.shape)``

        """
        # Check that tensor has structure (i.e. more than 1 dimension) such
        # that the concept of "channels" makes sense
        _validate_structured_pruning(t)

        # Check that self.dim is a valid dim to index t, else raise IndexError
        _validate_pruning_dim(t, self.dim)

        # Check that the amount of channels to prune is not > than the number of
        # channels in t along the dim to prune
        tensor_size = t.shape[self.dim]
        # Compute number of units to prune: amount if int,
        # else amount * tensor_size
        nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
        # This should raise an error if the number of units to prune is larger
        # than the number of units in the tensor
        _validate_pruning_amount(nparams_toprune, tensor_size)

        # Compute binary mask by initializing it to all 0s and then filling in
        # 1s wherever topk.indices indicates, along self.dim.
        # mask has the same shape as tensor t
        def make_mask(t, dim, nchannels, nchannels_toprune):
            # generate a random number in [0, 1] to associate to each channel
            prob = torch.rand(nchannels)
            # generate mask for each channel by 0ing out the channels that
            # got assigned the k = nchannels_toprune lowest values in prob
            threshold = torch.kthvalue(prob, k=nchannels_toprune).values
            channel_mask = prob > threshold

            mask = torch.zeros_like(t)
            slc = [slice(None)] * len(t.shape)
            slc[dim] = channel_mask
            mask[slc] = 1
            return mask

        if nparams_toprune == 0:  # k=0 not supported by torch.kthvalue
            mask = default_mask
        else:
            # apply the new structured mask on top of prior (potentially
            # unstructured) mask
            mask = make_mask(t, self.dim, tensor_size, nparams_toprune)
            mask *= default_mask.to(dtype=mask.dtype)
        return mask

    @classmethod
    def apply(cls, module, name, amount, dim=-1):
        r"""Add pruning on the fly and reparametrization of a tensor.



        Adds the forward pre-hook that enables pruning on the fly and

        the reparametrization of a tensor in terms of the original tensor

        and the pruning mask.



        Args:

            module (nn.Module): module containing the tensor to prune

            name (str): parameter name within ``module`` on which pruning

                will act.

            amount (int or float): quantity of parameters to prune.

                If ``float``, should be between 0.0 and 1.0 and represent the

                fraction of parameters to prune. If ``int``, it represents the

                absolute number of parameters to prune.

            dim (int, optional): index of the dim along which we define

                channels to prune. Default: -1.

        """
        return super().apply(module, name, amount=amount, dim=dim)


class LnStructured(BasePruningMethod):
    r"""Prune entire (currently unpruned) channels in a tensor based on their L\ ``n``-norm.



    Args:

        amount (int or float): quantity of channels to prune.

            If ``float``, should be between 0.0 and 1.0 and represent the

            fraction of parameters to prune. If ``int``, it represents the

            absolute number of parameters to prune.

        n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid

            entries for argument ``p`` in :func:`torch.norm`.

        dim (int, optional): index of the dim along which we define

            channels to prune. Default: -1.

    """

    PRUNING_TYPE = "structured"

    def __init__(self, amount, n, dim=-1):
        # Check range of validity of amount
        _validate_pruning_amount_init(amount)
        self.amount = amount
        self.n = n
        self.dim = dim

    def compute_mask(self, t, default_mask):
        r"""Compute and returns a mask for the input tensor ``t``.



        Starting from a base ``default_mask`` (which should be a mask of ones

        if the tensor has not been pruned yet), generate a mask to apply on

        top of the ``default_mask`` by zeroing out the channels along the

        specified dim with the lowest L\ ``n``-norm.



        Args:

            t (torch.Tensor): tensor representing the parameter to prune

            default_mask (torch.Tensor): Base mask from previous pruning

                iterations, that need to be respected after the new mask is

                applied.  Same dims as ``t``.



        Returns:

            mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t``



        Raises:

            IndexError: if ``self.dim >= len(t.shape)``

        """
        # Check that tensor has structure (i.e. more than 1 dimension) such
        # that the concept of "channels" makes sense
        _validate_structured_pruning(t)
        # Check that self.dim is a valid dim to index t, else raise IndexError
        _validate_pruning_dim(t, self.dim)

        # Check that the amount of channels to prune is not > than the number of
        # channels in t along the dim to prune
        tensor_size = t.shape[self.dim]
        # Compute number of units to prune: amount if int,
        # else amount * tensor_size
        nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size)
        nparams_tokeep = tensor_size - nparams_toprune
        # This should raise an error if the number of units to prune is larger
        # than the number of units in the tensor
        _validate_pruning_amount(nparams_toprune, tensor_size)

        # Structured pruning prunes entire channels so we need to know the
        # L_n norm along each channel to then find the topk based on this
        # metric
        norm = _compute_norm(t, self.n, self.dim)
        # largest=True --> top k; largest=False --> bottom k
        # Keep the largest k channels along dim=self.dim
        topk = torch.topk(norm, k=nparams_tokeep, largest=True)
        # topk will have .indices and .values

        # Compute binary mask by initializing it to all 0s and then filling in
        # 1s wherever topk.indices indicates, along self.dim.
        # mask has the same shape as tensor t
        def make_mask(t, dim, indices):
            # init mask to 0
            mask = torch.zeros_like(t)
            # e.g.: slc = [None, None, None], if len(t.shape) = 3
            slc = [slice(None)] * len(t.shape)
            # replace a None at position=dim with indices
            # e.g.: slc = [None, None, [0, 2, 3]] if dim=2 & indices=[0,2,3]
            slc[dim] = indices
            # use slc to slice mask and replace all its entries with 1s
            # e.g.: mask[:, :, [0, 2, 3]] = 1
            mask[slc] = 1
            return mask

        if nparams_toprune == 0:  # k=0 not supported by torch.kthvalue
            mask = default_mask
        else:
            mask = make_mask(t, self.dim, topk.indices)
            mask *= default_mask.to(dtype=mask.dtype)

        return mask

    @classmethod
    def apply(cls, module, name, amount, n, dim, importance_scores=None):
        r"""Add pruning on the fly and reparametrization of a tensor.



        Adds the forward pre-hook that enables pruning on the fly and

        the reparametrization of a tensor in terms of the original tensor

        and the pruning mask.



        Args:

            module (nn.Module): module containing the tensor to prune

            name (str): parameter name within ``module`` on which pruning

                will act.

            amount (int or float): quantity of parameters to prune.

                If ``float``, should be between 0.0 and 1.0 and represent the

                fraction of parameters to prune. If ``int``, it represents the

                absolute number of parameters to prune.

            n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid

                entries for argument ``p`` in :func:`torch.norm`.

            dim (int): index of the dim along which we define channels to

                prune.

            importance_scores (torch.Tensor): tensor of importance scores (of same

                shape as module parameter) used to compute mask for pruning.

                The values in this tensor indicate the importance of the corresponding

                elements in the parameter being pruned.

                If unspecified or None, the module parameter will be used in its place.

        """
        return super().apply(
            module,
            name,
            amount=amount,
            n=n,
            dim=dim,
            importance_scores=importance_scores,
        )


class CustomFromMask(BasePruningMethod):

    PRUNING_TYPE = "global"

    def __init__(self, mask):
        self.mask = mask

    def compute_mask(self, t, default_mask):
        assert default_mask.shape == self.mask.shape
        mask = default_mask * self.mask.to(dtype=default_mask.dtype)
        return mask

    @classmethod
    def apply(cls, module, name, mask):
        r"""Add pruning on the fly and reparametrization of a tensor.



        Adds the forward pre-hook that enables pruning on the fly and

        the reparametrization of a tensor in terms of the original tensor

        and the pruning mask.



        Args:

            module (nn.Module): module containing the tensor to prune

            name (str): parameter name within ``module`` on which pruning

                will act.

        """
        return super().apply(module, name, mask=mask)


def identity(module, name):
    r"""Apply pruning reparametrization without pruning any units.



    Applies pruning reparametrization to the tensor corresponding to the

    parameter called ``name`` in ``module`` without actually pruning any

    units. Modifies module in place (and also return the modified module)

    by:



    1) adding a named buffer called ``name+'_mask'`` corresponding to the

       binary mask applied to the parameter ``name`` by the pruning method.

    2) replacing the parameter ``name`` by its pruned version, while the

       original (unpruned) parameter is stored in a new parameter named

       ``name+'_orig'``.



    Note:

        The mask is a tensor of ones.



    Args:

        module (nn.Module): module containing the tensor to prune.

        name (str): parameter name within ``module`` on which pruning

                will act.



    Returns:

        module (nn.Module): modified (i.e. pruned) version of the input module



    Examples:

        >>> # xdoctest: +SKIP

        >>> m = prune.identity(nn.Linear(2, 3), 'bias')

        >>> print(m.bias_mask)

        tensor([1., 1., 1.])

    """
    Identity.apply(module, name)
    return module


def random_unstructured(module, name, amount):
    r"""Prune tensor by removing random (currently unpruned) units.



    Prunes tensor corresponding to parameter called ``name`` in ``module``

    by removing the specified ``amount`` of (currently unpruned) units

    selected at random.

    Modifies module in place (and also return the modified module) by:



    1) adding a named buffer called ``name+'_mask'`` corresponding to the

       binary mask applied to the parameter ``name`` by the pruning method.

    2) replacing the parameter ``name`` by its pruned version, while the

       original (unpruned) parameter is stored in a new parameter named

       ``name+'_orig'``.



    Args:

        module (nn.Module): module containing the tensor to prune

        name (str): parameter name within ``module`` on which pruning

                will act.

        amount (int or float): quantity of parameters to prune.

            If ``float``, should be between 0.0 and 1.0 and represent the

            fraction of parameters to prune. If ``int``, it represents the

            absolute number of parameters to prune.



    Returns:

        module (nn.Module): modified (i.e. pruned) version of the input module



    Examples:

        >>> # xdoctest: +SKIP

        >>> m = prune.random_unstructured(nn.Linear(2, 3), 'weight', amount=1)

        >>> torch.sum(m.weight_mask == 0)

        tensor(1)



    """
    RandomUnstructured.apply(module, name, amount)
    return module


def l1_unstructured(module, name, amount, importance_scores=None):
    r"""Prune tensor by removing units with the lowest L1-norm.



    Prunes tensor corresponding to parameter called ``name`` in ``module``

    by removing the specified `amount` of (currently unpruned) units with the

    lowest L1-norm.

    Modifies module in place (and also return the modified module)

    by:



    1) adding a named buffer called ``name+'_mask'`` corresponding to the

       binary mask applied to the parameter ``name`` by the pruning method.

    2) replacing the parameter ``name`` by its pruned version, while the

       original (unpruned) parameter is stored in a new parameter named

       ``name+'_orig'``.



    Args:

        module (nn.Module): module containing the tensor to prune

        name (str): parameter name within ``module`` on which pruning

                will act.

        amount (int or float): quantity of parameters to prune.

            If ``float``, should be between 0.0 and 1.0 and represent the

            fraction of parameters to prune. If ``int``, it represents the

            absolute number of parameters to prune.

        importance_scores (torch.Tensor): tensor of importance scores (of same

            shape as module parameter) used to compute mask for pruning.

            The values in this tensor indicate the importance of the corresponding

            elements in the parameter being pruned.

            If unspecified or None, the module parameter will be used in its place.



    Returns:

        module (nn.Module): modified (i.e. pruned) version of the input module



    Examples:

        >>> # xdoctest: +SKIP

        >>> m = prune.l1_unstructured(nn.Linear(2, 3), 'weight', amount=0.2)

        >>> m.state_dict().keys()

        odict_keys(['bias', 'weight_orig', 'weight_mask'])

    """
    L1Unstructured.apply(
        module, name, amount=amount, importance_scores=importance_scores
    )
    return module


def random_structured(module, name, amount, dim):
    r"""Prune tensor by removing random channels along the specified dimension.



    Prunes tensor corresponding to parameter called ``name`` in ``module``

    by removing the specified ``amount`` of (currently unpruned) channels

    along the specified ``dim`` selected at random.

    Modifies module in place (and also return the modified module)

    by:



    1) adding a named buffer called ``name+'_mask'`` corresponding to the

       binary mask applied to the parameter ``name`` by the pruning method.

    2) replacing the parameter ``name`` by its pruned version, while the

       original (unpruned) parameter is stored in a new parameter named

       ``name+'_orig'``.



    Args:

        module (nn.Module): module containing the tensor to prune

        name (str): parameter name within ``module`` on which pruning

                will act.

        amount (int or float): quantity of parameters to prune.

            If ``float``, should be between 0.0 and 1.0 and represent the

            fraction of parameters to prune. If ``int``, it represents the

            absolute number of parameters to prune.

        dim (int): index of the dim along which we define channels to prune.



    Returns:

        module (nn.Module): modified (i.e. pruned) version of the input module



    Examples:

        >>> # xdoctest: +SKIP

        >>> m = prune.random_structured(

        ...     nn.Linear(5, 3), 'weight', amount=3, dim=1

        ... )

        >>> columns_pruned = int(sum(torch.sum(m.weight, dim=0) == 0))

        >>> print(columns_pruned)

        3

    """
    RandomStructured.apply(module, name, amount, dim)
    return module


def ln_structured(module, name, amount, n, dim, importance_scores=None):
    r"""Prune tensor by removing channels with the lowest L\ ``n``-norm along the specified dimension.



    Prunes tensor corresponding to parameter called ``name`` in ``module``

    by removing the specified ``amount`` of (currently unpruned) channels

    along the specified ``dim`` with the lowest L\ ``n``-norm.

    Modifies module in place (and also return the modified module)

    by:



    1) adding a named buffer called ``name+'_mask'`` corresponding to the

       binary mask applied to the parameter ``name`` by the pruning method.

    2) replacing the parameter ``name`` by its pruned version, while the

       original (unpruned) parameter is stored in a new parameter named

       ``name+'_orig'``.



    Args:

        module (nn.Module): module containing the tensor to prune

        name (str): parameter name within ``module`` on which pruning

                will act.

        amount (int or float): quantity of parameters to prune.

            If ``float``, should be between 0.0 and 1.0 and represent the

            fraction of parameters to prune. If ``int``, it represents the

            absolute number of parameters to prune.

        n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid

            entries for argument ``p`` in :func:`torch.norm`.

        dim (int): index of the dim along which we define channels to prune.

        importance_scores (torch.Tensor): tensor of importance scores (of same

            shape as module parameter) used to compute mask for pruning.

            The values in this tensor indicate the importance of the corresponding

            elements in the parameter being pruned.

            If unspecified or None, the module parameter will be used in its place.



    Returns:

        module (nn.Module): modified (i.e. pruned) version of the input module



    Examples:

        >>> from torch.nn.utils import prune

        >>> m = prune.ln_structured(

        ...     nn.Conv2d(5, 3, 2), 'weight', amount=0.3, dim=1, n=float('-inf')

        ... )

    """
    LnStructured.apply(
        module, name, amount, n, dim, importance_scores=importance_scores
    )
    return module


def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs):
    r"""

    Globally prunes tensors corresponding to all parameters in ``parameters`` by applying the specified ``pruning_method``.



    Modifies modules in place by:



    1) adding a named buffer called ``name+'_mask'`` corresponding to the

       binary mask applied to the parameter ``name`` by the pruning method.

    2) replacing the parameter ``name`` by its pruned version, while the

       original (unpruned) parameter is stored in a new parameter named

       ``name+'_orig'``.



    Args:

        parameters (Iterable of (module, name) tuples): parameters of

            the model to prune in a global fashion, i.e. by aggregating all

            weights prior to deciding which ones to prune. module must be of

            type :class:`nn.Module`, and name must be a string.

        pruning_method (function): a valid pruning function from this module,

            or a custom one implemented by the user that satisfies the

            implementation guidelines and has ``PRUNING_TYPE='unstructured'``.

        importance_scores (dict): a dictionary mapping (module, name) tuples to

            the corresponding parameter's importance scores tensor. The tensor

            should be the same shape as the parameter, and is used for computing

            mask for pruning.

            If unspecified or None, the parameter will be used in place of its

            importance scores.

        kwargs: other keyword arguments such as:

            amount (int or float): quantity of parameters to prune across the

            specified parameters.

            If ``float``, should be between 0.0 and 1.0 and represent the

            fraction of parameters to prune. If ``int``, it represents the

            absolute number of parameters to prune.



    Raises:

        TypeError: if ``PRUNING_TYPE != 'unstructured'``



    Note:

        Since global structured pruning doesn't make much sense unless the

        norm is normalized by the size of the parameter, we now limit the

        scope of global pruning to unstructured methods.



    Examples:

        >>> from torch.nn.utils import prune

        >>> from collections import OrderedDict

        >>> net = nn.Sequential(OrderedDict([

        ...     ('first', nn.Linear(10, 4)),

        ...     ('second', nn.Linear(4, 1)),

        ... ]))

        >>> parameters_to_prune = (

        ...     (net.first, 'weight'),

        ...     (net.second, 'weight'),

        ... )

        >>> prune.global_unstructured(

        ...     parameters_to_prune,

        ...     pruning_method=prune.L1Unstructured,

        ...     amount=10,

        ... )

        >>> print(sum(torch.nn.utils.parameters_to_vector(net.buffers()) == 0))

        tensor(10)



    """
    # ensure parameters is a list or generator of tuples
    if not isinstance(parameters, Iterable):
        raise TypeError("global_unstructured(): parameters is not an Iterable")

    importance_scores = importance_scores if importance_scores is not None else {}
    if not isinstance(importance_scores, dict):
        raise TypeError("global_unstructured(): importance_scores must be of type dict")

    # flatten importance scores to consider them all at once in global pruning
    relevant_importance_scores = torch.nn.utils.parameters_to_vector(
        [
            importance_scores.get((module, name), getattr(module, name))
            for (module, name) in parameters
        ]
    )
    # similarly, flatten the masks (if they exist), or use a flattened vector
    # of 1s of the same dimensions as t
    default_mask = torch.nn.utils.parameters_to_vector(
        [
            getattr(module, name + "_mask", torch.ones_like(getattr(module, name)))
            for (module, name) in parameters
        ]
    )

    # use the canonical pruning methods to compute the new mask, even if the
    # parameter is now a flattened out version of `parameters`
    container = PruningContainer()
    container._tensor_name = "temp"  # to make it match that of `method`
    method = pruning_method(**kwargs)
    method._tensor_name = "temp"  # to make it match that of `container`
    if method.PRUNING_TYPE != "unstructured":
        raise TypeError(
            'Only "unstructured" PRUNING_TYPE supported for '
            f"the `pruning_method`. Found method {pruning_method} of type {method.PRUNING_TYPE}"
        )

    container.add_pruning_method(method)

    # use the `compute_mask` method from `PruningContainer` to combine the
    # mask computed by the new method with the pre-existing mask
    final_mask = container.compute_mask(relevant_importance_scores, default_mask)

    # Pointer for slicing the mask to match the shape of each parameter
    pointer = 0
    for module, name in parameters:

        param = getattr(module, name)
        # The length of the parameter
        num_param = param.numel()
        # Slice the mask, reshape it
        param_mask = final_mask[pointer : pointer + num_param].view_as(param)
        # Assign the correct pre-computed mask to each parameter and add it
        # to the forward_pre_hooks like any other pruning method
        custom_from_mask(module, name, mask=param_mask)

        # Increment the pointer to continue slicing the final_mask
        pointer += num_param


def custom_from_mask(module, name, mask):
    r"""Prune tensor corresponding to parameter called ``name`` in ``module`` by applying the pre-computed mask in ``mask``.



    Modifies module in place (and also return the modified module) by:



    1) adding a named buffer called ``name+'_mask'`` corresponding to the

       binary mask applied to the parameter ``name`` by the pruning method.

    2) replacing the parameter ``name`` by its pruned version, while the

       original (unpruned) parameter is stored in a new parameter named

       ``name+'_orig'``.



    Args:

        module (nn.Module): module containing the tensor to prune

        name (str): parameter name within ``module`` on which pruning

            will act.

        mask (Tensor): binary mask to be applied to the parameter.



    Returns:

        module (nn.Module): modified (i.e. pruned) version of the input module



    Examples:

        >>> from torch.nn.utils import prune

        >>> m = prune.custom_from_mask(

        ...     nn.Linear(5, 3), name='bias', mask=torch.tensor([0, 1, 0])

        ... )

        >>> print(m.bias_mask)

        tensor([0., 1., 0.])



    """
    CustomFromMask.apply(module, name, mask)
    return module


def remove(module, name):
    r"""Remove the pruning reparameterization from a module and the pruning method from the forward hook.



    The pruned parameter named ``name`` remains permanently pruned, and the parameter

    named ``name+'_orig'`` is removed from the parameter list. Similarly,

    the buffer named ``name+'_mask'`` is removed from the buffers.



    Note:

        Pruning itself is NOT undone or reversed!



    Args:

        module (nn.Module): module containing the tensor to prune

        name (str): parameter name within ``module`` on which pruning

            will act.



    Examples:

        >>> m = random_unstructured(nn.Linear(5, 7), name='weight', amount=0.2)

        >>> m = remove(m, name='weight')

    """
    for k, hook in module._forward_pre_hooks.items():
        if isinstance(hook, BasePruningMethod) and hook._tensor_name == name:
            hook.remove(module)
            del module._forward_pre_hooks[k]
            return module

    raise ValueError(
        f"Parameter '{name}' of module {module} has to be pruned before pruning can be removed"
    )


def is_pruned(module):
    r"""Check if a module is pruned by looking for pruning pre-hooks.



    Check whether ``module`` is pruned by looking for

    ``forward_pre_hooks`` in its modules that inherit from the

    :class:`BasePruningMethod`.



    Args:

        module (nn.Module): object that is either pruned or unpruned



    Returns:

        binary answer to whether ``module`` is pruned.



    Examples:

        >>> from torch.nn.utils import prune

        >>> m = nn.Linear(5, 7)

        >>> print(prune.is_pruned(m))

        False

        >>> prune.random_unstructured(m, name='weight', amount=0.2)

        >>> print(prune.is_pruned(m))

        True

    """
    for _, submodule in module.named_modules():
        for hook in submodule._forward_pre_hooks.values():
            if isinstance(hook, BasePruningMethod):
                return True
    return False


def _validate_pruning_amount_init(amount):
    r"""Validate helper to check the range of amount at init.



    Args:

        amount (int or float): quantity of parameters to prune.

            If float, should be between 0.0 and 1.0 and represent the

            fraction of parameters to prune. If int, it represents the

            absolute number of parameters to prune.



    Raises:

        ValueError: if amount is a float not in [0, 1], or if it's a negative

            integer.

        TypeError: if amount is neither a float nor an integer.



    Note:

        This does not take into account the number of parameters in the

        tensor to be pruned, which is known only at prune.

    """
    if not isinstance(amount, numbers.Real):
        raise TypeError(
            f"Invalid type for amount: {amount}. Must be int or float."
        )

    if (isinstance(amount, numbers.Integral) and amount < 0) or (
        not isinstance(amount, numbers.Integral)  # so it's a float
        and (float(amount) > 1.0 or float(amount) < 0.0)
    ):
        raise ValueError(
            f"amount={amount} should either be a float in the range [0, 1] or a non-negative integer"
        )


def _validate_pruning_amount(amount, tensor_size):
    r"""Validate that the pruning amount is meaningful wrt to the size of the data.



    Validation helper to check that the amount of parameters to prune

    is meaningful wrt to the size of the data (`tensor_size`).



    Args:

        amount (int or float): quantity of parameters to prune.

            If float, should be between 0.0 and 1.0 and represent the

            fraction of parameters to prune. If int, it represents the

            absolute number of parameters to prune.

        tensor_size (int): absolute number of parameters in the tensor

            to prune.

    """
    # TODO: consider removing this check and allowing users to specify
    # a number of units to prune that is greater than the number of units
    # left to prune. In this case, the tensor will just be fully pruned.

    if isinstance(amount, numbers.Integral) and amount > tensor_size:
        raise ValueError(
            f"amount={amount} should be smaller than the number of parameters to prune={tensor_size}"
        )


def _validate_structured_pruning(t):
    r"""Validate that the tensor to be pruned is at least 2-Dimensional.



    Validation helper to check that the tensor to be pruned is multi-

    dimensional, such that the concept of "channels" is well-defined.



    Args:

        t (torch.Tensor): tensor representing the parameter to prune



    Raises:

        ValueError: if the tensor `t` is not at least 2D.

    """
    shape = t.shape
    if len(shape) <= 1:
        raise ValueError(
            "Structured pruning can only be applied to "
            "multidimensional tensors. Found tensor of shape "
            f"{shape} with {len(shape)} dims"
        )


def _compute_nparams_toprune(amount, tensor_size):
    r"""Convert the pruning amount from a percentage to absolute value.



    Since amount can be expressed either in absolute value or as a

    percentage of the number of units/channels in a tensor, this utility

    function converts the percentage to absolute value to standardize

    the handling of pruning.



    Args:

        amount (int or float): quantity of parameters to prune.

            If float, should be between 0.0 and 1.0 and represent the

            fraction of parameters to prune. If int, it represents the

            absolute number of parameters to prune.

        tensor_size (int): absolute number of parameters in the tensor

            to prune.



    Returns:

        int: the number of units to prune in the tensor

    """
    # incorrect type already checked in _validate_pruning_amount_init
    if isinstance(amount, numbers.Integral):
        return amount
    else:
        return round(amount * tensor_size)


def _validate_pruning_dim(t, dim):
    r"""Validate that the pruning dimension is within the bounds of the tensor dimension.



    Args:

        t (torch.Tensor): tensor representing the parameter to prune

        dim (int): index of the dim along which we define channels to prune

    """
    if dim >= t.dim():
        raise IndexError(f"Invalid index {dim} for tensor of size {t.shape}")


def _compute_norm(t, n, dim):
    r"""Compute the L_n-norm of a tensor along all dimensions except for the specified dimension.



    The L_n-norm will be computed across all entries in tensor `t` along all dimension

    except for the one identified by dim.

    Example: if `t` is of shape, say, 3x2x4 and dim=2 (the last dim),

    then norm will have Size [4], and each entry will represent the

    `L_n`-norm computed using the 3x2=6 entries for each of the 4 channels.



    Args:

        t (torch.Tensor): tensor representing the parameter to prune

        n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid

            entries for argument p in torch.norm

        dim (int): dim identifying the channels to prune



    Returns:

        norm (torch.Tensor): L_n norm computed across all dimensions except

            for `dim`. By construction, `norm.shape = t.shape[-1]`.

    """
    # dims = all axes, except for the one identified by `dim`
    dims = list(range(t.dim()))
    # convert negative indexing
    if dim < 0:
        dim = dims[dim]
    dims.remove(dim)

    norm = torch.norm(t, p=n, dim=dims)
    return norm