Downloading PyTorch Vision Reference Scripts for Image Classification. These scripts are official reference implementations from PyTorch Vision that provide training and quantization utilities for image classification models.

In [1]:
! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/presets.py
! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/sampler.py
! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/train.py
! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/train_quantization.py
! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/transformers.py
! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/utils.py

--2025-08-04 11:06:42--  https://raw.githubusercontent.com/pytorch/vision/main/references/classification/presets.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.


HTTP request sent, awaiting response... 200 OK
Length: 3885 (3.8K) [text/plain]
Saving to: ‘presets.py.3’


2025-08-04 11:06:42 (17.6 MB/s) - ‘presets.py.3’ saved [3885/3885]

--2025-08-04 11:06:42--  https://raw.githubusercontent.com/pytorch/vision/main/references/classification/sampler.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2395 (2.3K) [text/plain]
Saving to: ‘sampler.py.3’


2025-08-04 11:06:42 (23.5 MB/s) - ‘sampler.py.3’ saved [2395/2395]

--2025-08-04 11:06:43--  https://raw.githubusercontent.com/pytorch/vision/main/references/classification/train.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185

In this block, we build a “loss” function for our sequential policy gradient algorithm. When the right data is plugged in, the gradient of this loss is equal to the policy gradient.

In [2]:
import types
from typing import Optional, List, Union, Callable

import torch
from torch import nn, Tensor
from torch.nn import functional as F

from torchvision.models.resnet import ResNet


def compute_policy_loss(loss_sequence, mask_sequence, rewards):
    losses = sum(mask * padded_loss for mask, padded_loss in zip(mask_sequence, loss_sequence))
    returns = sum(padded_reward * mask for padded_reward, mask in zip(rewards, mask_sequence))
    loss = torch.mean(losses * returns)

    return loss


In this block, we build a TPBlock for the Task Replica Prediction (TRP) module; This implementation provides the backbone without the shared prediction head.

In [3]:
class TPBlock(nn.Module):
    def __init__(self, depths: int, in_planes: int, out_planes: int = None, rank=1, shape_dims=3, channel_first=True, dtype=torch.float32) -> None:
        super().__init__()
        out_planes = in_planes if out_planes is None else out_planes
        self.layers = torch.nn.ModuleList([self._make_layer(in_planes, out_planes, rank, shape_dims, channel_first, dtype) for _ in range(depths)])

    def forward(self, x: Tensor) -> Tensor:
        for layer in self.layers:
            x = x + layer(x)
        return x

    def _make_layer(self, in_planes: int, out_planes: int = None, rank=1, shape_dims=3, channel_first=True, dtype=torch.float32) -> nn.Sequential:

        class Permute(nn.Module):
            def __init__(self, *dims):
                super().__init__()
                self.dims = dims
            def forward(self, x):
                return x.permute(*self.dims)

        class RMSNorm(nn.Module):
            __constants__ = ["eps"]
            eps: float

            def __init__(self, hidden_size, eps: float = 1e-6, device=None, dtype=None):
                """
                LlamaRMSNorm is equivalent to T5LayerNorm.
                """
                factory_kwargs = {"device": device, "dtype": dtype}
                super().__init__()
                self.eps = eps
                self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs))

            def forward(self, hidden_states):
                input_dtype = hidden_states.dtype
                hidden_states = hidden_states.to(torch.float32)
                variance = hidden_states.pow(2).mean(dim=1, keepdim=True)
                hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
                weight = self.weight.view(1, -1, *[1] * (hidden_states.ndim - 2))
                return weight * hidden_states.to(input_dtype)

            def extra_repr(self):
                return f"{self.weight.shape[0]}, eps={self.eps}"

        conv_map = {
            2: (nn.Conv1d, (0, 2, 1), (0, 2, 1)),
            3: (nn.Conv2d, (0, 3, 1, 2), (0, 2, 3, 1)),
            4: (nn.Conv3d, (0, 4, 1, 2, 3), (0, 2, 3, 4, 1)),
        }
        Conv, pre_dims, post_dims = conv_map[shape_dims]
        kernel_size, dilation, padding = self.generate_hyperparameters(rank)

        pre_permute = nn.Identity() if channel_first else Permute(*pre_dims)
        post_permute = nn.Identity() if channel_first else Permute(*post_dims)
        conv1 = Conv(in_planes, out_planes, kernel_size, padding=padding, dilation=dilation, bias=False, dtype=dtype, device='cuda')
        nn.init.zeros_(conv1.weight)
        bn1 = RMSNorm(out_planes, dtype=dtype, device="cuda")
        relu = nn.ReLU(inplace=True)
        conv2 = Conv(out_planes, in_planes, kernel_size, padding=padding, dilation=dilation, bias=False, dtype=dtype, device='cuda')
        nn.init.zeros_(conv2.weight)
        bn2 = RMSNorm(in_planes, dtype=dtype, device="cuda")

        return torch.nn.Sequential(pre_permute, conv1, bn1, relu, conv2, bn2, relu, post_permute)

    @staticmethod
    def generate_hyperparameters(rank: int):
        """
        Generates kernel size and dilation rate pairs sorted by increasing padded kernel size.

        Args:
            rank: Number of (kernel_size, dilation) pairs to generate. Must be positive.

        Returns:
            Tuple[int, int]: A (kernel_size, dilation) tuple where:
                - kernel_size: Always odd and >= 1
                - dilation: Computed to maintain consistent padded kernel size growth

        Note:
            Padded kernel size is calculated as:
                (kernel_size - 1) * dilation + 1
            Pairs are generated first in order of increasing padded kernel size,
            then by increasing kernel size for equal padded kernel sizes.
        """
        pairs = [(1, 1, 0)]  # Start with smallest possible
        padded_kernel_size = 3

        while len(pairs) < rank:
            for kernel_size in range(3, padded_kernel_size + 1, 2):
                if (padded_kernel_size - 1) % (kernel_size - 1) == 0:
                    dilation = (padded_kernel_size - 1) // (kernel_size - 1)
                    padding = dilation * (kernel_size - 1) // 2
                    pairs.append((kernel_size, dilation, padding))
                    if len(pairs) >= rank:
                        break

            # Move to next odd padded kernel size
            padded_kernel_size += 2

        return pairs[-1]

