Downloading PyTorch Vision Reference Scripts for Image Classification. These scripts are official reference implementations from PyTorch Vision that provide training and quantization utilities for image classification models.

In [1]:
! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/presets.py
! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/sampler.py
! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/train.py
! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/train_quantization.py
! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/transformers.py
! wget https://raw.githubusercontent.com/pytorch/vision/main/references/classification/utils.py

--2025-05-22 16:30:12--  https://raw.githubusercontent.com/pytorch/vision/main/references/classification/presets.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3885 (3.8K) [text/plain]
Saving to: ‘presets.py’


2025-05-22 16:30:12 (12.8 MB/s) - ‘presets.py’ saved [3885/3885]



--2025-05-22 16:30:12--  https://raw.githubusercontent.com/pytorch/vision/main/references/classification/sampler.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2395 (2.3K) [text/plain]
Saving to: ‘sampler.py’


2025-05-22 16:30:12 (18.4 MB/s) - ‘sampler.py’ saved [2395/2395]

--2025-05-22 16:30:12--  https://raw.githubusercontent.com/pytorch/vision/main/references/classification/train.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 23324 (23K) [text/plain]
Saving to: ‘train.py’


2025-05-22 16:30:13 (2.28 MB/s) - ‘train.py’

In this block, we build a “loss” function for our sequential policy gradient algorithm. When the right data is plugged in, the gradient of this loss is equal to the policy gradient.

In [2]:
import types
from typing import List, Callable

import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torchvision.models.resnet import BasicBlock


def trp_criterion(trp_blocks: nn.ModuleList, shared_head: Callable, criterion: Callable, lambdas: List[float], hidden_state: Tensor, logits: Tensor, targets: Tensor, loss_normalization=False):
    losses, rewards = criterion(logits, targets)
    returns = torch.ones_like(rewards, dtype=torch.float32, device=rewards.device)
    if loss_normalization:
        coeff = torch.mean(losses).detach()

    embeds = [hidden_state]
    predictions = []
    for k, w in enumerate(lambdas):
        embeds.append(trp_blocks[k](embeds[-1]))
        predictions.append(shared_head(embeds[-1]))
        returns = returns + w * rewards
        replica_losses, rewards = criterion(predictions[-1], targets, rewards)
        losses = losses + replica_losses
    loss = torch.mean(losses * returns)

    if loss_normalization:
        with torch.no_grad():
            coeff = torch.exp(coeff) / torch.exp(loss.detach())
        loss = coeff * loss

    return loss

In this block, we build a TPBlock for the Task Replica Prediction (TRP) module; This implementation provides the backbone without the shared prediction head.

In [3]:
class TPBlock(nn.Module):
    def __init__(self, depths: int, inplanes: int, planes: int):
        super(TPBlock, self).__init__()

        blocks = [BasicBlock(inplanes=inplanes, planes=planes) for _ in range(depths)]
        self.blocks = nn.Sequential(*blocks)
        for name, param in self.blocks.named_parameters():
            if 'conv' in name:
                nn.init.zeros_(param)  # Initialize weights
            elif 'downsample' in name:
                nn.init.zeros_(param)   # Initialize biases

    def forward(self, x):
        return self.blocks(x)

This implementation enables ResNet retraining in SPG mode.

Components:
-------------------------------------------------------------------------------
1. gen_criterion()
    - Purpose: compute per-sample losses and positional masks

2. gen_shared_head()
    - Purpose: Implements a shared prediction head that processes convolutional feature maps for prediction.

3. gen_forward()
    - Purpose: Extended forward pass supporting both traditional inference and SPG retraining.

