File size: 71,876 Bytes
b04de69
 
 
 
 
 
 
 
 
 
 
 
 
9e6844a
b04de69
 
 
 
 
 
 
 
 
9e6844a
 
 
b04de69
 
9e6844a
b04de69
9e6844a
b04de69
9e6844a
 
 
 
 
 
 
 
 
 
 
b04de69
 
9e6844a
b04de69
9e6844a
b04de69
9e6844a
b04de69
9e6844a
 
 
b04de69
 
9e6844a
b04de69
9e6844a
b04de69
9e6844a
b04de69
9e6844a
 
 
b04de69
 
9e6844a
b04de69
9e6844a
b04de69
9e6844a
b04de69
9e6844a
 
 
b04de69
9e6844a
b04de69
9e6844a
 
 
b04de69
 
9e6844a
b04de69
9e6844a
b04de69
9e6844a
b04de69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e6844a
b04de69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e6844a
 
 
 
b04de69
 
 
9e6844a
 
 
 
 
b04de69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e6844a
b04de69
 
 
9e6844a
 
 
b04de69
 
 
 
 
 
 
 
 
 
 
 
 
9e6844a
b04de69
 
 
9e6844a
 
 
 
 
b04de69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e6844a
b04de69
 
 
 
 
9e6844a
 
 
 
b04de69
 
 
9e6844a
 
 
 
 
b04de69
 
 
 
 
 
9e6844a
b04de69
 
 
9e6844a
 
 
b04de69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e6844a
b04de69
 
 
9e6844a
b04de69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e6844a
 
 
b04de69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e6844a
b04de69
 
 
9e6844a
b04de69
 
 
9e6844a
b04de69
 
 
 
 
 
 
 
 
9e6844a
b04de69
9e6844a
 
 
 
 
 
 
 
b04de69
 
 
 
 
9e6844a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b04de69
 
 
 
 
 
 
 
 
 
 
 
 
 
9e6844a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b04de69
 
 
 
 
 
 
 
 
 
9e6844a
 
b04de69
 
 
 
 
 
 
 
 
 
 
 