This implementation enables ResNet retraining in SPG mode.

Components:
-------------------------------------------------------------------------------
1. gen_criterion()
    - Purpose: compute per-sample losses and positional masks

2. gen_shared_head()
    - Purpose: Implements a shared prediction head that processes convolutional feature maps for prediction.

3. gen_forward()
    - Purpose: Extended forward pass supporting both traditional inference and SPG retraining.

In [4]:
class ResNetConfig:
    @staticmethod
    def gen_shared_head(self):
        def func(hidden_states):
            """
            Args:
                hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].

            Returns:
                logits (Tensor): Logits tensor of shape [B, C].
            """
            x = self.avgpool(hidden_states)
            x = torch.flatten(x, 1)
            logits = self.fc(x)
            return logits
        return func

    @staticmethod
    def gen_logits(self, shared_head):
        def func(hidden_states):
            """
            Args:
                hidden_states (Tensor): Hidden States tensor of shape [B, L, hidden_units].

            Returns:
                logits_seqence (List[Tensor]): List of Logits tensors.
            """
            logits_sequence = [shared_head(hidden_states)]
            for layer in self.trp_blocks:
                logits_sequence.append(shared_head(layer(hidden_states)))
            return logits_sequence
        return func

    @staticmethod
    def gen_mask(label_smoothing=0.0, top_k=1):
        def func(logits_sequence, labels):
            """
            Args:
                logits_sequence (List[Tensor]): List of Logits tensors.
                labels (Tensor): Target labels of shape [B] or [B, C].

            Returns:
                mask_sequence (List[Tensor]): List of Mask tensor.
                returns (Tensor): Boolean mask tensor of shape [B*(L-1)].
            """
            labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels

            mask_sequence = [torch.ones_like(labels, dtype=torch.float32, device=labels.device)]
            for logits in logits_sequence:
                with torch.no_grad():
                    topk_values, topk_indices = torch.topk(logits, top_k, dim=-1)
                    mask = torch.eq(topk_indices, labels[:, None]).any(dim=-1).to(torch.float32)
                    mask_sequence.append(mask_sequence[-1] * mask)
            return mask_sequence
        return func

    @staticmethod
    def gen_criterion(label_smoothing=0.0):
        def func(logits_sequence, labels):
            """
            Args:
                logits_sequence (List[Tensor]): List of Logits tensor.
                labels (Tensor): labels labels of shape [B] or [B, C].

            Returns:
                loss (Tensor): Scalar tensor representing the loss.
                mask (Tensor): Boolean mask tensor of shape [B].
            """
            labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels

            loss_sequence = []
            for logits in logits_sequence:
                loss_sequence.append(F.cross_entropy(logits, labels, reduction="none", label_smoothing=label_smoothing))

            return loss_sequence
        return func

    @staticmethod
    def gen_forward(rewards, label_smoothing=0.0, top_k=1):
        def func(self, x: Tensor, targets=None) -> Tensor:
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
            x = self.maxpool(x)

            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            hidden_states = self.layer4(x)
            x = self.avgpool(hidden_states)
            x = torch.flatten(x, 1)
            logits = self.fc(x)

            if self.training:
                shared_head = ResNetConfig.gen_shared_head(self)
                compute_logits = ResNetConfig.gen_logits(self, shared_head)
                compute_mask = ResNetConfig.gen_mask(label_smoothing, top_k)
                compute_loss = ResNetConfig.gen_criterion(label_smoothing)

                logits_sequence = compute_logits(hidden_states)
                mask_sequence = compute_mask(logits_sequence, targets)
                loss_sequence = compute_loss(logits_sequence, targets)
                loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)

                return logits, loss

            return logits

        return func