In [4]:
class ResNetConfig:
    @staticmethod
    def gen_criterion(label_smoothing=0.0, top_k=1):
        def func(input, target, mask=None):
            """
            Args:
                input (Tensor): Input tensor of shape [B, C].
                target (Tensor): Target labels of shape [B] or [B, C].

            Returns:
                loss (Tensor): Scalar tensor representing the loss.
                mask (Tensor): Boolean mask tensor of shape [B].
            """
            label = torch.argmax(target, dim=1) if label_smoothing > 0.0 else target

            unmasked_loss = F.cross_entropy(input, label, reduction="none", label_smoothing=label_smoothing)
            if mask is None:
                mask = torch.ones_like(unmasked_loss, dtype=torch.float32, device=target.device)
            losses = mask * unmasked_loss

            with torch.no_grad():
                topk_values, topk_indices = torch.topk(input, top_k, dim=-1)
                mask = mask * torch.eq(topk_indices, label[:, None]).any(dim=-1).to(input.dtype)

            return losses, mask
        return func

    @staticmethod
    def gen_shared_head(self):
        def func(x):
            """
            Args:
                x (Tensor): Hidden State tensor of shape [B, C, H, W].

            Returns:
                logits (Tensor): Logits tensor of shape [B, C].
            """
            x = self.layer4(x)
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            logits = self.fc(x)
            return logits
        return func

    @staticmethod
    def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
        def func(self, x: Tensor, targets=None) -> Tensor:
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
            x = self.maxpool(x)

            x = self.layer1(x)
            x = self.layer2(x)
            hidden_state = self.layer3(x)
            x = self.layer4(hidden_state)
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            logits = self.fc(x)

            if self.training:
                shared_head = ResNetConfig.gen_shared_head(self)
                criterion = ResNetConfig.gen_criterion(label_smoothing=label_smoothing, top_k=top_k)

                loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, hidden_state, logits, targets, loss_normalization=loss_normalization)

                return logits, loss

            return logits

        return func

Applies TRP modules to the base ResNet (main backbone). The k-th TRP module corresponding to a deeper ResNet variant with an additional depth of 3 * sum(depths[:k+1]).

In [5]:
def apply_trp(model, depths: List[int], planes: int, lambdas: List[float], **kwargs):
    print("✅ Applying TRP to ResNet for Image Classification...")
    model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, inplanes=planes, planes=planes) for d in depths])
    model.forward = types.MethodType(ResNetConfig.gen_forward(lambdas), model)
    return model

The following is a training script for classification models, primarily based on the official TorchVision `train.py` reference implementation. We have made two modifications:

Adding TRP Modules: We integrate TRP modules into the base model architecture before training begins:

```python
if args.apply_trp:
    model = apply_trp(model, args.trp_depths,  args.trp_planes, args.trp_lambdas)
```
Removing TRP Modules: We remove the TRP components from the base model before saving the base model:
```python
if args.output_dir:
    checkpoint = {
        "model": model.state_dict() if not args.apply_trp else {k: v for k, v in model.state_dict().items() if not k.startswith("trp_blocks")},
        "optimizer": optimizer.state_dict(),
        "lr_scheduler": lr_scheduler.state_dict(),
        "epoch": epoch,
        "args": args,
    }
    utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
    utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
```

In [6]:
import datetime
import os
import time
import warnings

import presets
import torch
import torch.utils.data
import torchvision
import utils
from torch import nn
from torchvision.transforms.functional import InterpolationMode


def load_data(traindir, valdir):
    # Data loading code
    print("Loading data")
    interpolation = InterpolationMode("bilinear")

    print("Loading training data")
    st = time.time()
    dataset = torchvision.datasets.ImageFolder(
        traindir,
        presets.ClassificationPresetTrain(crop_size=224, interpolation=interpolation, auto_augment_policy=None, random_erase_prob=0.0, ra_magnitude=9, augmix_severity=3),
    )
    print("Took", time.time() - st)

    print("Loading validation data")
    dataset_test = torchvision.datasets.ImageFolder(
        valdir,
        presets.ClassificationPresetEval(crop_size=224, resize_size=256, interpolation=interpolation)
    )

    print("Creating data loaders")
    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    return dataset, dataset_test, train_sampler, test_sampler



def train_one_epoch(model, optimizer, data_loader, device, epoch, args):
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
    metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))

    header = f"Epoch: [{epoch}]"
    for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
        start_time = time.time()
        image, target = image.to(device), target.to(device)
        with torch.amp.autocast("cuda", enabled=False):
            output, loss = model(image, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
        batch_size = image.shape[0]
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
        metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
        metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
        metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))


