File size: 71,876 Bytes
b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a b04de69 9e6844a |
|
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "X4cRE8IbIrIV"
},
"source": [
"Downloading PyTorch Vision Reference Scripts for Image Classification. These scripts are official reference implementations from PyTorch Vision that provide training and quantization utilities for image classification models."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "46CgrVgjg3E-",
"outputId": "7fb20ebe-d7fd-43fa-dc9b-ebbedf31575e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2025-05-22 16:30:12-- https://raw.githubusercontent.com/pytorch/vision/main/references/classification/presets.py\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.110.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 3885 (3.8K) [text/plain]\n",
"Saving to: ‘presets.py’\n",
"\n",
"presets.py 100%[===================>] 3.79K --.-KB/s in 0s \n",
"\n",
"2025-05-22 16:30:12 (12.8 MB/s) - ‘presets.py’ saved [3885/3885]\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2025-05-22 16:30:12-- https://raw.githubusercontent.com/pytorch/vision/main/references/classification/sampler.py\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.108.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 2395 (2.3K) [text/plain]\n",
"Saving to: ‘sampler.py’\n",
"\n",
"sampler.py 100%[===================>] 2.34K --.-KB/s in 0s \n",
"\n",
"2025-05-22 16:30:12 (18.4 MB/s) - ‘sampler.py’ saved [2395/2395]\n",
"\n",
"--2025-05-22 16:30:12-- https://raw.githubusercontent.com/pytorch/vision/main/references/classification/train.py\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.111.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 23324 (23K) [text/plain]\n",
"Saving to: ‘train.py’\n",
"\n",
"train.py 100%[===================>] 22.78K --.-KB/s in 0.01s \n",
"\n",
"2025-05-22 16:30:13 (2.28 MB/s) - ‘train.py’ saved [23324/23324]\n",
"\n",
"--2025-05-22 16:30:13-- https://raw.githubusercontent.com/pytorch/vision/main/references/classification/train_quantization.py\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.108.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 11647 (11K) [text/plain]\n",
"Saving to: ‘train_quantization.py’\n",
"\n",
"train_quantization. 100%[===================>] 11.37K --.-KB/s in 0.001s \n",
"\n",
"2025-05-22 16:30:13 (12.7 MB/s) - ‘train_quantization.py’ saved [11647/11647]\n",
"\n",
"--2025-05-22 16:30:13-- https://raw.githubusercontent.com/pytorch/vision/main/references/classification/transformers.py\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.111.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
"HTTP request sent, awaiting response... 404 Not Found\n",
"2025-05-22 16:30:13 ERROR 404: Not Found.\n",
"\n",
"--2025-05-22 16:30:13-- https://raw.githubusercontent.com/pytorch/vision/main/references/classification/utils.py\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.109.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 15791 (15K) [text/plain]\n",
"Saving to: ‘utils.py’\n",
"\n",
"utils.py 100%[===================>] 15.42K --.-KB/s in 0.01s \n",
"\n",
"2025-05-22 16:30:13 (1.43 MB/s) - ‘utils.py’ saved [15791/15791]\n",
"\n"
]
}
],
"source": [
"! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/presets.py\n",
"! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/sampler.py\n",
"! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/train.py\n",
"! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/train_quantization.py\n",
"! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/transformers.py\n",
"! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/utils.py"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HFASsisvIrIb"
},
"source": [
"In this block, we build a “loss” function for our sequential policy gradient algorithm. When the right data is plugged in, the gradient of this loss is equal to the policy gradient."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "EaBokYCpg3FA"
},
"outputs": [],
"source": [
"import types\n",
"from typing import List, Callable\n",
"\n",
"import torch\n",
"from torch import nn, Tensor\n",
"from torch.nn import functional as F\n",
"from torchvision.models.resnet import BasicBlock\n",
"\n",
"\n",
"def trp_criterion(trp_blocks: nn.ModuleList, shared_head: Callable, criterion: Callable, lambdas: List[float], hidden_state: Tensor, logits: Tensor, targets: Tensor, loss_normalization=False):\n",
" losses, rewards = criterion(logits, targets)\n",
" returns = torch.ones_like(rewards, dtype=torch.float32, device=rewards.device)\n",
" if loss_normalization:\n",
" coeff = torch.mean(losses).detach()\n",
"\n",
" embeds = [hidden_state]\n",
" predictions = []\n",
" for k, w in enumerate(lambdas):\n",
" embeds.append(trp_blocks[k](embeds[-1]))\n",
" predictions.append(shared_head(embeds[-1]))\n",
" returns = returns + w * rewards\n",
" replica_losses, rewards = criterion(predictions[-1], targets, rewards)\n",
" losses = losses + replica_losses\n",
" loss = torch.mean(losses * returns)\n",
"\n",
" if loss_normalization:\n",
" with torch.no_grad():\n",
" coeff = torch.exp(coeff) / torch.exp(loss.detach())\n",
" loss = coeff * loss\n",
"\n",
" return loss"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_Ig0Jm2w8DPH"
},
"source": [
"In this block, we build a TPBlock for the Task Replica Prediction (TRP) module; This implementation provides the backbone without the shared prediction head."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "wkBlmJT96jZj"
},
"outputs": [],
"source": [
"class TPBlock(nn.Module):\n",
" def __init__(self, depths: int, inplanes: int, planes: int):\n",
" super(TPBlock, self).__init__()\n",
"\n",
" blocks = [BasicBlock(inplanes=inplanes, planes=planes) for _ in range(depths)]\n",
" self.blocks = nn.Sequential(*blocks)\n",
" for name, param in self.blocks.named_parameters():\n",
" if 'conv' in name:\n",
" nn.init.zeros_(param) # Initialize weights\n",
" elif 'downsample' in name:\n",
" nn.init.zeros_(param) # Initialize biases\n",
"\n",
" def forward(self, x):\n",
" return self.blocks(x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UGxQdKZaF2NT"
},
"source": [
"This implementation enables ResNet retraining in SPG mode.\n",
"\n",
"Components:\n",
"-------------------------------------------------------------------------------\n",
"1. gen_criterion()\n",
" - Purpose: compute per-sample losses and positional masks\n",
"\n",
"2. gen_shared_head()\n",
" - Purpose: Implements a shared prediction head that processes convolutional feature maps for prediction.\n",
"\n",
"3. gen_forward()\n",
" - Purpose: Extended forward pass supporting both traditional inference and SPG retraining."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "kTZWkoLr8cfE"
},
"outputs": [],
"source": [
"class ResNetConfig:\n",
" @staticmethod\n",
" def gen_criterion(label_smoothing=0.0, top_k=1):\n",
" def func(input, target, mask=None):\n",
" \"\"\"\n",
" Args:\n",
" input (Tensor): Input tensor of shape [B, C].\n",
" target (Tensor): Target labels of shape [B] or [B, C].\n",
"\n",
" Returns:\n",
" loss (Tensor): Scalar tensor representing the loss.\n",
" mask (Tensor): Boolean mask tensor of shape [B].\n",
" \"\"\"\n",
" label = torch.argmax(target, dim=1) if label_smoothing > 0.0 else target\n",
"\n",
" unmasked_loss = F.cross_entropy(input, label, reduction=\"none\", label_smoothing=label_smoothing)\n",
" if mask is None:\n",
" mask = torch.ones_like(unmasked_loss, dtype=torch.float32, device=target.device)\n",
" losses = mask * unmasked_loss\n",
"\n",
" with torch.no_grad():\n",
" topk_values, topk_indices = torch.topk(input, top_k, dim=-1)\n",
" mask = mask * torch.eq(topk_indices, label[:, None]).any(dim=-1).to(input.dtype)\n",
"\n",
" return losses, mask\n",
" return func\n",
"\n",
" @staticmethod\n",
" def gen_shared_head(self):\n",
" def func(x):\n",
" \"\"\"\n",
" Args:\n",
" x (Tensor): Hidden State tensor of shape [B, C, H, W].\n",
"\n",
" Returns:\n",
" logits (Tensor): Logits tensor of shape [B, C].\n",
" \"\"\"\n",
" x = self.layer4(x)\n",
" x = self.avgpool(x)\n",
" x = torch.flatten(x, 1)\n",
" logits = self.fc(x)\n",
" return logits\n",
" return func\n",
"\n",
" @staticmethod\n",
" def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):\n",
" def func(self, x: Tensor, targets=None) -> Tensor:\n",
" x = self.conv1(x)\n",
" x = self.bn1(x)\n",
" x = self.relu(x)\n",
" x = self.maxpool(x)\n",
"\n",
" x = self.layer1(x)\n",
" x = self.layer2(x)\n",
" hidden_state = self.layer3(x)\n",
" x = self.layer4(hidden_state)\n",
" x = self.avgpool(x)\n",
" x = torch.flatten(x, 1)\n",
" logits = self.fc(x)\n",
"\n",
" if self.training:\n",
" shared_head = ResNetConfig.gen_shared_head(self)\n",
" criterion = ResNetConfig.gen_criterion(label_smoothing=label_smoothing, top_k=top_k)\n",
"\n",
" loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, hidden_state, logits, targets, loss_normalization=loss_normalization)\n",
"\n",
" return logits, loss\n",
"\n",
" return logits\n",
"\n",
" return func"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cCn6vwItH1CW"
},
"source": [
"Applies TRP modules to the base ResNet (main backbone). The k-th TRP module corresponding to a deeper ResNet variant with an additional depth of 3 * sum(depths[:k+1])."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "wXQF0oISH5Yp"
},
"outputs": [],
"source": [
"def apply_trp(model, depths: List[int], planes: int, lambdas: List[float], **kwargs):\n",
" print(\"✅ Applying TRP to ResNet for Image Classification...\")\n",
" model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, inplanes=planes, planes=planes) for d in depths])\n",
" model.forward = types.MethodType(ResNetConfig.gen_forward(lambdas), model)\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kDjSAv3PJr7P"
},
"source": [
"The following is a training script for classification models, primarily based on the official TorchVision `train.py` reference implementation. We have made two modifications:\n",
"\n",
"Adding TRP Modules: We integrate TRP modules into the base model architecture before training begins:\n",
"\n",
"```python\n",
"if args.apply_trp:\n",
" model = apply_trp(model, args.trp_depths, args.trp_planes, args.trp_lambdas)\n",
"```\n",
"Removing TRP Modules: We remove the TRP components from the base model before saving the base model:\n",
"```python\n",
"if args.output_dir:\n",
" checkpoint = {\n",
" \"model\": model.state_dict() if not args.apply_trp else {k: v for k, v in model.state_dict().items() if not k.startswith(\"trp_blocks\")},\n",
" \"optimizer\": optimizer.state_dict(),\n",
" \"lr_scheduler\": lr_scheduler.state_dict(),\n",
" \"epoch\": epoch,\n",
" \"args\": args,\n",
" }\n",
" utils.save_on_master(checkpoint, os.path.join(args.output_dir, f\"model_{epoch}.pth\"))\n",
" utils.save_on_master(checkpoint, os.path.join(args.output_dir, \"checkpoint.pth\"))\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "hK4Y7Sqv4xUa"
},
"outputs": [],
"source": [
"import datetime\n",
"import os\n",
"import time\n",
"import warnings\n",
"\n",
"import presets\n",
"import torch\n",
"import torch.utils.data\n",
"import torchvision\n",
"import utils\n",
"from torch import nn\n",
"from torchvision.transforms.functional import InterpolationMode\n",
"\n",
"\n",
"def load_data(traindir, valdir):\n",
" # Data loading code\n",
" print(\"Loading data\")\n",
" interpolation = InterpolationMode(\"bilinear\")\n",
"\n",
" print(\"Loading training data\")\n",
" st = time.time()\n",
" dataset = torchvision.datasets.ImageFolder(\n",
" traindir,\n",
" presets.ClassificationPresetTrain(crop_size=224, interpolation=interpolation, auto_augment_policy=None, random_erase_prob=0.0, ra_magnitude=9, augmix_severity=3),\n",
" )\n",
" print(\"Took\", time.time() - st)\n",
"\n",
" print(\"Loading validation data\")\n",
" dataset_test = torchvision.datasets.ImageFolder(\n",
" valdir,\n",
" presets.ClassificationPresetEval(crop_size=224, resize_size=256, interpolation=interpolation)\n",
" )\n",
"\n",
" print(\"Creating data loaders\")\n",
" train_sampler = torch.utils.data.RandomSampler(dataset)\n",
" test_sampler = torch.utils.data.SequentialSampler(dataset_test)\n",
"\n",
" return dataset, dataset_test, train_sampler, test_sampler\n",
"\n",
"\n",
"\n",
"def train_one_epoch(model, optimizer, data_loader, device, epoch, args):\n",
" model.train()\n",
" metric_logger = utils.MetricLogger(delimiter=\" \")\n",
" metric_logger.add_meter(\"lr\", utils.SmoothedValue(window_size=1, fmt=\"{value}\"))\n",
" metric_logger.add_meter(\"img/s\", utils.SmoothedValue(window_size=10, fmt=\"{value}\"))\n",
"\n",
" header = f\"Epoch: [{epoch}]\"\n",
" for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):\n",
" start_time = time.time()\n",
" image, target = image.to(device), target.to(device)\n",
" with torch.amp.autocast(\"cuda\", enabled=False):\n",
" output, loss = model(image, target)\n",
"\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))\n",
" batch_size = image.shape[0]\n",
" metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0][\"lr\"])\n",
" metric_logger.meters[\"acc1\"].update(acc1.item(), n=batch_size)\n",
" metric_logger.meters[\"acc5\"].update(acc5.item(), n=batch_size)\n",
" metric_logger.meters[\"img/s\"].update(batch_size / (time.time() - start_time))\n",
"\n",
"\n",
"def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=\"\"):\n",
" model.eval()\n",
" metric_logger = utils.MetricLogger(delimiter=\" \")\n",
" header = f\"Test: {log_suffix}\"\n",
"\n",
" num_processed_samples = 0\n",
" with torch.inference_mode():\n",
" for image, target in metric_logger.log_every(data_loader, print_freq, header):\n",
" image = image.to(device, non_blocking=True)\n",
" target = target.to(device, non_blocking=True)\n",
" output = model(image)\n",
" loss = criterion(output, target)\n",
"\n",
" acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))\n",
" # FIXME need to take into account that the datasets\n",
" # could have been padded in distributed setup\n",
" batch_size = image.shape[0]\n",
" metric_logger.update(loss=loss.item())\n",
" metric_logger.meters[\"acc1\"].update(acc1.item(), n=batch_size)\n",
" metric_logger.meters[\"acc5\"].update(acc5.item(), n=batch_size)\n",
" num_processed_samples += batch_size\n",
" # gather the stats from all processes\n",
"\n",
" num_processed_samples = utils.reduce_across_processes(num_processed_samples)\n",
" if (\n",
" hasattr(data_loader.dataset, \"__len__\")\n",
" and len(data_loader.dataset) != num_processed_samples\n",
" and torch.distributed.get_rank() == 0\n",
" ):\n",
" # See FIXME above\n",
" warnings.warn(\n",
" f\"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} \"\n",
" \"samples were used for the validation, which might bias the results. \"\n",
" \"Try adjusting the batch size and / or the world size. \"\n",
" \"Setting the world size to 1 is always a safe bet.\"\n",
" )\n",
"\n",
" metric_logger.synchronize_between_processes()\n",
"\n",
" print(f\"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}\")\n",
" return metric_logger.acc1.global_avg\n",
"\n",
"\n",
"def main(args):\n",
" if args.output_dir:\n",
" utils.mkdir(args.output_dir)\n",
" print(args)\n",
"\n",
" device = torch.device(args.device)\n",
"\n",
" if args.use_deterministic_algorithms:\n",
" torch.backends.cudnn.benchmark = False\n",
" torch.use_deterministic_algorithms(True)\n",
" else:\n",
" torch.backends.cudnn.benchmark = True\n",
"\n",
" train_dir = os.path.join(args.data_path, \"train\")\n",
" val_dir = os.path.join(args.data_path, \"val\")\n",
" dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir)\n",
"\n",
" num_classes = len(dataset.classes)\n",
" data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=16, pin_memory=True, collate_fn=None)\n",
" data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=64, sampler=test_sampler, num_workers=16, pin_memory=True)\n",
"\n",
" print(\"Creating model\")\n",
" model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)\n",
" if args.apply_trp:\n",
" model = apply_trp(model, args.trp_depths, args.trp_planes, args.trp_lambdas)\n",
" model.to(device)\n",
"\n",
" parameters = utils.set_weight_decay(model, args.weight_decay, norm_weight_decay=None, custom_keys_weight_decay=None)\n",
" optimizer = torch.optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=False)\n",
"\n",
" main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)\n",
" warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs)\n",
" lr_scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs])\n",
"\n",
"\n",
" print(\"Start training\")\n",
" start_time = time.time()\n",
" for epoch in range(args.epochs):\n",
" train_one_epoch(model, optimizer, data_loader, device, epoch, args)\n",
" lr_scheduler.step()\n",
" evaluate(model, nn.CrossEntropyLoss(), data_loader_test, device=device)\n",
" if args.output_dir:\n",
" checkpoint = {\n",
" \"model\": model.state_dict() if not args.apply_trp else {k: v for k, v in model.state_dict().items() if not k.startswith(\"trp_blocks\")}, # NOTE: remove TRP heads\n",
" \"optimizer\": optimizer.state_dict(),\n",
" \"lr_scheduler\": lr_scheduler.state_dict(),\n",
" \"epoch\": epoch,\n",
" \"args\": args,\n",
" }\n",
" utils.save_on_master(checkpoint, os.path.join(args.output_dir, f\"model_{epoch}.pth\"))\n",
" utils.save_on_master(checkpoint, os.path.join(args.output_dir, \"checkpoint.pth\"))\n",
"\n",
" total_time = time.time() - start_time\n",
" total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n",
" print(f\"Training time {total_time_str}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SV8s5k49KwgS"
},
"source": [
"Prepare the [ImageNet](http://image-net.org/) dataset manually and place it in `/path/to/imagenet`. For image classification examples, pass the argument `--data-path=/path/to/imagenet` to the training script. The extracted dataset directory should follow this structure:\n",
"```setup\n",
"/path/to/imagenet/:\n",
" train/:\n",
" n01440764:\n",
" n01440764_18.JPEG ...\n",
" n01443537:\n",
" n01443537_2.JPEG ...\n",
" val/:\n",
" n01440764:\n",
" ILSVRC2012_val_00000293.JPEG ...\n",
" n01443537:\n",
" ILSVRC2012_val_00000236.JPEG ...\n",
"```\n",
"\n",
"Now you can apply the SPG algorithm in model retraining.\n",
"\n",
"**Implementation Note:**\n",
"\n",
"- This demonstration runs on Google Colab using a single GPU configuration\n",
"- Performance Improvement: Enhances ResNet18 validation accuracy (ACC@1) from 69.76% to 70.09%\n",
"- For optimal results:\n",
" - Refer to our README.md for complete setup instructions\n",
" - Recommended hardware: 4× RTX A6000 GPUs"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "UDZxDNfT4xUb",
"outputId": "bcf86aa0-eb77-4815-e0fa-05997f1e1f1b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"namespace(data_path='/home/cs/Documents/datasets/imagenet', model='resnet18', device='cuda', batch_size=512, epochs=6, lr=0.0004, momentum=0.9, weight_decay=0.0001, lr_warmup_epochs=1, lr_warmup_decay=0.0, lr_step_size=2, lr_gamma=0.5, print_freq=100, output_dir='resnet18', use_deterministic_algorithms=False, weights='ResNet18_Weights.IMAGENET1K_V1', apply_trp=True, trp_depths=[3, 3, 3], trp_planes=256, trp_lambdas=[0.4, 0.2, 0.1])\n",
"Loading data\n",
"Loading training data\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Took 1.9062905311584473\n",
"Loading validation data\n",
"Creating data loaders\n",
"Creating model\n",
"✅ Applying TRP to ResNet for Image Classification...\n",
"Start training\n",
"Epoch: [0] [ 0/2503] eta: 10:05:09 lr: 0.0 img/s: 81.93631887515438 loss: 0.7334 (0.7334) acc1: 71.2891 (71.2891) acc5: 86.1328 (86.1328) time: 14.5065 data: 8.2577 max mem: 19119\n",
"Epoch: [0] [ 100/2503] eta: 0:29:06 lr: 0.0 img/s: 862.8257862120394 loss: 0.7145 (0.7308) acc1: 69.5312 (69.6105) acc5: 87.6953 (87.3704) time: 0.5927 data: 0.0003 max mem: 19119\n",
"Epoch: [0] [ 200/2503] eta: 0:25:23 lr: 0.0 img/s: 860.6862569301302 loss: 0.7355 (0.7353) acc1: 68.9453 (69.3427) acc5: 86.9141 (87.3125) time: 0.5966 data: 0.0003 max mem: 19119\n",
"Epoch: [0] [ 300/2503] eta: 0:23:29 lr: 0.0 img/s: 860.0754340960929 loss: 0.7159 (0.7314) acc1: 69.1406 (69.3463) acc5: 87.5000 (87.3676) time: 0.5967 data: 0.0003 max mem: 19119\n",
"Epoch: [0] [ 400/2503] eta: 0:22:03 lr: 0.0 img/s: 859.0790234707376 loss: 0.7594 (0.7361) acc1: 67.9688 (69.2283) acc5: 86.7188 (87.3232) time: 0.5960 data: 0.0003 max mem: 19119\n",
"Epoch: [0] [ 500/2503] eta: 0:20:46 lr: 0.0 img/s: 859.7486624250741 loss: 0.7204 (0.7343) acc1: 69.7266 (69.2396) acc5: 87.5000 (87.3827) time: 0.5958 data: 0.0003 max mem: 19119\n",
"Epoch: [0] [ 600/2503] eta: 0:19:36 lr: 0.0 img/s: 861.5204710456711 loss: 0.7483 (0.7345) acc1: 69.5312 (69.2449) acc5: 86.7188 (87.3950) time: 0.5958 data: 0.0002 max mem: 19119\n",
"Epoch: [0] [ 700/2503] eta: 0:18:28 lr: 0.0 img/s: 858.9934592 loss: 0.7225 (0.7350) acc1: 68.5547 (69.2331) acc5: 87.6953 (87.3738) time: 0.5958 data: 0.0003 max mem: 19119\n",
"Epoch: [0] [ 800/2503] eta: 0:17:23 lr: 0.0 img/s: 859.4995325247505 loss: 0.7639 (0.7355) acc1: 69.7266 (69.2177) acc5: 86.7188 (87.3578) time: 0.5961 data: 0.0002 max mem: 19119\n",
"Epoch: [0] [ 900/2503] eta: 0:16:18 lr: 0.0 img/s: 860.8087326554238 loss: 0.7118 (0.7349) acc1: 69.9219 (69.2440) acc5: 87.6953 (87.3548) time: 0.5961 data: 0.0002 max mem: 19119\n",
"Epoch: [0] [1000/2503] eta: 0:15:15 lr: 0.0 img/s: 859.5858857924882 loss: 0.7224 (0.7351) acc1: 69.3359 (69.2485) acc5: 87.3047 (87.3624) time: 0.5958 data: 0.0003 max mem: 19119\n",
"Epoch: [0] [1100/2503] eta: 0:14:12 lr: 0.0 img/s: 858.8670339725992 loss: 0.7240 (0.7360) acc1: 68.9453 (69.2212) acc5: 87.1094 (87.3361) time: 0.5958 data: 0.0002 max mem: 19119\n",
"Epoch: [0] [1200/2503] eta: 0:13:10 lr: 0.0 img/s: 861.4696676125856 loss: 0.7126 (0.7364) acc1: 68.3594 (69.1878) acc5: 87.3047 (87.3190) time: 0.5960 data: 0.0002 max mem: 19119\n",
"Epoch: [0] [1300/2503] eta: 0:12:09 lr: 0.0 img/s: 859.3643608581464 loss: 0.7291 (0.7367) acc1: 68.9453 (69.1669) acc5: 86.7188 (87.2990) time: 0.5959 data: 0.0002 max mem: 19119\n",
"Epoch: [0] [1400/2503] eta: 0:11:07 lr: 0.0 img/s: 861.1477063020853 loss: 0.7267 (0.7372) acc1: 69.9219 (69.1624) acc5: 87.1094 (87.2990) time: 0.5960 data: 0.0002 max mem: 19119\n",
"Epoch: [0] [1500/2503] eta: 0:10:06 lr: 0.0 img/s: 859.0494692253935 loss: 0.7234 (0.7374) acc1: 69.1406 (69.1607) acc5: 87.3047 (87.2939) time: 0.5959 data: 0.0003 max mem: 19119\n",
"Epoch: [0] [1600/2503] eta: 0:09:05 lr: 0.0 img/s: 860.660386236062 loss: 0.7456 (0.7374) acc1: 69.3359 (69.1730) acc5: 87.5000 (87.3019) time: 0.5960 data: 0.0002 max mem: 19119\n",
"Epoch: [0] [1700/2503] eta: 0:08:04 lr: 0.0 img/s: 858.9515423647326 loss: 0.7548 (0.7372) acc1: 69.1406 (69.1773) acc5: 87.5000 (87.3198) time: 0.5959 data: 0.0002 max mem: 19119\n",
"Epoch: [0] [1800/2503] eta: 0:07:04 lr: 0.0 img/s: 860.6800478217115 loss: 0.7596 (0.7375) acc1: 67.1875 (69.1614) acc5: 87.1094 (87.3191) time: 0.5958 data: 0.0002 max mem: 19119\n",
"Epoch: [0] [1900/2503] eta: 0:06:03 lr: 0.0 img/s: 859.6578027499652 loss: 0.7465 (0.7375) acc1: 68.3594 (69.1633) acc5: 86.7188 (87.3222) time: 0.5959 data: 0.0002 max mem: 19119\n",
"Epoch: [0] [2000/2503] eta: 0:05:03 lr: 0.0 img/s: 860.6507282423033 loss: 0.7385 (0.7375) acc1: 69.3359 (69.1609) acc5: 87.3047 (87.3233) time: 0.5959 data: 0.0002 max mem: 19119\n",
"Epoch: [0] [2100/2503] eta: 0:04:02 lr: 0.0 img/s: 860.72592834858 loss: 0.7153 (0.7373) acc1: 69.3359 (69.1710) acc5: 87.3047 (87.3230) time: 0.5961 data: 0.0004 max mem: 19119\n",
"Epoch: [0] [2200/2503] eta: 0:03:02 lr: 0.0 img/s: 859.2460775467988 loss: 0.7307 (0.7371) acc1: 68.9453 (69.1861) acc5: 87.5000 (87.3380) time: 0.5960 data: 0.0004 max mem: 19119\n",
"Epoch: [0] [2300/2503] eta: 0:02:02 lr: 0.0 img/s: 859.2639554931892 loss: 0.7077 (0.7367) acc1: 69.3359 (69.1971) acc5: 87.6953 (87.3516) time: 0.5959 data: 0.0004 max mem: 19119\n",
"Epoch: [0] [2400/2503] eta: 0:01:01 lr: 0.0 img/s: 861.341130585524 loss: 0.7279 (0.7365) acc1: 68.5547 (69.1921) acc5: 86.9141 (87.3412) time: 0.5961 data: 0.0004 max mem: 19119\n",
"Epoch: [0] [2500/2503] eta: 0:00:01 lr: 0.0 img/s: 861.8382147793436 loss: 0.7469 (0.7368) acc1: 68.5547 (69.1894) acc5: 87.5000 (87.3423) time: 0.5955 data: 0.0005 max mem: 19119\n",
"Epoch: [0] Total time: 0:25:05\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/cs/anaconda3/envs/csenv/lib/python3.9/site-packages/torch/optim/lr_scheduler.py:243: UserWarning: The epoch parameter in `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: https://github.com/pytorch/pytorch/issues/new/choose.\n",
" warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test: [ 0/782] eta: 0:23:05 loss: 0.6283 (0.6283) acc1: 89.0625 (89.0625) acc5: 95.3125 (95.3125) time: 1.7719 data: 1.3111 max mem: 19119\n",
"Test: [100/782] eta: 0:00:30 loss: 1.0688 (0.9382) acc1: 76.5625 (76.2840) acc5: 89.0625 (92.1875) time: 0.0399 data: 0.0263 max mem: 19119\n",
"Test: [200/782] eta: 0:00:21 loss: 0.9244 (0.9143) acc1: 73.4375 (75.8240) acc5: 95.3125 (93.2369) time: 0.0244 data: 0.0107 max mem: 19119\n",
"Test: [300/782] eta: 0:00:17 loss: 0.8615 (0.9072) acc1: 76.5625 (76.1991) acc5: 92.1875 (93.5008) time: 0.0381 data: 0.0244 max mem: 19119\n",
"Test: [400/782] eta: 0:00:13 loss: 1.6977 (1.0440) acc1: 59.3750 (73.6323) acc5: 82.8125 (91.7472) time: 0.0313 data: 0.0176 max mem: 19119\n",
"Test: [500/782] eta: 0:00:09 loss: 1.6021 (1.1237) acc1: 54.6875 (72.0964) acc5: 85.9375 (90.5845) time: 0.0247 data: 0.0109 max mem: 19119\n",
"Test: [600/782] eta: 0:00:06 loss: 1.3631 (1.1858) acc1: 64.0625 (70.8741) acc5: 84.3750 (89.7853) time: 0.0291 data: 0.0153 max mem: 19119\n",
"Test: [700/782] eta: 0:00:02 loss: 1.2494 (1.2361) acc1: 68.7500 (69.9313) acc5: 87.5000 (89.1115) time: 0.0391 data: 0.0254 max mem: 19119\n",
"Test: Total time: 0:00:26\n",
"Test: Acc@1 69.846 Acc@5 89.136\n",
"Epoch: [1] [ 0/2503] eta: 4:27:27 lr: 0.0004 img/s: 861.3684242192573 loss: 0.7611 (0.7611) acc1: 68.5547 (68.5547) acc5: 86.1328 (86.1328) time: 6.4115 data: 5.8170 max mem: 19119\n",
"Epoch: [1] [ 100/2503] eta: 0:26:25 lr: 0.0004 img/s: 856.9149263546342 loss: 0.7538 (0.7542) acc1: 70.5078 (69.0536) acc5: 87.5000 (87.1364) time: 0.5982 data: 0.0003 max mem: 19119\n",
"Epoch: [1] [ 200/2503] eta: 0:24:10 lr: 0.0004 img/s: 854.0172713632207 loss: 0.7744 (0.7573) acc1: 69.7266 (69.2990) acc5: 88.0859 (87.3785) time: 0.5998 data: 0.0003 max mem: 19119\n",
"Epoch: [1] [ 300/2503] eta: 0:22:45 lr: 0.0004 img/s: 856.0483105922384 loss: 0.7551 (0.7613) acc1: 69.1406 (69.2834) acc5: 87.3047 (87.4611) time: 0.5996 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [ 400/2503] eta: 0:21:32 lr: 0.0004 img/s: 854.8386016604893 loss: 0.7931 (0.7645) acc1: 68.5547 (69.3004) acc5: 87.3047 (87.4698) time: 0.5987 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [ 500/2503] eta: 0:20:24 lr: 0.0004 img/s: 855.3431965742996 loss: 0.7744 (0.7684) acc1: 68.1641 (69.2853) acc5: 86.9141 (87.4361) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [ 600/2503] eta: 0:19:19 lr: 0.0004 img/s: 855.1112541063571 loss: 0.7860 (0.7730) acc1: 69.1406 (69.2310) acc5: 86.7188 (87.3941) time: 0.5988 data: 0.0003 max mem: 19119\n",
"Epoch: [1] [ 700/2503] eta: 0:18:15 lr: 0.0004 img/s: 856.4904515045232 loss: 0.7908 (0.7773) acc1: 68.7500 (69.1746) acc5: 86.7188 (87.3543) time: 0.5985 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [ 800/2503] eta: 0:17:13 lr: 0.0004 img/s: 858.0146361031335 loss: 0.8157 (0.7805) acc1: 68.5547 (69.1660) acc5: 87.6953 (87.3181) time: 0.5991 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [ 900/2503] eta: 0:16:11 lr: 0.0004 img/s: 854.9138104963116 loss: 0.7641 (0.7825) acc1: 69.5312 (69.1807) acc5: 88.8672 (87.3346) time: 0.5989 data: 0.0003 max mem: 19119\n",
"Epoch: [1] [1000/2503] eta: 0:15:09 lr: 0.0004 img/s: 855.7491430488731 loss: 0.8024 (0.7852) acc1: 68.1641 (69.1506) acc5: 86.5234 (87.3234) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [1100/2503] eta: 0:14:08 lr: 0.0004 img/s: 856.0848253972304 loss: 0.8099 (0.7872) acc1: 69.1406 (69.1564) acc5: 86.7188 (87.3231) time: 0.5992 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [1200/2503] eta: 0:13:07 lr: 0.0004 img/s: 855.6028761225017 loss: 0.8307 (0.7894) acc1: 68.5547 (69.1569) acc5: 87.1094 (87.3258) time: 0.5989 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [1300/2503] eta: 0:12:06 lr: 0.0004 img/s: 855.8589613399885 loss: 0.8206 (0.7913) acc1: 68.9453 (69.1304) acc5: 87.3047 (87.3177) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [1] [1400/2503] eta: 0:11:05 lr: 0.0004 img/s: 856.6045604019511 loss: 0.8454 (0.7936) acc1: 68.1641 (69.1019) acc5: 86.9141 (87.2906) time: 0.5989 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [1500/2503] eta: 0:10:04 lr: 0.0004 img/s: 854.944442321167 loss: 0.8428 (0.7960) acc1: 68.1641 (69.0905) acc5: 87.3047 (87.2706) time: 0.5990 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [1600/2503] eta: 0:09:04 lr: 0.0004 img/s: 855.0727794914757 loss: 0.7906 (0.7974) acc1: 69.5312 (69.0922) acc5: 87.1094 (87.2686) time: 0.5990 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [1700/2503] eta: 0:08:03 lr: 0.0004 img/s: 855.4958499669949 loss: 0.8199 (0.7989) acc1: 69.7266 (69.0854) acc5: 87.1094 (87.2704) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [1800/2503] eta: 0:07:03 lr: 0.0004 img/s: 855.0251166287029 loss: 0.8257 (0.8007) acc1: 70.1172 (69.0869) acc5: 87.5000 (87.2656) time: 0.5988 data: 0.0003 max mem: 19119\n",
"Epoch: [1] [1900/2503] eta: 0:06:03 lr: 0.0004 img/s: 856.9867390518363 loss: 0.7952 (0.8018) acc1: 68.7500 (69.0943) acc5: 87.3047 (87.2670) time: 0.5989 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [2000/2503] eta: 0:05:02 lr: 0.0004 img/s: 854.3927252530574 loss: 0.8402 (0.8032) acc1: 68.5547 (69.0964) acc5: 87.1094 (87.2747) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [2100/2503] eta: 0:04:02 lr: 0.0004 img/s: 855.2427067851231 loss: 0.8451 (0.8042) acc1: 68.3594 (69.1089) acc5: 87.3047 (87.2816) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [2200/2503] eta: 0:03:02 lr: 0.0004 img/s: 853.8318747507567 loss: 0.8314 (0.8058) acc1: 68.9453 (69.1012) acc5: 87.3047 (87.2716) time: 0.5989 data: 0.0003 max mem: 19119\n",
"Epoch: [1] [2300/2503] eta: 0:02:02 lr: 0.0004 img/s: 855.3312728222841 loss: 0.8350 (0.8070) acc1: 68.7500 (69.0993) acc5: 86.3281 (87.2549) time: 0.5988 data: 0.0003 max mem: 19119\n",
"Epoch: [1] [2400/2503] eta: 0:01:01 lr: 0.0004 img/s: 855.1613103361218 loss: 0.8206 (0.8084) acc1: 68.5547 (69.0927) acc5: 86.9141 (87.2551) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [1] [2500/2503] eta: 0:00:01 lr: 0.0004 img/s: 856.7190414886559 loss: 0.8286 (0.8094) acc1: 69.1406 (69.1000) acc5: 87.1094 (87.2599) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [1] Total time: 0:25:05\n",
"Test: [ 0/782] eta: 0:16:19 loss: 0.5636 (0.5636) acc1: 87.5000 (87.5000) acc5: 96.8750 (96.8750) time: 1.2525 data: 1.2385 max mem: 19119\n",
"Test: [100/782] eta: 0:00:31 loss: 1.0393 (0.9414) acc1: 76.5625 (76.9647) acc5: 90.6250 (92.2958) time: 0.0417 data: 0.0280 max mem: 19119\n",
"Test: [200/782] eta: 0:00:22 loss: 0.8964 (0.9176) acc1: 73.4375 (76.4614) acc5: 95.3125 (93.3147) time: 0.0249 data: 0.0112 max mem: 19119\n",
"Test: [300/782] eta: 0:00:17 loss: 0.7984 (0.9094) acc1: 79.6875 (76.7130) acc5: 92.1875 (93.6150) time: 0.0311 data: 0.0173 max mem: 19119\n",
"Test: [400/782] eta: 0:00:13 loss: 1.7745 (1.0483) acc1: 57.8125 (73.9635) acc5: 84.3750 (91.8758) time: 0.0328 data: 0.0190 max mem: 19119\n",
"Test: [500/782] eta: 0:00:09 loss: 1.6435 (1.1264) acc1: 59.3750 (72.4239) acc5: 84.3750 (90.7934) time: 0.0328 data: 0.0190 max mem: 19119\n",
"Test: [600/782] eta: 0:00:06 loss: 1.3057 (1.1915) acc1: 62.5000 (71.0483) acc5: 85.9375 (90.0010) time: 0.0400 data: 0.0261 max mem: 19119\n",
"Test: [700/782] eta: 0:00:02 loss: 1.2212 (1.2428) acc1: 70.3125 (70.0985) acc5: 87.5000 (89.3010) time: 0.0253 data: 0.0115 max mem: 19119\n",
"Test: Total time: 0:00:26\n",
"Test: Acc@1 70.000 Acc@5 89.320\n",
"Epoch: [2] [ 0/2503] eta: 4:06:15 lr: 0.0004 img/s: 867.4756359685 loss: 0.8414 (0.8414) acc1: 67.9688 (67.9688) acc5: 86.1328 (86.1328) time: 5.9030 data: 5.3128 max mem: 19119\n",
"Epoch: [2] [ 100/2503] eta: 0:25:53 lr: 0.0004 img/s: 859.1872918530421 loss: 0.8472 (0.8456) acc1: 68.9453 (69.1194) acc5: 86.7188 (87.2892) time: 0.5963 data: 0.0003 max mem: 19119\n",
"Epoch: [2] [ 200/2503] eta: 0:23:52 lr: 0.0004 img/s: 857.5945509684602 loss: 0.8563 (0.8443) acc1: 68.1641 (69.1649) acc5: 86.3281 (87.1852) time: 0.5972 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [ 300/2503] eta: 0:22:31 lr: 0.0004 img/s: 859.3399450658505 loss: 0.8386 (0.8440) acc1: 69.1406 (69.0790) acc5: 87.3047 (87.1762) time: 0.5967 data: 0.0003 max mem: 19119\n",
"Epoch: [2] [ 400/2503] eta: 0:21:21 lr: 0.0004 img/s: 859.825426282471 loss: 0.8455 (0.8446) acc1: 69.3359 (69.0422) acc5: 87.3047 (87.1591) time: 0.5963 data: 0.0003 max mem: 19119\n",
"Epoch: [2] [ 500/2503] eta: 0:20:15 lr: 0.0004 img/s: 858.6002202991031 loss: 0.8400 (0.8448) acc1: 67.9688 (69.0373) acc5: 87.5000 (87.1640) time: 0.5962 data: 0.0003 max mem: 19119\n",
"Epoch: [2] [ 600/2503] eta: 0:19:11 lr: 0.0004 img/s: 859.6997885462696 loss: 0.8544 (0.8467) acc1: 68.3594 (69.0217) acc5: 86.9141 (87.1347) time: 0.5963 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [ 700/2503] eta: 0:18:08 lr: 0.0004 img/s: 858.0379481409502 loss: 0.8386 (0.8466) acc1: 68.7500 (69.0314) acc5: 87.1094 (87.1459) time: 0.5966 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [ 800/2503] eta: 0:17:06 lr: 0.0004 img/s: 859.3574830298114 loss: 0.8607 (0.8472) acc1: 69.5312 (69.0338) acc5: 87.1094 (87.1364) time: 0.5965 data: 0.0003 max mem: 19119\n",
"Epoch: [2] [ 900/2503] eta: 0:16:05 lr: 0.0004 img/s: 858.1785328651912 loss: 0.8502 (0.8474) acc1: 68.5547 (69.0273) acc5: 87.1094 (87.1404) time: 0.5966 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [1000/2503] eta: 0:15:04 lr: 0.0004 img/s: 858.6554923981921 loss: 0.8213 (0.8468) acc1: 70.1172 (69.0737) acc5: 87.6953 (87.1601) time: 0.5966 data: 0.0003 max mem: 19119\n",
"Epoch: [2] [1100/2503] eta: 0:14:03 lr: 0.0004 img/s: 858.942266240932 loss: 0.8322 (0.8466) acc1: 68.9453 (69.0824) acc5: 87.5000 (87.1826) time: 0.5965 data: 0.0003 max mem: 19119\n",
"Epoch: [2] [1200/2503] eta: 0:13:02 lr: 0.0004 img/s: 859.9796839014982 loss: 0.8353 (0.8468) acc1: 68.5547 (69.0858) acc5: 87.1094 (87.1886) time: 0.5962 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [1300/2503] eta: 0:12:02 lr: 0.0004 img/s: 858.8996673563247 loss: 0.8654 (0.8471) acc1: 68.3594 (69.0645) acc5: 86.7188 (87.1876) time: 0.5966 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [1400/2503] eta: 0:11:01 lr: 0.0004 img/s: 859.7879032225777 loss: 0.8277 (0.8466) acc1: 70.1172 (69.0861) acc5: 88.6719 (87.2244) time: 0.5963 data: 0.0003 max mem: 19119\n",
"Epoch: [2] [1500/2503] eta: 0:10:01 lr: 0.0004 img/s: 859.6271763544084 loss: 0.8703 (0.8471) acc1: 68.7500 (69.0763) acc5: 86.5234 (87.2101) time: 0.5962 data: 0.0003 max mem: 19119\n",
"Epoch: [2] [1600/2503] eta: 0:09:01 lr: 0.0004 img/s: 859.8206066235718 loss: 0.8818 (0.8481) acc1: 68.5547 (69.0523) acc5: 86.7188 (87.1975) time: 0.5965 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [1700/2503] eta: 0:08:01 lr: 0.0004 img/s: 858.3188206070029 loss: 0.8447 (0.8487) acc1: 69.3359 (69.0259) acc5: 87.3047 (87.1857) time: 0.5965 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [1800/2503] eta: 0:07:01 lr: 0.0004 img/s: 860.4651988761577 loss: 0.8492 (0.8489) acc1: 67.5781 (69.0295) acc5: 86.7188 (87.1785) time: 0.5965 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [1900/2503] eta: 0:06:01 lr: 0.0004 img/s: 858.5823699612625 loss: 0.8559 (0.8491) acc1: 67.9688 (69.0192) acc5: 87.5000 (87.1776) time: 0.5963 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [2000/2503] eta: 0:05:01 lr: 0.0004 img/s: 858.4468002775832 loss: 0.8712 (0.8498) acc1: 68.9453 (69.0207) acc5: 87.6953 (87.1836) time: 0.5963 data: 0.0003 max mem: 19119\n",
"Epoch: [2] [2100/2503] eta: 0:04:01 lr: 0.0004 img/s: 858.6208177650899 loss: 0.8782 (0.8507) acc1: 68.3594 (69.0053) acc5: 85.9375 (87.1818) time: 0.5964 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [2200/2503] eta: 0:03:01 lr: 0.0004 img/s: 858.769492116456 loss: 0.8845 (0.8514) acc1: 68.3594 (68.9953) acc5: 86.5234 (87.1744) time: 0.5963 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [2300/2503] eta: 0:02:01 lr: 0.0004 img/s: 860.1050589782235 loss: 0.8664 (0.8522) acc1: 68.7500 (68.9914) acc5: 87.5000 (87.1735) time: 0.5962 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [2400/2503] eta: 0:01:01 lr: 0.0004 img/s: 859.3210323775423 loss: 0.8824 (0.8529) acc1: 67.7734 (68.9772) acc5: 86.5234 (87.1693) time: 0.5963 data: 0.0002 max mem: 19119\n",
"Epoch: [2] [2500/2503] eta: 0:00:01 lr: 0.0004 img/s: 860.1956686665284 loss: 0.8302 (0.8531) acc1: 69.9219 (68.9880) acc5: 87.5000 (87.1751) time: 0.5962 data: 0.0002 max mem: 19119\n",
"Epoch: [2] Total time: 0:24:57\n",
"Test: [ 0/782] eta: 0:14:49 loss: 0.6400 (0.6400) acc1: 82.8125 (82.8125) acc5: 93.7500 (93.7500) time: 1.1370 data: 1.1232 max mem: 19119\n",
"Test: [100/782] eta: 0:00:28 loss: 1.0691 (0.9495) acc1: 75.0000 (76.8874) acc5: 89.0625 (92.2184) time: 0.0422 data: 0.0284 max mem: 19119\n",
"Test: [200/782] eta: 0:00:20 loss: 0.8384 (0.9253) acc1: 75.0000 (76.3293) acc5: 95.3125 (93.2292) time: 0.0298 data: 0.0161 max mem: 19119\n",
"Test: [300/782] eta: 0:00:16 loss: 0.8140 (0.9153) acc1: 78.1250 (76.6092) acc5: 92.1875 (93.5631) time: 0.0281 data: 0.0143 max mem: 19119\n",
"Test: [400/782] eta: 0:00:12 loss: 1.7029 (1.0528) acc1: 62.5000 (73.9479) acc5: 84.3750 (91.8797) time: 0.0260 data: 0.0123 max mem: 19119\n",
"Test: [500/782] eta: 0:00:09 loss: 1.7149 (1.1295) acc1: 59.3750 (72.4894) acc5: 84.3750 (90.7997) time: 0.0315 data: 0.0177 max mem: 19119\n",
"Test: [600/782] eta: 0:00:06 loss: 1.3215 (1.1949) acc1: 65.6250 (71.1288) acc5: 85.9375 (90.0192) time: 0.0343 data: 0.0204 max mem: 19119\n",
"Test: [700/782] eta: 0:00:02 loss: 1.3000 (1.2468) acc1: 70.3125 (70.1386) acc5: 85.9375 (89.2809) time: 0.0246 data: 0.0108 max mem: 19119\n",
"Test: Total time: 0:00:25\n",
"Test: Acc@1 70.034 Acc@5 89.306\n",
"Epoch: [3] [ 0/2503] eta: 3:48:40 lr: 0.0002 img/s: 868.6651772838787 loss: 0.9922 (0.9922) acc1: 65.8203 (65.8203) acc5: 84.3750 (84.3750) time: 5.4818 data: 4.8924 max mem: 19119\n",
"Epoch: [3] [ 100/2503] eta: 0:25:56 lr: 0.0002 img/s: 857.1146638568258 loss: 0.8599 (0.8484) acc1: 69.7266 (69.1851) acc5: 86.7188 (87.2660) time: 0.5978 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [ 200/2503] eta: 0:23:56 lr: 0.0002 img/s: 854.6256384868216 loss: 0.8801 (0.8570) acc1: 68.7500 (69.0182) acc5: 86.3281 (87.1521) time: 0.5998 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [ 300/2503] eta: 0:22:36 lr: 0.0002 img/s: 855.2042203405152 loss: 0.8260 (0.8538) acc1: 69.3359 (69.0959) acc5: 87.6953 (87.2262) time: 0.5990 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [ 400/2503] eta: 0:21:25 lr: 0.0002 img/s: 856.2763231713128 loss: 0.8881 (0.8553) acc1: 68.3594 (69.1733) acc5: 86.9141 (87.2141) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [3] [ 500/2503] eta: 0:20:19 lr: 0.0002 img/s: 856.0431919432866 loss: 0.8596 (0.8573) acc1: 68.3594 (69.1281) acc5: 87.3047 (87.2291) time: 0.5982 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [ 600/2503] eta: 0:19:15 lr: 0.0002 img/s: 855.2682527981125 loss: 0.8779 (0.8592) acc1: 68.1641 (69.1153) acc5: 86.7188 (87.2007) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [3] [ 700/2503] eta: 0:18:12 lr: 0.0002 img/s: 855.3551206587658 loss: 0.8727 (0.8601) acc1: 68.7500 (69.0902) acc5: 87.8906 (87.2033) time: 0.5988 data: 0.0003 max mem: 19119\n",
"Epoch: [3] [ 800/2503] eta: 0:17:10 lr: 0.0002 img/s: 854.5804059835012 loss: 0.8775 (0.8608) acc1: 69.1406 (69.0684) acc5: 86.5234 (87.1713) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [ 900/2503] eta: 0:16:08 lr: 0.0002 img/s: 855.525160200547 loss: 0.8299 (0.8601) acc1: 69.5312 (69.0866) acc5: 87.5000 (87.1883) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [1000/2503] eta: 0:15:07 lr: 0.0002 img/s: 855.1732293498484 loss: 0.8740 (0.8600) acc1: 68.3594 (69.0608) acc5: 86.7188 (87.1773) time: 0.5987 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [1100/2503] eta: 0:14:06 lr: 0.0002 img/s: 855.7201584498974 loss: 0.8490 (0.8600) acc1: 69.5312 (69.0574) acc5: 87.6953 (87.1810) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [3] [1200/2503] eta: 0:13:05 lr: 0.0002 img/s: 855.6761738093758 loss: 0.8551 (0.8598) acc1: 70.1172 (69.0749) acc5: 87.3047 (87.1956) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [1300/2503] eta: 0:12:04 lr: 0.0002 img/s: 855.1391759063549 loss: 0.8736 (0.8596) acc1: 68.9453 (69.1111) acc5: 87.5000 (87.2038) time: 0.5985 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [1400/2503] eta: 0:11:04 lr: 0.0002 img/s: 855.9558431039926 loss: 0.8849 (0.8602) acc1: 69.1406 (69.1073) acc5: 86.3281 (87.2021) time: 0.5989 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [1500/2503] eta: 0:10:03 lr: 0.0002 img/s: 856.2879318505251 loss: 0.8493 (0.8600) acc1: 69.7266 (69.1198) acc5: 87.3047 (87.2114) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [3] [1600/2503] eta: 0:09:03 lr: 0.0002 img/s: 855.3885098640291 loss: 0.8944 (0.8605) acc1: 67.9688 (69.1188) acc5: 86.5234 (87.2106) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [3] [1700/2503] eta: 0:08:03 lr: 0.0002 img/s: 855.653671788276 loss: 0.8327 (0.8606) acc1: 69.7266 (69.1132) acc5: 87.3047 (87.1988) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [1800/2503] eta: 0:07:02 lr: 0.0002 img/s: 854.6603313202065 loss: 0.8716 (0.8606) acc1: 69.5312 (69.1106) acc5: 87.3047 (87.2095) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [1900/2503] eta: 0:06:02 lr: 0.0002 img/s: 855.8654421819135 loss: 0.8433 (0.8607) acc1: 68.9453 (69.1149) acc5: 86.9141 (87.1973) time: 0.5985 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [2000/2503] eta: 0:05:02 lr: 0.0002 img/s: 858.0228637365412 loss: 0.8635 (0.8613) acc1: 69.1406 (69.0981) acc5: 86.7188 (87.1936) time: 0.5985 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [2100/2503] eta: 0:04:02 lr: 0.0002 img/s: 855.1837864680422 loss: 0.8389 (0.8614) acc1: 69.7266 (69.1067) acc5: 87.3047 (87.2038) time: 0.5987 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [2200/2503] eta: 0:03:02 lr: 0.0002 img/s: 854.9267436657309 loss: 0.8588 (0.8618) acc1: 69.5312 (69.1018) acc5: 86.9141 (87.2006) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [2300/2503] eta: 0:02:01 lr: 0.0002 img/s: 857.5592770650364 loss: 0.8385 (0.8623) acc1: 69.7266 (69.1041) acc5: 87.6953 (87.1965) time: 0.5985 data: 0.0003 max mem: 19119\n",
"Epoch: [3] [2400/2503] eta: 0:01:01 lr: 0.0002 img/s: 854.3804880688189 loss: 0.8534 (0.8625) acc1: 68.9453 (69.1074) acc5: 87.3047 (87.1914) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [3] [2500/2503] eta: 0:00:01 lr: 0.0002 img/s: 855.3973686621443 loss: 0.8348 (0.8625) acc1: 69.3359 (69.1134) acc5: 87.6953 (87.1933) time: 0.5984 data: 0.0002 max mem: 19119\n",
"Epoch: [3] Total time: 0:25:03\n",
"Test: [ 0/782] eta: 0:13:34 loss: 0.6298 (0.6298) acc1: 84.3750 (84.3750) acc5: 95.3125 (95.3125) time: 1.0412 data: 1.0273 max mem: 19119\n",
"Test: [100/782] eta: 0:00:28 loss: 1.0908 (0.9514) acc1: 75.0000 (76.9957) acc5: 89.0625 (92.2339) time: 0.0397 data: 0.0260 max mem: 19119\n",
"Test: [200/782] eta: 0:00:21 loss: 0.9058 (0.9231) acc1: 73.4375 (76.3137) acc5: 95.3125 (93.2680) time: 0.0258 data: 0.0121 max mem: 19119\n",
"Test: [300/782] eta: 0:00:17 loss: 0.8269 (0.9143) acc1: 79.6875 (76.5988) acc5: 92.1875 (93.5735) time: 0.0356 data: 0.0218 max mem: 19119\n",
"Test: [400/782] eta: 0:00:13 loss: 1.8047 (1.0535) acc1: 60.9375 (73.9207) acc5: 82.8125 (91.8329) time: 0.0270 data: 0.0133 max mem: 19119\n",
"Test: [500/782] eta: 0:00:09 loss: 1.6839 (1.1303) acc1: 59.3750 (72.5050) acc5: 85.9375 (90.7622) time: 0.0334 data: 0.0196 max mem: 19119\n",
"Test: [600/782] eta: 0:00:06 loss: 1.3633 (1.1951) acc1: 64.0625 (71.1704) acc5: 87.5000 (89.9776) time: 0.0256 data: 0.0118 max mem: 19119\n",
"Test: [700/782] eta: 0:00:02 loss: 1.2720 (1.2480) acc1: 71.8750 (70.1632) acc5: 85.9375 (89.2876) time: 0.0280 data: 0.0142 max mem: 19119\n",
"Test: Total time: 0:00:26\n",
"Test: Acc@1 70.092 Acc@5 89.308\n",
"Epoch: [4] [ 0/2503] eta: 3:53:04 lr: 0.0002 img/s: 868.854963030395 loss: 0.9245 (0.9245) acc1: 67.5781 (67.5781) acc5: 87.3047 (87.3047) time: 5.5871 data: 4.9977 max mem: 19119\n",
"Epoch: [4] [ 100/2503] eta: 0:25:58 lr: 0.0002 img/s: 856.1991675966316 loss: 0.8765 (0.8657) acc1: 68.9453 (69.2605) acc5: 86.5234 (87.1171) time: 0.5978 data: 0.0002 max mem: 19119\n",
"Epoch: [4] [ 200/2503] eta: 0:23:57 lr: 0.0002 img/s: 853.9622551773969 loss: 0.8653 (0.8621) acc1: 69.7266 (69.4321) acc5: 87.3047 (87.2464) time: 0.5998 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [ 300/2503] eta: 0:22:37 lr: 0.0002 img/s: 855.0397553711638 loss: 0.8689 (0.8659) acc1: 69.3359 (69.2879) acc5: 87.3047 (87.2275) time: 0.5990 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [ 400/2503] eta: 0:21:26 lr: 0.0002 img/s: 857.3206551376953 loss: 0.8390 (0.8645) acc1: 68.7500 (69.3432) acc5: 86.5234 (87.2808) time: 0.5983 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [ 500/2503] eta: 0:20:19 lr: 0.0002 img/s: 856.0698095863008 loss: 0.8473 (0.8638) acc1: 69.5312 (69.3262) acc5: 87.5000 (87.3148) time: 0.5984 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [ 600/2503] eta: 0:19:15 lr: 0.0002 img/s: 856.6852065793925 loss: 0.8467 (0.8628) acc1: 69.3359 (69.3753) acc5: 87.3047 (87.3388) time: 0.5985 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [ 700/2503] eta: 0:18:12 lr: 0.0002 img/s: 856.1111043339286 loss: 0.8966 (0.8657) acc1: 68.1641 (69.3106) acc5: 86.7188 (87.2924) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [4] [ 800/2503] eta: 0:17:10 lr: 0.0002 img/s: 856.0469456131993 loss: 0.8508 (0.8662) acc1: 69.3359 (69.2959) acc5: 87.6953 (87.2976) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [ 900/2503] eta: 0:16:08 lr: 0.0002 img/s: 854.9645243753337 loss: 0.8663 (0.8662) acc1: 70.3125 (69.2997) acc5: 87.6953 (87.2860) time: 0.5986 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [1000/2503] eta: 0:15:07 lr: 0.0002 img/s: 855.0186485068215 loss: 0.8554 (0.8671) acc1: 68.3594 (69.2827) acc5: 86.9141 (87.2858) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [1100/2503] eta: 0:14:06 lr: 0.0002 img/s: 855.2873281496619 loss: 0.8594 (0.8670) acc1: 69.7266 (69.2912) acc5: 87.5000 (87.2926) time: 0.5985 data: 0.0002 max mem: 19119\n",
"Epoch: [4] [1200/2503] eta: 0:13:05 lr: 0.0002 img/s: 856.2193086400064 loss: 0.8768 (0.8677) acc1: 68.7500 (69.2626) acc5: 86.7188 (87.2596) time: 0.5982 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [1300/2503] eta: 0:12:04 lr: 0.0002 img/s: 856.0534293024039 loss: 0.8824 (0.8688) acc1: 68.7500 (69.2354) acc5: 87.1094 (87.2479) time: 0.5985 data: 0.0002 max mem: 19119\n",
"Epoch: [4] [1400/2503] eta: 0:11:04 lr: 0.0002 img/s: 857.5664685962558 loss: 0.8628 (0.8691) acc1: 68.7500 (69.2155) acc5: 87.5000 (87.2509) time: 0.5984 data: 0.0002 max mem: 19119\n",
"Epoch: [4] [1500/2503] eta: 0:10:03 lr: 0.0002 img/s: 854.9161928928554 loss: 0.8838 (0.8696) acc1: 69.1406 (69.1992) acc5: 87.5000 (87.2396) time: 0.5983 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [1600/2503] eta: 0:09:03 lr: 0.0002 img/s: 858.2563886134587 loss: 0.8541 (0.8697) acc1: 68.5547 (69.1823) acc5: 87.3047 (87.2258) time: 0.5985 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [1700/2503] eta: 0:08:03 lr: 0.0002 img/s: 855.3152614496619 loss: 0.8702 (0.8694) acc1: 68.9453 (69.1954) acc5: 86.9141 (87.2312) time: 0.5988 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [1800/2503] eta: 0:07:02 lr: 0.0002 img/s: 856.2275018779335 loss: 0.8691 (0.8696) acc1: 68.9453 (69.1939) acc5: 87.1094 (87.2242) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [1900/2503] eta: 0:06:02 lr: 0.0002 img/s: 854.6021714024656 loss: 0.8867 (0.8703) acc1: 68.3594 (69.1812) acc5: 86.5234 (87.2186) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [4] [2000/2503] eta: 0:05:02 lr: 0.0002 img/s: 856.027153906284 loss: 0.8680 (0.8710) acc1: 67.9688 (69.1654) acc5: 87.1094 (87.2148) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [4] [2100/2503] eta: 0:04:02 lr: 0.0002 img/s: 854.944442321167 loss: 0.8930 (0.8715) acc1: 68.9453 (69.1539) acc5: 87.1094 (87.2134) time: 0.5987 data: 0.0002 max mem: 19119\n",
"Epoch: [4] [2200/2503] eta: 0:03:02 lr: 0.0002 img/s: 857.4733304797255 loss: 0.8344 (0.8713) acc1: 69.7266 (69.1614) acc5: 87.6953 (87.2185) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [4] [2300/2503] eta: 0:02:01 lr: 0.0002 img/s: 854.5545609884018 loss: 0.8644 (0.8712) acc1: 69.1406 (69.1611) acc5: 86.9141 (87.2154) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [4] [2400/2503] eta: 0:01:01 lr: 0.0002 img/s: 855.9186570165338 loss: 0.8843 (0.8714) acc1: 69.1406 (69.1609) acc5: 86.9141 (87.2155) time: 0.5987 data: 0.0002 max mem: 19119\n",
"Epoch: [4] [2500/2503] eta: 0:00:01 lr: 0.0002 img/s: 855.5220927564314 loss: 0.8757 (0.8719) acc1: 68.9453 (69.1569) acc5: 87.1094 (87.2136) time: 0.5979 data: 0.0002 max mem: 19119\n",
"Epoch: [4] Total time: 0:25:03\n",
"Test: [ 0/782] eta: 0:15:24 loss: 0.5899 (0.5899) acc1: 85.9375 (85.9375) acc5: 95.3125 (95.3125) time: 1.1827 data: 1.1689 max mem: 19119\n",
"Test: [100/782] eta: 0:00:27 loss: 1.0514 (0.9553) acc1: 76.5625 (76.7327) acc5: 89.0625 (92.1411) time: 0.0322 data: 0.0185 max mem: 19119\n",
"Test: [200/782] eta: 0:00:20 loss: 0.8755 (0.9239) acc1: 75.0000 (76.3682) acc5: 95.3125 (93.2369) time: 0.0254 data: 0.0116 max mem: 19119\n",
"Test: [300/782] eta: 0:00:16 loss: 0.7986 (0.9160) acc1: 78.1250 (76.6352) acc5: 92.1875 (93.5424) time: 0.0298 data: 0.0161 max mem: 19119\n",
"Test: [400/782] eta: 0:00:12 loss: 1.7921 (1.0555) acc1: 60.9375 (73.8817) acc5: 84.3750 (91.8095) time: 0.0308 data: 0.0171 max mem: 19119\n",
"Test: [500/782] eta: 0:00:09 loss: 1.7681 (1.1332) acc1: 59.3750 (72.4613) acc5: 84.3750 (90.7248) time: 0.0302 data: 0.0164 max mem: 19119\n",
"Test: [600/782] eta: 0:00:06 loss: 1.3149 (1.1978) acc1: 65.6250 (71.1340) acc5: 85.9375 (89.9880) time: 0.0445 data: 0.0307 max mem: 19119\n",
"Test: [700/782] eta: 0:00:02 loss: 1.2842 (1.2500) acc1: 70.3125 (70.1297) acc5: 87.5000 (89.2899) time: 0.0292 data: 0.0154 max mem: 19119\n",
"Test: Total time: 0:00:25\n",
"Test: Acc@1 70.056 Acc@5 89.296\n",
"Epoch: [5] [ 0/2503] eta: 3:36:03 lr: 0.0001 img/s: 868.6398787818482 loss: 0.9304 (0.9304) acc1: 68.1641 (68.1641) acc5: 84.5703 (84.5703) time: 5.1790 data: 4.5895 max mem: 19119\n",
"Epoch: [5] [ 100/2503] eta: 0:26:17 lr: 0.0001 img/s: 856.2223810858537 loss: 0.8658 (0.8735) acc1: 68.7500 (69.3843) acc5: 86.5234 (87.1229) time: 0.5978 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [ 200/2503] eta: 0:24:06 lr: 0.0001 img/s: 854.6059124455048 loss: 0.8757 (0.8726) acc1: 68.3594 (69.2893) acc5: 87.3047 (87.2027) time: 0.5998 data: 0.0003 max mem: 19119\n",
"Epoch: [5] [ 300/2503] eta: 0:22:42 lr: 0.0001 img/s: 854.7457148952566 loss: 0.9211 (0.8775) acc1: 68.3594 (69.1964) acc5: 86.5234 (87.1548) time: 0.5994 data: 0.0003 max mem: 19119\n",
"Epoch: [5] [ 400/2503] eta: 0:21:30 lr: 0.0001 img/s: 855.9967856527867 loss: 0.8790 (0.8763) acc1: 69.5312 (69.2220) acc5: 87.1094 (87.1795) time: 0.5981 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [ 500/2503] eta: 0:20:22 lr: 0.0001 img/s: 856.1715178770771 loss: 0.8828 (0.8767) acc1: 68.5547 (69.1835) acc5: 87.1094 (87.1916) time: 0.5981 data: 0.0003 max mem: 19119\n",
"Epoch: [5] [ 600/2503] eta: 0:19:17 lr: 0.0001 img/s: 857.1064536316845 loss: 0.8499 (0.8745) acc1: 70.1172 (69.2495) acc5: 87.6953 (87.2351) time: 0.5985 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [ 700/2503] eta: 0:18:14 lr: 0.0001 img/s: 855.5684475906053 loss: 0.8724 (0.8738) acc1: 68.7500 (69.2543) acc5: 87.3047 (87.2381) time: 0.5986 data: 0.0003 max mem: 19119\n",
"Epoch: [5] [ 800/2503] eta: 0:17:11 lr: 0.0001 img/s: 855.3483068555046 loss: 0.8648 (0.8747) acc1: 69.3359 (69.2645) acc5: 87.1094 (87.2354) time: 0.5987 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [ 900/2503] eta: 0:16:10 lr: 0.0001 img/s: 855.0516709967052 loss: 0.8869 (0.8744) acc1: 68.9453 (69.2787) acc5: 87.1094 (87.2609) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [1000/2503] eta: 0:15:08 lr: 0.0001 img/s: 856.4016452607827 loss: 0.8660 (0.8746) acc1: 69.5312 (69.2706) acc5: 87.3047 (87.2602) time: 0.5987 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [1100/2503] eta: 0:14:07 lr: 0.0001 img/s: 855.5752649016649 loss: 0.8801 (0.8741) acc1: 69.3359 (69.2911) acc5: 86.7188 (87.2653) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [5] [1200/2503] eta: 0:13:06 lr: 0.0001 img/s: 857.7295089955442 loss: 0.9027 (0.8741) acc1: 68.1641 (69.2950) acc5: 86.7188 (87.2543) time: 0.5983 data: 0.0003 max mem: 19119\n",
"Epoch: [5] [1300/2503] eta: 0:12:05 lr: 0.0001 img/s: 857.4562117036577 loss: 0.8591 (0.8742) acc1: 69.3359 (69.2867) acc5: 87.3047 (87.2523) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [5] [1400/2503] eta: 0:11:04 lr: 0.0001 img/s: 856.1711765336744 loss: 0.8659 (0.8747) acc1: 69.7266 (69.2756) acc5: 87.1094 (87.2657) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [1500/2503] eta: 0:10:04 lr: 0.0001 img/s: 855.6918577361109 loss: 0.8848 (0.8744) acc1: 67.5781 (69.2883) acc5: 86.9141 (87.2744) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [1600/2503] eta: 0:09:03 lr: 0.0001 img/s: 856.2199914038447 loss: 0.8476 (0.8740) acc1: 69.1406 (69.2962) acc5: 87.5000 (87.2808) time: 0.5987 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [1700/2503] eta: 0:08:03 lr: 0.0001 img/s: 856.7518536501099 loss: 0.8676 (0.8745) acc1: 68.3594 (69.2823) acc5: 87.6953 (87.2798) time: 0.5986 data: 0.0003 max mem: 19119\n",
"Epoch: [5] [1800/2503] eta: 0:07:03 lr: 0.0001 img/s: 855.5047109885885 loss: 0.8884 (0.8745) acc1: 68.1641 (69.2810) acc5: 87.3047 (87.2716) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [1900/2503] eta: 0:06:02 lr: 0.0001 img/s: 855.3728370553123 loss: 0.8565 (0.8744) acc1: 68.9453 (69.2820) acc5: 88.2812 (87.2709) time: 0.5987 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [2000/2503] eta: 0:05:02 lr: 0.0001 img/s: 856.9477535227904 loss: 0.8799 (0.8749) acc1: 68.7500 (69.2718) acc5: 86.7188 (87.2586) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [2100/2503] eta: 0:04:02 lr: 0.0001 img/s: 855.2914158356967 loss: 0.8914 (0.8745) acc1: 68.9453 (69.2829) acc5: 87.5000 (87.2690) time: 0.5987 data: 0.0003 max mem: 19119\n",
"Epoch: [5] [2200/2503] eta: 0:03:02 lr: 0.0001 img/s: 855.3816955287992 loss: 0.8659 (0.8744) acc1: 68.9453 (69.2853) acc5: 87.1094 (87.2749) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [2300/2503] eta: 0:02:02 lr: 0.0001 img/s: 857.8035141651501 loss: 0.8525 (0.8745) acc1: 69.7266 (69.2810) acc5: 87.5000 (87.2740) time: 0.5986 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [2400/2503] eta: 0:01:01 lr: 0.0001 img/s: 855.2345323832691 loss: 0.8424 (0.8745) acc1: 69.5312 (69.2709) acc5: 87.3047 (87.2737) time: 0.5988 data: 0.0002 max mem: 19119\n",
"Epoch: [5] [2500/2503] eta: 0:00:01 lr: 0.0001 img/s: 856.8721864153102 loss: 0.8785 (0.8748) acc1: 68.7500 (69.2651) acc5: 87.5000 (87.2714) time: 0.5982 data: 0.0002 max mem: 19119\n",
"Epoch: [5] Total time: 0:25:04\n",
"Test: [ 0/782] eta: 0:16:48 loss: 0.6137 (0.6137) acc1: 85.9375 (85.9375) acc5: 95.3125 (95.3125) time: 1.2890 data: 1.2749 max mem: 19119\n",
"Test: [100/782] eta: 0:00:29 loss: 1.0820 (0.9476) acc1: 76.5625 (77.0885) acc5: 89.0625 (92.3113) time: 0.0337 data: 0.0200 max mem: 19119\n",
"Test: [200/782] eta: 0:00:21 loss: 0.8791 (0.9212) acc1: 75.0000 (76.4537) acc5: 95.3125 (93.3613) time: 0.0276 data: 0.0139 max mem: 19119\n",
"Test: [300/782] eta: 0:00:16 loss: 0.8066 (0.9144) acc1: 76.5625 (76.6923) acc5: 92.1875 (93.6306) time: 0.0281 data: 0.0144 max mem: 19119\n",
"Test: [400/782] eta: 0:00:13 loss: 1.8165 (1.0555) acc1: 60.9375 (73.9596) acc5: 84.3750 (91.9031) time: 0.0350 data: 0.0214 max mem: 19119\n",
"Test: [500/782] eta: 0:00:09 loss: 1.7107 (1.1325) acc1: 59.3750 (72.5299) acc5: 84.3750 (90.7934) time: 0.0344 data: 0.0206 max mem: 19119\n",
"Test: [600/782] eta: 0:00:06 loss: 1.3799 (1.1970) acc1: 64.0625 (71.2068) acc5: 84.3750 (89.9880) time: 0.0266 data: 0.0127 max mem: 19119\n",
"Test: [700/782] eta: 0:00:02 loss: 1.2741 (1.2493) acc1: 68.7500 (70.1966) acc5: 85.9375 (89.2765) time: 0.0262 data: 0.0124 max mem: 19119\n",
"Test: Total time: 0:00:25\n",
"Test: Acc@1 70.120 Acc@5 89.284\n",
"Training time 2:33:01\n"
]
}
],
"source": [
"from types import SimpleNamespace\n",
"\n",
"args = SimpleNamespace(\n",
" data_path=\"/home/cs/Documents/datasets/imagenet\", # Replace with your /path/to/imagenet\n",
" model=\"resnet18\",\n",
" device=\"cuda\",\n",
" batch_size=512,\n",
" epochs=6,\n",
" lr=0.0004,\n",
" momentum=0.9,\n",
" weight_decay=1e-4,\n",
" lr_warmup_epochs=1,\n",
" lr_warmup_decay=0.0,\n",
" lr_step_size=2,\n",
" lr_gamma=0.5,\n",
" print_freq=100,\n",
" output_dir=\"resnet18\",\n",
" use_deterministic_algorithms=False,\n",
" weights=\"ResNet18_Weights.IMAGENET1K_V1\",\n",
" apply_trp=True,\n",
" trp_depths=[3, 3, 3],\n",
" trp_planes=256,\n",
" trp_lambdas=[0.4, 0.2, 0.1],\n",
")\n",
"\n",
"main(args)"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.21"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
|