File size: 71,876 Bytes
b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 |
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "X4cRE8IbIrIV"
},
"source": [
"Downloading PyTorch Vision Reference Scripts for Image Classification. These scripts are official reference implementations from PyTorch Vision that provide training and quantization utilities for image classification models."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "46CgrVgjg3E-",
"outputId": "7fb20ebe-d7fd-43fa-dc9b-ebbedf31575e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2025-05-22 16:30:12-- https://raw.githubusercontent.com/pytorch/vision/main/references/classification/presets.py\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.110.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 3885 (3.8K) [text/plain]\n",
"Saving to: ‘presets.py’\n",
"\n",
"presets.py 100%[===================>] 3.79K --.-KB/s in 0s \n",
"\n",
"2025-05-22 16:30:12 (12.8 MB/s) - ‘presets.py’ saved [3885/3885]\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2025-05-22 16:30:12-- https://raw.githubusercontent.com/pytorch/vision/main/references/classification/sampler.py\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.108.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 2395 (2.3K) [text/plain]\n",
"Saving to: ‘sampler.py’\n",
"\n",
"sampler.py 100%[===================>] 2.34K --.-KB/s in 0s \n",
"\n",
"2025-05-22 16:30:12 (18.4 MB/s) - ‘sampler.py’ saved [2395/2395]\n",
"\n",
"--2025-05-22 16:30:12-- https://raw.githubusercontent.com/pytorch/vision/main/references/classification/train.py\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.111.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 23324 (23K) [text/plain]\n",
"Saving to: ‘train.py’\n",
"\n",
"train.py 100%[===================>] 22.78K --.-KB/s in 0.01s \n",
"\n",
"2025-05-22 16:30:13 (2.28 MB/s) - ‘train.py’ saved [23324/23324]\n",
"\n",
"--2025-05-22 16:30:13-- https://raw.githubusercontent.com/pytorch/vision/main/references/classification/train_quantization.py\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.108.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 11647 (11K) [text/plain]\n",
"Saving to: ‘train_quantization.py’\n",
"\n",
"train_quantization. 100%[===================>] 11.37K --.-KB/s in 0.001s \n",
"\n",
"2025-05-22 16:30:13 (12.7 MB/s) - ‘train_quantization.py’ saved [11647/11647]\n",
"\n",
"--2025-05-22 16:30:13-- https://raw.githubusercontent.com/pytorch/vision/main/references/classification/transformers.py\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.111.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
"HTTP request sent, awaiting response... 404 Not Found\n",
"2025-05-22 16:30:13 ERROR 404: Not Found.\n",
"\n",
"--2025-05-22 16:30:13-- https://raw.githubusercontent.com/pytorch/vision/main/references/classification/utils.py\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.109.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 15791 (15K) [text/plain]\n",
"Saving to: ‘utils.py’\n",
"\n",
"utils.py 100%[===================>] 15.42K --.-KB/s in 0.01s \n",
"\n",
"2025-05-22 16:30:13 (1.43 MB/s) - ‘utils.py’ saved [15791/15791]\n",
"\n"
]
}
],
"source": [
"! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/presets.py\n",
"! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/sampler.py\n",
"! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/train.py\n",
"! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/train_quantization.py\n",
"! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/transformers.py\n",
"! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/utils.py"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HFASsisvIrIb"
},
"source": [
"In this block, we build a “loss” function for our sequential policy gradient algorithm. When the right data is plugged in, the gradient of this loss is equal to the policy gradient."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "EaBokYCpg3FA"
},
"outputs": [],
"source": [
"import types\n",
"from typing import List, Callable\n",
"\n",
"import torch\n",
"from torch import nn, Tensor\n",
"from torch.nn import functional as F\n",
"from torchvision.models.resnet import BasicBlock\n",
"\n",
"\n",
"def trp_criterion(trp_blocks: nn.ModuleList, shared_head: Callable, criterion: Callable, lambdas: List[float], hidden_state: Tensor, logits: Tensor, targets: Tensor, loss_normalization=False):\n",
" losses, rewards = criterion(logits, targets)\n",
" returns = torch.ones_like(rewards, dtype=torch.float32, device=rewards.device)\n",
" if loss_normalization:\n",
" coeff = torch.mean(losses).detach()\n",
"\n",
" embeds = [hidden_state]\n",
" predictions = []\n",
" for k, w in enumerate(lambdas):\n",
" embeds.append(trp_blocks[k](embeds[-1]))\n",
" predictions.append(shared_head(embeds[-1]))\n",
" returns = returns + w * rewards\n",
" replica_losses, rewards = criterion(predictions[-1], targets, rewards)\n",
" losses = losses + replica_losses\n",
" loss = torch.mean(losses * returns)\n",
"\n",
" if loss_normalization:\n",
" with torch.no_grad():\n",
" coeff = torch.exp(coeff) / torch.exp(loss.detach())\n",
" loss = coeff * loss\n",
"\n",
" return loss"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_Ig0Jm2w8DPH"
},
"source": [
"In this block, we build a TPBlock for the Task Replica Prediction (TRP) module; This implementation provides the backbone without the shared prediction head."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "wkBlmJT96jZj"
},
"outputs": [],
"source": [
"class TPBlock(nn.Module):\n",
" def __init__(self, depths: int, inplanes: int, planes: int):\n",
" super(TPBlock, self).__init__()\n",
"\n",
" blocks = [BasicBlock(inplanes=inplanes, planes=planes) for _ in range(depths)]\n",
" self.blocks = nn.Sequential(*blocks)\n",
" for name, param in self.blocks.named_parameters():\n",
" if 'conv' in name:\n",
" nn.init.zeros_(param) # Initialize weights\n",
" elif 'downsample' in name:\n",
" nn.init.zeros_(param) # Initialize biases\n",
"\n",
" def forward(self, x):\n",
" return self.blocks(x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UGxQdKZaF2NT"
},
"source": [
"This implementation enables ResNet retraining in SPG mode.\n",
"\n",
"Components:\n",
"-------------------------------------------------------------------------------\n",
"1. gen_criterion()\n",
" - Purpose: compute per-sample losses and positional masks\n",
"\n",
"2. gen_shared_head()\n",
" - Purpose: Implements a shared prediction head that processes convolutional feature maps for prediction.\n",
"\n",
"3. gen_forward()\n",
" - Purpose: Extended forward pass supporting both traditional inference and SPG retraining."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "kTZWkoLr8cfE"
},
"outputs": [],
"source": [
"class ResNetConfig:\n",
" @staticmethod\n",
" def gen_criterion(label_smoothing=0.0, top_k=1):\n",
" def func(input, target, mask=None):\n",
" \"\"\"\n",
" Args:\n",
" input (Tensor): Input tensor of shape [B, C].\n",
" target (Tensor): Target labels of shape [B] or [B, C].\n",
"\n",
" Returns:\n",
" loss (Tensor): Scalar tensor representing the loss.\n",
" mask (Tensor): Boolean mask tensor of shape [B].\n",
" \"\"\"\n",
" label = torch.argmax(target, dim=1) if label_smoothing > 0.0 else target\n",
"\n",
" unmasked_loss = F.cross_entropy(input, label, reduction=\"none\", label_smoothing=label_smoothing)\n",
" if mask is None:\n",
" mask = torch.ones_like(unmasked_loss, dtype=torch.float32, device=target.device)\n",
" losses = mask * unmasked_loss\n",
"\n",
" with torch.no_grad():\n",
" topk_values, topk_indices = torch.topk(input, top_k, dim=-1)\n",
" mask = mask * torch.eq(topk_indices, label[:, None]).any(dim=-1).to(input.dtype)\n",
"\n",
" return losses, mask\n",
" return func\n",
"\n",
" @staticmethod\n",
" def gen_shared_head(self):\n",
" def func(x):\n",
" \"\"\"\n",
" Args:\n",
" x (Tensor): Hidden State tensor of shape [B, C, H, W].\n",
"\n",
" Returns:\n",
" logits (Tensor): Logits tensor of shape [B, C].\n",
" \"\"\"\n",
" x = self.layer4(x)\n",
" x = self.avgpool(x)\n",
" x = torch.flatten(x, 1)\n",
" logits = self.fc(x)\n",
" return logits\n",
" return func\n",
"\n",
" @staticmethod\n",
" def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):\n",
" def func(self, x: Tensor, targets=None) -> Tensor:\n",
" x = self.conv1(x)\n",
" x = self.bn1(x)\n",
" x = self.relu(x)\n",
" x = self.maxpool(x)\n",
"\n",
" x = self.layer1(x)\n",
" x = self.layer2(x)\n",
" hidden_state = self.layer3(x)\n",
" x = self.layer4(hidden_state)\n",
" x = self.avgpool(x)\n",
" x = torch.flatten(x, 1)\n",
" logits = self.fc(x)\n",
"\n",
" if self.training:\n",
" shared_head = ResNetConfig.gen_shared_head(self)\n",
" criterion = ResNetConfig.gen_criterion(label_smoothing=label_smoothing, top_k=top_k)\n",
"\n",
" loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, hidden_state, logits, targets, loss_normalization=loss_normalization)\n",
"\n",
" return logits, loss\n",
"\n",
" return logits\n",
"\n",
" return func"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cCn6vwItH1CW"
},
"source": [
"Applies TRP modules to the base ResNet (main backbone). The k-th TRP module corresponding to a deeper ResNet variant with an additional depth of 3 * sum(depths[:k+1])."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "wXQF0oISH5Yp"
},
"outputs": [],
"source": [
"def apply_trp(model, depths: List[int], planes: int, lambdas: List[float], **kwargs):\n",
" print(\"✅ Applying TRP to ResNet for Image Classification...\")\n",
" model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, inplanes=planes, planes=planes) for d in depths])\n",
" model.forward = types.MethodType(ResNetConfig.gen_forward(lambdas), model)\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kDjSAv3PJr7P"
},
"source": [
"The following is a training script for classification models, primarily based on the official TorchVision `train.py` reference implementation. We have made two modifications:\n",
"\n",
"Adding TRP Modules: We integrate TRP modules into the base model architecture before training begins:\n",
"\n",
"```python\n",
"if args.apply_trp:\n",
" model = apply_trp(model, args.trp_depths, args.trp_planes, args.trp_lambdas)\n",
"```\n",
"Removing TRP Modules: We remove the TRP components from the base model before saving the base model:\n",
"```python\n",
"if args.output_dir:\n",
" checkpoint = {\n",
" \"model\": model.state_dict() if not args.apply_trp else {k: v for k, v in model.state_dict().items() if not k.startswith(\"trp_blocks\")},\n",
" \"optimizer\": optimizer.state_dict(),\n",
" \"lr_scheduler\": lr_scheduler.state_dict(),\n",
" \"epoch\": epoch,\n",
" \"args\": args,\n",
" }\n",
" utils.save_on_master(checkpoint, os.path.join(args.output_dir, f\"model_{epoch}.pth\"))\n",
" utils.save_on_master(checkpoint, os.path.join(args.output_dir, \"checkpoint.pth\"))\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "hK4Y7Sqv4xUa"
},
"outputs": [],
"source": [
"import datetime\n",
"import os\n",
"import time\n",
"import warnings\n",
"\n",
"import presets\n",
"import torch\n",
"import torch.utils.data\n",
"import torchvision\n",
"import utils\n",
"from torch import nn\n",
"from torchvision.transforms.functional import InterpolationMode\n",
"\n",
"\n",
"def load_data(traindir, valdir):\n",
" # Data loading code\n",
" print(\"Loading data\")\n",
" interpolation = InterpolationMode(\"bilinear\")\n",
"\n",
" print(\"Loading training data\")\n",
" st = time.time()\n",
" dataset = torchvision.datasets.ImageFolder(\n",
" traindir,\n",
" presets.ClassificationPresetTrain(crop_size=224, interpolation=interpolation, auto_augment_policy=None, random_erase_prob=0.0, ra_magnitude=9, augmix_severity=3),\n",
" )\n",
" print(\"Took\", time.time() - st)\n",
"\n",
" print(\"Loading validation data\")\n",
" dataset_test = torchvision.datasets.ImageFolder(\n",
" valdir,\n",
" presets.ClassificationPresetEval(crop_size=224, resize_size=256, interpolation=interpolation)\n",
" )\n",
"\n",
" print(\"Creating data loaders\")\n",
" train_sampler = torch.utils.data.RandomSampler(dataset)\n",
" test_sampler = torch.utils.data.SequentialSampler(dataset_test)\n",
"\n",
" return dataset, dataset_test, train_sampler, test_sampler\n",
"\n",
"\n",
"\n",
"def train_one_epoch(model, optimizer, data_loader, device, epoch, args):\n",
" model.train()\n",
" metric_logger = utils.MetricLogger(delimiter=\" \")\n",
" metric_logger.add_meter(\"lr\", utils.SmoothedValue(window_size=1, fmt=\"{value}\"))\n",
" metric_logger.add_meter(\"img/s\", utils.SmoothedValue(window_size=10, fmt=\"{value}\"))\n",
"\n",
" header = f\"Epoch: [{epoch}]\"\n",
" for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):\n",
" start_time = time.time()\n",
" image, target = image.to(device), target.to(device)\n",
" with torch.amp.autocast(\"cuda\", enabled=False):\n",
" output, loss = model(image, target)\n",
"\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))\n",
" batch_size = image.shape[0]\n",
" metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0][\"lr\"])\n",
" metric_logger.meters[\"acc1\"].update(acc1.item(), n=batch_size)\n",
" metric_logger.meters[\"acc5\"].update(acc5.item(), n=batch_size)\n",
" metric_logger.meters[\"img/s\"].update(batch_size / (time.time() - start_time))\n",
"\n",
"\n",
"def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=\"\"):\n",
" model.eval()\n",
" metric_logger = utils.MetricLogger(delimiter=\" \")\n",
" header = f\"Test: {log_suffix}\"\n",
"\n",
" num_processed_samples = 0\n",
" with torch.inference_mode():\n",
" for image, target in metric_logger.log_every(data_loader, print_freq, header):\n",
" image = image.to(device, non_blocking=True)\n",
" target = target.to(device, non_blocking=True)\n",
" output = model(image)\n",
" loss = criterion(output, target)\n",
"\n",
" acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))\n",
" # FIXME need to take into account that the datasets\n",
" # could have been padded in distributed setup\n",
" batch_size = image.shape[0]\n",
" metric_logger.update(loss=loss.item())\n",
" metric_logger.meters[\"acc1\"].update(acc1.item(), n=batch_size)\n",
" metric_logger.meters[\"acc5\"].update(acc5.item(), n=batch_size)\n",
" num_processed_samples += batch_size\n",
" # gather the stats from all processes\n",
"\n",
" num_processed_samples = utils.reduce_across_processes(num_processed_samples)\n",
" if (\n",
" hasattr(data_loader.dataset, \"__len__\")\n",
" and len(data_loader.dataset) != num_processed_samples\n",
" and torch.distributed.get_rank() == 0\n",
" ):\n",
" # See FIXME above\n",
" warnings.warn(\n",
" f\"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} \"\n",
" \"samples were used for the validation, which might bias the results. \"\n",
" \"Try adjusting the batch size and / or the world size. \"\n",
" \"Setting the world size to 1 is always a safe bet.\"\n",
" )\n",
"\n",
" metric_logger.synchronize_between_processes()\n",
"\n",
" print(f\"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}\")\n",
" return metric_logger.acc1.global_avg\n",
"\n",
"\n",
"def main(args):\n",
" if args.output_dir:\n",
" utils.mkdir(args.output_dir)\n",
" print(args)\n",
"\n",
" device = torch.device(args.device)\n",
"\n",
" if args.use_deterministic_algorithms:\n",
" torch.backends.cudnn.benchmark = False\n",
" torch.use_deterministic_algorithms(True)\n",
" else:\n",
" torch.backends.cudnn.benchmark = True\n",
"\n",
" train_dir = os.path.join(args.data_path, \"train\")\n",
" val_dir = os.path.join(args.data_path, \"val\")\n",
" dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir)\n",
"\n",
" num_classes = len(dataset.classes)\n",
" data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=16, pin_memory=True, collate_fn=None)\n",
" data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=64, sampler=test_sampler, num_workers=16, pin_memory=True)\n",
"\n",
" print(\"Creating model\")\n",
" model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)\n",
" if args.apply_trp:\n",
" model = apply_trp(model, args.trp_depths, args.trp_planes, args.trp_lambdas)\n",
" model.to(device)\n",
"\n",
" parameters = utils.set_weight_decay(model, args.weight_decay, norm_weight_decay=None, custom_keys_weight_decay=None)\n",
" optimizer = torch.optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=False)\n",
"\n",
" main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)\n",
" warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs)\n",
" lr_scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs])\n",
"\n",
"\n",
" print(\"Start training\")\n",
" start_time = time.time()\n",
" for epoch in range(args.epochs):\n",
" train_one_epoch(model, optimizer, data_loader, device, epoch, args)\n",
" lr_scheduler.step()\n",
" evaluate(model, nn.CrossEntropyLoss(), data_loader_test, device=device)\n",
" if args.output_dir:\n",
" checkpoint = {\n",
" \"model\": model.state_dict() if not args.apply_trp else {k: v for k, v in model.state_dict().items() if not k.startswith(\"trp_blocks\")}, # NOTE: remove TRP heads\n",
" \"optimizer\": optimizer.state_dict(),\n",
" \"lr_scheduler\": lr_scheduler.state_dict(),\n",
" \"epoch\": epoch,\n",
" \"args\": args,\n",
" }\n",
" utils.save_on_master(checkpoint, os.path.join(args.output_dir, f\"model_{epoch}.pth\"))\n",
" utils.save_on_master(checkpoint, os.path.join(args.output_dir, \"checkpoint.pth\"))\n",
"\n",
" total_time = time.time() - start_time\n",
" total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n",
" print(f\"Training time {total_time_str}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SV8s5k49KwgS"
},
"source": [
"Prepare the [ImageNet](http://image-net.org/) dataset manually and place it in `/path/to/imagenet`. For image classification examples, pass the argument `--data-path=/path/to/imagenet` to the training script. The extracted dataset directory should follow this structure:\n",
"```setup\n",
"/path/to/imagenet/:\n",
" train/:\n",
" n01440764:\n",
" n01440764_18.JPEG ...\n",
" n01443537:\n",
" n01443537_2.JPEG ...\n",
" val/:\n",
" n01440764:\n",
" ILSVRC2012_val_00000293.JPEG ...\n",
" n01443537:\n",
" ILSVRC2012_val_00000236.JPEG ...\n",
"```\n",
"\n",
"Now you can apply the SPG algorithm in model retraining.\n",
"\n",
"**Implementation Note:**\n",
"\n",
"- This demonstration runs on Google Colab using a single GPU configuration\n",
"- Performance Improvement: Enhances ResNet18 validation accuracy (ACC@1) from 69.76% to 70.09%\n",
"- For optimal results:\n",
" - Refer to our README.md for complete setup instructions\n",
" - Recommended hardware: 4× RTX A6000 GPUs"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "UDZxDNfT4xUb",
"outputId": "bcf86aa0-eb77-4815-e0fa-05997f1e1f1b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"namespace(data_path='/home/cs/Documents/datasets/imagenet', model='resnet18', device='cuda', batch_size=512, epochs=6, lr=0.0004, momentum=0.9, weight_decay=0.0001, lr_warmup_epochs=1, lr_warmup_decay=0.0, lr_step_size=2, lr_gamma=0.5, print_freq=100, output_dir='resnet18', use_deterministic_algorithms=False, weights='ResNet18_Weights.IMAGENET1K_V1', apply_trp=True, trp_depths=[3, 3, 3], trp_planes=256, trp_lambdas=[0.4, 0.2, 0.1])\n",
"Loading data\n",
"Loading training data\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Took 1.9062905311584473\n",
"Loading validation data\n",
"Creating data loaders\n",
"Creating model\n",
"✅ Applying TRP to ResNet for Image Classification...\n",
"Start training\n",
"Epoch: [0] [ 0/2503] eta: 10:05:09 lr: 0.0 img/s: 81.93631887515438 loss: 0.7334 (0.7334) acc1: 71.2891 (71.2891) acc5: 86.1328 (86.1328) time: 14.5065 data: 8.2577 max mem: 19119\n",
"Epoch: [0] [ 100/2503] eta: 0:29:06 lr: 0.0 img/s: 862.8257862120394 loss: 0.7145 (0.7308) acc1: 69.5312 (69.6105) acc5: 87.6953 (87.3704) time: 0.5927 data: 0.0003 max mem: 19119\n",
"Epoch: [0] [ 200/2503] eta: 0:25:23 lr: 0.0 img/s: 860.6862569301302 loss: 0.7355 (0.7353) acc1: 68.9453 (69.3427) acc5: 86.9141 (87.3125) time: 0.5966 data: 0.0003 max mem: 19119\n",
"Epoch: [0] [ 300/2503] eta: 0:23:29 lr: 0.0 img/s: 860.0754340960929 loss: 0.7159 (0.7314) acc1: 69.1406 (69.3463) acc5: 87.5000 (87.3676) time: 0.5967 data: 0.0003 max mem: 19119\n",
"Epoch: [0] [ 400/2503] eta: 0:22:03 lr: 0.0 img/s: 859.0790234707376 loss: 0.7594 (0.7361) acc1: 67.9688 (69.2283) acc5: 86.7188 (87.3232) time: 0.5960 data: 0.0003 max mem: 19119\n",
"Epoch: [0] [ 500/2503] eta: 0:20:46 lr: 0.0 img/s: 859.7486624250741 loss: 0.7204 (0.7343) acc1: 69.7266 (69.2396) acc5: 87.5000 (87.3827) time: 0.5958 data: 0.0003 max mem: 19119\n",
"Epoch: [0] [ 600/2503] eta: 0:19:36 lr: 0.0 img/s: 861.5204710456711 loss: 0.7483 (0.7345) acc1: 69.5312 (69.2449) acc5: 86.7188 (87.3950) time: 0.5958 data: 0.0002 max mem: 19119\n",
"Epoch: [0] [ 700/2503] eta: 0:18:28 lr: 0.0 img/s: 858.9934592 loss: 0.7225 (0.7350) acc1: 68.5547 (69.2331) acc5: 87.6953 (87.3738) time: 0.5958 data: 0.0003 max mem: 19119\n",
"Epoch: [0] [ 800/2503] eta: 0:17:23 lr: 0.0 img/s: 859.4995325247505 loss: 0.7639 (0.7355) acc1: 69.7266 (69.2177) acc5: 86.7188 (87.3578) time: 0.5961 data: 0.0002 max mem: 19119\n",
"Epoch: [0] [ 900/2503] eta: 0:16:18 lr: 0.0 img/s: 860.8087326554238 loss: 0.7118 (0.7349) acc1: 69.9219 (69.2440) acc5: 87.6953 (87.3548) time: 0.5961 data: 0.0002 max mem: 19119\n",
"Epoch: [0] [1000/2503] eta: 0:15:15 lr: 0.0 img/s: 859.5858857924882 loss: 0.7224 (0.7351) acc1: 69.3359 (69.2485) acc5: 87.3047 (87.3624) time: 0.5958 data: 0.0003 max mem: 19119\n",
"Epoch: [0] [1100/2503] eta: 0:14:12 lr: 0.0 img/s: 858.8670339725992 loss: 0.7240 (0.7360) acc1: 68.9453 (69.2212) acc5: 87.1094 (87.3361) time: 0.5958 data: 0.0002 max mem: 19119\n",
"Epoch: [0] [1200/2503] eta: 0:13:10 lr: 0.0 img/s: 861.4696676125856 loss: 0.7126 (0.7364) acc1: 68.3594 (69.1878) acc5: 87.3047 (87.3190) time: 0.5960 data: 0.0002 max mem: 19119\n",
"Epoch: [0] [1300/2503] eta: 0:12:09 lr: 0.0 img/s: 859.3643608581464 loss: 0.7291 (0.7367) acc1: 68.9453 (69.1669) acc5: 86.7188 (87.2990) time: 0.5959 data: 0.0002 max mem: 19119\n",
"Epoch: [0] [1400/2503] eta: 0:11:07 lr: 0.0 img/s: 861.1477063020853 loss: 0.7267 (0.7372) acc1: 69.9219 (69.1624) acc5: 87.1094 (87.2990) time: 0.5960 data: 0.0002 max mem: 19119\n",
"Epoch: [0] [1500/2503] eta: 0:10:06 lr: 0.0 img/s: 859.0494692253935 loss: 0.7234 (0.7374) acc1: 69.1406 (69.1607) acc5: 87.3047 (87.2939) time: 0.5959 data: 0.0003 max mem: 19119\n",
"Epoch: [0] [1600/2503] eta: 0:09:05 lr: 0.0 img/s: 860.660386236062 loss: 0.7456 (0.7374) acc1: 69.3359 (69.1730) acc5: 87.5000 (87.3019) time: 0.5960 data: 0.0002 max mem: 19119\n",
"Epoch: [0] [1700/2503] eta: 0:08:04 lr: 0.0 img/s: 858.9515423647326 loss: 0.7548 (0.7372) acc1: 69.1406 (69.1773) acc5: 87.5000 (87.3198) time: 0.5959 data: 0.0002 max mem: 19119\n",
"Epoch: [0] [1800/2503] eta: 0:07:04 lr: 0.0 img/s: 860.6800478217115 loss: 0.7596 (0.7375) acc1: 67.1875 (69.1614) acc5: 87.1094 (87.3191) time: 0.5958 data: 0.0002 max mem: 19119\n",
"Epoch: [0] [1900/2503] eta: 0:06:03 lr: 0.0 img/s: 859.6578027499652 loss: 0.7465 (0.7375) acc1: 68.3594 (69.1633) acc5: 86.7188 (87.3222) time: 0.5959 data: 0.0002 max mem: 19119\n",
"Epoch: [0] [2000/2503] eta: 0:05:03 lr: 0.0 img/s: 860.6507282423033 loss: 0.7385 (0.7375) acc1: 69.3359 (69.1609) acc5: 87.3047 (87.3233) time: 0.5959 data: 0.0002 max mem: 19119\n",
"Epoch: [0] [2100/2503] eta: 0:04:02 lr: 0.0 img/s: 860.72592834858 loss: 0.7153 (0.7373) acc1: 69.3359 (69.1710) acc5: 87.3047 (87.3230) time: 0.5961 data: 0.0004 max mem: 19119\n",
"Epoch: [0] [2200/2503] eta: 0:03:02 lr: 0.0 img/s: 859.2460775467988 loss: 0.7307 (0.7371) acc1: 68.9453 (69.1861) acc5: 87.5000 (87.3380) time: 0.5960 data: 0.0004 max mem: 19119\n",
"Epoch: [0] [2300/2503] eta: 0:02:02 lr: 0.0 img/s: 859.2639554931892 loss: 0.7077 (0.7367) acc1: 69.3359 (69.1971) acc5: 87.6953 (87.3516) time: 0.5959 data: 0.0004 max mem: 19119\n",
"Epoch: [0] [2400/2503] eta: 0:01:01 lr: 0.0 img/s: 861.341130585524 loss: 0.7279 (0.7365) acc1: 68.5547 (69.1921) acc5: 86.9141 (87.3412) time: 0.5961 data: 0.0004 max mem: 19119\n",
"Epoch: [0] [2500/2503] eta: 0:00:01 lr: 0.0 img/s: 861.8382147793436 loss: 0.7469 (0.7368) acc1: 68.5547 (69.1894) acc5: 87.5000 (87.3423) time: 0.5955 data: 0.0005 max mem: 19119\n",
"Epoch: [0] Total time: 0:25:05\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/cs/anaconda3/envs/csenv/lib/python3.9/site-packages/torch/optim/lr_scheduler.py:243: UserWarning: The epoch parameter in `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: https://github.com/pytorch/pytorch/issues/new/choose.\n",
" warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test: [ 0/782] eta: 0:23:05 loss: 0.6283 (0.6283) acc1: 89.0625 (89.0625) acc5: 95.3125 (95.3125) time: 1.7719 data: 1.3111 max mem: 19119\n",
"Test: [100/782] eta: 0:00:30 loss: 1.0688 (0.9382) acc1: 76.5625 (76.2840) acc5: 89.0625 (92.1875) time: 0.0399 data: 0.0263 max mem: 19119\n",
"Test: [200/782] eta: 0:00:21 loss: 0.9244 (0.9143) acc1: 73.4375 (75.8240) acc5: 95.3125 (93.2369) time: 0.0244 data: 0.0107 max mem: 19119\n",
"Test: [300/782] eta: 0:00:17 loss: 0.8615 (0.9072) acc1: 76.5625 (76.1991) acc5: 92.1875 (93.5008) time: 0.0381 data: 0.0244 max mem: 19119\n",
"Test: [400/782] eta: 0:00:13 loss: 1.6977 (1.0440) acc1: 59.3750 (73.6323) acc5: 82.8125 (91.7472) time: 0.0313 data: 0.0176 max mem: 19119\n",
"Test: [500/782] eta: 0:00:09 loss: 1.6021 (1.1237) acc1: 54.6875 (72.0964) acc5: 85.9375 (90.5845) time: 0.0247 data: 0.0109 max mem: 19119\n",
"Test: [600/782] eta: 0:00:06 loss: 1.3631 (1.1858) acc1: 64.0625 (70.8741) acc5: 84.3750 (89.7853) time: 0.0291 data: 0.0153 max mem: 19119\n",
"Test: [700/782] eta: 0:00:02 loss: 1.2494 (1.2361) acc1: 68.7500 (69.9313) acc5: 87.5000 (89.1115) time: 0.0391 data: 0.0254 max mem: 19119\n",
"Test: Total time: 0:00:26\n",
"Test: Acc@1 69.846 Acc@5 89.136\n",
"Epoch: [1] [ 0/2503] eta: 4:27:27 lr: 0.0004 img/s: 861.3684242192573 loss: 0.7611 (0.7611) acc1: 68.5547 (68.5547) acc5: 86.1328 (86.1328) time: 6.4115 data: 5.8170 max mem: 19119\n",
"Epoch: [1] [ 100/2503] eta: 0:26:25 lr: 0.0004 img/s: 856.9149263546342 loss: 0.7538 (0.7542) acc1: 70.5078 (69.0536) acc5: 87.5000 (87.1364) time: 0.5982 data: 0.0003 max mem: 19119\n",
"Epoch: [1] [ 200/2503] eta: 0:24:10 lr: 0.0004 img/s: 854.0172713632207 loss: 0.7744 (0.7573) acc1: 69.7266 (69.2990) acc5: 88.0859 (87.3785) time: 0.5998 data: 0.0003 max mem: 19119\n",
"Epoch: [1] [ 300/2503] eta: 0:22:45 lr: 0.0004 img/s: 856.0483105922384 loss: 0.7551 (0.7613) acc1: 69.1406 (69.2834) acc5: 87.3047 (87.4611) time: 0.5996 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [ 400/2503] eta: 0:21:32 lr: 0.0004 img/s: 854.8386016604893 loss: 0.7931 (0.7645) acc1: 68.5547 (69.3004) acc5: 87.3047 (87.4698) time: 0.5987 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [ 500/2503] eta: 0:20:24 lr: 0.0004 img/s: 855.3431965742996 loss: 0.7744 (0.7684) acc1: 68.1641 (69.2853) acc5: 86.9141 (87.4361) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [ 600/2503] eta: 0:19:19 lr: 0.0004 img/s: 855.1112541063571 loss: 0.7860 (0.7730) acc1: 69.1406 (69.2310) acc5: 86.7188 (87.3941) time: 0.5988 data: 0.0003 max mem: 19119\n",
"Epoch: [1] [ 700/2503] eta: 0:18:15 lr: 0.0004 img/s: 856.4904515045232 loss: 0.7908 (0.7773) acc1: 68.7500 (69.1746) acc5: 86.7188 (87.3543) time: 0.5985 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [ 800/2503] eta: 0:17:13 lr: 0.0004 img/s: 858.0146361031335 loss: 0.8157 (0.7805) acc1: 68.5547 (69.1660) acc5: 87.6953 (87.3181) time: 0.5991 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [ 900/2503] eta: 0:16:11 lr: 0.0004 img/s: 854.9138104963116 loss: 0.7641 (0.7825) acc1: 69.5312 (69.1807) acc5: 88.8672 (87.3346) time: 0.5989 data: 0.0003 max mem: 19119\n",
"Epoch: [1] [1000/2503] eta: 0:15:09 lr: 0.0004 img/s: 855.7491430488731 loss: 0.8024 (0.7852) acc1: 68.1641 (69.1506) acc5: 86.5234 (87.3234) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [1100/2503] eta: 0:14:08 lr: 0.0004 img/s: 856.0848253972304 loss: 0.8099 (0.7872) acc1: 69.1406 (69.1564) acc5: 86.7188 (87.3231) time: 0.5992 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [1200/2503] eta: 0:13:07 lr: 0.0004 img/s: 855.6028761225017 loss: 0.8307 (0.7894) acc1: 68.5547 (69.1569) acc5: 87.1094 (87.3258) time: 0.5989 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [1300/2503] eta: 0:12:06 lr: 0.0004 img/s: 855.8589613399885 loss: 0.8206 (0.7913) acc1: 68.9453 (69.1304) acc5: 87.3047 (87.3177) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [1] [1400/2503] eta: 0:11:05 lr: 0.0004 img/s: 856.6045604019511 loss: 0.8454 (0.7936) acc1: 68.1641 (69.1019) acc5: 86.9141 (87.2906) time: 0.5989 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [1500/2503] eta: 0:10:04 lr: 0.0004 img/s: 854.944442321167 loss: 0.8428 (0.7960) acc1: 68.1641 (69.0905) acc5: 87.3047 (87.2706) time: 0.5990 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [1600/2503] eta: 0:09:04 lr: 0.0004 img/s: 855.0727794914757 loss: 0.7906 (0.7974) acc1: 69.5312 (69.0922) acc5: 87.1094 (87.2686) time: 0.5990 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [1700/2503] eta: 0:08:03 lr: 0.0004 img/s: 855.4958499669949 loss: 0.8199 (0.7989) acc1: 69.7266 (69.0854) acc5: 87.1094 (87.2704) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [1800/2503] eta: 0:07:03 lr: 0.0004 img/s: 855.0251166287029 loss: 0.8257 (0.8007) acc1: 70.1172 (69.0869) acc5: 87.5000 (87.2656) time: 0.5988 data: 0.0003 max mem: 19119\n",
"Epoch: [1] [1900/2503] eta: 0:06:03 lr: 0.0004 img/s: 856.9867390518363 loss: 0.7952 (0.8018) acc1: 68.7500 (69.0943) acc5: 87.3047 (87.2670) time: 0.5989 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [2000/2503] eta: 0:05:02 lr: 0.0004 img/s: 854.3927252530574 loss: 0.8402 (0.8032) acc1: 68.5547 (69.0964) acc5: 87.1094 (87.2747) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [2100/2503] eta: 0:04:02 lr: 0.0004 img/s: 855.2427067851231 loss: 0.8451 (0.8042) acc1: 68.3594 (69.1089) acc5: 87.3047 (87.2816) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [2200/2503] eta: 0:03:02 lr: 0.0004 img/s: 853.8318747507567 loss: 0.8314 (0.8058) acc1: 68.9453 (69.1012) acc5: 87.3047 (87.2716) time: 0.5989 data: 0.0003 max mem: 19119\n",
"Epoch: [1] [2300/2503] eta: 0:02:02 lr: 0.0004 img/s: 855.3312728222841 loss: 0.8350 (0.8070) acc1: 68.7500 (69.0993) acc5: 86.3281 (87.2549) time: 0.5988 data: 0.0003 max mem: 19119\n",
"Epoch: [1] [2400/2503] eta: 0:01:01 lr: 0.0004 img/s: 855.1613103361218 loss: 0.8206 (0.8084) acc1: 68.5547 (69.0927) acc5: 86.9141 (87.2551) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [2500/2503] eta: 0:00:01 lr: 0.0004 img/s: 856.7190414886559 loss: 0.8286 (0.8094) acc1: 69.1406 (69.1000) acc5: 87.1094 (87.2599) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [1] Total time: 0:25:05\n",
"Test: [ 0/782] eta: 0:16:19 loss: 0.5636 (0.5636) acc1: 87.5000 (87.5000) acc5: 96.8750 (96.8750) time: 1.2525 data: 1.2385 max mem: 19119\n",
"Test: [100/782] eta: 0:00:31 loss: 1.0393 (0.9414) acc1: 76.5625 (76.9647) acc5: 90.6250 (92.2958) time: 0.0417 data: 0.0280 max mem: 19119\n",
"Test: [200/782] eta: 0:00:22 loss: 0.8964 (0.9176) acc1: 73.4375 (76.4614) acc5: 95.3125 (93.3147) time: 0.0249 data: 0.0112 max mem: 19119\n",
"Test: [300/782] eta: 0:00:17 loss: 0.7984 (0.9094) acc1: 79.6875 (76.7130) acc5: 92.1875 (93.6150) time: 0.0311 data: 0.0173 max mem: 19119\n",
"Test: [400/782] eta: 0:00:13 loss: 1.7745 (1.0483) acc1: 57.8125 (73.9635) acc5: 84.3750 (91.8758) time: 0.0328 data: 0.0190 max mem: 19119\n",
"Test: [500/782] eta: 0:00:09 loss: 1.6435 (1.1264) acc1: 59.3750 (72.4239) acc5: 84.3750 (90.7934) time: 0.0328 data: 0.0190 max mem: 19119\n",
"Test: [600/782] eta: 0:00:06 loss: 1.3057 (1.1915) acc1: 62.5000 (71.0483) acc5: 85.9375 (90.0010) time: 0.0400 data: 0.0261 max mem: 19119\n",
"Test: [700/782] eta: 0:00:02 loss: 1.2212 (1.2428) acc1: 70.3125 (70.0985) acc5: 87.5000 (89.3010) time: 0.0253 data: 0.0115 max mem: 19119\n",
"Test: Total time: 0:00:26\n",
"Test: Acc@1 70.000 Acc@5 89.320\n",
"Epoch: [2] [ 0/2503] eta: 4:06:15 lr: 0.0004 img/s: 867.4756359685 loss: 0.8414 (0.8414) acc1: 67.9688 (67.9688) acc5: 86.1328 (86.1328) time: 5.9030 data: 5.3128 max mem: 19119\n",
"Epoch: [2] [ 100/2503] eta: 0:25:53 lr: 0.0004 img/s: 859.1872918530421 loss: 0.8472 (0.8456) acc1: 68.9453 (69.1194) acc5: 86.7188 (87.2892) time: 0.5963 data: 0.0003 max mem: 19119\n",
"Epoch: [2] [ 200/2503] eta: 0:23:52 lr: 0.0004 img/s: 857.5945509684602 loss: 0.8563 (0.8443) acc1: 68.1641 (69.1649) acc5: 86.3281 (87.1852) time: 0.5972 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [ 300/2503] eta: 0:22:31 lr: 0.0004 img/s: 859.3399450658505 loss: 0.8386 (0.8440) acc1: 69.1406 (69.0790) acc5: 87.3047 (87.1762) time: 0.5967 data: 0.0003 max mem: 19119\n",
"Epoch: [2] [ 400/2503] eta: 0:21:21 lr: 0.0004 img/s: 859.825426282471 loss: 0.8455 (0.8446) acc1: 69.3359 (69.0422) acc5: 87.3047 (87.1591) time: 0.5963 data: 0.0003 max mem: 19119\n",
"Epoch: [2] [ 500/2503] eta: 0:20:15 lr: 0.0004 img/s: 858.6002202991031 loss: 0.8400 (0.8448) acc1: 67.9688 (69.0373) acc5: 87.5000 (87.1640) time: 0.5962 data: 0.0003 max mem: 19119\n",
"Epoch: [2] [ 600/2503] eta: 0:19:11 lr: 0.0004 img/s: 859.6997885462696 loss: 0.8544 (0.8467) acc1: 68.3594 (69.0217) acc5: 86.9141 (87.1347) time: 0.5963 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [ 700/2503] eta: 0:18:08 lr: 0.0004 img/s: 858.0379481409502 loss: 0.8386 (0.8466) acc1: 68.7500 (69.0314) acc5: 87.1094 (87.1459) time: 0.5966 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [ 800/2503] eta: 0:17:06 lr: 0.0004 img/s: 859.3574830298114 loss: 0.8607 (0.8472) acc1: 69.5312 (69.0338) acc5: 87.1094 (87.1364) time: 0.5965 data: 0.0003 max mem: 19119\n",
"Epoch: [2] [ 900/2503] eta: 0:16:05 lr: 0.0004 img/s: 858.1785328651912 loss: 0.8502 (0.8474) acc1: 68.5547 (69.0273) acc5: 87.1094 (87.1404) time: 0.5966 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [1000/2503] eta: 0:15:04 lr: 0.0004 img/s: 858.6554923981921 loss: 0.8213 (0.8468) acc1: 70.1172 (69.0737) acc5: 87.6953 (87.1601) time: 0.5966 data: 0.0003 max mem: 19119\n",
"Epoch: [2] [1100/2503] eta: 0:14:03 lr: 0.0004 img/s: 858.942266240932 loss: 0.8322 (0.8466) acc1: 68.9453 (69.0824) acc5: 87.5000 (87.1826) time: 0.5965 data: 0.0003 max mem: 19119\n",
"Epoch: [2] [1200/2503] eta: 0:13:02 lr: 0.0004 img/s: 859.9796839014982 loss: 0.8353 (0.8468) acc1: 68.5547 (69.0858) acc5: 87.1094 (87.1886) time: 0.5962 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [1300/2503] eta: 0:12:02 lr: 0.0004 img/s: 858.8996673563247 loss: 0.8654 (0.8471) acc1: 68.3594 (69.0645) acc5: 86.7188 (87.1876) time: 0.5966 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [1400/2503] eta: 0:11:01 lr: 0.0004 img/s: 859.7879032225777 loss: 0.8277 (0.8466) acc1: 70.1172 (69.0861) acc5: 88.6719 (87.2244) time: 0.5963 data: 0.0003 max mem: 19119\n",
"Epoch: [2] [1500/2503] eta: 0:10:01 lr: 0.0004 img/s: 859.6271763544084 loss: 0.8703 (0.8471) acc1: 68.7500 (69.0763) acc5: 86.5234 (87.2101) time: 0.5962 data: 0.0003 max mem: 19119\n",
"Epoch: [2] [1600/2503] eta: 0:09:01 lr: 0.0004 img/s: 859.8206066235718 loss: 0.8818 (0.8481) acc1: 68.5547 (69.0523) acc5: 86.7188 (87.1975) time: 0.5965 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [1700/2503] eta: 0:08:01 lr: 0.0004 img/s: 858.3188206070029 loss: 0.8447 (0.8487) acc1: 69.3359 (69.0259) acc5: 87.3047 (87.1857) time: 0.5965 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [1800/2503] eta: 0:07:01 lr: 0.0004 img/s: 860.4651988761577 loss: 0.8492 (0.8489) acc1: 67.5781 (69.0295) acc5: 86.7188 (87.1785) time: 0.5965 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [1900/2503] eta: 0:06:01 lr: 0.0004 img/s: 858.5823699612625 loss: 0.8559 (0.8491) acc1: 67.9688 (69.0192) acc5: 87.5000 (87.1776) time: 0.5963 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [2000/2503] eta: 0:05:01 lr: 0.0004 img/s: 858.4468002775832 loss: 0.8712 (0.8498) acc1: 68.9453 (69.0207) acc5: 87.6953 (87.1836) time: 0.5963 data: 0.0003 max mem: 19119\n",
"Epoch: [2] [2100/2503] eta: 0:04:01 lr: 0.0004 img/s: 858.6208177650899 loss: 0.8782 (0.8507) acc1: 68.3594 (69.0053) acc5: 85.9375 (87.1818) time: 0.5964 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [2200/2503] eta: 0:03:01 lr: 0.0004 img/s: 858.769492116456 loss: 0.8845 (0.8514) acc1: 68.3594 (68.9953) acc5: 86.5234 (87.1744) time: 0.5963 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [2300/2503] eta: 0:02:01 lr: 0.0004 img/s: 860.1050589782235 loss: 0.8664 (0.8522) acc1: 68.7500 (68.9914) acc5: 87.5000 (87.1735) time: 0.5962 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [2400/2503] eta: 0:01:01 lr: 0.0004 img/s: 859.3210323775423 loss: 0.8824 (0.8529) acc1: 67.7734 (68.9772) acc5: 86.5234 (87.1693) time: 0.5963 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [2500/2503] eta: 0:00:01 lr: 0.0004 img/s: 860.1956686665284 loss: 0.8302 (0.8531) acc1: 69.9219 (68.9880) acc5: 87.5000 (87.1751) time: 0.5962 data: 0.0002 max mem: 19119\n",
"Epoch: [2] Total time: 0:24:57\n",
"Test: [ 0/782] eta: 0:14:49 loss: 0.6400 (0.6400) acc1: 82.8125 (82.8125) acc5: 93.7500 (93.7500) time: 1.1370 data: 1.1232 max mem: 19119\n",
"Test: [100/782] eta: 0:00:28 loss: 1.0691 (0.9495) acc1: 75.0000 (76.8874) acc5: 89.0625 (92.2184) time: 0.0422 data: 0.0284 max mem: 19119\n",
"Test: [200/782] eta: 0:00:20 loss: 0.8384 (0.9253) acc1: 75.0000 (76.3293) acc5: 95.3125 (93.2292) time: 0.0298 data: 0.0161 max mem: 19119\n",
"Test: [300/782] eta: 0:00:16 loss: 0.8140 (0.9153) acc1: 78.1250 (76.6092) acc5: 92.1875 (93.5631) time: 0.0281 data: 0.0143 max mem: 19119\n",
"Test: [400/782] eta: 0:00:12 loss: 1.7029 (1.0528) acc1: 62.5000 (73.9479) acc5: 84.3750 (91.8797) time: 0.0260 data: 0.0123 max mem: 19119\n",
"Test: [500/782] eta: 0:00:09 loss: 1.7149 (1.1295) acc1: 59.3750 (72.4894) acc5: 84.3750 (90.7997) time: 0.0315 data: 0.0177 max mem: 19119\n",
"Test: [600/782] eta: 0:00:06 loss: 1.3215 (1.1949) acc1: 65.6250 (71.1288) acc5: 85.9375 (90.0192) time: 0.0343 data: 0.0204 max mem: 19119\n",
"Test: [700/782] eta: 0:00:02 loss: 1.3000 (1.2468) acc1: 70.3125 (70.1386) acc5: 85.9375 (89.2809) time: 0.0246 data: 0.0108 max mem: 19119\n",
"Test: Total time: 0:00:25\n",
"Test: Acc@1 70.034 Acc@5 89.306\n",
"Epoch: [3] [ 0/2503] eta: 3:48:40 lr: 0.0002 img/s: 868.6651772838787 loss: 0.9922 (0.9922) acc1: 65.8203 (65.8203) acc5: 84.3750 (84.3750) time: 5.4818 data: 4.8924 max mem: 19119\n",
"Epoch: [3] [ 100/2503] eta: 0:25:56 lr: 0.0002 img/s: 857.1146638568258 loss: 0.8599 (0.8484) acc1: 69.7266 (69.1851) acc5: 86.7188 (87.2660) time: 0.5978 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [ 200/2503] eta: 0:23:56 lr: 0.0002 img/s: 854.6256384868216 loss: 0.8801 (0.8570) acc1: 68.7500 (69.0182) acc5: 86.3281 (87.1521) time: 0.5998 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [ 300/2503] eta: 0:22:36 lr: 0.0002 img/s: 855.2042203405152 loss: 0.8260 (0.8538) acc1: 69.3359 (69.0959) acc5: 87.6953 (87.2262) time: 0.5990 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [ 400/2503] eta: 0:21:25 lr: 0.0002 img/s: 856.2763231713128 loss: 0.8881 (0.8553) acc1: 68.3594 (69.1733) acc5: 86.9141 (87.2141) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [3] [ 500/2503] eta: 0:20:19 lr: 0.0002 img/s: 856.0431919432866 loss: 0.8596 (0.8573) acc1: 68.3594 (69.1281) acc5: 87.3047 (87.2291) time: 0.5982 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [ 600/2503] eta: 0:19:15 lr: 0.0002 img/s: 855.2682527981125 loss: 0.8779 (0.8592) acc1: 68.1641 (69.1153) acc5: 86.7188 (87.2007) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [3] [ 700/2503] eta: 0:18:12 lr: 0.0002 img/s: 855.3551206587658 loss: 0.8727 (0.8601) acc1: 68.7500 (69.0902) acc5: 87.8906 (87.2033) time: 0.5988 data: 0.0003 max mem: 19119\n",
"Epoch: [3] [ 800/2503] eta: 0:17:10 lr: 0.0002 img/s: 854.5804059835012 loss: 0.8775 (0.8608) acc1: 69.1406 (69.0684) acc5: 86.5234 (87.1713) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [ 900/2503] eta: 0:16:08 lr: 0.0002 img/s: 855.525160200547 loss: 0.8299 (0.8601) acc1: 69.5312 (69.0866) acc5: 87.5000 (87.1883) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [1000/2503] eta: 0:15:07 lr: 0.0002 img/s: 855.1732293498484 loss: 0.8740 (0.8600) acc1: 68.3594 (69.0608) acc5: 86.7188 (87.1773) time: 0.5987 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [1100/2503] eta: 0:14:06 lr: 0.0002 img/s: 855.7201584498974 loss: 0.8490 (0.8600) acc1: 69.5312 (69.0574) acc5: 87.6953 (87.1810) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [3] [1200/2503] eta: 0:13:05 lr: 0.0002 img/s: 855.6761738093758 loss: 0.8551 (0.8598) acc1: 70.1172 (69.0749) acc5: 87.3047 (87.1956) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [1300/2503] eta: 0:12:04 lr: 0.0002 img/s: 855.1391759063549 loss: 0.8736 (0.8596) acc1: 68.9453 (69.1111) acc5: 87.5000 (87.2038) time: 0.5985 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [1400/2503] eta: 0:11:04 lr: 0.0002 img/s: 855.9558431039926 loss: 0.8849 (0.8602) acc1: 69.1406 (69.1073) acc5: 86.3281 (87.2021) time: 0.5989 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [1500/2503] eta: 0:10:03 lr: 0.0002 img/s: 856.2879318505251 loss: 0.8493 (0.8600) acc1: 69.7266 (69.1198) acc5: 87.3047 (87.2114) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [3] [1600/2503] eta: 0:09:03 lr: 0.0002 img/s: 855.3885098640291 loss: 0.8944 (0.8605) acc1: 67.9688 (69.1188) acc5: 86.5234 (87.2106) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [3] [1700/2503] eta: 0:08:03 lr: 0.0002 img/s: 855.653671788276 loss: 0.8327 (0.8606) acc1: 69.7266 (69.1132) acc5: 87.3047 (87.1988) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [1800/2503] eta: 0:07:02 lr: 0.0002 img/s: 854.6603313202065 loss: 0.8716 (0.8606) acc1: 69.5312 (69.1106) acc5: 87.3047 (87.2095) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [1900/2503] eta: 0:06:02 lr: 0.0002 img/s: 855.8654421819135 loss: 0.8433 (0.8607) acc1: 68.9453 (69.1149) acc5: 86.9141 (87.1973) time: 0.5985 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [2000/2503] eta: 0:05:02 lr: 0.0002 img/s: 858.0228637365412 loss: 0.8635 (0.8613) acc1: 69.1406 (69.0981) acc5: 86.7188 (87.1936) time: 0.5985 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [2100/2503] eta: 0:04:02 lr: 0.0002 img/s: 855.1837864680422 loss: 0.8389 (0.8614) acc1: 69.7266 (69.1067) acc5: 87.3047 (87.2038) time: 0.5987 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [2200/2503] eta: 0:03:02 lr: 0.0002 img/s: 854.9267436657309 loss: 0.8588 (0.8618) acc1: 69.5312 (69.1018) acc5: 86.9141 (87.2006) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [2300/2503] eta: 0:02:01 lr: 0.0002 img/s: 857.5592770650364 loss: 0.8385 (0.8623) acc1: 69.7266 (69.1041) acc5: 87.6953 (87.1965) time: 0.5985 data: 0.0003 max mem: 19119\n",
"Epoch: [3] [2400/2503] eta: 0:01:01 lr: 0.0002 img/s: 854.3804880688189 loss: 0.8534 (0.8625) acc1: 68.9453 (69.1074) acc5: 87.3047 (87.1914) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [2500/2503] eta: 0:00:01 lr: 0.0002 img/s: 855.3973686621443 loss: 0.8348 (0.8625) acc1: 69.3359 (69.1134) acc5: 87.6953 (87.1933) time: 0.5984 data: 0.0002 max mem: 19119\n",
"Epoch: [3] Total time: 0:25:03\n",
"Test: [ 0/782] eta: 0:13:34 loss: 0.6298 (0.6298) acc1: 84.3750 (84.3750) acc5: 95.3125 (95.3125) time: 1.0412 data: 1.0273 max mem: 19119\n",
"Test: [100/782] eta: 0:00:28 loss: 1.0908 (0.9514) acc1: 75.0000 (76.9957) acc5: 89.0625 (92.2339) time: 0.0397 data: 0.0260 max mem: 19119\n",
"Test: [200/782] eta: 0:00:21 loss: 0.9058 (0.9231) acc1: 73.4375 (76.3137) acc5: 95.3125 (93.2680) time: 0.0258 data: 0.0121 max mem: 19119\n",
"Test: [300/782] eta: 0:00:17 loss: 0.8269 (0.9143) acc1: 79.6875 (76.5988) acc5: 92.1875 (93.5735) time: 0.0356 data: 0.0218 max mem: 19119\n",
"Test: [400/782] eta: 0:00:13 loss: 1.8047 (1.0535) acc1: 60.9375 (73.9207) acc5: 82.8125 (91.8329) time: 0.0270 data: 0.0133 max mem: 19119\n",
"Test: [500/782] eta: 0:00:09 loss: 1.6839 (1.1303) acc1: 59.3750 (72.5050) acc5: 85.9375 (90.7622) time: 0.0334 data: 0.0196 max mem: 19119\n",
"Test: [600/782] eta: 0:00:06 loss: 1.3633 (1.1951) acc1: 64.0625 (71.1704) acc5: 87.5000 (89.9776) time: 0.0256 data: 0.0118 max mem: 19119\n",
"Test: [700/782] eta: 0:00:02 loss: 1.2720 (1.2480) acc1: 71.8750 (70.1632) acc5: 85.9375 (89.2876) time: 0.0280 data: 0.0142 max mem: 19119\n",
"Test: Total time: 0:00:26\n",
"Test: Acc@1 70.092 Acc@5 89.308\n",
"Epoch: [4] [ 0/2503] eta: 3:53:04 lr: 0.0002 img/s: 868.854963030395 loss: 0.9245 (0.9245) acc1: 67.5781 (67.5781) acc5: 87.3047 (87.3047) time: 5.5871 data: 4.9977 max mem: 19119\n",
"Epoch: [4] [ 100/2503] eta: 0:25:58 lr: 0.0002 img/s: 856.1991675966316 loss: 0.8765 (0.8657) acc1: 68.9453 (69.2605) acc5: 86.5234 (87.1171) time: 0.5978 data: 0.0002 max mem: 19119\n",
"Epoch: [4] [ 200/2503] eta: 0:23:57 lr: 0.0002 img/s: 853.9622551773969 loss: 0.8653 (0.8621) acc1: 69.7266 (69.4321) acc5: 87.3047 (87.2464) time: 0.5998 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [ 300/2503] eta: 0:22:37 lr: 0.0002 img/s: 855.0397553711638 loss: 0.8689 (0.8659) acc1: 69.3359 (69.2879) acc5: 87.3047 (87.2275) time: 0.5990 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [ 400/2503] eta: 0:21:26 lr: 0.0002 img/s: 857.3206551376953 loss: 0.8390 (0.8645) acc1: 68.7500 (69.3432) acc5: 86.5234 (87.2808) time: 0.5983 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [ 500/2503] eta: 0:20:19 lr: 0.0002 img/s: 856.0698095863008 loss: 0.8473 (0.8638) acc1: 69.5312 (69.3262) acc5: 87.5000 (87.3148) time: 0.5984 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [ 600/2503] eta: 0:19:15 lr: 0.0002 img/s: 856.6852065793925 loss: 0.8467 (0.8628) acc1: 69.3359 (69.3753) acc5: 87.3047 (87.3388) time: 0.5985 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [ 700/2503] eta: 0:18:12 lr: 0.0002 img/s: 856.1111043339286 loss: 0.8966 (0.8657) acc1: 68.1641 (69.3106) acc5: 86.7188 (87.2924) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [4] [ 800/2503] eta: 0:17:10 lr: 0.0002 img/s: 856.0469456131993 loss: 0.8508 (0.8662) acc1: 69.3359 (69.2959) acc5: 87.6953 (87.2976) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [ 900/2503] eta: 0:16:08 lr: 0.0002 img/s: 854.9645243753337 loss: 0.8663 (0.8662) acc1: 70.3125 (69.2997) acc5: 87.6953 (87.2860) time: 0.5986 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [1000/2503] eta: 0:15:07 lr: 0.0002 img/s: 855.0186485068215 loss: 0.8554 (0.8671) acc1: 68.3594 (69.2827) acc5: 86.9141 (87.2858) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [1100/2503] eta: 0:14:06 lr: 0.0002 img/s: 855.2873281496619 loss: 0.8594 (0.8670) acc1: 69.7266 (69.2912) acc5: 87.5000 (87.2926) time: 0.5985 data: 0.0002 max mem: 19119\n",
"Epoch: [4] [1200/2503] eta: 0:13:05 lr: 0.0002 img/s: 856.2193086400064 loss: 0.8768 (0.8677) acc1: 68.7500 (69.2626) acc5: 86.7188 (87.2596) time: 0.5982 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [1300/2503] eta: 0:12:04 lr: 0.0002 img/s: 856.0534293024039 loss: 0.8824 (0.8688) acc1: 68.7500 (69.2354) acc5: 87.1094 (87.2479) time: 0.5985 data: 0.0002 max mem: 19119\n",
"Epoch: [4] [1400/2503] eta: 0:11:04 lr: 0.0002 img/s: 857.5664685962558 loss: 0.8628 (0.8691) acc1: 68.7500 (69.2155) acc5: 87.5000 (87.2509) time: 0.5984 data: 0.0002 max mem: 19119\n",
"Epoch: [4] [1500/2503] eta: 0:10:03 lr: 0.0002 img/s: 854.9161928928554 loss: 0.8838 (0.8696) acc1: 69.1406 (69.1992) acc5: 87.5000 (87.2396) time: 0.5983 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [1600/2503] eta: 0:09:03 lr: 0.0002 img/s: 858.2563886134587 loss: 0.8541 (0.8697) acc1: 68.5547 (69.1823) acc5: 87.3047 (87.2258) time: 0.5985 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [1700/2503] eta: 0:08:03 lr: 0.0002 img/s: 855.3152614496619 loss: 0.8702 (0.8694) acc1: 68.9453 (69.1954) acc5: 86.9141 (87.2312) time: 0.5988 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [1800/2503] eta: 0:07:02 lr: 0.0002 img/s: 856.2275018779335 loss: 0.8691 (0.8696) acc1: 68.9453 (69.1939) acc5: 87.1094 (87.2242) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [1900/2503] eta: 0:06:02 lr: 0.0002 img/s: 854.6021714024656 loss: 0.8867 (0.8703) acc1: 68.3594 (69.1812) acc5: 86.5234 (87.2186) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [2000/2503] eta: 0:05:02 lr: 0.0002 img/s: 856.027153906284 loss: 0.8680 (0.8710) acc1: 67.9688 (69.1654) acc5: 87.1094 (87.2148) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [4] [2100/2503] eta: 0:04:02 lr: 0.0002 img/s: 854.944442321167 loss: 0.8930 (0.8715) acc1: 68.9453 (69.1539) acc5: 87.1094 (87.2134) time: 0.5987 data: 0.0002 max mem: 19119\n",
"Epoch: [4] [2200/2503] eta: 0:03:02 lr: 0.0002 img/s: 857.4733304797255 loss: 0.8344 (0.8713) acc1: 69.7266 (69.1614) acc5: 87.6953 (87.2185) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [4] [2300/2503] eta: 0:02:01 lr: 0.0002 img/s: 854.5545609884018 loss: 0.8644 (0.8712) acc1: 69.1406 (69.1611) acc5: 86.9141 (87.2154) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [4] [2400/2503] eta: 0:01:01 lr: 0.0002 img/s: 855.9186570165338 loss: 0.8843 (0.8714) acc1: 69.1406 (69.1609) acc5: 86.9141 (87.2155) time: 0.5987 data: 0.0002 max mem: 19119\n",
"Epoch: [4] [2500/2503] eta: 0:00:01 lr: 0.0002 img/s: 855.5220927564314 loss: 0.8757 (0.8719) acc1: 68.9453 (69.1569) acc5: 87.1094 (87.2136) time: 0.5979 data: 0.0002 max mem: 19119\n",
"Epoch: [4] Total time: 0:25:03\n",
"Test: [ 0/782] eta: 0:15:24 loss: 0.5899 (0.5899) acc1: 85.9375 (85.9375) acc5: 95.3125 (95.3125) time: 1.1827 data: 1.1689 max mem: 19119\n",
"Test: [100/782] eta: 0:00:27 loss: 1.0514 (0.9553) acc1: 76.5625 (76.7327) acc5: 89.0625 (92.1411) time: 0.0322 data: 0.0185 max mem: 19119\n",
"Test: [200/782] eta: 0:00:20 loss: 0.8755 (0.9239) acc1: 75.0000 (76.3682) acc5: 95.3125 (93.2369) time: 0.0254 data: 0.0116 max mem: 19119\n",
"Test: [300/782] eta: 0:00:16 loss: 0.7986 (0.9160) acc1: 78.1250 (76.6352) acc5: 92.1875 (93.5424) time: 0.0298 data: 0.0161 max mem: 19119\n",
"Test: [400/782] eta: 0:00:12 loss: 1.7921 (1.0555) acc1: 60.9375 (73.8817) acc5: 84.3750 (91.8095) time: 0.0308 data: 0.0171 max mem: 19119\n",
"Test: [500/782] eta: 0:00:09 loss: 1.7681 (1.1332) acc1: 59.3750 (72.4613) acc5: 84.3750 (90.7248) time: 0.0302 data: 0.0164 max mem: 19119\n",
"Test: [600/782] eta: 0:00:06 loss: 1.3149 (1.1978) acc1: 65.6250 (71.1340) acc5: 85.9375 (89.9880) time: 0.0445 data: 0.0307 max mem: 19119\n",
"Test: [700/782] eta: 0:00:02 loss: 1.2842 (1.2500) acc1: 70.3125 (70.1297) acc5: 87.5000 (89.2899) time: 0.0292 data: 0.0154 max mem: 19119\n",
"Test: Total time: 0:00:25\n",
"Test: Acc@1 70.056 Acc@5 89.296\n",
"Epoch: [5] [ 0/2503] eta: 3:36:03 lr: 0.0001 img/s: 868.6398787818482 loss: 0.9304 (0.9304) acc1: 68.1641 (68.1641) acc5: 84.5703 (84.5703) time: 5.1790 data: 4.5895 max mem: 19119\n",
"Epoch: [5] [ 100/2503] eta: 0:26:17 lr: 0.0001 img/s: 856.2223810858537 loss: 0.8658 (0.8735) acc1: 68.7500 (69.3843) acc5: 86.5234 (87.1229) time: 0.5978 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [ 200/2503] eta: 0:24:06 lr: 0.0001 img/s: 854.6059124455048 loss: 0.8757 (0.8726) acc1: 68.3594 (69.2893) acc5: 87.3047 (87.2027) time: 0.5998 data: 0.0003 max mem: 19119\n",
"Epoch: [5] [ 300/2503] eta: 0:22:42 lr: 0.0001 img/s: 854.7457148952566 loss: 0.9211 (0.8775) acc1: 68.3594 (69.1964) acc5: 86.5234 (87.1548) time: 0.5994 data: 0.0003 max mem: 19119\n",
"Epoch: [5] [ 400/2503] eta: 0:21:30 lr: 0.0001 img/s: 855.9967856527867 loss: 0.8790 (0.8763) acc1: 69.5312 (69.2220) acc5: 87.1094 (87.1795) time: 0.5981 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [ 500/2503] eta: 0:20:22 lr: 0.0001 img/s: 856.1715178770771 loss: 0.8828 (0.8767) acc1: 68.5547 (69.1835) acc5: 87.1094 (87.1916) time: 0.5981 data: 0.0003 max mem: 19119\n",
"Epoch: [5] [ 600/2503] eta: 0:19:17 lr: 0.0001 img/s: 857.1064536316845 loss: 0.8499 (0.8745) acc1: 70.1172 (69.2495) acc5: 87.6953 (87.2351) time: 0.5985 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [ 700/2503] eta: 0:18:14 lr: 0.0001 img/s: 855.5684475906053 loss: 0.8724 (0.8738) acc1: 68.7500 (69.2543) acc5: 87.3047 (87.2381) time: 0.5986 data: 0.0003 max mem: 19119\n",
"Epoch: [5] [ 800/2503] eta: 0:17:11 lr: 0.0001 img/s: 855.3483068555046 loss: 0.8648 (0.8747) acc1: 69.3359 (69.2645) acc5: 87.1094 (87.2354) time: 0.5987 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [ 900/2503] eta: 0:16:10 lr: 0.0001 img/s: 855.0516709967052 loss: 0.8869 (0.8744) acc1: 68.9453 (69.2787) acc5: 87.1094 (87.2609) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [1000/2503] eta: 0:15:08 lr: 0.0001 img/s: 856.4016452607827 loss: 0.8660 (0.8746) acc1: 69.5312 (69.2706) acc5: 87.3047 (87.2602) time: 0.5987 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [1100/2503] eta: 0:14:07 lr: 0.0001 img/s: 855.5752649016649 loss: 0.8801 (0.8741) acc1: 69.3359 (69.2911) acc5: 86.7188 (87.2653) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [5] [1200/2503] eta: 0:13:06 lr: 0.0001 img/s: 857.7295089955442 loss: 0.9027 (0.8741) acc1: 68.1641 (69.2950) acc5: 86.7188 (87.2543) time: 0.5983 data: 0.0003 max mem: 19119\n",
"Epoch: [5] [1300/2503] eta: 0:12:05 lr: 0.0001 img/s: 857.4562117036577 loss: 0.8591 (0.8742) acc1: 69.3359 (69.2867) acc5: 87.3047 (87.2523) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [5] [1400/2503] eta: 0:11:04 lr: 0.0001 img/s: 856.1711765336744 loss: 0.8659 (0.8747) acc1: 69.7266 (69.2756) acc5: 87.1094 (87.2657) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [1500/2503] eta: 0:10:04 lr: 0.0001 img/s: 855.6918577361109 loss: 0.8848 (0.8744) acc1: 67.5781 (69.2883) acc5: 86.9141 (87.2744) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [1600/2503] eta: 0:09:03 lr: 0.0001 img/s: 856.2199914038447 loss: 0.8476 (0.8740) acc1: 69.1406 (69.2962) acc5: 87.5000 (87.2808) time: 0.5987 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [1700/2503] eta: 0:08:03 lr: 0.0001 img/s: 856.7518536501099 loss: 0.8676 (0.8745) acc1: 68.3594 (69.2823) acc5: 87.6953 (87.2798) time: 0.5986 data: 0.0003 max mem: 19119\n",
"Epoch: [5] [1800/2503] eta: 0:07:03 lr: 0.0001 img/s: 855.5047109885885 loss: 0.8884 (0.8745) acc1: 68.1641 (69.2810) acc5: 87.3047 (87.2716) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [1900/2503] eta: 0:06:02 lr: 0.0001 img/s: 855.3728370553123 loss: 0.8565 (0.8744) acc1: 68.9453 (69.2820) acc5: 88.2812 (87.2709) time: 0.5987 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [2000/2503] eta: 0:05:02 lr: 0.0001 img/s: 856.9477535227904 loss: 0.8799 (0.8749) acc1: 68.7500 (69.2718) acc5: 86.7188 (87.2586) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [2100/2503] eta: 0:04:02 lr: 0.0001 img/s: 855.2914158356967 loss: 0.8914 (0.8745) acc1: 68.9453 (69.2829) acc5: 87.5000 (87.2690) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [5] [2200/2503] eta: 0:03:02 lr: 0.0001 img/s: 855.3816955287992 loss: 0.8659 (0.8744) acc1: 68.9453 (69.2853) acc5: 87.1094 (87.2749) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [2300/2503] eta: 0:02:02 lr: 0.0001 img/s: 857.8035141651501 loss: 0.8525 (0.8745) acc1: 69.7266 (69.2810) acc5: 87.5000 (87.2740) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [2400/2503] eta: 0:01:01 lr: 0.0001 img/s: 855.2345323832691 loss: 0.8424 (0.8745) acc1: 69.5312 (69.2709) acc5: 87.3047 (87.2737) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [2500/2503] eta: 0:00:01 lr: 0.0001 img/s: 856.8721864153102 loss: 0.8785 (0.8748) acc1: 68.7500 (69.2651) acc5: 87.5000 (87.2714) time: 0.5982 data: 0.0002 max mem: 19119\n",
"Epoch: [5] Total time: 0:25:04\n",
"Test: [ 0/782] eta: 0:16:48 loss: 0.6137 (0.6137) acc1: 85.9375 (85.9375) acc5: 95.3125 (95.3125) time: 1.2890 data: 1.2749 max mem: 19119\n",
"Test: [100/782] eta: 0:00:29 loss: 1.0820 (0.9476) acc1: 76.5625 (77.0885) acc5: 89.0625 (92.3113) time: 0.0337 data: 0.0200 max mem: 19119\n",
"Test: [200/782] eta: 0:00:21 loss: 0.8791 (0.9212) acc1: 75.0000 (76.4537) acc5: 95.3125 (93.3613) time: 0.0276 data: 0.0139 max mem: 19119\n",
"Test: [300/782] eta: 0:00:16 loss: 0.8066 (0.9144) acc1: 76.5625 (76.6923) acc5: 92.1875 (93.6306) time: 0.0281 data: 0.0144 max mem: 19119\n",
"Test: [400/782] eta: 0:00:13 loss: 1.8165 (1.0555) acc1: 60.9375 (73.9596) acc5: 84.3750 (91.9031) time: 0.0350 data: 0.0214 max mem: 19119\n",
"Test: [500/782] eta: 0:00:09 loss: 1.7107 (1.1325) acc1: 59.3750 (72.5299) acc5: 84.3750 (90.7934) time: 0.0344 data: 0.0206 max mem: 19119\n",
"Test: [600/782] eta: 0:00:06 loss: 1.3799 (1.1970) acc1: 64.0625 (71.2068) acc5: 84.3750 (89.9880) time: 0.0266 data: 0.0127 max mem: 19119\n",
"Test: [700/782] eta: 0:00:02 loss: 1.2741 (1.2493) acc1: 68.7500 (70.1966) acc5: 85.9375 (89.2765) time: 0.0262 data: 0.0124 max mem: 19119\n",
"Test: Total time: 0:00:25\n",
"Test: Acc@1 70.120 Acc@5 89.284\n",
"Training time 2:33:01\n"
]
}
],
"source": [
"from types import SimpleNamespace\n",
"\n",
"args = SimpleNamespace(\n",
" data_path=\"/home/cs/Documents/datasets/imagenet\", # Replace with your /path/to/imagenet\n",
" model=\"resnet18\",\n",
" device=\"cuda\",\n",
" batch_size=512,\n",
" epochs=6,\n",
" lr=0.0004,\n",
" momentum=0.9,\n",
" weight_decay=1e-4,\n",
" lr_warmup_epochs=1,\n",
" lr_warmup_decay=0.0,\n",
" lr_step_size=2,\n",
" lr_gamma=0.5,\n",
" print_freq=100,\n",
" output_dir=\"resnet18\",\n",
" use_deterministic_algorithms=False,\n",
" weights=\"ResNet18_Weights.IMAGENET1K_V1\",\n",
" apply_trp=True,\n",
" trp_depths=[3, 3, 3],\n",
" trp_planes=256,\n",
" trp_lambdas=[0.4, 0.2, 0.1],\n",
")\n",
"\n",
"main(args)"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.21"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
|