def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = f"Test: {log_suffix}"

    num_processed_samples = 0
    with torch.inference_mode():
        for image, target in metric_logger.log_every(data_loader, print_freq, header):
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(image)
            loss = criterion(output, target)

            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
            num_processed_samples += batch_size
    # gather the stats from all processes

    num_processed_samples = utils.reduce_across_processes(num_processed_samples)
    if (
        hasattr(data_loader.dataset, "__len__")
        and len(data_loader.dataset) != num_processed_samples
        and torch.distributed.get_rank() == 0
    ):
        # See FIXME above
        warnings.warn(
            f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
            "samples were used for the validation, which might bias the results. "
            "Try adjusting the batch size and / or the world size. "
            "Setting the world size to 1 is always a safe bet."
        )

    metric_logger.synchronize_between_processes()

    print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
    return metric_logger.acc1.global_avg


def main(args):
    if args.output_dir:
        utils.mkdir(args.output_dir)
    print(args)

    device = torch.device(args.device)

    if args.use_deterministic_algorithms:
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)
    else:
        torch.backends.cudnn.benchmark = True

    train_dir = os.path.join(args.data_path, "train")
    val_dir = os.path.join(args.data_path, "val")
    dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir)

    num_classes = len(dataset.classes)
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=16, pin_memory=True, collate_fn=None)
    data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=64, sampler=test_sampler, num_workers=16, pin_memory=True)

    print("Creating model")
    model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
    if args.apply_trp:
        model = apply_trp(model, args.trp_depths,  args.trp_planes, args.trp_lambdas)
    model.to(device)

    parameters = utils.set_weight_decay(model, args.weight_decay, norm_weight_decay=None, custom_keys_weight_decay=None)
    optimizer = torch.optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=False)

    main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
    warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs)
    lr_scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs])


    print("Start training")
    start_time = time.time()
    for epoch in range(args.epochs):
        train_one_epoch(model, optimizer, data_loader, device, epoch, args)
        lr_scheduler.step()
        evaluate(model, nn.CrossEntropyLoss(), data_loader_test, device=device)
        if args.output_dir:
            checkpoint = {
                "model": model.state_dict() if not args.apply_trp else {k: v for k, v in model.state_dict().items() if not k.startswith("trp_blocks")},  # NOTE: remove TRP heads
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
                "args": args,
            }
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(f"Training time {total_time_str}")