Applies TRP modules to the base ResNet (main backbone). The k-th TRP module corresponding to a deeper ResNet variant with an additional depth of 3 * sum(depths[:k+1]).

In [5]:
def apply_trp(model, depths: List[int], in_planes: int, out_planes: int, rewards, **kwargs):
    print("✅ Applying TRP to ResNet for Image Classification...")
    model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
    model.forward = types.MethodType(ResNetConfig.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
    return model

The following is a training script for classification models, primarily based on the official TorchVision `train.py` reference implementation. We have made two modifications:

Adding TRP Modules: We integrate TRP modules into the base model architecture before training begins:

```python
if args.apply_trp:
        model = apply_trp(model, args.trp_depths, args.in_planes, args.out_planes, args.trp_rewards, label_smoothing=args.label_smoothing)
```
Removing TRP Modules: We remove the TRP components from the base model before saving the base model:
```python
if args.output_dir:
    checkpoint = {
        "model": model_without_ddp.state_dict() if not args.apply_trp else {k: v for k, v in model_without_ddp.state_dict().items() if not "trp_blocks" in k},
        "optimizer": optimizer.state_dict(),
        "lr_scheduler": lr_scheduler.state_dict(),
        "epoch": epoch,
        "args": args,
    }
    if model_ema:
        checkpoint["model_ema"] = model_ema.state_dict() if not args.apply_trp else {k: v for k, v in model_ema.state_dict().items() if not "trp_blocks" in k}
    if scaler:
        checkpoint["scaler"] = scaler.state_dict()
    utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
    utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
```

In [6]:
import datetime
import os
import time
import warnings

import presets
import torch
import torch.utils.data
import torchvision
import utils
from torch import nn
from torchvision.transforms.functional import InterpolationMode


def load_data(traindir, valdir):
    # Data loading code
    print("Loading data")
    interpolation = InterpolationMode("bilinear")

    print("Loading training data")
    st = time.time()
    dataset = torchvision.datasets.ImageFolder(
        traindir,
        presets.ClassificationPresetTrain(crop_size=224, interpolation=interpolation, auto_augment_policy=None, random_erase_prob=0.0, ra_magnitude=9, augmix_severity=3),
    )
    print("Took", time.time() - st)

    print("Loading validation data")
    dataset_test = torchvision.datasets.ImageFolder(
        valdir,
        presets.ClassificationPresetEval(crop_size=224, resize_size=256, interpolation=interpolation)
    )

    print("Creating data loaders")
    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    return dataset, dataset_test, train_sampler, test_sampler



def train_one_epoch(model, optimizer, data_loader, device, epoch, args):
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
    metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))

    header = f"Epoch: [{epoch}]"
    for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
        start_time = time.time()
        image, target = image.to(device), target.to(device)
        with torch.amp.autocast("cuda", enabled=False):
            output, loss = model(image, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = image.shape[0]
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
        metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
        metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
        metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))