9e6844a
b04de69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e6844a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X4cRE8IbIrIV"
      },
      "source": [
        "Downloading PyTorch Vision Reference Scripts for Image Classification. These scripts are official reference implementations from PyTorch Vision that provide training and quantization utilities for image classification models."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "46CgrVgjg3E-",
        "outputId": "7fb20ebe-d7fd-43fa-dc9b-ebbedf31575e"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "--2025-05-22 16:30:12--  https://raw.githubusercontent.com/pytorch/vision/main/references/classification/presets.py\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.110.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 3885 (3.8K) [text/plain]\n",
            "Saving to: ‘presets.py’\n",
            "\n",
            "presets.py          100%[===================>]   3.79K  --.-KB/s    in 0s      \n",
            "\n",
            "2025-05-22 16:30:12 (12.8 MB/s) - ‘presets.py’ saved [3885/3885]\n",
            "\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "--2025-05-22 16:30:12--  https://raw.githubusercontent.com/pytorch/vision/main/references/classification/sampler.py\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.108.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 2395 (2.3K) [text/plain]\n",
            "Saving to: ‘sampler.py’\n",
            "\n",
            "sampler.py          100%[===================>]   2.34K  --.-KB/s    in 0s      \n",
            "\n",
            "2025-05-22 16:30:12 (18.4 MB/s) - ‘sampler.py’ saved [2395/2395]\n",
            "\n",
            "--2025-05-22 16:30:12--  https://raw.githubusercontent.com/pytorch/vision/main/references/classification/train.py\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.111.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 23324 (23K) [text/plain]\n",
            "Saving to: ‘train.py’\n",
            "\n",
            "train.py            100%[===================>]  22.78K  --.-KB/s    in 0.01s   \n",
            "\n",
            "2025-05-22 16:30:13 (2.28 MB/s) - ‘train.py’ saved [23324/23324]\n",
            "\n",
            "--2025-05-22 16:30:13--  https://raw.githubusercontent.com/pytorch/vision/main/references/classification/train_quantization.py\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.108.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 11647 (11K) [text/plain]\n",
            "Saving to: ‘train_quantization.py’\n",
            "\n",
            "train_quantization. 100%[===================>]  11.37K  --.-KB/s    in 0.001s  \n",
            "\n",
            "2025-05-22 16:30:13 (12.7 MB/s) - ‘train_quantization.py’ saved [11647/11647]\n",
            "\n",
            "--2025-05-22 16:30:13--  https://raw.githubusercontent.com/pytorch/vision/main/references/classification/transformers.py\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.111.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 404 Not Found\n",
            "2025-05-22 16:30:13 ERROR 404: Not Found.\n",
            "\n",
            "--2025-05-22 16:30:13--  https://raw.githubusercontent.com/pytorch/vision/main/references/classification/utils.py\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.109.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 15791 (15K) [text/plain]\n",
            "Saving to: ‘utils.py’\n",
            "\n",
            "utils.py            100%[===================>]  15.42K  --.-KB/s    in 0.01s   \n",
            "\n",
            "2025-05-22 16:30:13 (1.43 MB/s) - ‘utils.py’ saved [15791/15791]\n",
            "\n"
          ]
        }
      ],
      "source": [
        "! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/presets.py\n",
        "! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/sampler.py\n",
        "! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/train.py\n",
        "! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/train_quantization.py\n",
        "! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/transformers.py\n",
        "! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/utils.py"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HFASsisvIrIb"
      },
      "source": [
        "In this block, we build a “loss” function for our sequential policy gradient algorithm. When the right data is plugged in, the gradient of this loss is equal to the policy gradient."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "EaBokYCpg3FA"
      },
      "outputs": [],
      "source": [
        "import types\n",
        "from typing import List, Callable\n",
        "\n",
        "import torch\n",
        "from torch import nn, Tensor\n",
        "from torch.nn import functional as F\n",
        "from torchvision.models.resnet import BasicBlock\n",
        "\n",
        "\n",
        "def trp_criterion(trp_blocks: nn.ModuleList, shared_head: Callable, criterion: Callable, lambdas: List[float], hidden_state: Tensor, logits: Tensor, targets: Tensor, loss_normalization=False):\n",
        "    losses, rewards = criterion(logits, targets)\n",
        "    returns = torch.ones_like(rewards, dtype=torch.float32, device=rewards.device)\n",
        "    if loss_normalization:\n",
        "        coeff = torch.mean(losses).detach()\n",
        "\n",
        "    embeds = [hidden_state]\n",
        "    predictions = []\n",
        "    for k, w in enumerate(lambdas):\n",
        "        embeds.append(trp_blocks[k](embeds[-1]))\n",
        "        predictions.append(shared_head(embeds[-1]))\n",
        "        returns = returns + w * rewards\n",
        "        replica_losses, rewards = criterion(predictions[-1], targets, rewards)\n",
        "        losses = losses + replica_losses\n",
        "    loss = torch.mean(losses * returns)\n",
        "\n",
        "    if loss_normalization:\n",
        "        with torch.no_grad():\n",
        "            coeff = torch.exp(coeff) / torch.exp(loss.detach())\n",
        "        loss = coeff * loss\n",
        "\n",
        "    return loss"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_Ig0Jm2w8DPH"
      },
      "source": [
        "In this block, we build a TPBlock for the Task Replica Prediction (TRP) module; This implementation provides the backbone without the shared prediction head."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "wkBlmJT96jZj"
      },
      "outputs": [],
      "source": [
        "class TPBlock(nn.Module):\n",
        "    def __init__(self, depths: int, inplanes: int, planes: int):\n",
        "        super(TPBlock, self).__init__()\n",
        "\n",
        "        blocks = [BasicBlock(inplanes=inplanes, planes=planes) for _ in range(depths)]\n",
        "        self.blocks = nn.Sequential(*blocks)\n",
        "        for name, param in self.blocks.named_parameters():\n",
        "            if 'conv' in name:\n",
        "                nn.init.zeros_(param)  # Initialize weights\n",
        "            elif 'downsample' in name:\n",
        "                nn.init.zeros_(param)   # Initialize biases\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.blocks(x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UGxQdKZaF2NT"
      },
      "source": [
        "This implementation enables ResNet retraining in SPG mode.\n",
        "\n",
        "Components:\n",
        "-------------------------------------------------------------------------------\n",
        "1. gen_criterion()\n",
        "    - Purpose: compute per-sample losses and positional masks\n",
        "\n",
        "2. gen_shared_head()\n",
        "    - Purpose: Implements a shared prediction head that processes convolutional feature maps for prediction.\n",
        "\n",
        "3. gen_forward()\n",
        "    - Purpose: Extended forward pass supporting both traditional inference and SPG retraining."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "kTZWkoLr8cfE"
      },
      "outputs": [],
      "source": [
        "class ResNetConfig:\n",
        "    @staticmethod\n",
        "    def gen_criterion(label_smoothing=0.0, top_k=1):\n",
        "        def func(input, target, mask=None):\n",
        "            \"\"\"\n",
        "            Args:\n",
        "                input (Tensor): Input tensor of shape [B, C].\n",
        "                target (Tensor): Target labels of shape [B] or [B, C].\n",
        "\n",
        "            Returns:\n",
        "                loss (Tensor): Scalar tensor representing the loss.\n",
        "                mask (Tensor): Boolean mask tensor of shape [B].\n",
        "            \"\"\"\n",
        "            label = torch.argmax(target, dim=1) if label_smoothing > 0.0 else target\n",
        "\n",
        "            unmasked_loss = F.cross_entropy(input, label, reduction=\"none\", label_smoothing=label_smoothing)\n",
        "            if mask is None:\n",
        "                mask = torch.ones_like(unmasked_loss, dtype=torch.float32, device=target.device)\n",
        "            losses = mask * unmasked_loss\n",
        "\n",
        "            with torch.no_grad():\n",
        "                topk_values, topk_indices = torch.topk(input, top_k, dim=-1)\n",
        "                mask = mask * torch.eq(topk_indices, label[:, None]).any(dim=-1).to(input.dtype)\n",
        "\n",
        "            return losses, mask\n",
        "        return func\n",
        "\n",
        "    @staticmethod\n",
        "    def gen_shared_head(self):\n",
        "        def func(x):\n",
        "            \"\"\"\n",
        "            Args:\n",
        "                x (Tensor): Hidden State tensor of shape [B, C, H, W].\n",
        "\n",
        "            Returns:\n",
        "                logits (Tensor): Logits tensor of shape [B, C].\n",
        "            \"\"\"\n",
        "            x = self.layer4(x)\n",
        "            x = self.avgpool(x)\n",
        "            x = torch.flatten(x, 1)\n",
        "            logits = self.fc(x)\n",
        "            return logits\n",
        "        return func\n",
        "\n",
        "    @staticmethod\n",
        "    def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):\n",
        "        def func(self, x: Tensor, targets=None) -> Tensor:\n",
        "            x = self.conv1(x)\n",
        "            x = self.bn1(x)\n",
        "            x = self.relu(x)\n",
        "            x = self.maxpool(x)\n",
        "\n",
        "            x = self.layer1(x)\n",
        "            x = self.layer2(x)\n",
        "            hidden_state = self.layer3(x)\n",
        "            x = self.layer4(hidden_state)\n",
        "            x = self.avgpool(x)\n",
        "            x = torch.flatten(x, 1)\n",
        "            logits = self.fc(x)\n",
        "\n",
        "            if self.training:\n",
        "                shared_head = ResNetConfig.gen_shared_head(self)\n",
        "                criterion = ResNetConfig.gen_criterion(label_smoothing=label_smoothing, top_k=top_k)\n",
        "\n",
        "                loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, hidden_state, logits, targets, loss_normalization=loss_normalization)\n",
        "\n",
        "                return logits, loss\n",
        "\n",
        "            return logits\n",
        "\n",
        "        return func"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cCn6vwItH1CW"
      },
      "source": [
        "Applies TRP modules to the base ResNet (main backbone). The k-th TRP module corresponding to a deeper ResNet variant with an additional depth of 3 * sum(depths[:k+1])."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "wXQF0oISH5Yp"
      },
      "outputs": [],
      "source": [
        "def apply_trp(model, depths: List[int], planes: int, lambdas: List[float], **kwargs):\n",
        "    print(\"✅ Applying TRP to ResNet for Image Classification...\")\n",
        "    model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, inplanes=planes, planes=planes) for d in depths])\n",
        "    model.forward = types.MethodType(ResNetConfig.gen_forward(lambdas), model)\n",
        "    return model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kDjSAv3PJr7P"
      },
      "source": [
        "The following is a training script for classification models, primarily based on the official TorchVision `train.py` reference implementation. We have made two modifications:\n",
        "\n",
        "Adding TRP Modules: We integrate TRP modules into the base model architecture before training begins:\n",
        "\n",
        "```python\n",
        "if args.apply_trp:\n",
        "    model = apply_trp(model, args.trp_depths,  args.trp_planes, args.trp_lambdas)\n",
        "```\n",
        "Removing TRP Modules: We remove the TRP components from the base model before saving the base model:\n",
        "```python\n",
        "if args.output_dir:\n",
        "    checkpoint = {\n",
        "        \"model\": model.state_dict() if not args.apply_trp else {k: v for k, v in model.state_dict().items() if not k.startswith(\"trp_blocks\")},\n",
        "        \"optimizer\": optimizer.state_dict(),\n",
        "        \"lr_scheduler\": lr_scheduler.state_dict(),\n",
        "        \"epoch\": epoch,\n",
        "        \"args\": args,\n",
        "    }\n",
        "    utils.save_on_master(checkpoint, os.path.join(args.output_dir, f\"model_{epoch}.pth\"))\n",
        "    utils.save_on_master(checkpoint, os.path.join(args.output_dir, \"checkpoint.pth\"))\n",
        "```"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "hK4Y7Sqv4xUa"
      },
      "outputs": [],
      "source": [
        "import datetime\n",
        "import os\n",
        "import time\n",
        "import warnings\n",
        "\n",
        "import presets\n",
        "import torch\n",
        "import torch.utils.data\n",
        "import torchvision\n",
        "import utils\n",
        "from torch import nn\n",
        "from torchvision.transforms.functional import InterpolationMode\n",
        "\n",
        "\n",
        "def load_data(traindir, valdir):\n",
        "    # Data loading code\n",
        "    print(\"Loading data\")\n",
        "    interpolation = InterpolationMode(\"bilinear\")\n",
        "\n",
        "    print(\"Loading training data\")\n",
        "    st = time.time()\n",
        "    dataset = torchvision.datasets.ImageFolder(\n",
        "        traindir,\n",
        "        presets.ClassificationPresetTrain(crop_size=224, interpolation=interpolation, auto_augment_policy=None, random_erase_prob=0.0, ra_magnitude=9, augmix_severity=3),\n",
        "    )\n",
        "    print(\"Took\", time.time() - st)\n",
        "\n",
        "    print(\"Loading validation data\")\n",
        "    dataset_test = torchvision.datasets.ImageFolder(\n",
        "        valdir,\n",
        "        presets.ClassificationPresetEval(crop_size=224, resize_size=256, interpolation=interpolation)\n",
        "    )\n",
        "\n",
        "    print(\"Creating data loaders\")\n",
        "    train_sampler = torch.utils.data.RandomSampler(dataset)\n",
        "    test_sampler = torch.utils.data.SequentialSampler(dataset_test)\n",
        "\n",
        "    return dataset, dataset_test, train_sampler, test_sampler\n",
        "\n",
        "\n",
        "\n",
        "def train_one_epoch(model, optimizer, data_loader, device, epoch, args):\n",
        "    model.train()\n",
        "    metric_logger = utils.MetricLogger(delimiter=\"  \")\n",
        "    metric_logger.add_meter(\"lr\", utils.SmoothedValue(window_size=1, fmt=\"{value}\"))\n",
        "    metric_logger.add_meter(\"img/s\", utils.SmoothedValue(window_size=10, fmt=\"{value}\"))\n",
        "\n",
        "    header = f\"Epoch: [{epoch}]\"\n",
        "    for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):\n",
        "        start_time = time.time()\n",
        "        image, target = image.to(device), target.to(device)\n",
        "        with torch.amp.autocast(\"cuda\", enabled=False):\n",
        "            output, loss = model(image, target)\n",
        "\n",
        "        optimizer.zero_grad()\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))\n",
        "        batch_size = image.shape[0]\n",
        "        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0][\"lr\"])\n",
        "        metric_logger.meters[\"acc1\"].update(acc1.item(), n=batch_size)\n",
        "        metric_logger.meters[\"acc5\"].update(acc5.item(), n=batch_size)\n",
        "        metric_logger.meters[\"img/s\"].update(batch_size / (time.time() - start_time))\n",
        "\n",
        "\n",
        "def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=\"\"):\n",
        "    model.eval()\n",
        "    metric_logger = utils.MetricLogger(delimiter=\"  \")\n",
        "    header = f\"Test: {log_suffix}\"\n",
        "\n",
        "    num_processed_samples = 0\n",
        "    with torch.inference_mode():\n",
        "        for image, target in metric_logger.log_every(data_loader, print_freq, header):\n",
        "            image = image.to(device, non_blocking=True)\n",
        "            target = target.to(device, non_blocking=True)\n",
        "            output = model(image)\n",
        "            loss = criterion(output, target)\n",
        "\n",
        "            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))\n",
        "            # FIXME need to take into account that the datasets\n",
        "            # could have been padded in distributed setup\n",
        "            batch_size = image.shape[0]\n",
        "            metric_logger.update(loss=loss.item())\n",
        "            metric_logger.meters[\"acc1\"].update(acc1.item(), n=batch_size)\n",
        "            metric_logger.meters[\"acc5\"].update(acc5.item(), n=batch_size)\n",
        "            num_processed_samples += batch_size\n",
        "    # gather the stats from all processes\n",
        "\n",
        "    num_processed_samples = utils.reduce_across_processes(num_processed_samples)\n",
        "    if (\n",
        "        hasattr(data_loader.dataset, \"__len__\")\n",
        "        and len(data_loader.dataset) != num_processed_samples\n",
        "        and torch.distributed.get_rank() == 0\n",
        "    ):\n",
        "        # See FIXME above\n",
        "        warnings.warn(\n",
        "            f\"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} \"\n",
        "            \"samples were used for the validation, which might bias the results. \"\n",
        "            \"Try adjusting the batch size and / or the world size. \"\n",
        "            \"Setting the world size to 1 is always a safe bet.\"\n",
        "        )\n",
        "\n",
        "    metric_logger.synchronize_between_processes()\n",
        "\n",
        "    print(f\"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}\")\n",
        "    return metric_logger.acc1.global_avg\n",
        "\n",
        "\n",
        "def main(args):\n",
        "    if args.output_dir:\n",
        "        utils.mkdir(args.output_dir)\n",
        "    print(args)\n",
        "\n",
        "    device = torch.device(args.device)\n",
        "\n",
        "    if args.use_deterministic_algorithms:\n",
        "        torch.backends.cudnn.benchmark = False\n",
        "        torch.use_deterministic_algorithms(True)\n",
        "    else:\n",
        "        torch.backends.cudnn.benchmark = True\n",
        "\n",
        "    train_dir = os.path.join(args.data_path, \"train\")\n",
        "    val_dir = os.path.join(args.data_path, \"val\")\n",
        "    dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir)\n",
        "\n",
        "    num_classes = len(dataset.classes)\n",
        "    data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=16, pin_memory=True, collate_fn=None)\n",
        "    data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=64, sampler=test_sampler, num_workers=16, pin_memory=True)\n",
        "\n",
        "    print(\"Creating model\")\n",
        "    model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)\n",
        "    if args.apply_trp:\n",
        "        model = apply_trp(model, args.trp_depths,  args.trp_planes, args.trp_lambdas)\n",
        "    model.to(device)\n",
        "\n",
        "    parameters = utils.set_weight_decay(model, args.weight_decay, norm_weight_decay=None, custom_keys_weight_decay=None)\n",
        "    optimizer = torch.optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=False)\n",
        "\n",
        "    main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)\n",
        "    warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs)\n",
        "    lr_scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs])\n",
        "\n",
        "\n",
        "    print(\"Start training\")\n",
        "    start_time = time.time()\n",
        "    for epoch in range(args.epochs):\n",
        "        train_one_epoch(model, optimizer, data_loader, device, epoch, args)\n",
        "        lr_scheduler.step()\n",
        "        evaluate(model, nn.CrossEntropyLoss(), data_loader_test, device=device)\n",
        "        if args.output_dir:\n",
        "            checkpoint = {\n",
        "                \"model\": model.state_dict() if not args.apply_trp else {k: v for k, v in model.state_dict().items() if not k.startswith(\"trp_blocks\")},  # NOTE: remove TRP heads\n",
        "                \"optimizer\": optimizer.state_dict(),\n",
        "                \"lr_scheduler\": lr_scheduler.state_dict(),\n",
        "                \"epoch\": epoch,\n",
        "                \"args\": args,\n",
        "            }\n",
        "            utils.save_on_master(checkpoint, os.path.join(args.output_dir, f\"model_{epoch}.pth\"))\n",
        "            utils.save_on_master(checkpoint, os.path.join(args.output_dir, \"checkpoint.pth\"))\n",
        "\n",
        "    total_time = time.time() - start_time\n",
        "    total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n",
        "    print(f\"Training time {total_time_str}\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SV8s5k49KwgS"
      },
      "source": [
        "Prepare the [ImageNet](http://image-net.org/) dataset manually and place it in `/path/to/imagenet`. For image classification examples, pass the argument `--data-path=/path/to/imagenet` to the training script. The extracted dataset directory should follow this structure:\n",
        "```setup\n",
        "/path/to/imagenet/:\n",
        "    train/:\n",
        "        n01440764:\n",
        "            n01440764_18.JPEG ...\n",
        "        n01443537:\n",
        "            n01443537_2.JPEG ...\n",
        "    val/:\n",
        "        n01440764:\n",
        "            ILSVRC2012_val_00000293.JPEG ...\n",
        "        n01443537:\n",
        "            ILSVRC2012_val_00000236.JPEG ...\n",
        "```\n",
        "\n",
        "Now you can apply the SPG algorithm in model retraining.\n",
        "\n",
        "**Implementation Note:**\n",
        "\n",
        "- This demonstration runs on Google Colab using a single GPU configuration\n",
        "- Performance Improvement: Enhances ResNet18 validation accuracy (ACC@1) from 69.76% to 70.09%\n",
        "- For optimal results:\n",
        "  - Refer to our README.md for complete setup instructions\n",
        "  - Recommended hardware: 4× RTX A6000 GPUs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "UDZxDNfT4xUb",
        "outputId": "bcf86aa0-eb77-4815-e0fa-05997f1e1f1b"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "namespace(data_path='/home/cs/Documents/datasets/imagenet', model='resnet18', device='cuda', batch_size=512, epochs=6, lr=0.0004, momentum=0.9, weight_decay=0.0001, lr_warmup_epochs=1, lr_warmup_decay=0.0, lr_step_size=2, lr_gamma=0.5, print_freq=100, output_dir='resnet18', use_deterministic_algorithms=False, weights='ResNet18_Weights.IMAGENET1K_V1', apply_trp=True, trp_depths=[3, 3, 3], trp_planes=256, trp_lambdas=[0.4, 0.2, 0.1])\n",
            "Loading data\n",
            "Loading training data\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Took 1.9062905311584473\n",
            "Loading validation data\n",
            "Creating data loaders\n",
            "Creating model\n",
            "✅ Applying TRP to ResNet for Image Classification...\n",
            "Start training\n",
            "Epoch: [0]  [   0/2503]  eta: 10:05:09  lr: 0.0  img/s: 81.93631887515438  loss: 0.7334 (0.7334)  acc1: 71.2891 (71.2891)  acc5: 86.1328 (86.1328)  time: 14.5065  data: 8.2577  max mem: 19119\n",
            "Epoch: [0]  [ 100/2503]  eta: 0:29:06  lr: 0.0  img/s: 862.8257862120394  loss: 0.7145 (0.7308)  acc1: 69.5312 (69.6105)  acc5: 87.6953 (87.3704)  time: 0.5927  data: 0.0003  max mem: 19119\n",
            "Epoch: [0]  [ 200/2503]  eta: 0:25:23  lr: 0.0  img/s: 860.6862569301302  loss: 0.7355 (0.7353)  acc1: 68.9453 (69.3427)  acc5: 86.9141 (87.3125)  time: 0.5966  data: 0.0003  max mem: 19119\n",
            "Epoch: [0]  [ 300/2503]  eta: 0:23:29  lr: 0.0  img/s: 860.0754340960929  loss: 0.7159 (0.7314)  acc1: 69.1406 (69.3463)  acc5: 87.5000 (87.3676)  time: 0.5967  data: 0.0003  max mem: 19119\n",
            "Epoch: [0]  [ 400/2503]  eta: 0:22:03  lr: 0.0  img/s: 859.0790234707376  loss: 0.7594 (0.7361)  acc1: 67.9688 (69.2283)  acc5: 86.7188 (87.3232)  time: 0.5960  data: 0.0003  max mem: 19119\n",
            "Epoch: [0]  [ 500/2503]  eta: 0:20:46  lr: 0.0  img/s: 859.7486624250741  loss: 0.7204 (0.7343)  acc1: 69.7266 (69.2396)  acc5: 87.5000 (87.3827)  time: 0.5958  data: 0.0003  max mem: 19119\n",
            "Epoch: [0]  [ 600/2503]  eta: 0:19:36  lr: 0.0  img/s: 861.5204710456711  loss: 0.7483 (0.7345)  acc1: 69.5312 (69.2449)  acc5: 86.7188 (87.3950)  time: 0.5958  data: 0.0002  max mem: 19119\n",
            "Epoch: [0]  [ 700/2503]  eta: 0:18:28  lr: 0.0  img/s: 858.9934592  loss: 0.7225 (0.7350)  acc1: 68.5547 (69.2331)  acc5: 87.6953 (87.3738)  time: 0.5958  data: 0.0003  max mem: 19119\n",
            "Epoch: [0]  [ 800/2503]  eta: 0:17:23  lr: 0.0  img/s: 859.4995325247505  loss: 0.7639 (0.7355)  acc1: 69.7266 (69.2177)  acc5: 86.7188 (87.3578)  time: 0.5961  data: 0.0002  max mem: 19119\n",
            "Epoch: [0]  [ 900/2503]  eta: 0:16:18  lr: 0.0  img/s: 860.8087326554238  loss: 0.7118 (0.7349)  acc1: 69.9219 (69.2440)  acc5: 87.6953 (87.3548)  time: 0.5961  data: 0.0002  max mem: 19119\n",
            "Epoch: [0]  [1000/2503]  eta: 0:15:15  lr: 0.0  img/s: 859.5858857924882  loss: 0.7224 (0.7351)  acc1: 69.3359 (69.2485)  acc5: 87.3047 (87.3624)  time: 0.5958  data: 0.0003  max mem: 19119\n",
            "Epoch: [0]  [1100/2503]  eta: 0:14:12  lr: 0.0  img/s: 858.8670339725992  loss: 0.7240 (0.7360)  acc1: 68.9453 (69.2212)  acc5: 87.1094 (87.3361)  time: 0.5958  data: 0.0002  max mem: 19119\n",
            "Epoch: [0]  [1200/2503]  eta: 0:13:10  lr: 0.0  img/s: 861.4696676125856  loss: 0.7126 (0.7364)  acc1: 68.3594 (69.1878)  acc5: 87.3047 (87.3190)  time: 0.5960  data: 0.0002  max mem: 19119\n",
            "Epoch: [0]  [1300/2503]  eta: 0:12:09  lr: 0.0  img/s: 859.3643608581464  loss: 0.7291 (0.7367)  acc1: 68.9453 (69.1669)  acc5: 86.7188 (87.2990)  time: 0.5959  data: 0.0002  max mem: 19119\n",
            "Epoch: [0]  [1400/2503]  eta: 0:11:07  lr: 0.0  img/s: 861.1477063020853  loss: 0.7267 (0.7372)  acc1: 69.9219 (69.1624)  acc5: 87.1094 (87.2990)  time: 0.5960  data: 0.0002  max mem: 19119\n",
            "Epoch: [0]  [1500/2503]  eta: 0:10:06  lr: 0.0  img/s: 859.0494692253935  loss: 0.7234 (0.7374)  acc1: 69.1406 (69.1607)  acc5: 87.3047 (87.2939)  time: 0.5959  data: 0.0003  max mem: 19119\n",
            "Epoch: [0]  [1600/2503]  eta: 0:09:05  lr: 0.0  img/s: 860.660386236062  loss: 0.7456 (0.7374)  acc1: 69.3359 (69.1730)  acc5: 87.5000 (87.3019)  time: 0.5960  data: 0.0002  max mem: 19119\n",
            "Epoch: [0]  [1700/2503]  eta: 0:08:04  lr: 0.0  img/s: 858.9515423647326  loss: 0.7548 (0.7372)  acc1: 69.1406 (69.1773)  acc5: 87.5000 (87.3198)  time: 0.5959  data: 0.0002  max mem: 19119\n",
            "Epoch: [0]  [1800/2503]  eta: 0:07:04  lr: 0.0  img/s: 860.6800478217115  loss: 0.7596 (0.7375)  acc1: 67.1875 (69.1614)  acc5: 87.1094 (87.3191)  time: 0.5958  data: 0.0002  max mem: 19119\n",
            "Epoch: [0]  [1900/2503]  eta: 0:06:03  lr: 0.0  img/s: 859.6578027499652  loss: 0.7465 (0.7375)  acc1: 68.3594 (69.1633)  acc5: 86.7188 (87.3222)  time: 0.5959  data: 0.0002  max mem: 19119\n",
            "Epoch: [0]  [2000/2503]  eta: 0:05:03  lr: 0.0  img/s: 860.6507282423033  loss: 0.7385 (0.7375)  acc1: 69.3359 (69.1609)  acc5: 87.3047 (87.3233)  time: 0.5959  data: 0.0002  max mem: 19119\n",
            "Epoch: [0]  [2100/2503]  eta: 0:04:02  lr: 0.0  img/s: 860.72592834858  loss: 0.7153 (0.7373)  acc1: 69.3359 (69.1710)  acc5: 87.3047 (87.3230)  time: 0.5961  data: 0.0004  max mem: 19119\n",
            "Epoch: [0]  [2200/2503]  eta: 0:03:02  lr: 0.0  img/s: 859.2460775467988  loss: 0.7307 (0.7371)  acc1: 68.9453 (69.1861)  acc5: 87.5000 (87.3380)  time: 0.5960  data: 0.0004  max mem: 19119\n",
            "Epoch: [0]  [2300/2503]  eta: 0:02:02  lr: 0.0  img/s: 859.2639554931892  loss: 0.7077 (0.7367)  acc1: 69.3359 (69.1971)  acc5: 87.6953 (87.3516)  time: 0.5959  data: 0.0004  max mem: 19119\n",
            "Epoch: [0]  [2400/2503]  eta: 0:01:01  lr: 0.0  img/s: 861.341130585524  loss: 0.7279 (0.7365)  acc1: 68.5547 (69.1921)  acc5: 86.9141 (87.3412)  time: 0.5961  data: 0.0004  max mem: 19119\n",
            "Epoch: [0]  [2500/2503]  eta: 0:00:01  lr: 0.0  img/s: 861.8382147793436  loss: 0.7469 (0.7368)  acc1: 68.5547 (69.1894)  acc5: 87.5000 (87.3423)  time: 0.5955  data: 0.0005  max mem: 19119\n",
            "Epoch: [0] Total time: 0:25:05\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/home/cs/anaconda3/envs/csenv/lib/python3.9/site-packages/torch/optim/lr_scheduler.py:243: UserWarning: The epoch parameter in `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: https://github.com/pytorch/pytorch/issues/new/choose.\n",
            "  warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Test:   [  0/782]  eta: 0:23:05  loss: 0.6283 (0.6283)  acc1: 89.0625 (89.0625)  acc5: 95.3125 (95.3125)  time: 1.7719  data: 1.3111  max mem: 19119\n",
            "Test:   [100/782]  eta: 0:00:30  loss: 1.0688 (0.9382)  acc1: 76.5625 (76.2840)  acc5: 89.0625 (92.1875)  time: 0.0399  data: 0.0263  max mem: 19119\n",
            "Test:   [200/782]  eta: 0:00:21  loss: 0.9244 (0.9143)  acc1: 73.4375 (75.8240)  acc5: 95.3125 (93.2369)  time: 0.0244  data: 0.0107  max mem: 19119\n",
            "Test:   [300/782]  eta: 0:00:17  loss: 0.8615 (0.9072)  acc1: 76.5625 (76.1991)  acc5: 92.1875 (93.5008)  time: 0.0381  data: 0.0244  max mem: 19119\n",
            "Test:   [400/782]  eta: 0:00:13  loss: 1.6977 (1.0440)  acc1: 59.3750 (73.6323)  acc5: 82.8125 (91.7472)  time: 0.0313  data: 0.0176  max mem: 19119\n",
            "Test:   [500/782]  eta: 0:00:09  loss: 1.6021 (1.1237)  acc1: 54.6875 (72.0964)  acc5: 85.9375 (90.5845)  time: 0.0247  data: 0.0109  max mem: 19119\n",
            "Test:   [600/782]  eta: 0:00:06  loss: 1.3631 (1.1858)  acc1: 64.0625 (70.8741)  acc5: 84.3750 (89.7853)  time: 0.0291  data: 0.0153  max mem: 19119\n",
            "Test:   [700/782]  eta: 0:00:02  loss: 1.2494 (1.2361)  acc1: 68.7500 (69.9313)  acc5: 87.5000 (89.1115)  time: 0.0391  data: 0.0254  max mem: 19119\n",
            "Test:  Total time: 0:00:26\n",
            "Test:  Acc@1 69.846 Acc@5 89.136\n",
            "Epoch: [1]  [   0/2503]  eta: 4:27:27  lr: 0.0004  img/s: 861.3684242192573  loss: 0.7611 (0.7611)  acc1: 68.5547 (68.5547)  acc5: 86.1328 (86.1328)  time: 6.4115  data: 5.8170  max mem: 19119\n",
            "Epoch: [1]  [ 100/2503]  eta: 0:26:25  lr: 0.0004  img/s: 856.9149263546342  loss: 0.7538 (0.7542)  acc1: 70.5078 (69.0536)  acc5: 87.5000 (87.1364)  time: 0.5982  data: 0.0003  max mem: 19119\n",
            "Epoch: [1]  [ 200/2503]  eta: 0:24:10  lr: 0.0004  img/s: 854.0172713632207  loss: 0.7744 (0.7573)  acc1: 69.7266 (69.2990)  acc5: 88.0859 (87.3785)  time: 0.5998  data: 0.0003  max mem: 19119\n",
            "Epoch: [1]  [ 300/2503]  eta: 0:22:45  lr: 0.0004  img/s: 856.0483105922384  loss: 0.7551 (0.7613)  acc1: 69.1406 (69.2834)  acc5: 87.3047 (87.4611)  time: 0.5996  data: 0.0002  max mem: 19119\n",
            "Epoch: [1]  [ 400/2503]  eta: 0:21:32  lr: 0.0004  img/s: 854.8386016604893  loss: 0.7931 (0.7645)  acc1: 68.5547 (69.3004)  acc5: 87.3047 (87.4698)  time: 0.5987  data: 0.0002  max mem: 19119\n",
            "Epoch: [1]  [ 500/2503]  eta: 0:20:24  lr: 0.0004  img/s: 855.3431965742996  loss: 0.7744 (0.7684)  acc1: 68.1641 (69.2853)  acc5: 86.9141 (87.4361)  time: 0.5986  data: 0.0002  max mem: 19119\n",
            "Epoch: [1]  [ 600/2503]  eta: 0:19:19  lr: 0.0004  img/s: 855.1112541063571  loss: 0.7860 (0.7730)  acc1: 69.1406 (69.2310)  acc5: 86.7188 (87.3941)  time: 0.5988  data: 0.0003  max mem: 19119\n",
            "Epoch: [1]  [ 700/2503]  eta: 0:18:15  lr: 0.0004  img/s: 856.4904515045232  loss: 0.7908 (0.7773)  acc1: 68.7500 (69.1746)  acc5: 86.7188 (87.3543)  time: 0.5985  data: 0.0002  max mem: 19119\n",
            "Epoch: [1]  [ 800/2503]  eta: 0:17:13  lr: 0.0004  img/s: 858.0146361031335  loss: 0.8157 (0.7805)  acc1: 68.5547 (69.1660)  acc5: 87.6953 (87.3181)  time: 0.5991  data: 0.0002  max mem: 19119\n",
            "Epoch: [1]  [ 900/2503]  eta: 0:16:11  lr: 0.0004  img/s: 854.9138104963116  loss: 0.7641 (0.7825)  acc1: 69.5312 (69.1807)  acc5: 88.8672 (87.3346)  time: 0.5989  data: 0.0003  max mem: 19119\n",
            "Epoch: [1]  [1000/2503]  eta: 0:15:09  lr: 0.0004  img/s: 855.7491430488731  loss: 0.8024 (0.7852)  acc1: 68.1641 (69.1506)  acc5: 86.5234 (87.3234)  time: 0.5988  data: 0.0002  max mem: 19119\n",
            "Epoch: [1]  [1100/2503]  eta: 0:14:08  lr: 0.0004  img/s: 856.0848253972304  loss: 0.8099 (0.7872)  acc1: 69.1406 (69.1564)  acc5: 86.7188 (87.3231)  time: 0.5992  data: 0.0002  max mem: 19119\n",
            "Epoch: [1]  [1200/2503]  eta: 0:13:07  lr: 0.0004  img/s: 855.6028761225017  loss: 0.8307 (0.7894)  acc1: 68.5547 (69.1569)  acc5: 87.1094 (87.3258)  time: 0.5989  data: 0.0002  max mem: 19119\n",
            "Epoch: [1]  [1300/2503]  eta: 0:12:06  lr: 0.0004  img/s: 855.8589613399885  loss: 0.8206 (0.7913)  acc1: 68.9453 (69.1304)  acc5: 87.3047 (87.3177)  time: 0.5987  data: 0.0003  max mem: 19119\n",
            "Epoch: [1]  [1400/2503]  eta: 0:11:05  lr: 0.0004  img/s: 856.6045604019511  loss: 0.8454 (0.7936)  acc1: 68.1641 (69.1019)  acc5: 86.9141 (87.2906)  time: 0.5989  data: 0.0002  max mem: 19119\n",
            "Epoch: [1]  [1500/2503]  eta: 0:10:04  lr: 0.0004  img/s: 854.944442321167  loss: 0.8428 (0.7960)  acc1: 68.1641 (69.0905)  acc5: 87.3047 (87.2706)  time: 0.5990  data: 0.0002  max mem: 19119\n",
            "Epoch: [1]  [1600/2503]  eta: 0:09:04  lr: 0.0004  img/s: 855.0727794914757  loss: 0.7906 (0.7974)  acc1: 69.5312 (69.0922)  acc5: 87.1094 (87.2686)  time: 0.5990  data: 0.0002  max mem: 19119\n",
            "Epoch: [1]  [1700/2503]  eta: 0:08:03  lr: 0.0004  img/s: 855.4958499669949  loss: 0.8199 (0.7989)  acc1: 69.7266 (69.0854)  acc5: 87.1094 (87.2704)  time: 0.5988  data: 0.0002  max mem: 19119\n",
            "Epoch: [1]  [1800/2503]  eta: 0:07:03  lr: 0.0004  img/s: 855.0251166287029  loss: 0.8257 (0.8007)  acc1: 70.1172 (69.0869)  acc5: 87.5000 (87.2656)  time: 0.5988  data: 0.0003  max mem: 19119\n",
            "Epoch: [1]  [1900/2503]  eta: 0:06:03  lr: 0.0004  img/s: 856.9867390518363  loss: 0.7952 (0.8018)  acc1: 68.7500 (69.0943)  acc5: 87.3047 (87.2670)  time: 0.5989  data: 0.0002  max mem: 19119\n",
            "Epoch: [1]  [2000/2503]  eta: 0:05:02  lr: 0.0004  img/s: 854.3927252530574  loss: 0.8402 (0.8032)  acc1: 68.5547 (69.0964)  acc5: 87.1094 (87.2747)  time: 0.5988  data: 0.0002  max mem: 19119\n",
            "Epoch: [1]  [2100/2503]  eta: 0:04:02  lr: 0.0004  img/s: 855.2427067851231  loss: 0.8451 (0.8042)  acc1: 68.3594 (69.1089)  acc5: 87.3047 (87.2816)  time: 0.5988  data: 0.0002  max mem: 19119\n",
            "Epoch: [1]  [2200/2503]  eta: 0:03:02  lr: 0.0004  img/s: 853.8318747507567  loss: 0.8314 (0.8058)  acc1: 68.9453 (69.1012)  acc5: 87.3047 (87.2716)  time: 0.5989  data: 0.0003  max mem: 19119\n",
            "Epoch: [1]  [2300/2503]  eta: 0:02:02  lr: 0.0004  img/s: 855.3312728222841  loss: 0.8350 (0.8070)  acc1: 68.7500 (69.0993)  acc5: 86.3281 (87.2549)  time: 0.5988  data: 0.0003  max mem: 19119\n",
            "Epoch: [1]  [2400/2503]  eta: 0:01:01  lr: 0.0004  img/s: 855.1613103361218  loss: 0.8206 (0.8084)  acc1: 68.5547 (69.0927)  acc5: 86.9141 (87.2551)  time: 0.5988  data: 0.0002  max mem: 19119\n",
            "Epoch: [1]  [2500/2503]  eta: 0:00:01  lr: 0.0004  img/s: 856.7190414886559  loss: 0.8286 (0.8094)  acc1: 69.1406 (69.1000)  acc5: 87.1094 (87.2599)  time: 0.5986  data: 0.0002  max mem: 19119\n",
            "Epoch: [1] Total time: 0:25:05\n",
            "Test:   [  0/782]  eta: 0:16:19  loss: 0.5636 (0.5636)  acc1: 87.5000 (87.5000)  acc5: 96.8750 (96.8750)  time: 1.2525  data: 1.2385  max mem: 19119\n",
            "Test:   [100/782]  eta: 0:00:31  loss: 1.0393 (0.9414)  acc1: 76.5625 (76.9647)  acc5: 90.6250 (92.2958)  time: 0.0417  data: 0.0280  max mem: 19119\n",
            "Test:   [200/782]  eta: 0:00:22  loss: 0.8964 (0.9176)  acc1: 73.4375 (76.4614)  acc5: 95.3125 (93.3147)  time: 0.0249  data: 0.0112  max mem: 19119\n",
            "Test:   [300/782]  eta: 0:00:17  loss: 0.7984 (0.9094)  acc1: 79.6875 (76.7130)  acc5: 92.1875 (93.6150)  time: 0.0311  data: 0.0173  max mem: 19119\n",
            "Test:   [400/782]  eta: 0:00:13  loss: 1.7745 (1.0483)  acc1: 57.8125 (73.9635)  acc5: 84.3750 (91.8758)  time: 0.0328  data: 0.0190  max mem: 19119\n",
            "Test:   [500/782]  eta: 0:00:09  loss: 1.6435 (1.1264)  acc1: 59.3750 (72.4239)  acc5: 84.3750 (90.7934)  time: 0.0328  data: 0.0190  max mem: 19119\n",
            "Test:   [600/782]  eta: 0:00:06  loss: 1.3057 (1.1915)  acc1: 62.5000 (71.0483)  acc5: 85.9375 (90.0010)  time: 0.0400  data: 0.0261  max mem: 19119\n",
            "Test:   [700/782]  eta: 0:00:02  loss: 1.2212 (1.2428)  acc1: 70.3125 (70.0985)  acc5: 87.5000 (89.3010)  time: 0.0253  data: 0.0115  max mem: 19119\n",
            "Test:  Total time: 0:00:26\n",
            "Test:  Acc@1 70.000 Acc@5 89.320\n",
            "Epoch: [2]  [   0/2503]  eta: 4:06:15  lr: 0.0004  img/s: 867.4756359685  loss: 0.8414 (0.8414)  acc1: 67.9688 (67.9688)  acc5: 86.1328 (86.1328)  time: 5.9030  data: 5.3128  max mem: 19119\n",
            "Epoch: [2]  [ 100/2503]  eta: 0:25:53  lr: 0.0004  img/s: 859.1872918530421  loss: 0.8472 (0.8456)  acc1: 68.9453 (69.1194)  acc5: 86.7188 (87.2892)  time: 0.5963  data: 0.0003  max mem: 19119\n",
            "Epoch: [2]  [ 200/2503]  eta: 0:23:52  lr: 0.0004  img/s: 857.5945509684602  loss: 0.8563 (0.8443)  acc1: 68.1641 (69.1649)  acc5: 86.3281 (87.1852)  time: 0.5972  data: 0.0002  max mem: 19119\n",
            "Epoch: [2]  [ 300/2503]  eta: 0:22:31  lr: 0.0004  img/s: 859.3399450658505  loss: 0.8386 (0.8440)  acc1: 69.1406 (69.0790)  acc5: 87.3047 (87.1762)  time: 0.5967  data: 0.0003  max mem: 19119\n",
            "Epoch: [2]  [ 400/2503]  eta: 0:21:21  lr: 0.0004  img/s: 859.825426282471  loss: 0.8455 (0.8446)  acc1: 69.3359 (69.0422)  acc5: 87.3047 (87.1591)  time: 0.5963  data: 0.0003  max mem: 19119\n",
            "Epoch: [2]  [ 500/2503]  eta: 0:20:15  lr: 0.0004  img/s: 858.6002202991031  loss: 0.8400 (0.8448)  acc1: 67.9688 (69.0373)  acc5: 87.5000 (87.1640)  time: 0.5962  data: 0.0003  max mem: 19119\n",
            "Epoch: [2]  [ 600/2503]  eta: 0:19:11  lr: 0.0004  img/s: 859.6997885462696  loss: 0.8544 (0.8467)  acc1: 68.3594 (69.0217)  acc5: 86.9141 (87.1347)  time: 0.5963  data: 0.0002  max mem: 19119\n",
            "Epoch: [2]  [ 700/2503]  eta: 0:18:08  lr: 0.0004  img/s: 858.0379481409502  loss: 0.8386 (0.8466)  acc1: 68.7500 (69.0314)  acc5: 87.1094 (87.1459)  time: 0.5966  data: 0.0002  max mem: 19119\n",
            "Epoch: [2]  [ 800/2503]  eta: 0:17:06  lr: 0.0004  img/s: 859.3574830298114  loss: 0.8607 (0.8472)  acc1: 69.5312 (69.0338)  acc5: 87.1094 (87.1364)  time: 0.5965  data: 0.0003  max mem: 19119\n",
            "Epoch: [2]  [ 900/2503]  eta: 0:16:05  lr: 0.0004  img/s: 858.1785328651912  loss: 0.8502 (0.8474)  acc1: 68.5547 (69.0273)  acc5: 87.1094 (87.1404)  time: 0.5966  data: 0.0002  max mem: 19119\n",
            "Epoch: [2]  [1000/2503]  eta: 0:15:04  lr: 0.0004  img/s: 858.6554923981921  loss: 0.8213 (0.8468)  acc1: 70.1172 (69.0737)  acc5: 87.6953 (87.1601)  time: 0.5966  data: 0.0003  max mem: 19119\n",
            "Epoch: [2]  [1100/2503]  eta: 0:14:03  lr: 0.0004  img/s: 858.942266240932  loss: 0.8322 (0.8466)  acc1: 68.9453 (69.0824)  acc5: 87.5000 (87.1826)  time: 0.5965  data: 0.0003  max mem: 19119\n",
            "Epoch: [2]  [1200/2503]  eta: 0:13:02  lr: 0.0004  img/s: 859.9796839014982  loss: 0.8353 (0.8468)  acc1: 68.5547 (69.0858)  acc5: 87.1094 (87.1886)  time: 0.5962  data: 0.0002  max mem: 19119\n",
            "Epoch: [2]  [1300/2503]  eta: 0:12:02  lr: 0.0004  img/s: 858.8996673563247  loss: 0.8654 (0.8471)  acc1: 68.3594 (69.0645)  acc5: 86.7188 (87.1876)  time: 0.5966  data: 0.0002  max mem: 19119\n",
            "Epoch: [2]  [1400/2503]  eta: 0:11:01  lr: 0.0004  img/s: 859.7879032225777  loss: 0.8277 (0.8466)  acc1: 70.1172 (69.0861)  acc5: 88.6719 (87.2244)  time: 0.5963  data: 0.0003  max mem: 19119\n",
            "Epoch: [2]  [1500/2503]  eta: 0:10:01  lr: 0.0004  img/s: 859.6271763544084  loss: 0.8703 (0.8471)  acc1: 68.7500 (69.0763)  acc5: 86.5234 (87.2101)  time: 0.5962  data: 0.0003  max mem: 19119\n",
            "Epoch: [2]  [1600/2503]  eta: 0:09:01  lr: 0.0004  img/s: 859.8206066235718  loss: 0.8818 (0.8481)  acc1: 68.5547 (69.0523)  acc5: 86.7188 (87.1975)  time: 0.5965  data: 0.0002  max mem: 19119\n",
            "Epoch: [2]  [1700/2503]  eta: 0:08:01  lr: 0.0004  img/s: 858.3188206070029  loss: 0.8447 (0.8487)  acc1: 69.3359 (69.0259)  acc5: 87.3047 (87.1857)  time: 0.5965  data: 0.0002  max mem: 19119\n",
            "Epoch: [2]  [1800/2503]  eta: 0:07:01  lr: 0.0004  img/s: 860.4651988761577  loss: 0.8492 (0.8489)  acc1: 67.5781 (69.0295)  acc5: 86.7188 (87.1785)  time: 0.5965  data: 0.0002  max mem: 19119\n",
            "Epoch: [2]  [1900/2503]  eta: 0:06:01  lr: 0.0004  img/s: 858.5823699612625  loss: 0.8559 (0.8491)  acc1: 67.9688 (69.0192)  acc5: 87.5000 (87.1776)  time: 0.5963  data: 0.0002  max mem: 19119\n",
            "Epoch: [2]  [2000/2503]  eta: 0:05:01  lr: 0.0004  img/s: 858.4468002775832  loss: 0.8712 (0.8498)  acc1: 68.9453 (69.0207)  acc5: 87.6953 (87.1836)  time: 0.5963  data: 0.0003  max mem: 19119\n",
            "Epoch: [2]  [2100/2503]  eta: 0:04:01  lr: 0.0004  img/s: 858.6208177650899  loss: 0.8782 (0.8507)  acc1: 68.3594 (69.0053)  acc5: 85.9375 (87.1818)  time: 0.5964  data: 0.0002  max mem: 19119\n",
            "Epoch: [2]  [2200/2503]  eta: 0:03:01  lr: 0.0004  img/s: 858.769492116456  loss: 0.8845 (0.8514)  acc1: 68.3594 (68.9953)  acc5: 86.5234 (87.1744)  time: 0.5963  data: 0.0002  max mem: 19119\n",
            "Epoch: [2]  [2300/2503]  eta: 0:02:01  lr: 0.0004  img/s: 860.1050589782235  loss: 0.8664 (0.8522)  acc1: 68.7500 (68.9914)  acc5: 87.5000 (87.1735)  time: 0.5962  data: 0.0002  max mem: 19119\n",
            "Epoch: [2]  [2400/2503]  eta: 0:01:01  lr: 0.0004  img/s: 859.3210323775423  loss: 0.8824 (0.8529)  acc1: 67.7734 (68.9772)  acc5: 86.5234 (87.1693)  time: 0.5963  data: 0.0002  max mem: 19119\n",
            "Epoch: [2]  [2500/2503]  eta: 0:00:01  lr: 0.0004  img/s: 860.1956686665284  loss: 0.8302 (0.8531)  acc1: 69.9219 (68.9880)  acc5: 87.5000 (87.1751)  time: 0.5962  data: 0.0002  max mem: 19119\n",
            "Epoch: [2] Total time: 0:24:57\n",
            "Test:   [  0/782]  eta: 0:14:49  loss: 0.6400 (0.6400)  acc1: 82.8125 (82.8125)  acc5: 93.7500 (93.7500)  time: 1.1370  data: 1.1232  max mem: 19119\n",
            "Test:   [100/782]  eta: 0:00:28  loss: 1.0691 (0.9495)  acc1: 75.0000 (76.8874)  acc5: 89.0625 (92.2184)  time: 0.0422  data: 0.0284  max mem: 19119\n",
            "Test:   [200/782]  eta: 0:00:20  loss: 0.8384 (0.9253)  acc1: 75.0000 (76.3293)  acc5: 95.3125 (93.2292)  time: 0.0298  data: 0.0161  max mem: 19119\n",
            "Test:   [300/782]  eta: 0:00:16  loss: 0.8140 (0.9153)  acc1: 78.1250 (76.6092)  acc5: 92.1875 (93.5631)  time: 0.0281  data: 0.0143  max mem: 19119\n",
            "Test:   [400/782]  eta: 0:00:12  loss: 1.7029 (1.0528)  acc1: 62.5000 (73.9479)  acc5: 84.3750 (91.8797)  time: 0.0260  data: 0.0123  max mem: 19119\n",
            "Test:   [500/782]  eta: 0:00:09  loss: 1.7149 (1.1295)  acc1: 59.3750 (72.4894)  acc5: 84.3750 (90.7997)  time: 0.0315  data: 0.0177  max mem: 19119\n",
            "Test:   [600/782]  eta: 0:00:06  loss: 1.3215 (1.1949)  acc1: 65.6250 (71.1288)  acc5: 85.9375 (90.0192)  time: 0.0343  data: 0.0204  max mem: 19119\n",
            "Test:   [700/782]  eta: 0:00:02  loss: 1.3000 (1.2468)  acc1: 70.3125 (70.1386)  acc5: 85.9375 (89.2809)  time: 0.0246  data: 0.0108  max mem: 19119\n",
            "Test:  Total time: 0:00:25\n",
            "Test:  Acc@1 70.034 Acc@5 89.306\n",
            "Epoch: [3]  [   0/2503]  eta: 3:48:40  lr: 0.0002  img/s: 868.6651772838787  loss: 0.9922 (0.9922)  acc1: 65.8203 (65.8203)  acc5: 84.3750 (84.3750)  time: 5.4818  data: 4.8924  max mem: 19119\n",
            "Epoch: [3]  [ 100/2503]  eta: 0:25:56  lr: 0.0002  img/s: 857.1146638568258  loss: 0.8599 (0.8484)  acc1: 69.7266 (69.1851)  acc5: 86.7188 (87.2660)  time: 0.5978  data: 0.0002  max mem: 19119\n",
            "Epoch: [3]  [ 200/2503]  eta: 0:23:56  lr: 0.0002  img/s: 854.6256384868216  loss: 0.8801 (0.8570)  acc1: 68.7500 (69.0182)  acc5: 86.3281 (87.1521)  time: 0.5998  data: 0.0002  max mem: 19119\n",
            "Epoch: [3]  [ 300/2503]  eta: 0:22:36  lr: 0.0002  img/s: 855.2042203405152  loss: 0.8260 (0.8538)  acc1: 69.3359 (69.0959)  acc5: 87.6953 (87.2262)  time: 0.5990  data: 0.0002  max mem: 19119\n",
            "Epoch: [3]  [ 400/2503]  eta: 0:21:25  lr: 0.0002  img/s: 856.2763231713128  loss: 0.8881 (0.8553)  acc1: 68.3594 (69.1733)  acc5: 86.9141 (87.2141)  time: 0.5987  data: 0.0003  max mem: 19119\n",
            "Epoch: [3]  [ 500/2503]  eta: 0:20:19  lr: 0.0002  img/s: 856.0431919432866  loss: 0.8596 (0.8573)  acc1: 68.3594 (69.1281)  acc5: 87.3047 (87.2291)  time: 0.5982  data: 0.0002  max mem: 19119\n",
            "Epoch: [3]  [ 600/2503]  eta: 0:19:15  lr: 0.0002  img/s: 855.2682527981125  loss: 0.8779 (0.8592)  acc1: 68.1641 (69.1153)  acc5: 86.7188 (87.2007)  time: 0.5987  data: 0.0003  max mem: 19119\n",
            "Epoch: [3]  [ 700/2503]  eta: 0:18:12  lr: 0.0002  img/s: 855.3551206587658  loss: 0.8727 (0.8601)  acc1: 68.7500 (69.0902)  acc5: 87.8906 (87.2033)  time: 0.5988  data: 0.0003  max mem: 19119\n",
            "Epoch: [3]  [ 800/2503]  eta: 0:17:10  lr: 0.0002  img/s: 854.5804059835012  loss: 0.8775 (0.8608)  acc1: 69.1406 (69.0684)  acc5: 86.5234 (87.1713)  time: 0.5988  data: 0.0002  max mem: 19119\n",
            "Epoch: [3]  [ 900/2503]  eta: 0:16:08  lr: 0.0002  img/s: 855.525160200547  loss: 0.8299 (0.8601)  acc1: 69.5312 (69.0866)  acc5: 87.5000 (87.1883)  time: 0.5986  data: 0.0002  max mem: 19119\n",
            "Epoch: [3]  [1000/2503]  eta: 0:15:07  lr: 0.0002  img/s: 855.1732293498484  loss: 0.8740 (0.8600)  acc1: 68.3594 (69.0608)  acc5: 86.7188 (87.1773)  time: 0.5987  data: 0.0002  max mem: 19119\n",
            "Epoch: [3]  [1100/2503]  eta: 0:14:06  lr: 0.0002  img/s: 855.7201584498974  loss: 0.8490 (0.8600)  acc1: 69.5312 (69.0574)  acc5: 87.6953 (87.1810)  time: 0.5987  data: 0.0003  max mem: 19119\n",
            "Epoch: [3]  [1200/2503]  eta: 0:13:05  lr: 0.0002  img/s: 855.6761738093758  loss: 0.8551 (0.8598)  acc1: 70.1172 (69.0749)  acc5: 87.3047 (87.1956)  time: 0.5988  data: 0.0002  max mem: 19119\n",
            "Epoch: [3]  [1300/2503]  eta: 0:12:04  lr: 0.0002  img/s: 855.1391759063549  loss: 0.8736 (0.8596)  acc1: 68.9453 (69.1111)  acc5: 87.5000 (87.2038)  time: 0.5985  data: 0.0002  max mem: 19119\n",
            "Epoch: [3]  [1400/2503]  eta: 0:11:04  lr: 0.0002  img/s: 855.9558431039926  loss: 0.8849 (0.8602)  acc1: 69.1406 (69.1073)  acc5: 86.3281 (87.2021)  time: 0.5989  data: 0.0002  max mem: 19119\n",
            "Epoch: [3]  [1500/2503]  eta: 0:10:03  lr: 0.0002  img/s: 856.2879318505251  loss: 0.8493 (0.8600)  acc1: 69.7266 (69.1198)  acc5: 87.3047 (87.2114)  time: 0.5987  data: 0.0003  max mem: 19119\n",
            "Epoch: [3]  [1600/2503]  eta: 0:09:03  lr: 0.0002  img/s: 855.3885098640291  loss: 0.8944 (0.8605)  acc1: 67.9688 (69.1188)  acc5: 86.5234 (87.2106)  time: 0.5987  data: 0.0003  max mem: 19119\n",
            "Epoch: [3]  [1700/2503]  eta: 0:08:03  lr: 0.0002  img/s: 855.653671788276  loss: 0.8327 (0.8606)  acc1: 69.7266 (69.1132)  acc5: 87.3047 (87.1988)  time: 0.5986  data: 0.0002  max mem: 19119\n",
            "Epoch: [3]  [1800/2503]  eta: 0:07:02  lr: 0.0002  img/s: 854.6603313202065  loss: 0.8716 (0.8606)  acc1: 69.5312 (69.1106)  acc5: 87.3047 (87.2095)  time: 0.5988  data: 0.0002  max mem: 19119\n",
            "Epoch: [3]  [1900/2503]  eta: 0:06:02  lr: 0.0002  img/s: 855.8654421819135  loss: 0.8433 (0.8607)  acc1: 68.9453 (69.1149)  acc5: 86.9141 (87.1973)  time: 0.5985  data: 0.0002  max mem: 19119\n",
            "Epoch: [3]  [2000/2503]  eta: 0:05:02  lr: 0.0002  img/s: 858.0228637365412  loss: 0.8635 (0.8613)  acc1: 69.1406 (69.0981)  acc5: 86.7188 (87.1936)  time: 0.5985  data: 0.0002  max mem: 19119\n",
            "Epoch: [3]  [2100/2503]  eta: 0:04:02  lr: 0.0002  img/s: 855.1837864680422  loss: 0.8389 (0.8614)  acc1: 69.7266 (69.1067)  acc5: 87.3047 (87.2038)  time: 0.5987  data: 0.0002  max mem: 19119\n",
            "Epoch: [3]  [2200/2503]  eta: 0:03:02  lr: 0.0002  img/s: 854.9267436657309  loss: 0.8588 (0.8618)  acc1: 69.5312 (69.1018)  acc5: 86.9141 (87.2006)  time: 0.5988  data: 0.0002  max mem: 19119\n",
            "Epoch: [3]  [2300/2503]  eta: 0:02:01  lr: 0.0002  img/s: 857.5592770650364  loss: 0.8385 (0.8623)  acc1: 69.7266 (69.1041)  acc5: 87.6953 (87.1965)  time: 0.5985  data: 0.0003  max mem: 19119\n",
            "Epoch: [3]  [2400/2503]  eta: 0:01:01  lr: 0.0002  img/s: 854.3804880688189  loss: 0.8534 (0.8625)  acc1: 68.9453 (69.1074)  acc5: 87.3047 (87.1914)  time: 0.5988  data: 0.0002  max mem: 19119\n",
            "Epoch: [3]  [2500/2503]  eta: 0:00:01  lr: 0.0002  img/s: 855.3973686621443  loss: 0.8348 (0.8625)  acc1: 69.3359 (69.1134)  acc5: 87.6953 (87.1933)  time: 0.5984  data: 0.0002  max mem: 19119\n",
            "Epoch: [3] Total time: 0:25:03\n",
            "Test:   [  0/782]  eta: 0:13:34  loss: 0.6298 (0.6298)  acc1: 84.3750 (84.3750)  acc5: 95.3125 (95.3125)  time: 1.0412  data: 1.0273  max mem: 19119\n",
            "Test:   [100/782]  eta: 0:00:28  loss: 1.0908 (0.9514)  acc1: 75.0000 (76.9957)  acc5: 89.0625 (92.2339)  time: 0.0397  data: 0.0260  max mem: 19119\n",
            "Test:   [200/782]  eta: 0:00:21  loss: 0.9058 (0.9231)  acc1: 73.4375 (76.3137)  acc5: 95.3125 (93.2680)  time: 0.0258  data: 0.0121  max mem: 19119\n",
            "Test:   [300/782]  eta: 0:00:17  loss: 0.8269 (0.9143)  acc1: 79.6875 (76.5988)  acc5: 92.1875 (93.5735)  time: 0.0356  data: 0.0218  max mem: 19119\n",
            "Test:   [400/782]  eta: 0:00:13  loss: 1.8047 (1.0535)  acc1: 60.9375 (73.9207)  acc5: 82.8125 (91.8329)  time: 0.0270  data: 0.0133  max mem: 19119\n",
            "Test:   [500/782]  eta: 0:00:09  loss: 1.6839 (1.1303)  acc1: 59.3750 (72.5050)  acc5: 85.9375 (90.7622)  time: 0.0334  data: 0.0196  max mem: 19119\n",
            "Test:   [600/782]  eta: 0:00:06  loss: 1.3633 (1.1951)  acc1: 64.0625 (71.1704)  acc5: 87.5000 (89.9776)  time: 0.0256  data: 0.0118  max mem: 19119\n",
            "Test:   [700/782]  eta: 0:00:02  loss: 1.2720 (1.2480)  acc1: 71.8750 (70.1632)  acc5: 85.9375 (89.2876)  time: 0.0280  data: 0.0142  max mem: 19119\n",
            "Test:  Total time: 0:00:26\n",
            "Test:  Acc@1 70.092 Acc@5 89.308\n",
            "Epoch: [4]  [   0/2503]  eta: 3:53:04  lr: 0.0002  img/s: 868.854963030395  loss: 0.9245 (0.9245)  acc1: 67.5781 (67.5781)  acc5: 87.3047 (87.3047)  time: 5.5871  data: 4.9977  max mem: 19119\n",
            "Epoch: [4]  [ 100/2503]  eta: 0:25:58  lr: 0.0002  img/s: 856.1991675966316  loss: 0.8765 (0.8657)  acc1: 68.9453 (69.2605)  acc5: 86.5234 (87.1171)  time: 0.5978  data: 0.0002  max mem: 19119\n",
            "Epoch: [4]  [ 200/2503]  eta: 0:23:57  lr: 0.0002  img/s: 853.9622551773969  loss: 0.8653 (0.8621)  acc1: 69.7266 (69.4321)  acc5: 87.3047 (87.2464)  time: 0.5998  data: 0.0003  max mem: 19119\n",
            "Epoch: [4]  [ 300/2503]  eta: 0:22:37  lr: 0.0002  img/s: 855.0397553711638  loss: 0.8689 (0.8659)  acc1: 69.3359 (69.2879)  acc5: 87.3047 (87.2275)  time: 0.5990  data: 0.0003  max mem: 19119\n",
            "Epoch: [4]  [ 400/2503]  eta: 0:21:26  lr: 0.0002  img/s: 857.3206551376953  loss: 0.8390 (0.8645)  acc1: 68.7500 (69.3432)  acc5: 86.5234 (87.2808)  time: 0.5983  data: 0.0003  max mem: 19119\n",
            "Epoch: [4]  [ 500/2503]  eta: 0:20:19  lr: 0.0002  img/s: 856.0698095863008  loss: 0.8473 (0.8638)  acc1: 69.5312 (69.3262)  acc5: 87.5000 (87.3148)  time: 0.5984  data: 0.0003  max mem: 19119\n",
            "Epoch: [4]  [ 600/2503]  eta: 0:19:15  lr: 0.0002  img/s: 856.6852065793925  loss: 0.8467 (0.8628)  acc1: 69.3359 (69.3753)  acc5: 87.3047 (87.3388)  time: 0.5985  data: 0.0003  max mem: 19119\n",
            "Epoch: [4]  [ 700/2503]  eta: 0:18:12  lr: 0.0002  img/s: 856.1111043339286  loss: 0.8966 (0.8657)  acc1: 68.1641 (69.3106)  acc5: 86.7188 (87.2924)  time: 0.5986  data: 0.0002  max mem: 19119\n",
            "Epoch: [4]  [ 800/2503]  eta: 0:17:10  lr: 0.0002  img/s: 856.0469456131993  loss: 0.8508 (0.8662)  acc1: 69.3359 (69.2959)  acc5: 87.6953 (87.2976)  time: 0.5987  data: 0.0003  max mem: 19119\n",
            "Epoch: [4]  [ 900/2503]  eta: 0:16:08  lr: 0.0002  img/s: 854.9645243753337  loss: 0.8663 (0.8662)  acc1: 70.3125 (69.2997)  acc5: 87.6953 (87.2860)  time: 0.5986  data: 0.0003  max mem: 19119\n",
            "Epoch: [4]  [1000/2503]  eta: 0:15:07  lr: 0.0002  img/s: 855.0186485068215  loss: 0.8554 (0.8671)  acc1: 68.3594 (69.2827)  acc5: 86.9141 (87.2858)  time: 0.5987  data: 0.0003  max mem: 19119\n",
            "Epoch: [4]  [1100/2503]  eta: 0:14:06  lr: 0.0002  img/s: 855.2873281496619  loss: 0.8594 (0.8670)  acc1: 69.7266 (69.2912)  acc5: 87.5000 (87.2926)  time: 0.5985  data: 0.0002  max mem: 19119\n",
            "Epoch: [4]  [1200/2503]  eta: 0:13:05  lr: 0.0002  img/s: 856.2193086400064  loss: 0.8768 (0.8677)  acc1: 68.7500 (69.2626)  acc5: 86.7188 (87.2596)  time: 0.5982  data: 0.0003  max mem: 19119\n",
            "Epoch: [4]  [1300/2503]  eta: 0:12:04  lr: 0.0002  img/s: 856.0534293024039  loss: 0.8824 (0.8688)  acc1: 68.7500 (69.2354)  acc5: 87.1094 (87.2479)  time: 0.5985  data: 0.0002  max mem: 19119\n",
            "Epoch: [4]  [1400/2503]  eta: 0:11:04  lr: 0.0002  img/s: 857.5664685962558  loss: 0.8628 (0.8691)  acc1: 68.7500 (69.2155)  acc5: 87.5000 (87.2509)  time: 0.5984  data: 0.0002  max mem: 19119\n",
            "Epoch: [4]  [1500/2503]  eta: 0:10:03  lr: 0.0002  img/s: 854.9161928928554  loss: 0.8838 (0.8696)  acc1: 69.1406 (69.1992)  acc5: 87.5000 (87.2396)  time: 0.5983  data: 0.0003  max mem: 19119\n",
            "Epoch: [4]  [1600/2503]  eta: 0:09:03  lr: 0.0002  img/s: 858.2563886134587  loss: 0.8541 (0.8697)  acc1: 68.5547 (69.1823)  acc5: 87.3047 (87.2258)  time: 0.5985  data: 0.0003  max mem: 19119\n",
            "Epoch: [4]  [1700/2503]  eta: 0:08:03  lr: 0.0002  img/s: 855.3152614496619  loss: 0.8702 (0.8694)  acc1: 68.9453 (69.1954)  acc5: 86.9141 (87.2312)  time: 0.5988  data: 0.0003  max mem: 19119\n",
            "Epoch: [4]  [1800/2503]  eta: 0:07:02  lr: 0.0002  img/s: 856.2275018779335  loss: 0.8691 (0.8696)  acc1: 68.9453 (69.1939)  acc5: 87.1094 (87.2242)  time: 0.5987  data: 0.0003  max mem: 19119\n",
            "Epoch: [4]  [1900/2503]  eta: 0:06:02  lr: 0.0002  img/s: 854.6021714024656  loss: 0.8867 (0.8703)  acc1: 68.3594 (69.1812)  acc5: 86.5234 (87.2186)  time: 0.5987  data: 0.0003  max mem: 19119\n",
            "Epoch: [4]  [2000/2503]  eta: 0:05:02  lr: 0.0002  img/s: 856.027153906284  loss: 0.8680 (0.8710)  acc1: 67.9688 (69.1654)  acc5: 87.1094 (87.2148)  time: 0.5986  data: 0.0002  max mem: 19119\n",
            "Epoch: [4]  [2100/2503]  eta: 0:04:02  lr: 0.0002  img/s: 854.944442321167  loss: 0.8930 (0.8715)  acc1: 68.9453 (69.1539)  acc5: 87.1094 (87.2134)  time: 0.5987  data: 0.0002  max mem: 19119\n",
            "Epoch: [4]  [2200/2503]  eta: 0:03:02  lr: 0.0002  img/s: 857.4733304797255  loss: 0.8344 (0.8713)  acc1: 69.7266 (69.1614)  acc5: 87.6953 (87.2185)  time: 0.5986  data: 0.0002  max mem: 19119\n",
            "Epoch: [4]  [2300/2503]  eta: 0:02:01  lr: 0.0002  img/s: 854.5545609884018  loss: 0.8644 (0.8712)  acc1: 69.1406 (69.1611)  acc5: 86.9141 (87.2154)  time: 0.5988  data: 0.0002  max mem: 19119\n",
            "Epoch: [4]  [2400/2503]  eta: 0:01:01  lr: 0.0002  img/s: 855.9186570165338  loss: 0.8843 (0.8714)  acc1: 69.1406 (69.1609)  acc5: 86.9141 (87.2155)  time: 0.5987  data: 0.0002  max mem: 19119\n",
            "Epoch: [4]  [2500/2503]  eta: 0:00:01  lr: 0.0002  img/s: 855.5220927564314  loss: 0.8757 (0.8719)  acc1: 68.9453 (69.1569)  acc5: 87.1094 (87.2136)  time: 0.5979  data: 0.0002  max mem: 19119\n",
            "Epoch: [4] Total time: 0:25:03\n",
            "Test:   [  0/782]  eta: 0:15:24  loss: 0.5899 (0.5899)  acc1: 85.9375 (85.9375)  acc5: 95.3125 (95.3125)  time: 1.1827  data: 1.1689  max mem: 19119\n",
            "Test:   [100/782]  eta: 0:00:27  loss: 1.0514 (0.9553)  acc1: 76.5625 (76.7327)  acc5: 89.0625 (92.1411)  time: 0.0322  data: 0.0185  max mem: 19119\n",
            "Test:   [200/782]  eta: 0:00:20  loss: 0.8755 (0.9239)  acc1: 75.0000 (76.3682)  acc5: 95.3125 (93.2369)  time: 0.0254  data: 0.0116  max mem: 19119\n",
            "Test:   [300/782]  eta: 0:00:16  loss: 0.7986 (0.9160)  acc1: 78.1250 (76.6352)  acc5: 92.1875 (93.5424)  time: 0.0298  data: 0.0161  max mem: 19119\n",
            "Test:   [400/782]  eta: 0:00:12  loss: 1.7921 (1.0555)  acc1: 60.9375 (73.8817)  acc5: 84.3750 (91.8095)  time: 0.0308  data: 0.0171  max mem: 19119\n",
            "Test:   [500/782]  eta: 0:00:09  loss: 1.7681 (1.1332)  acc1: 59.3750 (72.4613)  acc5: 84.3750 (90.7248)  time: 0.0302  data: 0.0164  max mem: 19119\n",
            "Test:   [600/782]  eta: 0:00:06  loss: 1.3149 (1.1978)  acc1: 65.6250 (71.1340)  acc5: 85.9375 (89.9880)  time: 0.0445  data: 0.0307  max mem: 19119\n",
            "Test:   [700/782]  eta: 0:00:02  loss: 1.2842 (1.2500)  acc1: 70.3125 (70.1297)  acc5: 87.5000 (89.2899)  time: 0.0292  data: 0.0154  max mem: 19119\n",
            "Test:  Total time: 0:00:25\n",
            "Test:  Acc@1 70.056 Acc@5 89.296\n",
            "Epoch: [5]  [   0/2503]  eta: 3:36:03  lr: 0.0001  img/s: 868.6398787818482  loss: 0.9304 (0.9304)  acc1: 68.1641 (68.1641)  acc5: 84.5703 (84.5703)  time: 5.1790  data: 4.5895  max mem: 19119\n",
            "Epoch: [5]  [ 100/2503]  eta: 0:26:17  lr: 0.0001  img/s: 856.2223810858537  loss: 0.8658 (0.8735)  acc1: 68.7500 (69.3843)  acc5: 86.5234 (87.1229)  time: 0.5978  data: 0.0002  max mem: 19119\n",
            "Epoch: [5]  [ 200/2503]  eta: 0:24:06  lr: 0.0001  img/s: 854.6059124455048  loss: 0.8757 (0.8726)  acc1: 68.3594 (69.2893)  acc5: 87.3047 (87.2027)  time: 0.5998  data: 0.0003  max mem: 19119\n",
            "Epoch: [5]  [ 300/2503]  eta: 0:22:42  lr: 0.0001  img/s: 854.7457148952566  loss: 0.9211 (0.8775)  acc1: 68.3594 (69.1964)  acc5: 86.5234 (87.1548)  time: 0.5994  data: 0.0003  max mem: 19119\n",
            "Epoch: [5]  [ 400/2503]  eta: 0:21:30  lr: 0.0001  img/s: 855.9967856527867  loss: 0.8790 (0.8763)  acc1: 69.5312 (69.2220)  acc5: 87.1094 (87.1795)  time: 0.5981  data: 0.0002  max mem: 19119\n",
            "Epoch: [5]  [ 500/2503]  eta: 0:20:22  lr: 0.0001  img/s: 856.1715178770771  loss: 0.8828 (0.8767)  acc1: 68.5547 (69.1835)  acc5: 87.1094 (87.1916)  time: 0.5981  data: 0.0003  max mem: 19119\n",
            "Epoch: [5]  [ 600/2503]  eta: 0:19:17  lr: 0.0001  img/s: 857.1064536316845  loss: 0.8499 (0.8745)  acc1: 70.1172 (69.2495)  acc5: 87.6953 (87.2351)  time: 0.5985  data: 0.0002  max mem: 19119\n",
            "Epoch: [5]  [ 700/2503]  eta: 0:18:14  lr: 0.0001  img/s: 855.5684475906053  loss: 0.8724 (0.8738)  acc1: 68.7500 (69.2543)  acc5: 87.3047 (87.2381)  time: 0.5986  data: 0.0003  max mem: 19119\n",
            "Epoch: [5]  [ 800/2503]  eta: 0:17:11  lr: 0.0001  img/s: 855.3483068555046  loss: 0.8648 (0.8747)  acc1: 69.3359 (69.2645)  acc5: 87.1094 (87.2354)  time: 0.5987  data: 0.0002  max mem: 19119\n",
            "Epoch: [5]  [ 900/2503]  eta: 0:16:10  lr: 0.0001  img/s: 855.0516709967052  loss: 0.8869 (0.8744)  acc1: 68.9453 (69.2787)  acc5: 87.1094 (87.2609)  time: 0.5988  data: 0.0002  max mem: 19119\n",
            "Epoch: [5]  [1000/2503]  eta: 0:15:08  lr: 0.0001  img/s: 856.4016452607827  loss: 0.8660 (0.8746)  acc1: 69.5312 (69.2706)  acc5: 87.3047 (87.2602)  time: 0.5987  data: 0.0002  max mem: 19119\n",
            "Epoch: [5]  [1100/2503]  eta: 0:14:07  lr: 0.0001  img/s: 855.5752649016649  loss: 0.8801 (0.8741)  acc1: 69.3359 (69.2911)  acc5: 86.7188 (87.2653)  time: 0.5987  data: 0.0003  max mem: 19119\n",
            "Epoch: [5]  [1200/2503]  eta: 0:13:06  lr: 0.0001  img/s: 857.7295089955442  loss: 0.9027 (0.8741)  acc1: 68.1641 (69.2950)  acc5: 86.7188 (87.2543)  time: 0.5983  data: 0.0003  max mem: 19119\n",
            "Epoch: [5]  [1300/2503]  eta: 0:12:05  lr: 0.0001  img/s: 857.4562117036577  loss: 0.8591 (0.8742)  acc1: 69.3359 (69.2867)  acc5: 87.3047 (87.2523)  time: 0.5987  data: 0.0003  max mem: 19119\n",
            "Epoch: [5]  [1400/2503]  eta: 0:11:04  lr: 0.0001  img/s: 856.1711765336744  loss: 0.8659 (0.8747)  acc1: 69.7266 (69.2756)  acc5: 87.1094 (87.2657)  time: 0.5986  data: 0.0002  max mem: 19119\n",
            "Epoch: [5]  [1500/2503]  eta: 0:10:04  lr: 0.0001  img/s: 855.6918577361109  loss: 0.8848 (0.8744)  acc1: 67.5781 (69.2883)  acc5: 86.9141 (87.2744)  time: 0.5986  data: 0.0002  max mem: 19119\n",
            "Epoch: [5]  [1600/2503]  eta: 0:09:03  lr: 0.0001  img/s: 856.2199914038447  loss: 0.8476 (0.8740)  acc1: 69.1406 (69.2962)  acc5: 87.5000 (87.2808)  time: 0.5987  data: 0.0002  max mem: 19119\n",
            "Epoch: [5]  [1700/2503]  eta: 0:08:03  lr: 0.0001  img/s: 856.7518536501099  loss: 0.8676 (0.8745)  acc1: 68.3594 (69.2823)  acc5: 87.6953 (87.2798)  time: 0.5986  data: 0.0003  max mem: 19119\n",
            "Epoch: [5]  [1800/2503]  eta: 0:07:03  lr: 0.0001  img/s: 855.5047109885885  loss: 0.8884 (0.8745)  acc1: 68.1641 (69.2810)  acc5: 87.3047 (87.2716)  time: 0.5986  data: 0.0002  max mem: 19119\n",
            "Epoch: [5]  [1900/2503]  eta: 0:06:02  lr: 0.0001  img/s: 855.3728370553123  loss: 0.8565 (0.8744)  acc1: 68.9453 (69.2820)  acc5: 88.2812 (87.2709)  time: 0.5987  data: 0.0002  max mem: 19119\n",
            "Epoch: [5]  [2000/2503]  eta: 0:05:02  lr: 0.0001  img/s: 856.9477535227904  loss: 0.8799 (0.8749)  acc1: 68.7500 (69.2718)  acc5: 86.7188 (87.2586)  time: 0.5986  data: 0.0002  max mem: 19119\n",
            "Epoch: [5]  [2100/2503]  eta: 0:04:02  lr: 0.0001  img/s: 855.2914158356967  loss: 0.8914 (0.8745)  acc1: 68.9453 (69.2829)  acc5: 87.5000 (87.2690)  time: 0.5987  data: 0.0003  max mem: 19119\n",
            "Epoch: [5]  [2200/2503]  eta: 0:03:02  lr: 0.0001  img/s: 855.3816955287992  loss: 0.8659 (0.8744)  acc1: 68.9453 (69.2853)  acc5: 87.1094 (87.2749)  time: 0.5986  data: 0.0002  max mem: 19119\n",
            "Epoch: [5]  [2300/2503]  eta: 0:02:02  lr: 0.0001  img/s: 857.8035141651501  loss: 0.8525 (0.8745)  acc1: 69.7266 (69.2810)  acc5: 87.5000 (87.2740)  time: 0.5986  data: 0.0002  max mem: 19119\n",
            "Epoch: [5]  [2400/2503]  eta: 0:01:01  lr: 0.0001  img/s: 855.2345323832691  loss: 0.8424 (0.8745)  acc1: 69.5312 (69.2709)  acc5: 87.3047 (87.2737)  time: 0.5988  data: 0.0002  max mem: 19119\n",
            "Epoch: [5]  [2500/2503]  eta: 0:00:01  lr: 0.0001  img/s: 856.8721864153102  loss: 0.8785 (0.8748)  acc1: 68.7500 (69.2651)  acc5: 87.5000 (87.2714)  time: 0.5982  data: 0.0002  max mem: 19119\n",
            "Epoch: [5] Total time: 0:25:04\n",
            "Test:   [  0/782]  eta: 0:16:48  loss: 0.6137 (0.6137)  acc1: 85.9375 (85.9375)  acc5: 95.3125 (95.3125)  time: 1.2890  data: 1.2749  max mem: 19119\n",
            "Test:   [100/782]  eta: 0:00:29  loss: 1.0820 (0.9476)  acc1: 76.5625 (77.0885)  acc5: 89.0625 (92.3113)  time: 0.0337  data: 0.0200  max mem: 19119\n",
            "Test:   [200/782]  eta: 0:00:21  loss: 0.8791 (0.9212)  acc1: 75.0000 (76.4537)  acc5: 95.3125 (93.3613)  time: 0.0276  data: 0.0139  max mem: 19119\n",
            "Test:   [300/782]  eta: 0:00:16  loss: 0.8066 (0.9144)  acc1: 76.5625 (76.6923)  acc5: 92.1875 (93.6306)  time: 0.0281  data: 0.0144  max mem: 19119\n",
            "Test:   [400/782]  eta: 0:00:13  loss: 1.8165 (1.0555)  acc1: 60.9375 (73.9596)  acc5: 84.3750 (91.9031)  time: 0.0350  data: 0.0214  max mem: 19119\n",
            "Test:   [500/782]  eta: 0:00:09  loss: 1.7107 (1.1325)  acc1: 59.3750 (72.5299)  acc5: 84.3750 (90.7934)  time: 0.0344  data: 0.0206  max mem: 19119\n",
            "Test:   [600/782]  eta: 0:00:06  loss: 1.3799 (1.1970)  acc1: 64.0625 (71.2068)  acc5: 84.3750 (89.9880)  time: 0.0266  data: 0.0127  max mem: 19119\n",
            "Test:   [700/782]  eta: 0:00:02  loss: 1.2741 (1.2493)  acc1: 68.7500 (70.1966)  acc5: 85.9375 (89.2765)  time: 0.0262  data: 0.0124  max mem: 19119\n",
            "Test:  Total time: 0:00:25\n",
            "Test:  Acc@1 70.120 Acc@5 89.284\n",
            "Training time 2:33:01\n"
          ]
        }
      ],
      "source": [
        "from types import SimpleNamespace\n",
        "\n",
        "args = SimpleNamespace(\n",
        "    data_path=\"/home/cs/Documents/datasets/imagenet\",  # Replace with your /path/to/imagenet\n",
        "    model=\"resnet18\",\n",
        "    device=\"cuda\",\n",
        "    batch_size=512,\n",
        "    epochs=6,\n",
        "    lr=0.0004,\n",
        "    momentum=0.9,\n",
        "    weight_decay=1e-4,\n",
        "    lr_warmup_epochs=1,\n",
        "    lr_warmup_decay=0.0,\n",
        "    lr_step_size=2,\n",
        "    lr_gamma=0.5,\n",
        "    print_freq=100,\n",
        "    output_dir=\"resnet18\",\n",
        "    use_deterministic_algorithms=False,\n",
        "    weights=\"ResNet18_Weights.IMAGENET1K_V1\",\n",
        "    apply_trp=True,\n",
        "    trp_depths=[3, 3, 3],\n",
        "    trp_planes=256,\n",
        "    trp_lambdas=[0.4, 0.2, 0.1],\n",
        ")\n",
        "\n",
        "main(args)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.21"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}