Prepare the [ImageNet](http://image-net.org/) dataset manually and place it in `/path/to/imagenet`. For image classification examples, pass the argument `--data-path=/path/to/imagenet` to the training script. The extracted dataset directory should follow this structure:
```setup
/path/to/imagenet/:
    train/:
        n01440764:
            n01440764_18.JPEG ...
        n01443537:
            n01443537_2.JPEG ...
    val/:
        n01440764:
            ILSVRC2012_val_00000293.JPEG ...
        n01443537:
            ILSVRC2012_val_00000236.JPEG ...
```

Now you can apply the SPG algorithm in model retraining.

**Implementation Note:**

- This demonstration runs on Google Colab using a single GPU configuration
- Performance Improvement: Enhances ResNet18 validation accuracy (ACC@1) from 69.76% to 70.09%
- For optimal results:
  - Refer to our README.md for complete setup instructions
  - Recommended hardware: 4× RTX A6000 GPUs

In [7]:
from types import SimpleNamespace

args = SimpleNamespace(
    data_path="/home/cs/Documents/datasets/imagenet",  # Replace with your /path/to/imagenet
    model="resnet18",
    device="cuda",
    batch_size=512,
    epochs=6,
    lr=0.0004,
    momentum=0.9,
    weight_decay=1e-4,
    lr_warmup_epochs=1,
    lr_warmup_decay=0.0,
    lr_step_size=2,
    lr_gamma=0.5,
    print_freq=100,
    output_dir="resnet18",
    use_deterministic_algorithms=False,
    weights="ResNet18_Weights.IMAGENET1K_V1",
    apply_trp=True,
    trp_depths=[3, 3, 3],
    trp_planes=256,
    trp_lambdas=[0.4, 0.2, 0.1],
)

main(args)

namespace(data_path='/home/cs/Documents/datasets/imagenet', model='resnet18', device='cuda', batch_size=512, epochs=6, lr=0.0004, momentum=0.9, weight_decay=0.0001, lr_warmup_epochs=1, lr_warmup_decay=0.0, lr_step_size=2, lr_gamma=0.5, print_freq=100, output_dir='resnet18', use_deterministic_algorithms=False, weights='ResNet18_Weights.IMAGENET1K_V1', apply_trp=True, trp_depths=[3, 3, 3], trp_planes=256, trp_lambdas=[0.4, 0.2, 0.1])
Loading data
Loading training data


Took 1.9062905311584473
Loading validation data
Creating data loaders
Creating model
✅ Applying TRP to ResNet for Image Classification...
Start training
Epoch: [0]  [   0/2503]  eta: 10:05:09  lr: 0.0  img/s: 81.93631887515438  loss: 0.7334 (0.7334)  acc1: 71.2891 (71.2891)  acc5: 86.1328 (86.1328)  time: 14.5065  data: 8.2577  max mem: 19119
Epoch: [0]  [ 100/2503]  eta: 0:29:06  lr: 0.0  img/s: 862.8257862120394  loss: 0.7145 (0.7308)  acc1: 69.5312 (69.6105)  acc5: 87.6953 (87.3704)  time: 0.5927  data: 0.0003  max mem: 19119
Epoch: [0]  [ 200/2503]  eta: 0:25:23  lr: 0.0  img/s: 860.6862569301302  loss: 0.7355 (0.7353)  acc1: 68.9453 (69.3427)  acc5: 86.9141 (87.3125)  time: 0.5966  data: 0.0003  max mem: 19119
Epoch: [0]  [ 300/2503]  eta: 0:23:29  lr: 0.0  img/s: 860.0754340960929  loss: 0.7159 (0.7314)  acc1: 69.1406 (69.3463)  acc5: 87.5000 (87.3676)  time: 0.5967  data: 0.0003  max mem: 19119
Epoch: [0]  [ 400/2503]  eta: 0:22:03  lr: 0.0  img/s: 859.0790234707376  loss: 0.759



Test:   [  0/782]  eta: 0:23:05  loss: 0.6283 (0.6283)  acc1: 89.0625 (89.0625)  acc5: 95.3125 (95.3125)  time: 1.7719  data: 1.3111  max mem: 19119
Test:   [100/782]  eta: 0:00:30  loss: 1.0688 (0.9382)  acc1: 76.5625 (76.2840)  acc5: 89.0625 (92.1875)  time: 0.0399  data: 0.0263  max mem: 19119
Test:   [200/782]  eta: 0:00:21  loss: 0.9244 (0.9143)  acc1: 73.4375 (75.8240)  acc5: 95.3125 (93.2369)  time: 0.0244  data: 0.0107  max mem: 19119
Test:   [300/782]  eta: 0:00:17  loss: 0.8615 (0.9072)  acc1: 76.5625 (76.1991)  acc5: 92.1875 (93.5008)  time: 0.0381  data: 0.0244  max mem: 19119
Test:   [400/782]  eta: 0:00:13  loss: 1.6977 (1.0440)  acc1: 59.3750 (73.6323)  acc5: 82.8125 (91.7472)  time: 0.0313  data: 0.0176  max mem: 19119
Test:   [500/782]  eta: 0:00:09  loss: 1.6021 (1.1237)  acc1: 54.6875 (72.0964)  acc5: 85.9375 (90.5845)  time: 0.0247  data: 0.0109  max mem: 19119
Test:   [600/782]  eta: 0:00:06  loss: 1.3631 (1.1858)  acc1: 64.0625 (70.8741)  acc5: 84.3750 (89.7853)  