def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = f"Test: {log_suffix}"

    num_processed_samples = 0
    with torch.inference_mode():
        for image, target in metric_logger.log_every(data_loader, print_freq, header):
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(image)
            loss = criterion(output, target)

            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
            num_processed_samples += batch_size
    # gather the stats from all processes

    num_processed_samples = utils.reduce_across_processes(num_processed_samples)
    if (
        hasattr(data_loader.dataset, "__len__")
        and len(data_loader.dataset) != num_processed_samples
        and torch.distributed.get_rank() == 0
    ):
        # See FIXME above
        warnings.warn(
            f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
            "samples were used for the validation, which might bias the results. "
            "Try adjusting the batch size and / or the world size. "
            "Setting the world size to 1 is always a safe bet."
        )

    metric_logger.synchronize_between_processes()

    print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
    return metric_logger.acc1.global_avg


def main(args):
    if args.output_dir:
        utils.mkdir(args.output_dir)
    print(args)

    device = torch.device(args.device)

    if args.use_deterministic_algorithms:
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)
    else:
        torch.backends.cudnn.benchmark = True

    train_dir = os.path.join(args.data_path, "train")
    val_dir = os.path.join(args.data_path, "val")
    dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir)

    num_classes = len(dataset.classes)
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=16, pin_memory=True, collate_fn=None)
    data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=64, sampler=test_sampler, num_workers=16, pin_memory=True)

    print("Creating model")
    model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
    if args.apply_trp:
        model = apply_trp(model, args.trp_depths, args.in_planes, args.out_planes, args.trp_rewards, label_smoothing=args.label_smoothing)
    model.to(device)

    parameters = utils.set_weight_decay(model, args.weight_decay, norm_weight_decay=None, custom_keys_weight_decay=None)
    optimizer = torch.optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=False)

    main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
    warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs)
    lr_scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs])


    print("Start training")
    start_time = time.time()
    for epoch in range(args.epochs):
        train_one_epoch(model, optimizer, data_loader, device, epoch, args)
        lr_scheduler.step()
        evaluate(model, nn.CrossEntropyLoss(), data_loader_test, device=device)
        if args.output_dir:
            checkpoint = {
                "model": model.state_dict() if not args.apply_trp else {k: v for k, v in model.state_dict().items() if not "trp_blocks" in k},
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
                "args": args,
            }
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(f"Training time {total_time_str}")


Prepare the [ImageNet](http://image-net.org/) dataset manually and place it in `/path/to/imagenet`. For image classification examples, pass the argument `--data-path=/path/to/imagenet` to the training script. The extracted dataset directory should follow this structure:
```setup
/path/to/imagenet/:
    train/:
        n01440764:
            n01440764_18.JPEG ...
        n01443537:
            n01443537_2.JPEG ...
    val/:
        n01440764:
            ILSVRC2012_val_00000293.JPEG ...
        n01443537:
            ILSVRC2012_val_00000236.JPEG ...
```

Now you can apply the SPG algorithm in model retraining.

**Implementation Note:**

- This demonstration runs on Google Colab using a single GPU configuration
- Performance Improvement: Enhances ResNet18 validation accuracy (ACC@1) from 69.758% to 70.554%
- For optimal results:
  - Refer to our README.md for complete setup instructions
  - Recommended hardware: 4× RTX A6000 GPUs

In [7]:
from types import SimpleNamespace

args = SimpleNamespace(
    data_path="/home/cs/Documents/datasets/imagenet",  # Replace with your /path/to/imagenet
    model="resnet18",
    device="cuda",
    batch_size=512,
    epochs=16,
    lr=0.002,
    momentum=0.9,
    weight_decay=1e-4,
    label_smoothing=0.0,
    lr_warmup_epochs=1,
    lr_warmup_decay=0.0,
    lr_step_size=2,
    lr_gamma=0.5,
    print_freq=100,
    output_dir="resnet18",
    use_deterministic_algorithms=False,
    weights="ResNet18_Weights.IMAGENET1K_V1",
    apply_trp=True,
    trp_depths=[4, 4, 4],
    in_planes=512,
    out_planes=8,
    trp_rewards=[1.0, 0.4, 0.2, 0.1],
)

main(args)

namespace(data_path='/home/cs/Documents/datasets/imagenet', model='resnet18', device='cuda', batch_size=512, epochs=16, lr=0.002, momentum=0.9, weight_decay=0.0001, label_smoothing=0.0, lr_warmup_epochs=1, lr_warmup_decay=0.0, lr_step_size=2, lr_gamma=0.5, print_freq=100, output_dir='resnet18', use_deterministic_algorithms=False, weights='ResNet18_Weights.IMAGENET1K_V1', apply_trp=True, trp_depths=[4, 4, 4], in_planes=512, out_planes=8, trp_rewards=[1.0, 0.4, 0.2, 0.1])
Loading data
Loading training data


Took 1.663217306137085
Loading validation data
Creating data loaders
Creating model
✅ Applying TRP to ResNet for Image Classification...
Start training
Epoch: [0]  [   0/2503]  eta: 10:08:27  lr: 0.0  img/s: 50.42287085194474  loss: 2.4802 (2.4802)  acc1: 68.1641 (68.1641)  acc5: 88.6719 (88.6719)  time: 14.5854  data: 4.4313  max mem: 14260
Epoch: [0]  [ 100/2503]  eta: 0:20:23  lr: 0.0  img/s: 1383.432089555604  loss: 2.5032 (2.5529)  acc1: 68.9453 (69.0652)  acc5: 87.6953 (87.2351)  time: 0.3696  data: 0.0003  max mem: 14260
Epoch: [0]  [ 200/2503]  eta: 0:16:54  lr: 0.0  img/s: 1378.2442081212737  loss: 2.5111 (2.5476)  acc1: 69.7266 (69.0396)  acc5: 88.0859 (87.3883)  time: 0.3723  data: 0.0003  max mem: 14260
Epoch: [0]  [ 300/2503]  eta: 0:15:21  lr: 0.0  img/s: 1374.120426975581  loss: 2.5257 (2.5469)  acc1: 69.7266 (69.1783)  acc5: 87.6953 (87.3566)  time: 0.3735  data: 0.0003  max mem: 14260
Epoch: [0]  [ 400/2503]  eta: 0:14:16  lr: 0.0  img/s: 1371.8057112267902  loss: 2.56



Test:   [  0/782]  eta: 0:47:20  loss: 0.6410 (0.6410)  acc1: 85.9375 (85.9375)  acc5: 95.3125 (95.3125)  time: 3.6326  data: 3.1976  max mem: 14260
Test:   [100/782]  eta: 0:00:43  loss: 1.0596 (0.9365)  acc1: 76.5625 (76.3150)  acc5: 89.0625 (92.2030)  time: 0.0358  data: 0.0219  max mem: 14260
Test:   [200/782]  eta: 0:00:28  loss: 0.9542 (0.9132)  acc1: 73.4375 (75.7851)  acc5: 96.8750 (93.2214)  time: 0.0266  data: 0.0126  max mem: 14260
Test:   [300/782]  eta: 0:00:21  loss: 0.8381 (0.9042)  acc1: 76.5625 (76.1991)  acc5: 92.1875 (93.5112)  time: 0.0387  data: 0.0248  max mem: 14260
Test:   [400/782]  eta: 0:00:15  loss: 1.6487 (1.0423)  acc1: 59.3750 (73.5817)  acc5: 82.8125 (91.7160)  time: 0.0286  data: 0.0148  max mem: 14260
Test:   [500/782]  eta: 0:00:11  loss: 1.5886 (1.1231)  acc1: 56.2500 (71.9935)  acc5: 84.3750 (90.5845)  time: 0.0256  data: 0.0116  max mem: 14260
Test:   [600/782]  eta: 0:00:06  loss: 1.3772 (1.1848)  acc1: 64.0625 (70.8403)  acc5: 84.3750 (89.7853)  