Upload 13 files
Browse files- nas-examples/image-classification/presets.py +71 -0
- nas-examples/image-classification/sampler.py +62 -0
- nas-examples/image-classification/train.py +525 -0
- nas-examples/image-classification/train_quantization.py +265 -0
- nas-examples/image-classification/transforms.py +183 -0
- nas-examples/image-classification/trplib.py +383 -0
- nas-examples/image-classification/utils.py +465 -0
- nas-examples/semantic-segmentation/coco_utils.py +108 -0
- nas-examples/semantic-segmentation/presets.py +39 -0
- nas-examples/semantic-segmentation/train.py +327 -0
- nas-examples/semantic-segmentation/transforms.py +100 -0
- nas-examples/semantic-segmentation/trplib.py +555 -0
- nas-examples/semantic-segmentation/utils.py +300 -0
nas-examples/image-classification/presets.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision.transforms import autoaugment, transforms
|
3 |
+
from torchvision.transforms.functional import InterpolationMode
|
4 |
+
|
5 |
+
|
6 |
+
class ClassificationPresetTrain:
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
*,
|
10 |
+
crop_size,
|
11 |
+
mean=(0.485, 0.456, 0.406),
|
12 |
+
std=(0.229, 0.224, 0.225),
|
13 |
+
interpolation=InterpolationMode.BILINEAR,
|
14 |
+
hflip_prob=0.5,
|
15 |
+
auto_augment_policy=None,
|
16 |
+
ra_magnitude=9,
|
17 |
+
augmix_severity=3,
|
18 |
+
random_erase_prob=0.0,
|
19 |
+
):
|
20 |
+
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
|
21 |
+
if hflip_prob > 0:
|
22 |
+
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
|
23 |
+
if auto_augment_policy is not None:
|
24 |
+
if auto_augment_policy == "ra":
|
25 |
+
trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
|
26 |
+
elif auto_augment_policy == "ta_wide":
|
27 |
+
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
|
28 |
+
elif auto_augment_policy == "augmix":
|
29 |
+
trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity))
|
30 |
+
else:
|
31 |
+
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
|
32 |
+
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
|
33 |
+
trans.extend(
|
34 |
+
[
|
35 |
+
transforms.PILToTensor(),
|
36 |
+
transforms.ConvertImageDtype(torch.float),
|
37 |
+
transforms.Normalize(mean=mean, std=std),
|
38 |
+
]
|
39 |
+
)
|
40 |
+
if random_erase_prob > 0:
|
41 |
+
trans.append(transforms.RandomErasing(p=random_erase_prob))
|
42 |
+
|
43 |
+
self.transforms = transforms.Compose(trans)
|
44 |
+
|
45 |
+
def __call__(self, img):
|
46 |
+
return self.transforms(img)
|
47 |
+
|
48 |
+
|
49 |
+
class ClassificationPresetEval:
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
*,
|
53 |
+
crop_size,
|
54 |
+
resize_size=256,
|
55 |
+
mean=(0.485, 0.456, 0.406),
|
56 |
+
std=(0.229, 0.224, 0.225),
|
57 |
+
interpolation=InterpolationMode.BILINEAR,
|
58 |
+
):
|
59 |
+
|
60 |
+
self.transforms = transforms.Compose(
|
61 |
+
[
|
62 |
+
transforms.Resize(resize_size, interpolation=interpolation),
|
63 |
+
transforms.CenterCrop(crop_size),
|
64 |
+
transforms.PILToTensor(),
|
65 |
+
transforms.ConvertImageDtype(torch.float),
|
66 |
+
transforms.Normalize(mean=mean, std=std),
|
67 |
+
]
|
68 |
+
)
|
69 |
+
|
70 |
+
def __call__(self, img):
|
71 |
+
return self.transforms(img)
|
nas-examples/image-classification/sampler.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.distributed as dist
|
5 |
+
|
6 |
+
|
7 |
+
class RASampler(torch.utils.data.Sampler):
|
8 |
+
"""Sampler that restricts data loading to a subset of the dataset for distributed,
|
9 |
+
with repeated augmentation.
|
10 |
+
It ensures that different each augmented version of a sample will be visible to a
|
11 |
+
different process (GPU).
|
12 |
+
Heavily based on 'torch.utils.data.DistributedSampler'.
|
13 |
+
|
14 |
+
This is borrowed from the DeiT Repo:
|
15 |
+
https://github.com/facebookresearch/deit/blob/main/samplers.py
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3):
|
19 |
+
if num_replicas is None:
|
20 |
+
if not dist.is_available():
|
21 |
+
raise RuntimeError("Requires distributed package to be available!")
|
22 |
+
num_replicas = dist.get_world_size()
|
23 |
+
if rank is None:
|
24 |
+
if not dist.is_available():
|
25 |
+
raise RuntimeError("Requires distributed package to be available!")
|
26 |
+
rank = dist.get_rank()
|
27 |
+
self.dataset = dataset
|
28 |
+
self.num_replicas = num_replicas
|
29 |
+
self.rank = rank
|
30 |
+
self.epoch = 0
|
31 |
+
self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
|
32 |
+
self.total_size = self.num_samples * self.num_replicas
|
33 |
+
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
|
34 |
+
self.shuffle = shuffle
|
35 |
+
self.seed = seed
|
36 |
+
self.repetitions = repetitions
|
37 |
+
|
38 |
+
def __iter__(self):
|
39 |
+
if self.shuffle:
|
40 |
+
# Deterministically shuffle based on epoch
|
41 |
+
g = torch.Generator()
|
42 |
+
g.manual_seed(self.seed + self.epoch)
|
43 |
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
44 |
+
else:
|
45 |
+
indices = list(range(len(self.dataset)))
|
46 |
+
|
47 |
+
# Add extra samples to make it evenly divisible
|
48 |
+
indices = [ele for ele in indices for i in range(self.repetitions)]
|
49 |
+
indices += indices[: (self.total_size - len(indices))]
|
50 |
+
assert len(indices) == self.total_size
|
51 |
+
|
52 |
+
# Subsample
|
53 |
+
indices = indices[self.rank : self.total_size : self.num_replicas]
|
54 |
+
assert len(indices) == self.num_samples
|
55 |
+
|
56 |
+
return iter(indices[: self.num_selected_samples])
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
return self.num_selected_samples
|
60 |
+
|
61 |
+
def set_epoch(self, epoch):
|
62 |
+
self.epoch = epoch
|
nas-examples/image-classification/train.py
ADDED
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
import presets
|
7 |
+
import torch
|
8 |
+
import torch.utils.data
|
9 |
+
import torchvision
|
10 |
+
import transforms
|
11 |
+
import utils
|
12 |
+
from sampler import RASampler
|
13 |
+
from torch import nn
|
14 |
+
from torch.utils.data.dataloader import default_collate
|
15 |
+
from torchvision.transforms.functional import InterpolationMode
|
16 |
+
|
17 |
+
from trplib import apply_trp
|
18 |
+
|
19 |
+
|
20 |
+
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
|
21 |
+
model.train()
|
22 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
23 |
+
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
|
24 |
+
metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
|
25 |
+
|
26 |
+
header = f"Epoch: [{epoch}]"
|
27 |
+
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
|
28 |
+
start_time = time.time()
|
29 |
+
image, target = image.to(device), target.to(device)
|
30 |
+
with torch.amp.autocast("cuda", enabled=scaler is not None):
|
31 |
+
# output = model(image)
|
32 |
+
# loss = criterion(output, target)
|
33 |
+
output, loss = model(image, target)
|
34 |
+
|
35 |
+
optimizer.zero_grad()
|
36 |
+
if scaler is not None:
|
37 |
+
scaler.scale(loss).backward()
|
38 |
+
if args.clip_grad_norm is not None:
|
39 |
+
# we should unscale the gradients of optimizer's assigned params if do gradient clipping
|
40 |
+
scaler.unscale_(optimizer)
|
41 |
+
nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
|
42 |
+
scaler.step(optimizer)
|
43 |
+
scaler.update()
|
44 |
+
else:
|
45 |
+
loss.backward()
|
46 |
+
if args.clip_grad_norm is not None:
|
47 |
+
nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
|
48 |
+
optimizer.step()
|
49 |
+
|
50 |
+
if model_ema and i % args.model_ema_steps == 0:
|
51 |
+
model_ema.update_parameters(model)
|
52 |
+
if epoch < args.lr_warmup_epochs:
|
53 |
+
# Reset ema buffer to keep copying weights during warmup period
|
54 |
+
model_ema.n_averaged.fill_(0)
|
55 |
+
|
56 |
+
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
|
57 |
+
batch_size = image.shape[0]
|
58 |
+
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
|
59 |
+
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
|
60 |
+
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
|
61 |
+
metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
|
62 |
+
|
63 |
+
|
64 |
+
def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
|
65 |
+
model.eval()
|
66 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
67 |
+
header = f"Test: {log_suffix}"
|
68 |
+
|
69 |
+
num_processed_samples = 0
|
70 |
+
with torch.inference_mode():
|
71 |
+
for image, target in metric_logger.log_every(data_loader, print_freq, header):
|
72 |
+
image = image.to(device, non_blocking=True)
|
73 |
+
target = target.to(device, non_blocking=True)
|
74 |
+
output = model(image)
|
75 |
+
loss = criterion(output, target)
|
76 |
+
|
77 |
+
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
|
78 |
+
# FIXME need to take into account that the datasets
|
79 |
+
# could have been padded in distributed setup
|
80 |
+
batch_size = image.shape[0]
|
81 |
+
metric_logger.update(loss=loss.item())
|
82 |
+
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
|
83 |
+
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
|
84 |
+
num_processed_samples += batch_size
|
85 |
+
# gather the stats from all processes
|
86 |
+
|
87 |
+
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
|
88 |
+
if (
|
89 |
+
hasattr(data_loader.dataset, "__len__")
|
90 |
+
and len(data_loader.dataset) != num_processed_samples
|
91 |
+
and torch.distributed.get_rank() == 0
|
92 |
+
):
|
93 |
+
# See FIXME above
|
94 |
+
warnings.warn(
|
95 |
+
f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
|
96 |
+
"samples were used for the validation, which might bias the results. "
|
97 |
+
"Try adjusting the batch size and / or the world size. "
|
98 |
+
"Setting the world size to 1 is always a safe bet."
|
99 |
+
)
|
100 |
+
|
101 |
+
metric_logger.synchronize_between_processes()
|
102 |
+
|
103 |
+
print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
|
104 |
+
return metric_logger.acc1.global_avg
|
105 |
+
|
106 |
+
|
107 |
+
def _get_cache_path(filepath):
|
108 |
+
import hashlib
|
109 |
+
|
110 |
+
h = hashlib.sha1(filepath.encode()).hexdigest()
|
111 |
+
cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
|
112 |
+
cache_path = os.path.expanduser(cache_path)
|
113 |
+
return cache_path
|
114 |
+
|
115 |
+
|
116 |
+
def load_data(traindir, valdir, args):
|
117 |
+
# Data loading code
|
118 |
+
print("Loading data")
|
119 |
+
val_resize_size, val_crop_size, train_crop_size = (
|
120 |
+
args.val_resize_size,
|
121 |
+
args.val_crop_size,
|
122 |
+
args.train_crop_size,
|
123 |
+
)
|
124 |
+
interpolation = InterpolationMode(args.interpolation)
|
125 |
+
|
126 |
+
print("Loading training data")
|
127 |
+
st = time.time()
|
128 |
+
cache_path = _get_cache_path(traindir)
|
129 |
+
if args.cache_dataset and os.path.exists(cache_path):
|
130 |
+
# Attention, as the transforms are also cached!
|
131 |
+
print(f"Loading dataset_train from {cache_path}")
|
132 |
+
dataset, _ = torch.load(cache_path)
|
133 |
+
else:
|
134 |
+
auto_augment_policy = getattr(args, "auto_augment", None)
|
135 |
+
random_erase_prob = getattr(args, "random_erase", 0.0)
|
136 |
+
ra_magnitude = args.ra_magnitude
|
137 |
+
augmix_severity = args.augmix_severity
|
138 |
+
dataset = torchvision.datasets.ImageFolder(
|
139 |
+
traindir,
|
140 |
+
presets.ClassificationPresetTrain(
|
141 |
+
crop_size=train_crop_size,
|
142 |
+
interpolation=interpolation,
|
143 |
+
auto_augment_policy=auto_augment_policy,
|
144 |
+
random_erase_prob=random_erase_prob,
|
145 |
+
ra_magnitude=ra_magnitude,
|
146 |
+
augmix_severity=augmix_severity,
|
147 |
+
),
|
148 |
+
)
|
149 |
+
if args.cache_dataset:
|
150 |
+
print(f"Saving dataset_train to {cache_path}")
|
151 |
+
utils.mkdir(os.path.dirname(cache_path))
|
152 |
+
utils.save_on_master((dataset, traindir), cache_path)
|
153 |
+
print("Took", time.time() - st)
|
154 |
+
|
155 |
+
print("Loading validation data")
|
156 |
+
cache_path = _get_cache_path(valdir)
|
157 |
+
if args.cache_dataset and os.path.exists(cache_path):
|
158 |
+
# Attention, as the transforms are also cached!
|
159 |
+
print(f"Loading dataset_test from {cache_path}")
|
160 |
+
dataset_test, _ = torch.load(cache_path)
|
161 |
+
else:
|
162 |
+
if args.weights and args.test_only:
|
163 |
+
weights = torchvision.models.get_weight(args.weights)
|
164 |
+
preprocessing = weights.transforms()
|
165 |
+
else:
|
166 |
+
preprocessing = presets.ClassificationPresetEval(
|
167 |
+
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
|
168 |
+
)
|
169 |
+
|
170 |
+
dataset_test = torchvision.datasets.ImageFolder(
|
171 |
+
valdir,
|
172 |
+
preprocessing,
|
173 |
+
)
|
174 |
+
if args.cache_dataset:
|
175 |
+
print(f"Saving dataset_test to {cache_path}")
|
176 |
+
utils.mkdir(os.path.dirname(cache_path))
|
177 |
+
utils.save_on_master((dataset_test, valdir), cache_path)
|
178 |
+
|
179 |
+
print("Creating data loaders")
|
180 |
+
if args.distributed:
|
181 |
+
if hasattr(args, "ra_sampler") and args.ra_sampler:
|
182 |
+
train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
|
183 |
+
else:
|
184 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
185 |
+
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
|
186 |
+
else:
|
187 |
+
train_sampler = torch.utils.data.RandomSampler(dataset)
|
188 |
+
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
189 |
+
|
190 |
+
return dataset, dataset_test, train_sampler, test_sampler
|
191 |
+
|
192 |
+
|
193 |
+
def main(args):
|
194 |
+
if args.output_dir:
|
195 |
+
utils.mkdir(args.output_dir)
|
196 |
+
|
197 |
+
utils.init_distributed_mode(args)
|
198 |
+
print(args)
|
199 |
+
|
200 |
+
device = torch.device(args.device)
|
201 |
+
|
202 |
+
if args.use_deterministic_algorithms:
|
203 |
+
torch.backends.cudnn.benchmark = False
|
204 |
+
torch.use_deterministic_algorithms(True)
|
205 |
+
else:
|
206 |
+
torch.backends.cudnn.benchmark = True
|
207 |
+
|
208 |
+
train_dir = os.path.join(args.data_path, "train")
|
209 |
+
val_dir = os.path.join(args.data_path, "val")
|
210 |
+
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
|
211 |
+
|
212 |
+
collate_fn = None
|
213 |
+
num_classes = len(dataset.classes)
|
214 |
+
mixup_transforms = []
|
215 |
+
if args.mixup_alpha > 0.0:
|
216 |
+
mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
|
217 |
+
if args.cutmix_alpha > 0.0:
|
218 |
+
mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
|
219 |
+
if mixup_transforms:
|
220 |
+
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
|
221 |
+
|
222 |
+
def collate_fn(batch):
|
223 |
+
return mixupcutmix(*default_collate(batch))
|
224 |
+
|
225 |
+
data_loader = torch.utils.data.DataLoader(
|
226 |
+
dataset,
|
227 |
+
batch_size=args.batch_size,
|
228 |
+
sampler=train_sampler,
|
229 |
+
num_workers=args.workers,
|
230 |
+
pin_memory=True,
|
231 |
+
collate_fn=collate_fn,
|
232 |
+
)
|
233 |
+
data_loader_test = torch.utils.data.DataLoader(
|
234 |
+
dataset_test, batch_size=8, sampler=test_sampler, num_workers=args.workers, pin_memory=True
|
235 |
+
)
|
236 |
+
|
237 |
+
print("Creating model")
|
238 |
+
model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
|
239 |
+
if args.apply_trp:
|
240 |
+
model = apply_trp(model, args.trp_depths, args.in_planes, args.out_planes, args.trp_rewards, label_smoothing=args.label_smoothing)
|
241 |
+
model.to(device)
|
242 |
+
|
243 |
+
if args.distributed and args.sync_bn:
|
244 |
+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
245 |
+
|
246 |
+
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
247 |
+
|
248 |
+
custom_keys_weight_decay = []
|
249 |
+
if args.bias_weight_decay is not None:
|
250 |
+
custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
|
251 |
+
if args.transformer_embedding_decay is not None:
|
252 |
+
for key in ["class_token", "position_embedding", "relative_position_bias_table"]:
|
253 |
+
custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
|
254 |
+
parameters = utils.set_weight_decay(
|
255 |
+
model,
|
256 |
+
args.weight_decay,
|
257 |
+
norm_weight_decay=args.norm_weight_decay,
|
258 |
+
custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
|
259 |
+
)
|
260 |
+
|
261 |
+
opt_name = args.opt.lower()
|
262 |
+
if opt_name.startswith("sgd"):
|
263 |
+
optimizer = torch.optim.SGD(
|
264 |
+
parameters,
|
265 |
+
lr=args.lr,
|
266 |
+
momentum=args.momentum,
|
267 |
+
weight_decay=args.weight_decay,
|
268 |
+
nesterov="nesterov" in opt_name,
|
269 |
+
)
|
270 |
+
elif opt_name == "rmsprop":
|
271 |
+
optimizer = torch.optim.RMSprop(
|
272 |
+
parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
|
273 |
+
)
|
274 |
+
elif opt_name == "adamw":
|
275 |
+
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
|
276 |
+
else:
|
277 |
+
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")
|
278 |
+
|
279 |
+
scaler = torch.amp.GradScaler("cuda") if args.amp else None
|
280 |
+
|
281 |
+
args.lr_scheduler = args.lr_scheduler.lower()
|
282 |
+
if args.lr_scheduler == "steplr":
|
283 |
+
main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
|
284 |
+
elif args.lr_scheduler == "cosineannealinglr":
|
285 |
+
main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
286 |
+
optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min
|
287 |
+
)
|
288 |
+
elif args.lr_scheduler == "exponentiallr":
|
289 |
+
main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
|
290 |
+
else:
|
291 |
+
raise RuntimeError(
|
292 |
+
f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
|
293 |
+
"are supported."
|
294 |
+
)
|
295 |
+
|
296 |
+
if args.lr_warmup_epochs > 0:
|
297 |
+
if args.lr_warmup_method == "linear":
|
298 |
+
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
|
299 |
+
optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
|
300 |
+
)
|
301 |
+
elif args.lr_warmup_method == "constant":
|
302 |
+
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
|
303 |
+
optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
|
304 |
+
)
|
305 |
+
else:
|
306 |
+
raise RuntimeError(
|
307 |
+
f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
|
308 |
+
)
|
309 |
+
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
|
310 |
+
optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
|
311 |
+
)
|
312 |
+
else:
|
313 |
+
lr_scheduler = main_lr_scheduler
|
314 |
+
|
315 |
+
model_without_ddp = model
|
316 |
+
if args.distributed:
|
317 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
318 |
+
model_without_ddp = model.module
|
319 |
+
|
320 |
+
model_ema = None
|
321 |
+
if args.model_ema:
|
322 |
+
# Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at:
|
323 |
+
# https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
|
324 |
+
#
|
325 |
+
# total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
|
326 |
+
# We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus:
|
327 |
+
# adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
|
328 |
+
adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
|
329 |
+
alpha = 1.0 - args.model_ema_decay
|
330 |
+
alpha = min(1.0, alpha * adjust)
|
331 |
+
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
|
332 |
+
|
333 |
+
if args.resume:
|
334 |
+
checkpoint = torch.load(args.resume, map_location="cpu", weights_only=False)
|
335 |
+
model_without_ddp.load_state_dict(checkpoint["model"])
|
336 |
+
if not args.test_only:
|
337 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
338 |
+
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
|
339 |
+
args.start_epoch = checkpoint["epoch"] + 1
|
340 |
+
if model_ema:
|
341 |
+
model_ema.load_state_dict(checkpoint["model_ema"])
|
342 |
+
if scaler:
|
343 |
+
scaler.load_state_dict(checkpoint["scaler"])
|
344 |
+
|
345 |
+
if args.test_only:
|
346 |
+
# We disable the cudnn benchmarking because it can noticeably affect the accuracy
|
347 |
+
torch.backends.cudnn.benchmark = False
|
348 |
+
torch.backends.cudnn.deterministic = True
|
349 |
+
if model_ema:
|
350 |
+
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
|
351 |
+
else:
|
352 |
+
evaluate(model, criterion, data_loader_test, device=device)
|
353 |
+
return
|
354 |
+
|
355 |
+
print("Start training")
|
356 |
+
start_time = time.time()
|
357 |
+
for epoch in range(args.start_epoch, args.epochs):
|
358 |
+
if args.distributed:
|
359 |
+
train_sampler.set_epoch(epoch)
|
360 |
+
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
|
361 |
+
lr_scheduler.step()
|
362 |
+
evaluate(model, criterion, data_loader_test, device=device)
|
363 |
+
if model_ema:
|
364 |
+
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
|
365 |
+
if args.output_dir:
|
366 |
+
checkpoint = {
|
367 |
+
"model": model_without_ddp.state_dict() if not args.apply_trp else {k: v for k, v in model_without_ddp.state_dict().items() if not "trp_blocks" in k},
|
368 |
+
"optimizer": optimizer.state_dict(),
|
369 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
370 |
+
"epoch": epoch,
|
371 |
+
"args": args,
|
372 |
+
}
|
373 |
+
if model_ema:
|
374 |
+
checkpoint["model_ema"] = model_ema.state_dict() if not args.apply_trp else {k: v for k, v in model_ema.state_dict().items() if not "trp_blocks" in k}
|
375 |
+
if scaler:
|
376 |
+
checkpoint["scaler"] = scaler.state_dict()
|
377 |
+
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
|
378 |
+
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
|
379 |
+
|
380 |
+
total_time = time.time() - start_time
|
381 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
382 |
+
print(f"Training time {total_time_str}")
|
383 |
+
|
384 |
+
|
385 |
+
def get_args_parser(add_help=True):
|
386 |
+
import argparse
|
387 |
+
|
388 |
+
parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
|
389 |
+
|
390 |
+
parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
|
391 |
+
parser.add_argument("--model", default="resnet18", type=str, help="model name")
|
392 |
+
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
|
393 |
+
parser.add_argument(
|
394 |
+
"-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
|
395 |
+
)
|
396 |
+
parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
|
397 |
+
parser.add_argument(
|
398 |
+
"-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
|
399 |
+
)
|
400 |
+
parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
|
401 |
+
parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
|
402 |
+
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
|
403 |
+
parser.add_argument(
|
404 |
+
"--wd",
|
405 |
+
"--weight-decay",
|
406 |
+
default=1e-4,
|
407 |
+
type=float,
|
408 |
+
metavar="W",
|
409 |
+
help="weight decay (default: 1e-4)",
|
410 |
+
dest="weight_decay",
|
411 |
+
)
|
412 |
+
parser.add_argument(
|
413 |
+
"--norm-weight-decay",
|
414 |
+
default=None,
|
415 |
+
type=float,
|
416 |
+
help="weight decay for Normalization layers (default: None, same value as --wd)",
|
417 |
+
)
|
418 |
+
parser.add_argument(
|
419 |
+
"--bias-weight-decay",
|
420 |
+
default=None,
|
421 |
+
type=float,
|
422 |
+
help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
|
423 |
+
)
|
424 |
+
parser.add_argument(
|
425 |
+
"--transformer-embedding-decay",
|
426 |
+
default=None,
|
427 |
+
type=float,
|
428 |
+
help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
|
429 |
+
)
|
430 |
+
parser.add_argument(
|
431 |
+
"--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
|
432 |
+
)
|
433 |
+
parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
|
434 |
+
parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
|
435 |
+
parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
|
436 |
+
parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
|
437 |
+
parser.add_argument(
|
438 |
+
"--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
|
439 |
+
)
|
440 |
+
parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
|
441 |
+
parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
|
442 |
+
parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
|
443 |
+
parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
|
444 |
+
parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
|
445 |
+
parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
|
446 |
+
parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
|
447 |
+
parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
|
448 |
+
parser.add_argument(
|
449 |
+
"--cache-dataset",
|
450 |
+
dest="cache_dataset",
|
451 |
+
help="Cache the datasets for quicker initialization. It also serializes the transforms",
|
452 |
+
action="store_true",
|
453 |
+
)
|
454 |
+
parser.add_argument(
|
455 |
+
"--sync-bn",
|
456 |
+
dest="sync_bn",
|
457 |
+
help="Use sync batch norm",
|
458 |
+
action="store_true",
|
459 |
+
)
|
460 |
+
parser.add_argument(
|
461 |
+
"--test-only",
|
462 |
+
dest="test_only",
|
463 |
+
help="Only test the model",
|
464 |
+
action="store_true",
|
465 |
+
)
|
466 |
+
parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
|
467 |
+
parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy")
|
468 |
+
parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy")
|
469 |
+
parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
|
470 |
+
|
471 |
+
# Mixed precision training parameters
|
472 |
+
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
|
473 |
+
|
474 |
+
# distributed training parameters
|
475 |
+
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
|
476 |
+
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
|
477 |
+
parser.add_argument(
|
478 |
+
"--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
|
479 |
+
)
|
480 |
+
parser.add_argument(
|
481 |
+
"--model-ema-steps",
|
482 |
+
type=int,
|
483 |
+
default=32,
|
484 |
+
help="the number of iterations that controls how often to update the EMA model (default: 32)",
|
485 |
+
)
|
486 |
+
parser.add_argument(
|
487 |
+
"--model-ema-decay",
|
488 |
+
type=float,
|
489 |
+
default=0.99998,
|
490 |
+
help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
|
491 |
+
)
|
492 |
+
parser.add_argument(
|
493 |
+
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
|
494 |
+
)
|
495 |
+
parser.add_argument(
|
496 |
+
"--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
|
497 |
+
)
|
498 |
+
parser.add_argument(
|
499 |
+
"--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
|
500 |
+
)
|
501 |
+
parser.add_argument(
|
502 |
+
"--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
|
503 |
+
)
|
504 |
+
parser.add_argument(
|
505 |
+
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
|
506 |
+
)
|
507 |
+
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
|
508 |
+
parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
|
509 |
+
parser.add_argument(
|
510 |
+
"--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
|
511 |
+
)
|
512 |
+
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
|
513 |
+
|
514 |
+
parser.add_argument("--apply-trp", action="store_true", help="enable applying trp")
|
515 |
+
parser.add_argument("--trp-depths", nargs="+", default=[2, 2, 2], type=int, help="number of depth for each trp block")
|
516 |
+
parser.add_argument("--in-planes", type=int, help="the dimension of the hidden states")
|
517 |
+
parser.add_argument("--out-planes", default=8, type=int, help="the dimension of the inner hidden states")
|
518 |
+
parser.add_argument("--trp-rewards", nargs="+", default=[1.0, 0.4, 0.2, 0.1], type=float, help="trp rewards")
|
519 |
+
|
520 |
+
return parser
|
521 |
+
|
522 |
+
|
523 |
+
if __name__ == "__main__":
|
524 |
+
args = get_args_parser().parse_args()
|
525 |
+
main(args)
|
nas-examples/image-classification/train_quantization.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import datetime
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.ao.quantization
|
8 |
+
import torch.utils.data
|
9 |
+
import torchvision
|
10 |
+
import utils
|
11 |
+
from torch import nn
|
12 |
+
from train import evaluate, load_data, train_one_epoch
|
13 |
+
|
14 |
+
|
15 |
+
def main(args):
|
16 |
+
if args.output_dir:
|
17 |
+
utils.mkdir(args.output_dir)
|
18 |
+
|
19 |
+
utils.init_distributed_mode(args)
|
20 |
+
print(args)
|
21 |
+
|
22 |
+
if args.post_training_quantize and args.distributed:
|
23 |
+
raise RuntimeError("Post training quantization example should not be performed on distributed mode")
|
24 |
+
|
25 |
+
# Set backend engine to ensure that quantized model runs on the correct kernels
|
26 |
+
if args.backend not in torch.backends.quantized.supported_engines:
|
27 |
+
raise RuntimeError("Quantized backend not supported: " + str(args.backend))
|
28 |
+
torch.backends.quantized.engine = args.backend
|
29 |
+
|
30 |
+
device = torch.device(args.device)
|
31 |
+
torch.backends.cudnn.benchmark = True
|
32 |
+
|
33 |
+
# Data loading code
|
34 |
+
print("Loading data")
|
35 |
+
train_dir = os.path.join(args.data_path, "train")
|
36 |
+
val_dir = os.path.join(args.data_path, "val")
|
37 |
+
|
38 |
+
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
|
39 |
+
data_loader = torch.utils.data.DataLoader(
|
40 |
+
dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True
|
41 |
+
)
|
42 |
+
|
43 |
+
data_loader_test = torch.utils.data.DataLoader(
|
44 |
+
dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
|
45 |
+
)
|
46 |
+
|
47 |
+
print("Creating model", args.model)
|
48 |
+
# when training quantized models, we always start from a pre-trained fp32 reference model
|
49 |
+
prefix = "quantized_"
|
50 |
+
model_name = args.model
|
51 |
+
if not model_name.startswith(prefix):
|
52 |
+
model_name = prefix + model_name
|
53 |
+
model = torchvision.models.get_model(model_name, weights=args.weights, quantize=args.test_only)
|
54 |
+
model.to(device)
|
55 |
+
|
56 |
+
if not (args.test_only or args.post_training_quantize):
|
57 |
+
model.fuse_model(is_qat=True)
|
58 |
+
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
|
59 |
+
torch.ao.quantization.prepare_qat(model, inplace=True)
|
60 |
+
|
61 |
+
if args.distributed and args.sync_bn:
|
62 |
+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
63 |
+
|
64 |
+
optimizer = torch.optim.SGD(
|
65 |
+
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
|
66 |
+
)
|
67 |
+
|
68 |
+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
|
69 |
+
|
70 |
+
criterion = nn.CrossEntropyLoss()
|
71 |
+
model_without_ddp = model
|
72 |
+
if args.distributed:
|
73 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
74 |
+
model_without_ddp = model.module
|
75 |
+
|
76 |
+
if args.resume:
|
77 |
+
checkpoint = torch.load(args.resume, map_location="cpu")
|
78 |
+
model_without_ddp.load_state_dict(checkpoint["model"])
|
79 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
80 |
+
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
|
81 |
+
args.start_epoch = checkpoint["epoch"] + 1
|
82 |
+
|
83 |
+
if args.post_training_quantize:
|
84 |
+
# perform calibration on a subset of the training dataset
|
85 |
+
# for that, create a subset of the training dataset
|
86 |
+
ds = torch.utils.data.Subset(dataset, indices=list(range(args.batch_size * args.num_calibration_batches)))
|
87 |
+
data_loader_calibration = torch.utils.data.DataLoader(
|
88 |
+
ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
|
89 |
+
)
|
90 |
+
model.eval()
|
91 |
+
model.fuse_model(is_qat=False)
|
92 |
+
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
|
93 |
+
torch.ao.quantization.prepare(model, inplace=True)
|
94 |
+
# Calibrate first
|
95 |
+
print("Calibrating")
|
96 |
+
evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
|
97 |
+
torch.ao.quantization.convert(model, inplace=True)
|
98 |
+
if args.output_dir:
|
99 |
+
print("Saving quantized model")
|
100 |
+
if utils.is_main_process():
|
101 |
+
torch.save(model.state_dict(), os.path.join(args.output_dir, "quantized_post_train_model.pth"))
|
102 |
+
print("Evaluating post-training quantized model")
|
103 |
+
evaluate(model, criterion, data_loader_test, device=device)
|
104 |
+
return
|
105 |
+
|
106 |
+
if args.test_only:
|
107 |
+
evaluate(model, criterion, data_loader_test, device=device)
|
108 |
+
return
|
109 |
+
|
110 |
+
model.apply(torch.ao.quantization.enable_observer)
|
111 |
+
model.apply(torch.ao.quantization.enable_fake_quant)
|
112 |
+
start_time = time.time()
|
113 |
+
for epoch in range(args.start_epoch, args.epochs):
|
114 |
+
if args.distributed:
|
115 |
+
train_sampler.set_epoch(epoch)
|
116 |
+
print("Starting training for epoch", epoch)
|
117 |
+
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args)
|
118 |
+
lr_scheduler.step()
|
119 |
+
with torch.inference_mode():
|
120 |
+
if epoch >= args.num_observer_update_epochs:
|
121 |
+
print("Disabling observer for subseq epochs, epoch = ", epoch)
|
122 |
+
model.apply(torch.ao.quantization.disable_observer)
|
123 |
+
if epoch >= args.num_batch_norm_update_epochs:
|
124 |
+
print("Freezing BN for subseq epochs, epoch = ", epoch)
|
125 |
+
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
|
126 |
+
print("Evaluate QAT model")
|
127 |
+
|
128 |
+
evaluate(model, criterion, data_loader_test, device=device, log_suffix="QAT")
|
129 |
+
quantized_eval_model = copy.deepcopy(model_without_ddp)
|
130 |
+
quantized_eval_model.eval()
|
131 |
+
quantized_eval_model.to(torch.device("cpu"))
|
132 |
+
torch.ao.quantization.convert(quantized_eval_model, inplace=True)
|
133 |
+
|
134 |
+
print("Evaluate Quantized model")
|
135 |
+
evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
|
136 |
+
|
137 |
+
model.train()
|
138 |
+
|
139 |
+
if args.output_dir:
|
140 |
+
checkpoint = {
|
141 |
+
"model": model_without_ddp.state_dict(),
|
142 |
+
"eval_model": quantized_eval_model.state_dict(),
|
143 |
+
"optimizer": optimizer.state_dict(),
|
144 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
145 |
+
"epoch": epoch,
|
146 |
+
"args": args,
|
147 |
+
}
|
148 |
+
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
|
149 |
+
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
|
150 |
+
print("Saving models after epoch ", epoch)
|
151 |
+
|
152 |
+
total_time = time.time() - start_time
|
153 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
154 |
+
print(f"Training time {total_time_str}")
|
155 |
+
|
156 |
+
|
157 |
+
def get_args_parser(add_help=True):
|
158 |
+
import argparse
|
159 |
+
|
160 |
+
parser = argparse.ArgumentParser(description="PyTorch Quantized Classification Training", add_help=add_help)
|
161 |
+
|
162 |
+
parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
|
163 |
+
parser.add_argument("--model", default="mobilenet_v2", type=str, help="model name")
|
164 |
+
parser.add_argument("--backend", default="qnnpack", type=str, help="fbgemm or qnnpack")
|
165 |
+
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
|
166 |
+
|
167 |
+
parser.add_argument(
|
168 |
+
"-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
|
169 |
+
)
|
170 |
+
parser.add_argument("--eval-batch-size", default=128, type=int, help="batch size for evaluation")
|
171 |
+
parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
|
172 |
+
parser.add_argument(
|
173 |
+
"--num-observer-update-epochs",
|
174 |
+
default=4,
|
175 |
+
type=int,
|
176 |
+
metavar="N",
|
177 |
+
help="number of total epochs to update observers",
|
178 |
+
)
|
179 |
+
parser.add_argument(
|
180 |
+
"--num-batch-norm-update-epochs",
|
181 |
+
default=3,
|
182 |
+
type=int,
|
183 |
+
metavar="N",
|
184 |
+
help="number of total epochs to update batch norm stats",
|
185 |
+
)
|
186 |
+
parser.add_argument(
|
187 |
+
"--num-calibration-batches",
|
188 |
+
default=32,
|
189 |
+
type=int,
|
190 |
+
metavar="N",
|
191 |
+
help="number of batches of training set for \
|
192 |
+
observer calibration ",
|
193 |
+
)
|
194 |
+
|
195 |
+
parser.add_argument(
|
196 |
+
"-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
|
197 |
+
)
|
198 |
+
parser.add_argument("--lr", default=0.0001, type=float, help="initial learning rate")
|
199 |
+
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
|
200 |
+
parser.add_argument(
|
201 |
+
"--wd",
|
202 |
+
"--weight-decay",
|
203 |
+
default=1e-4,
|
204 |
+
type=float,
|
205 |
+
metavar="W",
|
206 |
+
help="weight decay (default: 1e-4)",
|
207 |
+
dest="weight_decay",
|
208 |
+
)
|
209 |
+
parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
|
210 |
+
parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
|
211 |
+
parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
|
212 |
+
parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
|
213 |
+
parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
|
214 |
+
parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
|
215 |
+
parser.add_argument(
|
216 |
+
"--cache-dataset",
|
217 |
+
dest="cache_dataset",
|
218 |
+
help="Cache the datasets for quicker initialization. \
|
219 |
+
It also serializes the transforms",
|
220 |
+
action="store_true",
|
221 |
+
)
|
222 |
+
parser.add_argument(
|
223 |
+
"--sync-bn",
|
224 |
+
dest="sync_bn",
|
225 |
+
help="Use sync batch norm",
|
226 |
+
action="store_true",
|
227 |
+
)
|
228 |
+
parser.add_argument(
|
229 |
+
"--test-only",
|
230 |
+
dest="test_only",
|
231 |
+
help="Only test the model",
|
232 |
+
action="store_true",
|
233 |
+
)
|
234 |
+
parser.add_argument(
|
235 |
+
"--post-training-quantize",
|
236 |
+
dest="post_training_quantize",
|
237 |
+
help="Post training quantize the model",
|
238 |
+
action="store_true",
|
239 |
+
)
|
240 |
+
|
241 |
+
# distributed training parameters
|
242 |
+
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
|
243 |
+
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
|
244 |
+
|
245 |
+
parser.add_argument(
|
246 |
+
"--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
|
247 |
+
)
|
248 |
+
parser.add_argument(
|
249 |
+
"--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
|
250 |
+
)
|
251 |
+
parser.add_argument(
|
252 |
+
"--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
|
253 |
+
)
|
254 |
+
parser.add_argument(
|
255 |
+
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
|
256 |
+
)
|
257 |
+
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
|
258 |
+
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
|
259 |
+
|
260 |
+
return parser
|
261 |
+
|
262 |
+
|
263 |
+
if __name__ == "__main__":
|
264 |
+
args = get_args_parser().parse_args()
|
265 |
+
main(args)
|
nas-examples/image-classification/transforms.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import Tensor
|
6 |
+
from torchvision.transforms import functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class RandomMixup(torch.nn.Module):
|
10 |
+
"""Randomly apply Mixup to the provided batch and targets.
|
11 |
+
The class implements the data augmentations as described in the paper
|
12 |
+
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
num_classes (int): number of classes used for one-hot encoding.
|
16 |
+
p (float): probability of the batch being transformed. Default value is 0.5.
|
17 |
+
alpha (float): hyperparameter of the Beta distribution used for mixup.
|
18 |
+
Default value is 1.0.
|
19 |
+
inplace (bool): boolean to make this transform inplace. Default set to False.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
if num_classes < 1:
|
26 |
+
raise ValueError(
|
27 |
+
f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}"
|
28 |
+
)
|
29 |
+
|
30 |
+
if alpha <= 0:
|
31 |
+
raise ValueError("Alpha param can't be zero.")
|
32 |
+
|
33 |
+
self.num_classes = num_classes
|
34 |
+
self.p = p
|
35 |
+
self.alpha = alpha
|
36 |
+
self.inplace = inplace
|
37 |
+
|
38 |
+
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
|
39 |
+
"""
|
40 |
+
Args:
|
41 |
+
batch (Tensor): Float tensor of size (B, C, H, W)
|
42 |
+
target (Tensor): Integer tensor of size (B, )
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
Tensor: Randomly transformed batch.
|
46 |
+
"""
|
47 |
+
if batch.ndim != 4:
|
48 |
+
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
|
49 |
+
if target.ndim != 1:
|
50 |
+
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
|
51 |
+
if not batch.is_floating_point():
|
52 |
+
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
|
53 |
+
if target.dtype != torch.int64:
|
54 |
+
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
|
55 |
+
|
56 |
+
if not self.inplace:
|
57 |
+
batch = batch.clone()
|
58 |
+
target = target.clone()
|
59 |
+
|
60 |
+
if target.ndim == 1:
|
61 |
+
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
|
62 |
+
|
63 |
+
if torch.rand(1).item() >= self.p:
|
64 |
+
return batch, target
|
65 |
+
|
66 |
+
# It's faster to roll the batch by one instead of shuffling it to create image pairs
|
67 |
+
batch_rolled = batch.roll(1, 0)
|
68 |
+
target_rolled = target.roll(1, 0)
|
69 |
+
|
70 |
+
# Implemented as on mixup paper, page 3.
|
71 |
+
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
|
72 |
+
batch_rolled.mul_(1.0 - lambda_param)
|
73 |
+
batch.mul_(lambda_param).add_(batch_rolled)
|
74 |
+
|
75 |
+
target_rolled.mul_(1.0 - lambda_param)
|
76 |
+
target.mul_(lambda_param).add_(target_rolled)
|
77 |
+
|
78 |
+
return batch, target
|
79 |
+
|
80 |
+
def __repr__(self) -> str:
|
81 |
+
s = (
|
82 |
+
f"{self.__class__.__name__}("
|
83 |
+
f"num_classes={self.num_classes}"
|
84 |
+
f", p={self.p}"
|
85 |
+
f", alpha={self.alpha}"
|
86 |
+
f", inplace={self.inplace}"
|
87 |
+
f")"
|
88 |
+
)
|
89 |
+
return s
|
90 |
+
|
91 |
+
|
92 |
+
class RandomCutmix(torch.nn.Module):
|
93 |
+
"""Randomly apply Cutmix to the provided batch and targets.
|
94 |
+
The class implements the data augmentations as described in the paper
|
95 |
+
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
|
96 |
+
<https://arxiv.org/abs/1905.04899>`_.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
num_classes (int): number of classes used for one-hot encoding.
|
100 |
+
p (float): probability of the batch being transformed. Default value is 0.5.
|
101 |
+
alpha (float): hyperparameter of the Beta distribution used for cutmix.
|
102 |
+
Default value is 1.0.
|
103 |
+
inplace (bool): boolean to make this transform inplace. Default set to False.
|
104 |
+
"""
|
105 |
+
|
106 |
+
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
|
107 |
+
super().__init__()
|
108 |
+
if num_classes < 1:
|
109 |
+
raise ValueError("Please provide a valid positive value for the num_classes.")
|
110 |
+
if alpha <= 0:
|
111 |
+
raise ValueError("Alpha param can't be zero.")
|
112 |
+
|
113 |
+
self.num_classes = num_classes
|
114 |
+
self.p = p
|
115 |
+
self.alpha = alpha
|
116 |
+
self.inplace = inplace
|
117 |
+
|
118 |
+
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
|
119 |
+
"""
|
120 |
+
Args:
|
121 |
+
batch (Tensor): Float tensor of size (B, C, H, W)
|
122 |
+
target (Tensor): Integer tensor of size (B, )
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
Tensor: Randomly transformed batch.
|
126 |
+
"""
|
127 |
+
if batch.ndim != 4:
|
128 |
+
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
|
129 |
+
if target.ndim != 1:
|
130 |
+
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
|
131 |
+
if not batch.is_floating_point():
|
132 |
+
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
|
133 |
+
if target.dtype != torch.int64:
|
134 |
+
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
|
135 |
+
|
136 |
+
if not self.inplace:
|
137 |
+
batch = batch.clone()
|
138 |
+
target = target.clone()
|
139 |
+
|
140 |
+
if target.ndim == 1:
|
141 |
+
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
|
142 |
+
|
143 |
+
if torch.rand(1).item() >= self.p:
|
144 |
+
return batch, target
|
145 |
+
|
146 |
+
# It's faster to roll the batch by one instead of shuffling it to create image pairs
|
147 |
+
batch_rolled = batch.roll(1, 0)
|
148 |
+
target_rolled = target.roll(1, 0)
|
149 |
+
|
150 |
+
# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
|
151 |
+
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
|
152 |
+
_, H, W = F.get_dimensions(batch)
|
153 |
+
|
154 |
+
r_x = torch.randint(W, (1,))
|
155 |
+
r_y = torch.randint(H, (1,))
|
156 |
+
|
157 |
+
r = 0.5 * math.sqrt(1.0 - lambda_param)
|
158 |
+
r_w_half = int(r * W)
|
159 |
+
r_h_half = int(r * H)
|
160 |
+
|
161 |
+
x1 = int(torch.clamp(r_x - r_w_half, min=0))
|
162 |
+
y1 = int(torch.clamp(r_y - r_h_half, min=0))
|
163 |
+
x2 = int(torch.clamp(r_x + r_w_half, max=W))
|
164 |
+
y2 = int(torch.clamp(r_y + r_h_half, max=H))
|
165 |
+
|
166 |
+
batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
|
167 |
+
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
|
168 |
+
|
169 |
+
target_rolled.mul_(1.0 - lambda_param)
|
170 |
+
target.mul_(lambda_param).add_(target_rolled)
|
171 |
+
|
172 |
+
return batch, target
|
173 |
+
|
174 |
+
def __repr__(self) -> str:
|
175 |
+
s = (
|
176 |
+
f"{self.__class__.__name__}("
|
177 |
+
f"num_classes={self.num_classes}"
|
178 |
+
f", p={self.p}"
|
179 |
+
f", alpha={self.alpha}"
|
180 |
+
f", inplace={self.inplace}"
|
181 |
+
f")"
|
182 |
+
)
|
183 |
+
return s
|
nas-examples/image-classification/trplib.py
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import types
|
2 |
+
from typing import Optional, List, Union, Callable
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn, Tensor
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from torchvision.models.mobilenetv2 import MobileNetV2
|
9 |
+
from torchvision.models.resnet import ResNet
|
10 |
+
from torchvision.models.efficientnet import EfficientNet
|
11 |
+
from torchvision.models.vision_transformer import VisionTransformer
|
12 |
+
|
13 |
+
|
14 |
+
def compute_policy_loss(loss_sequence, mask_sequence, rewards):
|
15 |
+
losses = sum(mask * padded_loss for mask, padded_loss in zip(mask_sequence, loss_sequence))
|
16 |
+
returns = sum(padded_reward * mask for padded_reward, mask in zip(rewards, mask_sequence))
|
17 |
+
loss = torch.mean(losses * returns)
|
18 |
+
|
19 |
+
return loss
|
20 |
+
|
21 |
+
|
22 |
+
class TPBlock(nn.Module):
|
23 |
+
def __init__(self, depths: int, in_planes: int, out_planes: int = None, rank=1, shape_dims=3, channel_first=True, dtype=torch.float32) -> None:
|
24 |
+
super().__init__()
|
25 |
+
out_planes = in_planes if out_planes is None else out_planes
|
26 |
+
self.layers = torch.nn.ModuleList([self._make_layer(in_planes, out_planes, rank, shape_dims, channel_first, dtype) for _ in range(depths)])
|
27 |
+
|
28 |
+
def forward(self, x: Tensor) -> Tensor:
|
29 |
+
for layer in self.layers:
|
30 |
+
x = x + layer(x)
|
31 |
+
return x
|
32 |
+
|
33 |
+
def _make_layer(self, in_planes: int, out_planes: int = None, rank=1, shape_dims=3, channel_first=True, dtype=torch.float32) -> nn.Sequential:
|
34 |
+
|
35 |
+
class Permute(nn.Module):
|
36 |
+
def __init__(self, *dims):
|
37 |
+
super().__init__()
|
38 |
+
self.dims = dims
|
39 |
+
def forward(self, x):
|
40 |
+
return x.permute(*self.dims)
|
41 |
+
|
42 |
+
class RMSNorm(nn.Module):
|
43 |
+
__constants__ = ["eps"]
|
44 |
+
eps: float
|
45 |
+
|
46 |
+
def __init__(self, hidden_size, eps: float = 1e-6, device=None, dtype=None):
|
47 |
+
"""
|
48 |
+
LlamaRMSNorm is equivalent to T5LayerNorm.
|
49 |
+
"""
|
50 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
51 |
+
super().__init__()
|
52 |
+
self.eps = eps
|
53 |
+
self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs))
|
54 |
+
|
55 |
+
def forward(self, hidden_states):
|
56 |
+
input_dtype = hidden_states.dtype
|
57 |
+
hidden_states = hidden_states.to(torch.float32)
|
58 |
+
variance = hidden_states.pow(2).mean(dim=1, keepdim=True)
|
59 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
60 |
+
weight = self.weight.view(1, -1, *[1] * (hidden_states.ndim - 2))
|
61 |
+
return weight * hidden_states.to(input_dtype)
|
62 |
+
|
63 |
+
def extra_repr(self):
|
64 |
+
return f"{self.weight.shape[0]}, eps={self.eps}"
|
65 |
+
|
66 |
+
conv_map = {
|
67 |
+
2: (nn.Conv1d, (0, 2, 1), (0, 2, 1)),
|
68 |
+
3: (nn.Conv2d, (0, 3, 1, 2), (0, 2, 3, 1)),
|
69 |
+
4: (nn.Conv3d, (0, 4, 1, 2, 3), (0, 2, 3, 4, 1)),
|
70 |
+
}
|
71 |
+
Conv, pre_dims, post_dims = conv_map[shape_dims]
|
72 |
+
kernel_size, dilation, padding = self.generate_hyperparameters(rank)
|
73 |
+
|
74 |
+
pre_permute = nn.Identity() if channel_first else Permute(*pre_dims)
|
75 |
+
post_permute = nn.Identity() if channel_first else Permute(*post_dims)
|
76 |
+
conv1 = Conv(in_planes, out_planes, kernel_size, padding=padding, dilation=dilation, bias=False, dtype=dtype, device='cuda')
|
77 |
+
nn.init.zeros_(conv1.weight)
|
78 |
+
bn1 = RMSNorm(out_planes, dtype=dtype, device="cuda")
|
79 |
+
relu = nn.ReLU(inplace=True)
|
80 |
+
conv2 = Conv(out_planes, in_planes, kernel_size, padding=padding, dilation=dilation, bias=False, dtype=dtype, device='cuda')
|
81 |
+
nn.init.zeros_(conv2.weight)
|
82 |
+
bn2 = RMSNorm(in_planes, dtype=dtype, device="cuda")
|
83 |
+
|
84 |
+
return torch.nn.Sequential(pre_permute, conv1, bn1, relu, conv2, bn2, relu, post_permute)
|
85 |
+
|
86 |
+
@staticmethod
|
87 |
+
def generate_hyperparameters(rank: int):
|
88 |
+
"""
|
89 |
+
Generates kernel size and dilation rate pairs sorted by increasing padded kernel size.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
rank: Number of (kernel_size, dilation) pairs to generate. Must be positive.
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
Tuple[int, int]: A (kernel_size, dilation) tuple where:
|
96 |
+
- kernel_size: Always odd and >= 1
|
97 |
+
- dilation: Computed to maintain consistent padded kernel size growth
|
98 |
+
|
99 |
+
Note:
|
100 |
+
Padded kernel size is calculated as:
|
101 |
+
(kernel_size - 1) * dilation + 1
|
102 |
+
Pairs are generated first in order of increasing padded kernel size,
|
103 |
+
then by increasing kernel size for equal padded kernel sizes.
|
104 |
+
"""
|
105 |
+
pairs = [(1, 1, 0)] # Start with smallest possible
|
106 |
+
padded_kernel_size = 3
|
107 |
+
|
108 |
+
while len(pairs) < rank:
|
109 |
+
for kernel_size in range(3, padded_kernel_size + 1, 2):
|
110 |
+
if (padded_kernel_size - 1) % (kernel_size - 1) == 0:
|
111 |
+
dilation = (padded_kernel_size - 1) // (kernel_size - 1)
|
112 |
+
padding = dilation * (kernel_size - 1) // 2
|
113 |
+
pairs.append((kernel_size, dilation, padding))
|
114 |
+
if len(pairs) >= rank:
|
115 |
+
break
|
116 |
+
|
117 |
+
# Move to next odd padded kernel size
|
118 |
+
padded_kernel_size += 2
|
119 |
+
|
120 |
+
return pairs[-1]
|
121 |
+
|
122 |
+
|
123 |
+
class ResNetConfig:
|
124 |
+
@staticmethod
|
125 |
+
def gen_shared_head(self):
|
126 |
+
def func(hidden_states):
|
127 |
+
"""
|
128 |
+
Args:
|
129 |
+
hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
133 |
+
"""
|
134 |
+
x = self.avgpool(hidden_states)
|
135 |
+
x = torch.flatten(x, 1)
|
136 |
+
logits = self.fc(x)
|
137 |
+
return logits
|
138 |
+
return func
|
139 |
+
|
140 |
+
@staticmethod
|
141 |
+
def gen_logits(self, shared_head):
|
142 |
+
def func(hidden_states):
|
143 |
+
"""
|
144 |
+
Args:
|
145 |
+
hidden_states (Tensor): Hidden States tensor of shape [B, L, hidden_units].
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
logits_seqence (List[Tensor]): List of Logits tensors.
|
149 |
+
"""
|
150 |
+
logits_sequence = [shared_head(hidden_states)]
|
151 |
+
for layer in self.trp_blocks:
|
152 |
+
logits_sequence.append(shared_head(layer(hidden_states)))
|
153 |
+
return logits_sequence
|
154 |
+
return func
|
155 |
+
|
156 |
+
@staticmethod
|
157 |
+
def gen_mask(label_smoothing=0.0, top_k=1):
|
158 |
+
def func(logits_sequence, labels):
|
159 |
+
"""
|
160 |
+
Args:
|
161 |
+
logits_sequence (List[Tensor]): List of Logits tensors.
|
162 |
+
labels (Tensor): Target labels of shape [B] or [B, C].
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
mask_sequence (List[Tensor]): List of Mask tensor.
|
166 |
+
returns (Tensor): Boolean mask tensor of shape [B*(L-1)].
|
167 |
+
"""
|
168 |
+
labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels
|
169 |
+
|
170 |
+
mask_sequence = [torch.ones_like(labels, dtype=torch.float32, device=labels.device)]
|
171 |
+
for logits in logits_sequence:
|
172 |
+
with torch.no_grad():
|
173 |
+
topk_values, topk_indices = torch.topk(logits, top_k, dim=-1)
|
174 |
+
mask = torch.eq(topk_indices, labels[:, None]).any(dim=-1).to(torch.float32)
|
175 |
+
mask_sequence.append(mask_sequence[-1] * mask)
|
176 |
+
return mask_sequence
|
177 |
+
return func
|
178 |
+
|
179 |
+
@staticmethod
|
180 |
+
def gen_criterion(label_smoothing=0.0):
|
181 |
+
def func(logits_sequence, labels):
|
182 |
+
"""
|
183 |
+
Args:
|
184 |
+
logits_sequence (List[Tensor]): List of Logits tensor.
|
185 |
+
labels (Tensor): labels labels of shape [B] or [B, C].
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
loss (Tensor): Scalar tensor representing the loss.
|
189 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
190 |
+
"""
|
191 |
+
labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels
|
192 |
+
|
193 |
+
loss_sequence = []
|
194 |
+
for logits in logits_sequence:
|
195 |
+
loss_sequence.append(F.cross_entropy(logits, labels, reduction="none", label_smoothing=label_smoothing))
|
196 |
+
|
197 |
+
return loss_sequence
|
198 |
+
return func
|
199 |
+
|
200 |
+
@staticmethod
|
201 |
+
def gen_forward(rewards, label_smoothing=0.0, top_k=1):
|
202 |
+
def func(self, x: Tensor, targets=None) -> Tensor:
|
203 |
+
x = self.conv1(x)
|
204 |
+
x = self.bn1(x)
|
205 |
+
x = self.relu(x)
|
206 |
+
x = self.maxpool(x)
|
207 |
+
|
208 |
+
x = self.layer1(x)
|
209 |
+
x = self.layer2(x)
|
210 |
+
x = self.layer3(x)
|
211 |
+
hidden_states = self.layer4(x)
|
212 |
+
x = self.avgpool(hidden_states)
|
213 |
+
x = torch.flatten(x, 1)
|
214 |
+
logits = self.fc(x)
|
215 |
+
|
216 |
+
if self.training:
|
217 |
+
shared_head = ResNetConfig.gen_shared_head(self)
|
218 |
+
compute_logits = ResNetConfig.gen_logits(self, shared_head)
|
219 |
+
compute_mask = ResNetConfig.gen_mask(label_smoothing, top_k)
|
220 |
+
compute_loss = ResNetConfig.gen_criterion(label_smoothing)
|
221 |
+
|
222 |
+
logits_sequence = compute_logits(hidden_states)
|
223 |
+
mask_sequence = compute_mask(logits_sequence, targets)
|
224 |
+
loss_sequence = compute_loss(logits_sequence, targets)
|
225 |
+
loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
|
226 |
+
|
227 |
+
return logits, loss
|
228 |
+
|
229 |
+
return logits
|
230 |
+
|
231 |
+
return func
|
232 |
+
|
233 |
+
|
234 |
+
class MobileNetV2Config(ResNetConfig):
|
235 |
+
@staticmethod
|
236 |
+
def gen_shared_head(self):
|
237 |
+
def func(hidden_states):
|
238 |
+
"""
|
239 |
+
Args:
|
240 |
+
hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
|
241 |
+
|
242 |
+
Returns:
|
243 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
244 |
+
"""
|
245 |
+
x = nn.functional.adaptive_avg_pool2d(hidden_states, (1, 1))
|
246 |
+
x = torch.flatten(x, 1)
|
247 |
+
logits = self.classifier(x)
|
248 |
+
return logits
|
249 |
+
return func
|
250 |
+
|
251 |
+
@staticmethod
|
252 |
+
def gen_forward(rewards, label_smoothing=0.0, top_k=1):
|
253 |
+
def func(self, x: Tensor, targets=None) -> Tensor:
|
254 |
+
hidden_states = self.features(x)
|
255 |
+
# Cannot use "squeeze" as batch-size can be 1
|
256 |
+
x = nn.functional.adaptive_avg_pool2d(hidden_states, (1, 1))
|
257 |
+
x = torch.flatten(x, 1)
|
258 |
+
logits = self.classifier(x)
|
259 |
+
|
260 |
+
if self.training:
|
261 |
+
shared_head = MobileNetV2Config.gen_shared_head(self)
|
262 |
+
compute_logits = MobileNetV2Config.gen_logits(self, shared_head)
|
263 |
+
compute_mask = MobileNetV2Config.gen_mask(label_smoothing, top_k)
|
264 |
+
compute_loss = MobileNetV2Config.gen_criterion(label_smoothing)
|
265 |
+
|
266 |
+
logits_sequence = compute_logits(hidden_states)
|
267 |
+
mask_sequence = compute_mask(logits_sequence, targets)
|
268 |
+
loss_sequence = compute_loss(logits_sequence, targets)
|
269 |
+
loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
|
270 |
+
|
271 |
+
return logits, loss
|
272 |
+
|
273 |
+
return logits
|
274 |
+
|
275 |
+
return func
|
276 |
+
|
277 |
+
|
278 |
+
class EfficientNetConfig(ResNetConfig):
|
279 |
+
@staticmethod
|
280 |
+
def gen_shared_head(self):
|
281 |
+
def func(hidden_states):
|
282 |
+
"""
|
283 |
+
Args:
|
284 |
+
hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
|
285 |
+
|
286 |
+
Returns:
|
287 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
288 |
+
"""
|
289 |
+
x = self.avgpool(hidden_states)
|
290 |
+
x = torch.flatten(x, 1)
|
291 |
+
logits = self.classifier(x)
|
292 |
+
return logits
|
293 |
+
return func
|
294 |
+
|
295 |
+
@staticmethod
|
296 |
+
def gen_forward(rewards, label_smoothing=0.0, top_k=1):
|
297 |
+
def func(self, x: Tensor, targets=None) -> Tensor:
|
298 |
+
hidden_states = self.features(x)
|
299 |
+
x = self.avgpool(hidden_states)
|
300 |
+
x = torch.flatten(x, 1)
|
301 |
+
logits = self.classifier(x)
|
302 |
+
|
303 |
+
if self.training:
|
304 |
+
shared_head = EfficientNetConfig.gen_shared_head(self)
|
305 |
+
compute_logits = EfficientNetConfig.gen_logits(self, shared_head)
|
306 |
+
compute_mask = EfficientNetConfig.gen_mask(label_smoothing, top_k)
|
307 |
+
compute_loss = EfficientNetConfig.gen_criterion(label_smoothing)
|
308 |
+
|
309 |
+
logits_sequence = compute_logits(hidden_states)
|
310 |
+
mask_sequence = compute_mask(logits_sequence, targets)
|
311 |
+
loss_sequence = compute_loss(logits_sequence, targets)
|
312 |
+
loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
|
313 |
+
|
314 |
+
return logits, loss
|
315 |
+
|
316 |
+
return logits
|
317 |
+
|
318 |
+
return func
|
319 |
+
|
320 |
+
|
321 |
+
class VisionTransformerConfig(ResNetConfig):
|
322 |
+
@staticmethod
|
323 |
+
def gen_shared_head(self):
|
324 |
+
def func(hidden_states):
|
325 |
+
"""
|
326 |
+
Args:
|
327 |
+
hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
|
328 |
+
|
329 |
+
Returns:
|
330 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
331 |
+
"""
|
332 |
+
x = hidden_states[:, 0]
|
333 |
+
logits = self.heads(x)
|
334 |
+
return logits
|
335 |
+
return func
|
336 |
+
|
337 |
+
@staticmethod
|
338 |
+
def gen_forward(rewards, label_smoothing=0.0, top_k=1):
|
339 |
+
def func(self, images: Tensor, targets=None):
|
340 |
+
x = self._process_input(images)
|
341 |
+
n = x.shape[0]
|
342 |
+
batch_class_token = self.class_token.expand(n, -1, -1)
|
343 |
+
x = torch.cat([batch_class_token, x], dim=1)
|
344 |
+
hidden_states = self.encoder(x)
|
345 |
+
x = hidden_states[:, 0]
|
346 |
+
|
347 |
+
logits = self.heads(x)
|
348 |
+
|
349 |
+
|
350 |
+
if self.training:
|
351 |
+
shared_head = VisionTransformerConfig.gen_shared_head(self)
|
352 |
+
compute_logits = VisionTransformerConfig.gen_logits(self, shared_head)
|
353 |
+
compute_mask = VisionTransformerConfig.gen_mask(label_smoothing, top_k)
|
354 |
+
compute_loss = VisionTransformerConfig.gen_criterion(label_smoothing)
|
355 |
+
|
356 |
+
logits_sequence = compute_logits(hidden_states)
|
357 |
+
mask_sequence = compute_mask(logits_sequence, targets)
|
358 |
+
loss_sequence = compute_loss(logits_sequence, targets)
|
359 |
+
loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
|
360 |
+
|
361 |
+
return logits, loss
|
362 |
+
return logits
|
363 |
+
return func
|
364 |
+
|
365 |
+
|
366 |
+
def apply_trp(model, depths: List[int], in_planes: int, out_planes: int, rewards, **kwargs):
|
367 |
+
if isinstance(model, ResNet):
|
368 |
+
print("✅ Applying TRP to ResNet for Image Classification...")
|
369 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
|
370 |
+
model.forward = types.MethodType(ResNetConfig.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
|
371 |
+
elif isinstance(model, MobileNetV2):
|
372 |
+
print("✅ Applying TRP to MobileNetV2 for Image Classification...")
|
373 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
|
374 |
+
model.forward = types.MethodType(MobileNetV2Config.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
|
375 |
+
elif isinstance(model, EfficientNet):
|
376 |
+
print("✅ Applying TRP to EfficientNet for Image Classification...")
|
377 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
|
378 |
+
model.forward = types.MethodType(EfficientNetConfig.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
|
379 |
+
elif isinstance(model, VisionTransformer):
|
380 |
+
print("✅ Applying TRP to VisionTransformer for Image Classification...")
|
381 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k, shape_dims=2, channel_first=False) for k, d in enumerate(depths)])
|
382 |
+
model.forward = types.MethodType(VisionTransformerConfig.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
|
383 |
+
return model
|
nas-examples/image-classification/utils.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import datetime
|
3 |
+
import errno
|
4 |
+
import hashlib
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
from collections import defaultdict, deque, OrderedDict
|
8 |
+
from typing import List, Optional, Tuple
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.distributed as dist
|
12 |
+
|
13 |
+
|
14 |
+
class SmoothedValue:
|
15 |
+
"""Track a series of values and provide access to smoothed values over a
|
16 |
+
window or the global series average.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, window_size=20, fmt=None):
|
20 |
+
if fmt is None:
|
21 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
22 |
+
self.deque = deque(maxlen=window_size)
|
23 |
+
self.total = 0.0
|
24 |
+
self.count = 0
|
25 |
+
self.fmt = fmt
|
26 |
+
|
27 |
+
def update(self, value, n=1):
|
28 |
+
self.deque.append(value)
|
29 |
+
self.count += n
|
30 |
+
self.total += value * n
|
31 |
+
|
32 |
+
def synchronize_between_processes(self):
|
33 |
+
"""
|
34 |
+
Warning: does not synchronize the deque!
|
35 |
+
"""
|
36 |
+
t = reduce_across_processes([self.count, self.total])
|
37 |
+
t = t.tolist()
|
38 |
+
self.count = int(t[0])
|
39 |
+
self.total = t[1]
|
40 |
+
|
41 |
+
@property
|
42 |
+
def median(self):
|
43 |
+
d = torch.tensor(list(self.deque))
|
44 |
+
return d.median().item()
|
45 |
+
|
46 |
+
@property
|
47 |
+
def avg(self):
|
48 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
49 |
+
return d.mean().item()
|
50 |
+
|
51 |
+
@property
|
52 |
+
def global_avg(self):
|
53 |
+
return self.total / self.count
|
54 |
+
|
55 |
+
@property
|
56 |
+
def max(self):
|
57 |
+
return max(self.deque)
|
58 |
+
|
59 |
+
@property
|
60 |
+
def value(self):
|
61 |
+
return self.deque[-1]
|
62 |
+
|
63 |
+
def __str__(self):
|
64 |
+
return self.fmt.format(
|
65 |
+
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
|
66 |
+
)
|
67 |
+
|
68 |
+
|
69 |
+
class MetricLogger:
|
70 |
+
def __init__(self, delimiter="\t"):
|
71 |
+
self.meters = defaultdict(SmoothedValue)
|
72 |
+
self.delimiter = delimiter
|
73 |
+
|
74 |
+
def update(self, **kwargs):
|
75 |
+
for k, v in kwargs.items():
|
76 |
+
if isinstance(v, torch.Tensor):
|
77 |
+
v = v.item()
|
78 |
+
assert isinstance(v, (float, int))
|
79 |
+
self.meters[k].update(v)
|
80 |
+
|
81 |
+
def __getattr__(self, attr):
|
82 |
+
if attr in self.meters:
|
83 |
+
return self.meters[attr]
|
84 |
+
if attr in self.__dict__:
|
85 |
+
return self.__dict__[attr]
|
86 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
|
87 |
+
|
88 |
+
def __str__(self):
|
89 |
+
loss_str = []
|
90 |
+
for name, meter in self.meters.items():
|
91 |
+
loss_str.append(f"{name}: {str(meter)}")
|
92 |
+
return self.delimiter.join(loss_str)
|
93 |
+
|
94 |
+
def synchronize_between_processes(self):
|
95 |
+
for meter in self.meters.values():
|
96 |
+
meter.synchronize_between_processes()
|
97 |
+
|
98 |
+
def add_meter(self, name, meter):
|
99 |
+
self.meters[name] = meter
|
100 |
+
|
101 |
+
def log_every(self, iterable, print_freq, header=None):
|
102 |
+
i = 0
|
103 |
+
if not header:
|
104 |
+
header = ""
|
105 |
+
start_time = time.time()
|
106 |
+
end = time.time()
|
107 |
+
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
108 |
+
data_time = SmoothedValue(fmt="{avg:.4f}")
|
109 |
+
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
110 |
+
if torch.cuda.is_available():
|
111 |
+
log_msg = self.delimiter.join(
|
112 |
+
[
|
113 |
+
header,
|
114 |
+
"[{0" + space_fmt + "}/{1}]",
|
115 |
+
"eta: {eta}",
|
116 |
+
"{meters}",
|
117 |
+
"time: {time}",
|
118 |
+
"data: {data}",
|
119 |
+
"max mem: {memory:.0f}",
|
120 |
+
]
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
log_msg = self.delimiter.join(
|
124 |
+
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
|
125 |
+
)
|
126 |
+
MB = 1024.0 * 1024.0
|
127 |
+
for obj in iterable:
|
128 |
+
data_time.update(time.time() - end)
|
129 |
+
yield obj
|
130 |
+
iter_time.update(time.time() - end)
|
131 |
+
if i % print_freq == 0:
|
132 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
133 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
134 |
+
if torch.cuda.is_available():
|
135 |
+
print(
|
136 |
+
log_msg.format(
|
137 |
+
i,
|
138 |
+
len(iterable),
|
139 |
+
eta=eta_string,
|
140 |
+
meters=str(self),
|
141 |
+
time=str(iter_time),
|
142 |
+
data=str(data_time),
|
143 |
+
memory=torch.cuda.max_memory_allocated() / MB,
|
144 |
+
)
|
145 |
+
)
|
146 |
+
else:
|
147 |
+
print(
|
148 |
+
log_msg.format(
|
149 |
+
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
|
150 |
+
)
|
151 |
+
)
|
152 |
+
i += 1
|
153 |
+
end = time.time()
|
154 |
+
total_time = time.time() - start_time
|
155 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
156 |
+
print(f"{header} Total time: {total_time_str}")
|
157 |
+
|
158 |
+
|
159 |
+
class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
|
160 |
+
"""Maintains moving averages of model parameters using an exponential decay.
|
161 |
+
``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
|
162 |
+
`torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
|
163 |
+
is used to compute the EMA.
|
164 |
+
"""
|
165 |
+
|
166 |
+
def __init__(self, model, decay, device="cpu"):
|
167 |
+
def ema_avg(avg_model_param, model_param, num_averaged):
|
168 |
+
return decay * avg_model_param + (1 - decay) * model_param
|
169 |
+
|
170 |
+
super().__init__(model, device, ema_avg, use_buffers=True)
|
171 |
+
|
172 |
+
|
173 |
+
def accuracy(output, target, topk=(1,)):
|
174 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
175 |
+
with torch.inference_mode():
|
176 |
+
maxk = max(topk)
|
177 |
+
batch_size = target.size(0)
|
178 |
+
if target.ndim == 2:
|
179 |
+
target = target.max(dim=1)[1]
|
180 |
+
|
181 |
+
_, pred = output.topk(maxk, 1, True, True)
|
182 |
+
pred = pred.t()
|
183 |
+
correct = pred.eq(target[None])
|
184 |
+
|
185 |
+
res = []
|
186 |
+
for k in topk:
|
187 |
+
correct_k = correct[:k].flatten().sum(dtype=torch.float32)
|
188 |
+
res.append(correct_k * (100.0 / batch_size))
|
189 |
+
return res
|
190 |
+
|
191 |
+
|
192 |
+
def mkdir(path):
|
193 |
+
try:
|
194 |
+
os.makedirs(path)
|
195 |
+
except OSError as e:
|
196 |
+
if e.errno != errno.EEXIST:
|
197 |
+
raise
|
198 |
+
|
199 |
+
|
200 |
+
def setup_for_distributed(is_master):
|
201 |
+
"""
|
202 |
+
This function disables printing when not in master process
|
203 |
+
"""
|
204 |
+
import builtins as __builtin__
|
205 |
+
|
206 |
+
builtin_print = __builtin__.print
|
207 |
+
|
208 |
+
def print(*args, **kwargs):
|
209 |
+
force = kwargs.pop("force", False)
|
210 |
+
if is_master or force:
|
211 |
+
builtin_print(*args, **kwargs)
|
212 |
+
|
213 |
+
__builtin__.print = print
|
214 |
+
|
215 |
+
|
216 |
+
def is_dist_avail_and_initialized():
|
217 |
+
if not dist.is_available():
|
218 |
+
return False
|
219 |
+
if not dist.is_initialized():
|
220 |
+
return False
|
221 |
+
return True
|
222 |
+
|
223 |
+
|
224 |
+
def get_world_size():
|
225 |
+
if not is_dist_avail_and_initialized():
|
226 |
+
return 1
|
227 |
+
return dist.get_world_size()
|
228 |
+
|
229 |
+
|
230 |
+
def get_rank():
|
231 |
+
if not is_dist_avail_and_initialized():
|
232 |
+
return 0
|
233 |
+
return dist.get_rank()
|
234 |
+
|
235 |
+
|
236 |
+
def is_main_process():
|
237 |
+
return get_rank() == 0
|
238 |
+
|
239 |
+
|
240 |
+
def save_on_master(*args, **kwargs):
|
241 |
+
if is_main_process():
|
242 |
+
torch.save(*args, **kwargs)
|
243 |
+
|
244 |
+
|
245 |
+
def init_distributed_mode(args):
|
246 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
247 |
+
args.rank = int(os.environ["RANK"])
|
248 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
249 |
+
args.gpu = int(os.environ["LOCAL_RANK"])
|
250 |
+
elif "SLURM_PROCID" in os.environ:
|
251 |
+
args.rank = int(os.environ["SLURM_PROCID"])
|
252 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
253 |
+
elif hasattr(args, "rank"):
|
254 |
+
pass
|
255 |
+
else:
|
256 |
+
print("Not using distributed mode")
|
257 |
+
args.distributed = False
|
258 |
+
return
|
259 |
+
|
260 |
+
args.distributed = True
|
261 |
+
|
262 |
+
torch.cuda.set_device(args.gpu)
|
263 |
+
args.dist_backend = "nccl"
|
264 |
+
print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
|
265 |
+
torch.distributed.init_process_group(
|
266 |
+
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
|
267 |
+
)
|
268 |
+
torch.distributed.barrier()
|
269 |
+
setup_for_distributed(args.rank == 0)
|
270 |
+
|
271 |
+
|
272 |
+
def average_checkpoints(inputs):
|
273 |
+
"""Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from:
|
274 |
+
https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16
|
275 |
+
|
276 |
+
Args:
|
277 |
+
inputs (List[str]): An iterable of string paths of checkpoints to load from.
|
278 |
+
Returns:
|
279 |
+
A dict of string keys mapping to various values. The 'model' key
|
280 |
+
from the returned dict should correspond to an OrderedDict mapping
|
281 |
+
string parameter names to torch Tensors.
|
282 |
+
"""
|
283 |
+
params_dict = OrderedDict()
|
284 |
+
params_keys = None
|
285 |
+
new_state = None
|
286 |
+
num_models = len(inputs)
|
287 |
+
for fpath in inputs:
|
288 |
+
with open(fpath, "rb") as f:
|
289 |
+
state = torch.load(
|
290 |
+
f,
|
291 |
+
map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")),
|
292 |
+
)
|
293 |
+
# Copies over the settings from the first checkpoint
|
294 |
+
if new_state is None:
|
295 |
+
new_state = state
|
296 |
+
model_params = state["model"]
|
297 |
+
model_params_keys = list(model_params.keys())
|
298 |
+
if params_keys is None:
|
299 |
+
params_keys = model_params_keys
|
300 |
+
elif params_keys != model_params_keys:
|
301 |
+
raise KeyError(
|
302 |
+
f"For checkpoint {f}, expected list of params: {params_keys}, but found: {model_params_keys}"
|
303 |
+
)
|
304 |
+
for k in params_keys:
|
305 |
+
p = model_params[k]
|
306 |
+
if isinstance(p, torch.HalfTensor):
|
307 |
+
p = p.float()
|
308 |
+
if k not in params_dict:
|
309 |
+
params_dict[k] = p.clone()
|
310 |
+
# NOTE: clone() is needed in case of p is a shared parameter
|
311 |
+
else:
|
312 |
+
params_dict[k] += p
|
313 |
+
averaged_params = OrderedDict()
|
314 |
+
for k, v in params_dict.items():
|
315 |
+
averaged_params[k] = v
|
316 |
+
if averaged_params[k].is_floating_point():
|
317 |
+
averaged_params[k].div_(num_models)
|
318 |
+
else:
|
319 |
+
averaged_params[k] //= num_models
|
320 |
+
new_state["model"] = averaged_params
|
321 |
+
return new_state
|
322 |
+
|
323 |
+
|
324 |
+
def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=True):
|
325 |
+
"""
|
326 |
+
This method can be used to prepare weights files for new models. It receives as
|
327 |
+
input a model architecture and a checkpoint from the training script and produces
|
328 |
+
a file with the weights ready for release.
|
329 |
+
|
330 |
+
Examples:
|
331 |
+
from torchvision import models as M
|
332 |
+
|
333 |
+
# Classification
|
334 |
+
model = M.mobilenet_v3_large(weights=None)
|
335 |
+
print(store_model_weights(model, './class.pth'))
|
336 |
+
|
337 |
+
# Quantized Classification
|
338 |
+
model = M.quantization.mobilenet_v3_large(weights=None, quantize=False)
|
339 |
+
model.fuse_model(is_qat=True)
|
340 |
+
model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
|
341 |
+
_ = torch.ao.quantization.prepare_qat(model, inplace=True)
|
342 |
+
print(store_model_weights(model, './qat.pth'))
|
343 |
+
|
344 |
+
# Object Detection
|
345 |
+
model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=None, weights_backbone=None)
|
346 |
+
print(store_model_weights(model, './obj.pth'))
|
347 |
+
|
348 |
+
# Segmentation
|
349 |
+
model = M.segmentation.deeplabv3_mobilenet_v3_large(weights=None, weights_backbone=None, aux_loss=True)
|
350 |
+
print(store_model_weights(model, './segm.pth', strict=False))
|
351 |
+
|
352 |
+
Args:
|
353 |
+
model (pytorch.nn.Module): The model on which the weights will be loaded for validation purposes.
|
354 |
+
checkpoint_path (str): The path of the checkpoint we will load.
|
355 |
+
checkpoint_key (str, optional): The key of the checkpoint where the model weights are stored.
|
356 |
+
Default: "model".
|
357 |
+
strict (bool): whether to strictly enforce that the keys
|
358 |
+
in :attr:`state_dict` match the keys returned by this module's
|
359 |
+
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
|
360 |
+
|
361 |
+
Returns:
|
362 |
+
output_path (str): The location where the weights are saved.
|
363 |
+
"""
|
364 |
+
# Store the new model next to the checkpoint_path
|
365 |
+
checkpoint_path = os.path.abspath(checkpoint_path)
|
366 |
+
output_dir = os.path.dirname(checkpoint_path)
|
367 |
+
|
368 |
+
# Deep copy to avoid side-effects on the model object.
|
369 |
+
model = copy.deepcopy(model)
|
370 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
371 |
+
|
372 |
+
# Load the weights to the model to validate that everything works
|
373 |
+
# and remove unnecessary weights (such as auxiliaries, etc)
|
374 |
+
if checkpoint_key == "model_ema":
|
375 |
+
del checkpoint[checkpoint_key]["n_averaged"]
|
376 |
+
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(checkpoint[checkpoint_key], "module.")
|
377 |
+
model.load_state_dict(checkpoint[checkpoint_key], strict=strict)
|
378 |
+
|
379 |
+
tmp_path = os.path.join(output_dir, str(model.__hash__()))
|
380 |
+
torch.save(model.state_dict(), tmp_path)
|
381 |
+
|
382 |
+
sha256_hash = hashlib.sha256()
|
383 |
+
with open(tmp_path, "rb") as f:
|
384 |
+
# Read and update hash string value in blocks of 4K
|
385 |
+
for byte_block in iter(lambda: f.read(4096), b""):
|
386 |
+
sha256_hash.update(byte_block)
|
387 |
+
hh = sha256_hash.hexdigest()
|
388 |
+
|
389 |
+
output_path = os.path.join(output_dir, "weights-" + str(hh[:8]) + ".pth")
|
390 |
+
os.replace(tmp_path, output_path)
|
391 |
+
|
392 |
+
return output_path
|
393 |
+
|
394 |
+
|
395 |
+
def reduce_across_processes(val):
|
396 |
+
if not is_dist_avail_and_initialized():
|
397 |
+
# nothing to sync, but we still convert to tensor for consistency with the distributed case.
|
398 |
+
return torch.tensor(val)
|
399 |
+
|
400 |
+
t = torch.tensor(val, device="cuda")
|
401 |
+
dist.barrier()
|
402 |
+
dist.all_reduce(t)
|
403 |
+
return t
|
404 |
+
|
405 |
+
|
406 |
+
def set_weight_decay(
|
407 |
+
model: torch.nn.Module,
|
408 |
+
weight_decay: float,
|
409 |
+
norm_weight_decay: Optional[float] = None,
|
410 |
+
norm_classes: Optional[List[type]] = None,
|
411 |
+
custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None,
|
412 |
+
):
|
413 |
+
if not norm_classes:
|
414 |
+
norm_classes = [
|
415 |
+
torch.nn.modules.batchnorm._BatchNorm,
|
416 |
+
torch.nn.LayerNorm,
|
417 |
+
torch.nn.GroupNorm,
|
418 |
+
torch.nn.modules.instancenorm._InstanceNorm,
|
419 |
+
torch.nn.LocalResponseNorm,
|
420 |
+
]
|
421 |
+
norm_classes = tuple(norm_classes)
|
422 |
+
|
423 |
+
params = {
|
424 |
+
"other": [],
|
425 |
+
"norm": [],
|
426 |
+
}
|
427 |
+
params_weight_decay = {
|
428 |
+
"other": weight_decay,
|
429 |
+
"norm": norm_weight_decay,
|
430 |
+
}
|
431 |
+
custom_keys = []
|
432 |
+
if custom_keys_weight_decay is not None:
|
433 |
+
for key, weight_decay in custom_keys_weight_decay:
|
434 |
+
params[key] = []
|
435 |
+
params_weight_decay[key] = weight_decay
|
436 |
+
custom_keys.append(key)
|
437 |
+
|
438 |
+
def _add_params(module, prefix=""):
|
439 |
+
for name, p in module.named_parameters(recurse=False):
|
440 |
+
if not p.requires_grad:
|
441 |
+
continue
|
442 |
+
is_custom_key = False
|
443 |
+
for key in custom_keys:
|
444 |
+
target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name
|
445 |
+
if key == target_name:
|
446 |
+
params[key].append(p)
|
447 |
+
is_custom_key = True
|
448 |
+
break
|
449 |
+
if not is_custom_key:
|
450 |
+
if norm_weight_decay is not None and isinstance(module, norm_classes):
|
451 |
+
params["norm"].append(p)
|
452 |
+
else:
|
453 |
+
params["other"].append(p)
|
454 |
+
|
455 |
+
for child_name, child_module in module.named_children():
|
456 |
+
child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
|
457 |
+
_add_params(child_module, prefix=child_prefix)
|
458 |
+
|
459 |
+
_add_params(model)
|
460 |
+
|
461 |
+
param_groups = []
|
462 |
+
for key in params:
|
463 |
+
if len(params[key]) > 0:
|
464 |
+
param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]})
|
465 |
+
return param_groups
|
nas-examples/semantic-segmentation/coco_utils.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.utils.data
|
6 |
+
import torchvision
|
7 |
+
from PIL import Image
|
8 |
+
from pycocotools import mask as coco_mask
|
9 |
+
from transforms import Compose
|
10 |
+
|
11 |
+
|
12 |
+
class FilterAndRemapCocoCategories:
|
13 |
+
def __init__(self, categories, remap=True):
|
14 |
+
self.categories = categories
|
15 |
+
self.remap = remap
|
16 |
+
|
17 |
+
def __call__(self, image, anno):
|
18 |
+
anno = [obj for obj in anno if obj["category_id"] in self.categories]
|
19 |
+
if not self.remap:
|
20 |
+
return image, anno
|
21 |
+
anno = copy.deepcopy(anno)
|
22 |
+
for obj in anno:
|
23 |
+
obj["category_id"] = self.categories.index(obj["category_id"])
|
24 |
+
return image, anno
|
25 |
+
|
26 |
+
|
27 |
+
def convert_coco_poly_to_mask(segmentations, height, width):
|
28 |
+
masks = []
|
29 |
+
for polygons in segmentations:
|
30 |
+
rles = coco_mask.frPyObjects(polygons, height, width)
|
31 |
+
mask = coco_mask.decode(rles)
|
32 |
+
if len(mask.shape) < 3:
|
33 |
+
mask = mask[..., None]
|
34 |
+
mask = torch.as_tensor(mask, dtype=torch.uint8)
|
35 |
+
mask = mask.any(dim=2)
|
36 |
+
masks.append(mask)
|
37 |
+
if masks:
|
38 |
+
masks = torch.stack(masks, dim=0)
|
39 |
+
else:
|
40 |
+
masks = torch.zeros((0, height, width), dtype=torch.uint8)
|
41 |
+
return masks
|
42 |
+
|
43 |
+
|
44 |
+
class ConvertCocoPolysToMask:
|
45 |
+
def __call__(self, image, anno):
|
46 |
+
w, h = image.size
|
47 |
+
segmentations = [obj["segmentation"] for obj in anno]
|
48 |
+
cats = [obj["category_id"] for obj in anno]
|
49 |
+
if segmentations:
|
50 |
+
masks = convert_coco_poly_to_mask(segmentations, h, w)
|
51 |
+
cats = torch.as_tensor(cats, dtype=masks.dtype)
|
52 |
+
# merge all instance masks into a single segmentation map
|
53 |
+
# with its corresponding categories
|
54 |
+
target, _ = (masks * cats[:, None, None]).max(dim=0)
|
55 |
+
# discard overlapping instances
|
56 |
+
target[masks.sum(0) > 1] = 255
|
57 |
+
else:
|
58 |
+
target = torch.zeros((h, w), dtype=torch.uint8)
|
59 |
+
target = Image.fromarray(target.numpy())
|
60 |
+
return image, target
|
61 |
+
|
62 |
+
|
63 |
+
def _coco_remove_images_without_annotations(dataset, cat_list=None):
|
64 |
+
def _has_valid_annotation(anno):
|
65 |
+
# if it's empty, there is no annotation
|
66 |
+
if len(anno) == 0:
|
67 |
+
return False
|
68 |
+
# if more than 1k pixels occupied in the image
|
69 |
+
return sum(obj["area"] for obj in anno) > 1000
|
70 |
+
|
71 |
+
if not isinstance(dataset, torchvision.datasets.CocoDetection):
|
72 |
+
raise TypeError(
|
73 |
+
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
|
74 |
+
)
|
75 |
+
|
76 |
+
ids = []
|
77 |
+
for ds_idx, img_id in enumerate(dataset.ids):
|
78 |
+
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
|
79 |
+
anno = dataset.coco.loadAnns(ann_ids)
|
80 |
+
if cat_list:
|
81 |
+
anno = [obj for obj in anno if obj["category_id"] in cat_list]
|
82 |
+
if _has_valid_annotation(anno):
|
83 |
+
ids.append(ds_idx)
|
84 |
+
|
85 |
+
dataset = torch.utils.data.Subset(dataset, ids)
|
86 |
+
return dataset
|
87 |
+
|
88 |
+
|
89 |
+
def get_coco(root, image_set, transforms):
|
90 |
+
PATHS = {
|
91 |
+
"train": ("train2017", os.path.join("annotations", "instances_train2017.json")),
|
92 |
+
"val": ("val2017", os.path.join("annotations", "instances_val2017.json")),
|
93 |
+
# "train": ("val2017", os.path.join("annotations", "instances_val2017.json"))
|
94 |
+
}
|
95 |
+
CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72]
|
96 |
+
|
97 |
+
transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms])
|
98 |
+
|
99 |
+
img_folder, ann_file = PATHS[image_set]
|
100 |
+
img_folder = os.path.join(root, img_folder)
|
101 |
+
ann_file = os.path.join(root, ann_file)
|
102 |
+
|
103 |
+
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
|
104 |
+
|
105 |
+
if image_set == "train":
|
106 |
+
dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST)
|
107 |
+
|
108 |
+
return dataset
|
nas-examples/semantic-segmentation/presets.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import transforms as T
|
3 |
+
|
4 |
+
|
5 |
+
class SegmentationPresetTrain:
|
6 |
+
def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
|
7 |
+
min_size = int(0.5 * base_size)
|
8 |
+
max_size = int(2.0 * base_size)
|
9 |
+
|
10 |
+
trans = [T.RandomResize(min_size, max_size)]
|
11 |
+
if hflip_prob > 0:
|
12 |
+
trans.append(T.RandomHorizontalFlip(hflip_prob))
|
13 |
+
trans.extend(
|
14 |
+
[
|
15 |
+
T.RandomCrop(crop_size),
|
16 |
+
T.PILToTensor(),
|
17 |
+
T.ConvertImageDtype(torch.float),
|
18 |
+
T.Normalize(mean=mean, std=std),
|
19 |
+
]
|
20 |
+
)
|
21 |
+
self.transforms = T.Compose(trans)
|
22 |
+
|
23 |
+
def __call__(self, img, target):
|
24 |
+
return self.transforms(img, target)
|
25 |
+
|
26 |
+
|
27 |
+
class SegmentationPresetEval:
|
28 |
+
def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
|
29 |
+
self.transforms = T.Compose(
|
30 |
+
[
|
31 |
+
T.RandomResize(base_size, base_size),
|
32 |
+
T.PILToTensor(),
|
33 |
+
T.ConvertImageDtype(torch.float),
|
34 |
+
T.Normalize(mean=mean, std=std),
|
35 |
+
]
|
36 |
+
)
|
37 |
+
|
38 |
+
def __call__(self, img, target):
|
39 |
+
return self.transforms(img, target)
|
nas-examples/semantic-segmentation/train.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
import presets
|
7 |
+
import torch
|
8 |
+
import torch.utils.data
|
9 |
+
import torchvision
|
10 |
+
import utils
|
11 |
+
from coco_utils import get_coco
|
12 |
+
from torch import nn
|
13 |
+
from torch.optim.lr_scheduler import PolynomialLR
|
14 |
+
from torchvision.transforms import functional as F, InterpolationMode
|
15 |
+
|
16 |
+
from trplib import apply_trp
|
17 |
+
|
18 |
+
|
19 |
+
def get_dataset(dir_path, name, image_set, transform):
|
20 |
+
def sbd(*args, **kwargs):
|
21 |
+
return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs)
|
22 |
+
|
23 |
+
paths = {
|
24 |
+
"voc": (dir_path, torchvision.datasets.VOCSegmentation, 21),
|
25 |
+
"voc_aug": (dir_path, sbd, 21),
|
26 |
+
"coco": (dir_path, get_coco, 21),
|
27 |
+
}
|
28 |
+
p, ds_fn, num_classes = paths[name]
|
29 |
+
|
30 |
+
ds = ds_fn(p, image_set=image_set, transforms=transform)
|
31 |
+
return ds, num_classes
|
32 |
+
|
33 |
+
|
34 |
+
def get_transform(train, args):
|
35 |
+
if train:
|
36 |
+
return presets.SegmentationPresetTrain(base_size=520, crop_size=480)
|
37 |
+
elif args.weights and args.test_only:
|
38 |
+
weights = torchvision.models.get_weight(args.weights)
|
39 |
+
trans = weights.transforms()
|
40 |
+
|
41 |
+
def preprocessing(img, target):
|
42 |
+
img = trans(img)
|
43 |
+
size = F.get_dimensions(img)[1:]
|
44 |
+
target = F.resize(target, size, interpolation=InterpolationMode.NEAREST)
|
45 |
+
return img, F.pil_to_tensor(target)
|
46 |
+
|
47 |
+
return preprocessing
|
48 |
+
else:
|
49 |
+
return presets.SegmentationPresetEval(base_size=520)
|
50 |
+
|
51 |
+
|
52 |
+
def criterion(inputs, target):
|
53 |
+
losses = {}
|
54 |
+
for name, x in inputs.items():
|
55 |
+
losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)
|
56 |
+
|
57 |
+
if len(losses) == 1:
|
58 |
+
return losses["out"]
|
59 |
+
|
60 |
+
return losses["out"] + 0.5 * losses["aux"]
|
61 |
+
|
62 |
+
|
63 |
+
def evaluate(model, data_loader, device, num_classes):
|
64 |
+
model.eval()
|
65 |
+
confmat = utils.ConfusionMatrix(num_classes)
|
66 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
67 |
+
header = "Test:"
|
68 |
+
num_processed_samples = 0
|
69 |
+
with torch.inference_mode():
|
70 |
+
for image, target in metric_logger.log_every(data_loader, 100, header):
|
71 |
+
image, target = image.to(device), target.to(device)
|
72 |
+
output = model(image)
|
73 |
+
output = output["out"]
|
74 |
+
|
75 |
+
confmat.update(target.flatten(), output.argmax(1).flatten())
|
76 |
+
# FIXME need to take into account that the datasets
|
77 |
+
# could have been padded in distributed setup
|
78 |
+
num_processed_samples += image.shape[0]
|
79 |
+
|
80 |
+
confmat.reduce_from_all_processes()
|
81 |
+
|
82 |
+
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
|
83 |
+
if (
|
84 |
+
hasattr(data_loader.dataset, "__len__")
|
85 |
+
and len(data_loader.dataset) != num_processed_samples
|
86 |
+
and torch.distributed.get_rank() == 0
|
87 |
+
):
|
88 |
+
# See FIXME above
|
89 |
+
warnings.warn(
|
90 |
+
f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
|
91 |
+
"samples were used for the validation, which might bias the results. "
|
92 |
+
"Try adjusting the batch size and / or the world size. "
|
93 |
+
"Setting the world size to 1 is always a safe bet."
|
94 |
+
)
|
95 |
+
|
96 |
+
return confmat
|
97 |
+
|
98 |
+
|
99 |
+
def train_one_epoch(model, optimizer, data_loader, lr_scheduler, device, epoch, print_freq, scaler=None):
|
100 |
+
model.train()
|
101 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
102 |
+
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
|
103 |
+
header = f"Epoch: [{epoch}]"
|
104 |
+
for image, target in metric_logger.log_every(data_loader, print_freq, header):
|
105 |
+
image, target = image.to(device), target.to(device)
|
106 |
+
with torch.amp.autocast(device_type="cuda", enabled=scaler is not None):
|
107 |
+
_, loss = model(image, target)
|
108 |
+
# output = model(image)
|
109 |
+
# loss = criterion(output, target)
|
110 |
+
|
111 |
+
optimizer.zero_grad()
|
112 |
+
if scaler is not None:
|
113 |
+
scaler.scale(loss).backward()
|
114 |
+
scaler.step(optimizer)
|
115 |
+
scaler.update()
|
116 |
+
else:
|
117 |
+
loss.backward()
|
118 |
+
optimizer.step()
|
119 |
+
|
120 |
+
lr_scheduler.step()
|
121 |
+
|
122 |
+
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
|
123 |
+
|
124 |
+
|
125 |
+
def main(args):
|
126 |
+
if args.output_dir:
|
127 |
+
utils.mkdir(args.output_dir)
|
128 |
+
|
129 |
+
utils.init_distributed_mode(args)
|
130 |
+
print(args)
|
131 |
+
|
132 |
+
device = torch.device(args.device)
|
133 |
+
|
134 |
+
if args.use_deterministic_algorithms:
|
135 |
+
torch.backends.cudnn.benchmark = False
|
136 |
+
torch.use_deterministic_algorithms(True)
|
137 |
+
else:
|
138 |
+
torch.backends.cudnn.benchmark = True
|
139 |
+
|
140 |
+
dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(True, args))
|
141 |
+
dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(False, args))
|
142 |
+
|
143 |
+
if args.distributed:
|
144 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
145 |
+
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
|
146 |
+
else:
|
147 |
+
train_sampler = torch.utils.data.RandomSampler(dataset)
|
148 |
+
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
149 |
+
|
150 |
+
data_loader = torch.utils.data.DataLoader(
|
151 |
+
dataset,
|
152 |
+
batch_size=args.batch_size,
|
153 |
+
sampler=train_sampler,
|
154 |
+
num_workers=args.workers,
|
155 |
+
collate_fn=utils.collate_fn,
|
156 |
+
drop_last=True,
|
157 |
+
)
|
158 |
+
|
159 |
+
data_loader_test = torch.utils.data.DataLoader(
|
160 |
+
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
|
161 |
+
)
|
162 |
+
|
163 |
+
model = torchvision.models.get_model(
|
164 |
+
args.model,
|
165 |
+
weights=args.weights,
|
166 |
+
weights_backbone=args.weights_backbone,
|
167 |
+
num_classes=num_classes,
|
168 |
+
aux_loss=args.aux_loss,
|
169 |
+
)
|
170 |
+
if args.apply_trp:
|
171 |
+
model = apply_trp(model, args.trp_depths, None, args.out_planes, args.trp_rewards)
|
172 |
+
model.to(device)
|
173 |
+
if args.distributed:
|
174 |
+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
175 |
+
|
176 |
+
model_without_ddp = model
|
177 |
+
if args.distributed:
|
178 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
179 |
+
model_without_ddp = model.module
|
180 |
+
|
181 |
+
params_to_optimize = [
|
182 |
+
{"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]},
|
183 |
+
{"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]},
|
184 |
+
]
|
185 |
+
if args.aux_loss:
|
186 |
+
params = [p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad]
|
187 |
+
params_to_optimize.append({"params": params, "lr": args.lr * 10})
|
188 |
+
optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
|
189 |
+
|
190 |
+
scaler = torch.amp.GradScaler(device="cuda") if args.amp else None
|
191 |
+
|
192 |
+
iters_per_epoch = len(data_loader)
|
193 |
+
main_lr_scheduler = PolynomialLR(
|
194 |
+
optimizer, total_iters=iters_per_epoch * (args.epochs - args.lr_warmup_epochs), power=0.9
|
195 |
+
)
|
196 |
+
|
197 |
+
if args.lr_warmup_epochs > 0:
|
198 |
+
warmup_iters = iters_per_epoch * args.lr_warmup_epochs
|
199 |
+
args.lr_warmup_method = args.lr_warmup_method.lower()
|
200 |
+
if args.lr_warmup_method == "linear":
|
201 |
+
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
|
202 |
+
optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters
|
203 |
+
)
|
204 |
+
elif args.lr_warmup_method == "constant":
|
205 |
+
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
|
206 |
+
optimizer, factor=args.lr_warmup_decay, total_iters=warmup_iters
|
207 |
+
)
|
208 |
+
else:
|
209 |
+
raise RuntimeError(
|
210 |
+
f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
|
211 |
+
)
|
212 |
+
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
|
213 |
+
optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters]
|
214 |
+
)
|
215 |
+
else:
|
216 |
+
lr_scheduler = main_lr_scheduler
|
217 |
+
|
218 |
+
if args.resume:
|
219 |
+
checkpoint = torch.load(args.resume, map_location="cpu", weights_only=False)
|
220 |
+
model_without_ddp.load_state_dict(checkpoint["model"], strict=not args.test_only)
|
221 |
+
if not args.test_only:
|
222 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
223 |
+
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
|
224 |
+
args.start_epoch = checkpoint["epoch"] + 1
|
225 |
+
if args.amp:
|
226 |
+
scaler.load_state_dict(checkpoint["scaler"])
|
227 |
+
|
228 |
+
if args.test_only:
|
229 |
+
# We disable the cudnn benchmarking because it can noticeably affect the accuracy
|
230 |
+
torch.backends.cudnn.benchmark = False
|
231 |
+
torch.backends.cudnn.deterministic = True
|
232 |
+
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
|
233 |
+
print(confmat)
|
234 |
+
return
|
235 |
+
|
236 |
+
start_time = time.time()
|
237 |
+
for epoch in range(args.start_epoch, args.epochs):
|
238 |
+
if args.distributed:
|
239 |
+
train_sampler.set_epoch(epoch)
|
240 |
+
train_one_epoch(model, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq, scaler)
|
241 |
+
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
|
242 |
+
print(confmat)
|
243 |
+
|
244 |
+
if args.output_dir:
|
245 |
+
checkpoint = {
|
246 |
+
"model": model_without_ddp.state_dict() if not args.apply_trp else {k: v for k, v in model_without_ddp.state_dict().items() if not "trp_blocks" in k},
|
247 |
+
"optimizer": optimizer.state_dict(),
|
248 |
+
"lr_scheduler": lr_scheduler.state_dict(),
|
249 |
+
"epoch": epoch,
|
250 |
+
"args": args,
|
251 |
+
}
|
252 |
+
if args.amp:
|
253 |
+
checkpoint["scaler"] = scaler.state_dict()
|
254 |
+
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
|
255 |
+
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
|
256 |
+
|
257 |
+
total_time = time.time() - start_time
|
258 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
259 |
+
print(f"Training time {total_time_str}")
|
260 |
+
|
261 |
+
|
262 |
+
def get_args_parser(add_help=True):
|
263 |
+
import argparse
|
264 |
+
|
265 |
+
parser = argparse.ArgumentParser(description="PyTorch Segmentation Training", add_help=add_help)
|
266 |
+
|
267 |
+
parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path")
|
268 |
+
parser.add_argument("--dataset", default="coco", type=str, help="dataset name")
|
269 |
+
parser.add_argument("--model", default="fcn_resnet101", type=str, help="model name")
|
270 |
+
parser.add_argument("--aux-loss", action="store_true", help="auxiliar loss")
|
271 |
+
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
|
272 |
+
parser.add_argument(
|
273 |
+
"-b", "--batch-size", default=8, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
|
274 |
+
)
|
275 |
+
parser.add_argument("--epochs", default=30, type=int, metavar="N", help="number of total epochs to run")
|
276 |
+
|
277 |
+
parser.add_argument(
|
278 |
+
"-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
|
279 |
+
)
|
280 |
+
parser.add_argument("--lr", default=0.01, type=float, help="initial learning rate")
|
281 |
+
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
|
282 |
+
parser.add_argument(
|
283 |
+
"--wd",
|
284 |
+
"--weight-decay",
|
285 |
+
default=1e-4,
|
286 |
+
type=float,
|
287 |
+
metavar="W",
|
288 |
+
help="weight decay (default: 1e-4)",
|
289 |
+
dest="weight_decay",
|
290 |
+
)
|
291 |
+
parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
|
292 |
+
parser.add_argument("--lr-warmup-method", default="linear", type=str, help="the warmup method (default: linear)")
|
293 |
+
parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
|
294 |
+
parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
|
295 |
+
parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
|
296 |
+
parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
|
297 |
+
parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
|
298 |
+
parser.add_argument(
|
299 |
+
"--test-only",
|
300 |
+
dest="test_only",
|
301 |
+
help="Only test the model",
|
302 |
+
action="store_true",
|
303 |
+
)
|
304 |
+
parser.add_argument(
|
305 |
+
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
|
306 |
+
)
|
307 |
+
# distributed training parameters
|
308 |
+
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
|
309 |
+
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
|
310 |
+
|
311 |
+
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
|
312 |
+
parser.add_argument("--weights-backbone", default=None, type=str, help="the backbone weights enum name to load")
|
313 |
+
|
314 |
+
# Mixed precision training parameters
|
315 |
+
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
|
316 |
+
|
317 |
+
parser.add_argument("--apply-trp", action="store_true", help="enable applying trp")
|
318 |
+
parser.add_argument("--trp-depths", nargs="+", default=[2, 2, 2], type=int, help="number of depth for each trp block")
|
319 |
+
parser.add_argument("--out-planes", default=8, type=int, help="the dimension of the inner hidden states")
|
320 |
+
parser.add_argument("--trp-rewards", nargs="+", default=[1.0, 0.4, 0.2, 0.1], type=float, help="trp rewards")
|
321 |
+
|
322 |
+
return parser
|
323 |
+
|
324 |
+
|
325 |
+
if __name__ == "__main__":
|
326 |
+
args = get_args_parser().parse_args()
|
327 |
+
main(args)
|
nas-examples/semantic-segmentation/transforms.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torchvision import transforms as T
|
6 |
+
from torchvision.transforms import functional as F
|
7 |
+
|
8 |
+
|
9 |
+
def pad_if_smaller(img, size, fill=0):
|
10 |
+
min_size = min(img.size)
|
11 |
+
if min_size < size:
|
12 |
+
ow, oh = img.size
|
13 |
+
padh = size - oh if oh < size else 0
|
14 |
+
padw = size - ow if ow < size else 0
|
15 |
+
img = F.pad(img, (0, 0, padw, padh), fill=fill)
|
16 |
+
return img
|
17 |
+
|
18 |
+
|
19 |
+
class Compose:
|
20 |
+
def __init__(self, transforms):
|
21 |
+
self.transforms = transforms
|
22 |
+
|
23 |
+
def __call__(self, image, target):
|
24 |
+
for t in self.transforms:
|
25 |
+
image, target = t(image, target)
|
26 |
+
return image, target
|
27 |
+
|
28 |
+
|
29 |
+
class RandomResize:
|
30 |
+
def __init__(self, min_size, max_size=None):
|
31 |
+
self.min_size = min_size
|
32 |
+
if max_size is None:
|
33 |
+
max_size = min_size
|
34 |
+
self.max_size = max_size
|
35 |
+
|
36 |
+
def __call__(self, image, target):
|
37 |
+
size = random.randint(self.min_size, self.max_size)
|
38 |
+
image = F.resize(image, size)
|
39 |
+
target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST)
|
40 |
+
return image, target
|
41 |
+
|
42 |
+
|
43 |
+
class RandomHorizontalFlip:
|
44 |
+
def __init__(self, flip_prob):
|
45 |
+
self.flip_prob = flip_prob
|
46 |
+
|
47 |
+
def __call__(self, image, target):
|
48 |
+
if random.random() < self.flip_prob:
|
49 |
+
image = F.hflip(image)
|
50 |
+
target = F.hflip(target)
|
51 |
+
return image, target
|
52 |
+
|
53 |
+
|
54 |
+
class RandomCrop:
|
55 |
+
def __init__(self, size):
|
56 |
+
self.size = size
|
57 |
+
|
58 |
+
def __call__(self, image, target):
|
59 |
+
image = pad_if_smaller(image, self.size)
|
60 |
+
target = pad_if_smaller(target, self.size, fill=255)
|
61 |
+
crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
|
62 |
+
image = F.crop(image, *crop_params)
|
63 |
+
target = F.crop(target, *crop_params)
|
64 |
+
return image, target
|
65 |
+
|
66 |
+
|
67 |
+
class CenterCrop:
|
68 |
+
def __init__(self, size):
|
69 |
+
self.size = size
|
70 |
+
|
71 |
+
def __call__(self, image, target):
|
72 |
+
image = F.center_crop(image, self.size)
|
73 |
+
target = F.center_crop(target, self.size)
|
74 |
+
return image, target
|
75 |
+
|
76 |
+
|
77 |
+
class PILToTensor:
|
78 |
+
def __call__(self, image, target):
|
79 |
+
image = F.pil_to_tensor(image)
|
80 |
+
target = torch.as_tensor(np.array(target), dtype=torch.int64)
|
81 |
+
return image, target
|
82 |
+
|
83 |
+
|
84 |
+
class ConvertImageDtype:
|
85 |
+
def __init__(self, dtype):
|
86 |
+
self.dtype = dtype
|
87 |
+
|
88 |
+
def __call__(self, image, target):
|
89 |
+
image = F.convert_image_dtype(image, self.dtype)
|
90 |
+
return image, target
|
91 |
+
|
92 |
+
|
93 |
+
class Normalize:
|
94 |
+
def __init__(self, mean, std):
|
95 |
+
self.mean = mean
|
96 |
+
self.std = std
|
97 |
+
|
98 |
+
def __call__(self, image, target):
|
99 |
+
image = F.normalize(image, mean=self.mean, std=self.std)
|
100 |
+
return image, target
|
nas-examples/semantic-segmentation/trplib.py
ADDED
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import types
|
2 |
+
from typing import Optional, List, Union, Callable
|
3 |
+
from collections import OrderedDict
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn, Tensor
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
from torchvision.models.mobilenetv2 import MobileNetV2
|
10 |
+
from torchvision.models.resnet import ResNet
|
11 |
+
from torchvision.models.efficientnet import EfficientNet
|
12 |
+
from torchvision.models.vision_transformer import VisionTransformer
|
13 |
+
from torchvision.models.segmentation.fcn import FCN
|
14 |
+
from torchvision.models.segmentation.deeplabv3 import DeepLabV3
|
15 |
+
|
16 |
+
|
17 |
+
def compute_policy_loss(loss_sequence, mask_sequence, rewards):
|
18 |
+
losses = sum(mask * padded_loss for mask, padded_loss in zip(mask_sequence, loss_sequence))
|
19 |
+
returns = sum(padded_reward * mask for padded_reward, mask in zip(rewards, mask_sequence))
|
20 |
+
loss = torch.mean(losses * returns)
|
21 |
+
|
22 |
+
return loss
|
23 |
+
|
24 |
+
|
25 |
+
class TPBlock(nn.Module):
|
26 |
+
def __init__(self, depths: int, in_planes: int, out_planes: int = None, rank=1, shape_dims=3, channel_first=True, dtype=torch.float32) -> None:
|
27 |
+
super().__init__()
|
28 |
+
out_planes = in_planes if out_planes is None else out_planes
|
29 |
+
self.layers = torch.nn.ModuleList([self._make_layer(in_planes, out_planes, rank, shape_dims, channel_first, dtype) for _ in range(depths)])
|
30 |
+
|
31 |
+
def forward(self, x: Tensor) -> Tensor:
|
32 |
+
for layer in self.layers:
|
33 |
+
x = x + layer(x)
|
34 |
+
return x
|
35 |
+
|
36 |
+
def _make_layer(self, in_planes: int, out_planes: int = None, rank=1, shape_dims=3, channel_first=True, dtype=torch.float32) -> nn.Sequential:
|
37 |
+
|
38 |
+
class Permute(nn.Module):
|
39 |
+
def __init__(self, *dims):
|
40 |
+
super().__init__()
|
41 |
+
self.dims = dims
|
42 |
+
def forward(self, x):
|
43 |
+
return x.permute(*self.dims)
|
44 |
+
|
45 |
+
class RMSNorm(nn.Module):
|
46 |
+
__constants__ = ["eps"]
|
47 |
+
eps: float
|
48 |
+
|
49 |
+
def __init__(self, hidden_size, eps: float = 1e-6, device=None, dtype=None):
|
50 |
+
"""
|
51 |
+
LlamaRMSNorm is equivalent to T5LayerNorm.
|
52 |
+
"""
|
53 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
54 |
+
super().__init__()
|
55 |
+
self.eps = eps
|
56 |
+
self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs))
|
57 |
+
|
58 |
+
def forward(self, hidden_states):
|
59 |
+
input_dtype = hidden_states.dtype
|
60 |
+
hidden_states = hidden_states.to(torch.float32)
|
61 |
+
variance = hidden_states.pow(2).mean(dim=1, keepdim=True)
|
62 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
63 |
+
weight = self.weight.view(1, -1, *[1] * (hidden_states.ndim - 2))
|
64 |
+
return weight * hidden_states.to(input_dtype)
|
65 |
+
|
66 |
+
def extra_repr(self):
|
67 |
+
return f"{self.weight.shape[0]}, eps={self.eps}"
|
68 |
+
|
69 |
+
conv_map = {
|
70 |
+
2: (nn.Conv1d, (0, 2, 1), (0, 2, 1)),
|
71 |
+
3: (nn.Conv2d, (0, 3, 1, 2), (0, 2, 3, 1)),
|
72 |
+
4: (nn.Conv3d, (0, 4, 1, 2, 3), (0, 2, 3, 4, 1)),
|
73 |
+
}
|
74 |
+
Conv, pre_dims, post_dims = conv_map[shape_dims]
|
75 |
+
kernel_size, dilation, padding = self.generate_hyperparameters(rank)
|
76 |
+
|
77 |
+
pre_permute = nn.Identity() if channel_first else Permute(*pre_dims)
|
78 |
+
post_permute = nn.Identity() if channel_first else Permute(*post_dims)
|
79 |
+
conv1 = Conv(in_planes, out_planes, kernel_size, padding=padding, dilation=dilation, bias=False, dtype=dtype, device='cuda')
|
80 |
+
nn.init.zeros_(conv1.weight)
|
81 |
+
bn1 = RMSNorm(out_planes, dtype=dtype, device="cuda")
|
82 |
+
relu = nn.ReLU(inplace=True)
|
83 |
+
conv2 = Conv(out_planes, in_planes, kernel_size, padding=padding, dilation=dilation, bias=False, dtype=dtype, device='cuda')
|
84 |
+
nn.init.zeros_(conv2.weight)
|
85 |
+
bn2 = RMSNorm(in_planes, dtype=dtype, device="cuda")
|
86 |
+
|
87 |
+
return torch.nn.Sequential(pre_permute, conv1, bn1, relu, conv2, bn2, relu, post_permute)
|
88 |
+
|
89 |
+
@staticmethod
|
90 |
+
def generate_hyperparameters(rank: int):
|
91 |
+
"""
|
92 |
+
Generates kernel size and dilation rate pairs sorted by increasing padded kernel size.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
rank: Number of (kernel_size, dilation) pairs to generate. Must be positive.
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
Tuple[int, int]: A (kernel_size, dilation) tuple where:
|
99 |
+
- kernel_size: Always odd and >= 1
|
100 |
+
- dilation: Computed to maintain consistent padded kernel size growth
|
101 |
+
|
102 |
+
Note:
|
103 |
+
Padded kernel size is calculated as:
|
104 |
+
(kernel_size - 1) * dilation + 1
|
105 |
+
Pairs are generated first in order of increasing padded kernel size,
|
106 |
+
then by increasing kernel size for equal padded kernel sizes.
|
107 |
+
"""
|
108 |
+
pairs = [(1, 1, 0)] # Start with smallest possible
|
109 |
+
padded_kernel_size = 3
|
110 |
+
|
111 |
+
while len(pairs) < rank:
|
112 |
+
for kernel_size in range(3, padded_kernel_size + 1, 2):
|
113 |
+
if (padded_kernel_size - 1) % (kernel_size - 1) == 0:
|
114 |
+
dilation = (padded_kernel_size - 1) // (kernel_size - 1)
|
115 |
+
padding = dilation * (kernel_size - 1) // 2
|
116 |
+
pairs.append((kernel_size, dilation, padding))
|
117 |
+
if len(pairs) >= rank:
|
118 |
+
break
|
119 |
+
|
120 |
+
# Move to next odd padded kernel size
|
121 |
+
padded_kernel_size += 2
|
122 |
+
|
123 |
+
return pairs[-1]
|
124 |
+
|
125 |
+
|
126 |
+
# ResNet for Image Classification
|
127 |
+
class ResNetConfig:
|
128 |
+
@staticmethod
|
129 |
+
def gen_shared_head(self):
|
130 |
+
def func(hidden_states):
|
131 |
+
"""
|
132 |
+
Args:
|
133 |
+
hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
137 |
+
"""
|
138 |
+
x = self.avgpool(hidden_states)
|
139 |
+
x = torch.flatten(x, 1)
|
140 |
+
logits = self.fc(x)
|
141 |
+
return logits
|
142 |
+
return func
|
143 |
+
|
144 |
+
@staticmethod
|
145 |
+
def gen_logits(self, shared_head):
|
146 |
+
def func(hidden_states):
|
147 |
+
"""
|
148 |
+
Args:
|
149 |
+
hidden_states (Tensor): Hidden States tensor of shape [B, L, hidden_units].
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
logits_seqence (List[Tensor]): List of Logits tensors.
|
153 |
+
"""
|
154 |
+
logits_sequence = [shared_head(hidden_states)]
|
155 |
+
for layer in self.trp_blocks:
|
156 |
+
logits_sequence.append(shared_head(layer(hidden_states)))
|
157 |
+
return logits_sequence
|
158 |
+
return func
|
159 |
+
|
160 |
+
@staticmethod
|
161 |
+
def gen_mask(label_smoothing=0.0, top_k=1):
|
162 |
+
def func(logits_sequence, labels):
|
163 |
+
"""
|
164 |
+
Args:
|
165 |
+
logits_sequence (List[Tensor]): List of Logits tensors.
|
166 |
+
labels (Tensor): Target labels of shape [B] or [B, C].
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
mask_sequence (List[Tensor]): Boolean mask tensor of shape [B*(L-1)].
|
170 |
+
"""
|
171 |
+
labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels
|
172 |
+
|
173 |
+
mask_sequence = [torch.ones_like(labels, dtype=torch.float32, device=labels.device)]
|
174 |
+
for logits in logits_sequence:
|
175 |
+
with torch.no_grad():
|
176 |
+
topk_values, topk_indices = torch.topk(logits, top_k, dim=-1)
|
177 |
+
mask = torch.eq(topk_indices, labels[:, None]).any(dim=-1).to(torch.float32)
|
178 |
+
mask_sequence.append(mask_sequence[-1] * mask)
|
179 |
+
return mask_sequence
|
180 |
+
return func
|
181 |
+
|
182 |
+
@staticmethod
|
183 |
+
def gen_criterion(label_smoothing=0.0):
|
184 |
+
def func(logits_sequence, labels):
|
185 |
+
"""
|
186 |
+
Args:
|
187 |
+
logits_sequence (List[Tensor]): List of Logits tensor.
|
188 |
+
labels (Tensor): labels labels of shape [B] or [B, C].
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
loss (Tensor): Scalar tensor representing the loss.
|
192 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
193 |
+
"""
|
194 |
+
labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels
|
195 |
+
|
196 |
+
loss_sequence = []
|
197 |
+
for logits in logits_sequence:
|
198 |
+
loss_sequence.append(F.cross_entropy(logits, labels, reduction="none", label_smoothing=label_smoothing))
|
199 |
+
|
200 |
+
return loss_sequence
|
201 |
+
return func
|
202 |
+
|
203 |
+
@staticmethod
|
204 |
+
def gen_forward(rewards, label_smoothing=0.0, top_k=1):
|
205 |
+
def func(self, x: Tensor, targets=None) -> Tensor:
|
206 |
+
x = self.conv1(x)
|
207 |
+
x = self.bn1(x)
|
208 |
+
x = self.relu(x)
|
209 |
+
x = self.maxpool(x)
|
210 |
+
|
211 |
+
x = self.layer1(x)
|
212 |
+
x = self.layer2(x)
|
213 |
+
x = self.layer3(x)
|
214 |
+
hidden_states = self.layer4(x)
|
215 |
+
x = self.avgpool(hidden_states)
|
216 |
+
x = torch.flatten(x, 1)
|
217 |
+
logits = self.fc(x)
|
218 |
+
|
219 |
+
if self.training:
|
220 |
+
shared_head = ResNetConfig.gen_shared_head(self)
|
221 |
+
compute_logits = ResNetConfig.gen_logits(self, shared_head)
|
222 |
+
compute_mask = ResNetConfig.gen_mask(label_smoothing, top_k)
|
223 |
+
compute_loss = ResNetConfig.gen_criterion(label_smoothing)
|
224 |
+
|
225 |
+
logits_sequence = compute_logits(hidden_states)
|
226 |
+
mask_sequence = compute_mask(logits_sequence, targets)
|
227 |
+
loss_sequence = compute_loss(logits_sequence, targets)
|
228 |
+
loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
|
229 |
+
|
230 |
+
return logits, loss
|
231 |
+
|
232 |
+
return logits
|
233 |
+
|
234 |
+
return func
|
235 |
+
|
236 |
+
|
237 |
+
# MobileNetV2 for Image Classification
|
238 |
+
class MobileNetV2Config(ResNetConfig):
|
239 |
+
@staticmethod
|
240 |
+
def gen_shared_head(self):
|
241 |
+
def func(hidden_states):
|
242 |
+
"""
|
243 |
+
Args:
|
244 |
+
hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
|
245 |
+
|
246 |
+
Returns:
|
247 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
248 |
+
"""
|
249 |
+
x = nn.functional.adaptive_avg_pool2d(hidden_states, (1, 1))
|
250 |
+
x = torch.flatten(x, 1)
|
251 |
+
logits = self.classifier(x)
|
252 |
+
return logits
|
253 |
+
return func
|
254 |
+
|
255 |
+
@staticmethod
|
256 |
+
def gen_forward(rewards, label_smoothing=0.0, top_k=1):
|
257 |
+
def func(self, x: Tensor, targets=None) -> Tensor:
|
258 |
+
hidden_states = self.features(x)
|
259 |
+
# Cannot use "squeeze" as batch-size can be 1
|
260 |
+
x = nn.functional.adaptive_avg_pool2d(hidden_states, (1, 1))
|
261 |
+
x = torch.flatten(x, 1)
|
262 |
+
logits = self.classifier(x)
|
263 |
+
|
264 |
+
if self.training:
|
265 |
+
shared_head = MobileNetV2Config.gen_shared_head(self)
|
266 |
+
compute_logits = MobileNetV2Config.gen_logits(self, shared_head)
|
267 |
+
compute_mask = MobileNetV2Config.gen_mask(label_smoothing, top_k)
|
268 |
+
compute_loss = MobileNetV2Config.gen_criterion(label_smoothing)
|
269 |
+
|
270 |
+
logits_sequence = compute_logits(hidden_states)
|
271 |
+
mask_sequence = compute_mask(logits_sequence, targets)
|
272 |
+
loss_sequence = compute_loss(logits_sequence, targets)
|
273 |
+
loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
|
274 |
+
|
275 |
+
return logits, loss
|
276 |
+
|
277 |
+
return logits
|
278 |
+
|
279 |
+
return func
|
280 |
+
|
281 |
+
|
282 |
+
# EfficientNet for Image Classification
|
283 |
+
class EfficientNetConfig(ResNetConfig):
|
284 |
+
@staticmethod
|
285 |
+
def gen_shared_head(self):
|
286 |
+
def func(hidden_states):
|
287 |
+
"""
|
288 |
+
Args:
|
289 |
+
hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
|
290 |
+
|
291 |
+
Returns:
|
292 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
293 |
+
"""
|
294 |
+
x = self.avgpool(hidden_states)
|
295 |
+
x = torch.flatten(x, 1)
|
296 |
+
logits = self.classifier(x)
|
297 |
+
return logits
|
298 |
+
return func
|
299 |
+
|
300 |
+
@staticmethod
|
301 |
+
def gen_forward(rewards, label_smoothing=0.0, top_k=1):
|
302 |
+
def func(self, x: Tensor, targets=None) -> Tensor:
|
303 |
+
hidden_states = self.features(x)
|
304 |
+
x = self.avgpool(hidden_states)
|
305 |
+
x = torch.flatten(x, 1)
|
306 |
+
logits = self.classifier(x)
|
307 |
+
|
308 |
+
if self.training:
|
309 |
+
shared_head = EfficientNetConfig.gen_shared_head(self)
|
310 |
+
compute_logits = EfficientNetConfig.gen_logits(self, shared_head)
|
311 |
+
compute_mask = EfficientNetConfig.gen_mask(label_smoothing, top_k)
|
312 |
+
compute_loss = EfficientNetConfig.gen_criterion(label_smoothing)
|
313 |
+
|
314 |
+
logits_sequence = compute_logits(hidden_states)
|
315 |
+
mask_sequence = compute_mask(logits_sequence, targets)
|
316 |
+
loss_sequence = compute_loss(logits_sequence, targets)
|
317 |
+
loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
|
318 |
+
|
319 |
+
return logits, loss
|
320 |
+
|
321 |
+
return logits
|
322 |
+
|
323 |
+
return func
|
324 |
+
|
325 |
+
|
326 |
+
# VisionTransformer for Image Classification
|
327 |
+
class VisionTransformerConfig(ResNetConfig):
|
328 |
+
@staticmethod
|
329 |
+
def gen_shared_head(self):
|
330 |
+
def func(hidden_states):
|
331 |
+
"""
|
332 |
+
Args:
|
333 |
+
hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
|
334 |
+
|
335 |
+
Returns:
|
336 |
+
logits (Tensor): Logits tensor of shape [B, C].
|
337 |
+
"""
|
338 |
+
x = hidden_states[:, 0]
|
339 |
+
logits = self.heads(x)
|
340 |
+
return logits
|
341 |
+
return func
|
342 |
+
|
343 |
+
@staticmethod
|
344 |
+
def gen_forward(rewards, label_smoothing=0.0, top_k=1):
|
345 |
+
def func(self, images: Tensor, targets=None):
|
346 |
+
x = self._process_input(images)
|
347 |
+
n = x.shape[0]
|
348 |
+
batch_class_token = self.class_token.expand(n, -1, -1)
|
349 |
+
x = torch.cat([batch_class_token, x], dim=1)
|
350 |
+
hidden_states = self.encoder(x)
|
351 |
+
x = hidden_states[:, 0]
|
352 |
+
|
353 |
+
logits = self.heads(x)
|
354 |
+
|
355 |
+
|
356 |
+
if self.training:
|
357 |
+
shared_head = VisionTransformerConfig.gen_shared_head(self)
|
358 |
+
compute_logits = VisionTransformerConfig.gen_logits(self, shared_head)
|
359 |
+
compute_mask = VisionTransformerConfig.gen_mask(label_smoothing, top_k)
|
360 |
+
compute_loss = VisionTransformerConfig.gen_criterion(label_smoothing)
|
361 |
+
|
362 |
+
logits_sequence = compute_logits(hidden_states)
|
363 |
+
mask_sequence = compute_mask(logits_sequence, targets)
|
364 |
+
loss_sequence = compute_loss(logits_sequence, targets)
|
365 |
+
loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
|
366 |
+
|
367 |
+
return logits, loss
|
368 |
+
return logits
|
369 |
+
return func
|
370 |
+
|
371 |
+
|
372 |
+
# FCN for Semantic Segmentation
|
373 |
+
class FCNConfig(ResNetConfig):
|
374 |
+
@staticmethod
|
375 |
+
def gen_out_shared_head(self, input_shape):
|
376 |
+
def func(features):
|
377 |
+
"""
|
378 |
+
Args:
|
379 |
+
features (Tensor): features tensor of shape [B, hidden_units, H, W].
|
380 |
+
|
381 |
+
Returns:
|
382 |
+
result (Tensors): result tensor of shape [B, C, H, W].
|
383 |
+
"""
|
384 |
+
x = self.classifier(features)
|
385 |
+
result = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
|
386 |
+
return result
|
387 |
+
return func
|
388 |
+
|
389 |
+
@staticmethod
|
390 |
+
def gen_aux_shared_head(self, input_shape):
|
391 |
+
def func(features):
|
392 |
+
"""
|
393 |
+
Args:
|
394 |
+
features (Tensor): features tensor of shape [B, hidden_units, H, W].
|
395 |
+
|
396 |
+
Returns:
|
397 |
+
result (Tensors): result tensor of shape [B, C, H, W].
|
398 |
+
"""
|
399 |
+
x = self.aux_classifier(features)
|
400 |
+
result = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
|
401 |
+
return result
|
402 |
+
return func
|
403 |
+
|
404 |
+
@staticmethod
|
405 |
+
def gen_out_logits(self, shared_head):
|
406 |
+
def func(hidden_states):
|
407 |
+
"""
|
408 |
+
Args:
|
409 |
+
hidden_states (Tensor): Hidden States tensor of shape [B, L, hidden_units].
|
410 |
+
|
411 |
+
Returns:
|
412 |
+
logits_seqence (List[Tensor]): List of Logits tensors.
|
413 |
+
"""
|
414 |
+
logits_sequence = [shared_head(hidden_states)]
|
415 |
+
for layer in self.out_trp_blocks:
|
416 |
+
logits_sequence.append(shared_head(layer(hidden_states)))
|
417 |
+
return logits_sequence
|
418 |
+
return func
|
419 |
+
|
420 |
+
@staticmethod
|
421 |
+
def gen_aux_logits(self, shared_head):
|
422 |
+
def func(hidden_states):
|
423 |
+
"""
|
424 |
+
Args:
|
425 |
+
hidden_states (Tensor): Hidden States tensor of shape [B, L, hidden_units].
|
426 |
+
|
427 |
+
Returns:
|
428 |
+
logits_seqence (List[Tensor]): List of Logits tensors.
|
429 |
+
"""
|
430 |
+
logits_sequence = [shared_head(hidden_states)]
|
431 |
+
for layer in self.aux_trp_blocks:
|
432 |
+
logits_sequence.append(shared_head(layer(hidden_states)))
|
433 |
+
return logits_sequence
|
434 |
+
return func
|
435 |
+
|
436 |
+
@staticmethod
|
437 |
+
def gen_mask(label_smoothing=0.0, top_k=1):
|
438 |
+
def func(logits_sequence, labels):
|
439 |
+
"""
|
440 |
+
Args:
|
441 |
+
logits_sequence (List[Tensor]): List of Logits tensors with shape [B, C, H, W].
|
442 |
+
labels (Tensor): Target labels of shape [B, H, W].
|
443 |
+
|
444 |
+
Returns:
|
445 |
+
mask_sequence (List[Tensor]): Boolean mask tensor of shape [B, H, W].
|
446 |
+
"""
|
447 |
+
labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels
|
448 |
+
|
449 |
+
mask_sequence = [torch.ones_like(labels, dtype=torch.float32, device=labels.device)]
|
450 |
+
for logits in logits_sequence:
|
451 |
+
with torch.no_grad():
|
452 |
+
topk_values, topk_indices = torch.topk(logits, top_k, dim=1)
|
453 |
+
mask = torch.eq(topk_indices, labels[:, None, :, :]).any(dim=1).to(torch.float32)
|
454 |
+
mask_sequence.append(mask_sequence[-1] * mask)
|
455 |
+
return mask_sequence
|
456 |
+
return func
|
457 |
+
|
458 |
+
@staticmethod
|
459 |
+
def gen_criterion(label_smoothing=0.0):
|
460 |
+
def func(logits_sequence, labels):
|
461 |
+
"""
|
462 |
+
Args:
|
463 |
+
logits_sequence (List[Tensor]): List of Logits tensor.
|
464 |
+
labels (Tensor): labels labels of shape [B] or [B, C].
|
465 |
+
|
466 |
+
Returns:
|
467 |
+
loss (Tensor): Scalar tensor representing the loss.
|
468 |
+
mask (Tensor): Boolean mask tensor of shape [B].
|
469 |
+
"""
|
470 |
+
labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels
|
471 |
+
|
472 |
+
loss_sequence = []
|
473 |
+
for logits in logits_sequence:
|
474 |
+
loss_sequence.append(F.cross_entropy(logits, labels, ignore_index=255, reduction="none", label_smoothing=label_smoothing))
|
475 |
+
|
476 |
+
return loss_sequence
|
477 |
+
return func
|
478 |
+
|
479 |
+
@staticmethod
|
480 |
+
def gen_forward(rewards, label_smoothing=0.0, top_k=1):
|
481 |
+
def func(self, images: Tensor, targets=None):
|
482 |
+
input_shape = images.shape[-2:]
|
483 |
+
# contract: features is a dict of tensors
|
484 |
+
features = self.backbone(images)
|
485 |
+
|
486 |
+
result = OrderedDict()
|
487 |
+
x = features["out"]
|
488 |
+
x = self.classifier(x)
|
489 |
+
x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
|
490 |
+
result["out"] = x
|
491 |
+
|
492 |
+
if self.aux_classifier is not None:
|
493 |
+
x = features["aux"]
|
494 |
+
x = self.aux_classifier(x)
|
495 |
+
x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
|
496 |
+
result["aux"] = x
|
497 |
+
|
498 |
+
if self.training:
|
499 |
+
torch._assert(targets is not None, "targets should not be none when in training mode")
|
500 |
+
out_shared_head = FCNConfig.gen_out_shared_head(self, input_shape)
|
501 |
+
aux_shared_head = FCNConfig.gen_aux_shared_head(self, input_shape)
|
502 |
+
compute_out_logits = FCNConfig.gen_out_logits(self, out_shared_head)
|
503 |
+
compute_aux_logits = FCNConfig.gen_aux_logits(self, aux_shared_head)
|
504 |
+
compute_mask = FCNConfig.gen_mask(label_smoothing, top_k)
|
505 |
+
compute_loss = FCNConfig.gen_criterion(label_smoothing)
|
506 |
+
|
507 |
+
out_logits_sequence = compute_out_logits(features["out"])
|
508 |
+
out_mask_sequence = compute_mask(out_logits_sequence, targets)
|
509 |
+
out_loss_sequence = compute_loss(out_logits_sequence, targets)
|
510 |
+
out_loss = compute_policy_loss(out_loss_sequence, out_mask_sequence, rewards)
|
511 |
+
|
512 |
+
aux_logits_sequence = compute_aux_logits(features["aux"])
|
513 |
+
aux_mask_sequence = compute_mask(aux_logits_sequence, targets)
|
514 |
+
aux_loss_sequence = compute_loss(aux_logits_sequence, targets)
|
515 |
+
aux_loss = compute_policy_loss(aux_loss_sequence, aux_mask_sequence, rewards)
|
516 |
+
|
517 |
+
loss = out_loss + 0.5 * aux_loss
|
518 |
+
return result, loss
|
519 |
+
return result
|
520 |
+
return func
|
521 |
+
|
522 |
+
|
523 |
+
# DeepLabV3Config for Semantic Segmentation
|
524 |
+
class DeepLabV3Config(FCNConfig):
|
525 |
+
pass
|
526 |
+
|
527 |
+
|
528 |
+
def apply_trp(model, depths: List[int], in_planes: int, out_planes: int, rewards, **kwargs):
|
529 |
+
if isinstance(model, ResNet):
|
530 |
+
print("✅ Applying TRP to ResNet for Image Classification...")
|
531 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
|
532 |
+
model.forward = types.MethodType(ResNetConfig.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
|
533 |
+
elif isinstance(model, MobileNetV2):
|
534 |
+
print("✅ Applying TRP to MobileNetV2 for Image Classification...")
|
535 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
|
536 |
+
model.forward = types.MethodType(MobileNetV2Config.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
|
537 |
+
elif isinstance(model, EfficientNet):
|
538 |
+
print("✅ Applying TRP to EfficientNet for Image Classification...")
|
539 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
|
540 |
+
model.forward = types.MethodType(EfficientNetConfig.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
|
541 |
+
elif isinstance(model, VisionTransformer):
|
542 |
+
print("✅ Applying TRP to VisionTransformer for Image Classification...")
|
543 |
+
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k, shape_dims=2, channel_first=False) for k, d in enumerate(depths)])
|
544 |
+
model.forward = types.MethodType(VisionTransformerConfig.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
|
545 |
+
elif isinstance(model, FCN):
|
546 |
+
print("✅ Applying TRP to FCN for Semantic Segmentation...")
|
547 |
+
model.out_trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=2048, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
|
548 |
+
model.aux_trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=1024, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
|
549 |
+
model.forward = types.MethodType(FCNConfig.gen_forward(rewards, label_smoothing=0.0, top_k=1), model)
|
550 |
+
elif isinstance(model, DeepLabV3):
|
551 |
+
print("✅ Applying TRP to DeepLabV3 for Semantic Segmentation...")
|
552 |
+
model.out_trp_blocks = torch.nn.ModuleList([TPBlock(depths, in_planes=2048, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
|
553 |
+
model.aux_trp_blocks = torch.nn.ModuleList([TPBlock(depths, in_planes=1024, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
|
554 |
+
model.forward = types.MethodType(DeepLabV3Config.gen_forward(rewards, label_smoothing=0.0, top_k=1), model)
|
555 |
+
return model
|
nas-examples/semantic-segmentation/utils.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import errno
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
from collections import defaultdict, deque
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.distributed as dist
|
9 |
+
|
10 |
+
|
11 |
+
class SmoothedValue:
|
12 |
+
"""Track a series of values and provide access to smoothed values over a
|
13 |
+
window or the global series average.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, window_size=20, fmt=None):
|
17 |
+
if fmt is None:
|
18 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
19 |
+
self.deque = deque(maxlen=window_size)
|
20 |
+
self.total = 0.0
|
21 |
+
self.count = 0
|
22 |
+
self.fmt = fmt
|
23 |
+
|
24 |
+
def update(self, value, n=1):
|
25 |
+
self.deque.append(value)
|
26 |
+
self.count += n
|
27 |
+
self.total += value * n
|
28 |
+
|
29 |
+
def synchronize_between_processes(self):
|
30 |
+
"""
|
31 |
+
Warning: does not synchronize the deque!
|
32 |
+
"""
|
33 |
+
t = reduce_across_processes([self.count, self.total])
|
34 |
+
t = t.tolist()
|
35 |
+
self.count = int(t[0])
|
36 |
+
self.total = t[1]
|
37 |
+
|
38 |
+
@property
|
39 |
+
def median(self):
|
40 |
+
d = torch.tensor(list(self.deque))
|
41 |
+
return d.median().item()
|
42 |
+
|
43 |
+
@property
|
44 |
+
def avg(self):
|
45 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
46 |
+
return d.mean().item()
|
47 |
+
|
48 |
+
@property
|
49 |
+
def global_avg(self):
|
50 |
+
return self.total / self.count
|
51 |
+
|
52 |
+
@property
|
53 |
+
def max(self):
|
54 |
+
return max(self.deque)
|
55 |
+
|
56 |
+
@property
|
57 |
+
def value(self):
|
58 |
+
return self.deque[-1]
|
59 |
+
|
60 |
+
def __str__(self):
|
61 |
+
return self.fmt.format(
|
62 |
+
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
+
class ConfusionMatrix:
|
67 |
+
def __init__(self, num_classes):
|
68 |
+
self.num_classes = num_classes
|
69 |
+
self.mat = None
|
70 |
+
|
71 |
+
def update(self, a, b):
|
72 |
+
n = self.num_classes
|
73 |
+
if self.mat is None:
|
74 |
+
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
|
75 |
+
with torch.inference_mode():
|
76 |
+
k = (a >= 0) & (a < n)
|
77 |
+
inds = n * a[k].to(torch.int64) + b[k]
|
78 |
+
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
|
79 |
+
|
80 |
+
def reset(self):
|
81 |
+
self.mat.zero_()
|
82 |
+
|
83 |
+
def compute(self):
|
84 |
+
h = self.mat.float()
|
85 |
+
acc_global = torch.diag(h).sum() / h.sum()
|
86 |
+
acc = torch.diag(h) / h.sum(1)
|
87 |
+
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
|
88 |
+
return acc_global, acc, iu
|
89 |
+
|
90 |
+
def reduce_from_all_processes(self):
|
91 |
+
reduce_across_processes(self.mat)
|
92 |
+
|
93 |
+
def __str__(self):
|
94 |
+
acc_global, acc, iu = self.compute()
|
95 |
+
return ("global correct: {:.1f}\naverage row correct: {}\nIoU: {}\nmean IoU: {:.1f}").format(
|
96 |
+
acc_global.item() * 100,
|
97 |
+
[f"{i:.1f}" for i in (acc * 100).tolist()],
|
98 |
+
[f"{i:.1f}" for i in (iu * 100).tolist()],
|
99 |
+
iu.mean().item() * 100,
|
100 |
+
)
|
101 |
+
|
102 |
+
|
103 |
+
class MetricLogger:
|
104 |
+
def __init__(self, delimiter="\t"):
|
105 |
+
self.meters = defaultdict(SmoothedValue)
|
106 |
+
self.delimiter = delimiter
|
107 |
+
|
108 |
+
def update(self, **kwargs):
|
109 |
+
for k, v in kwargs.items():
|
110 |
+
if isinstance(v, torch.Tensor):
|
111 |
+
v = v.item()
|
112 |
+
if not isinstance(v, (float, int)):
|
113 |
+
raise TypeError(
|
114 |
+
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
|
115 |
+
)
|
116 |
+
self.meters[k].update(v)
|
117 |
+
|
118 |
+
def __getattr__(self, attr):
|
119 |
+
if attr in self.meters:
|
120 |
+
return self.meters[attr]
|
121 |
+
if attr in self.__dict__:
|
122 |
+
return self.__dict__[attr]
|
123 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
|
124 |
+
|
125 |
+
def __str__(self):
|
126 |
+
loss_str = []
|
127 |
+
for name, meter in self.meters.items():
|
128 |
+
loss_str.append(f"{name}: {str(meter)}")
|
129 |
+
return self.delimiter.join(loss_str)
|
130 |
+
|
131 |
+
def synchronize_between_processes(self):
|
132 |
+
for meter in self.meters.values():
|
133 |
+
meter.synchronize_between_processes()
|
134 |
+
|
135 |
+
def add_meter(self, name, meter):
|
136 |
+
self.meters[name] = meter
|
137 |
+
|
138 |
+
def log_every(self, iterable, print_freq, header=None):
|
139 |
+
i = 0
|
140 |
+
if not header:
|
141 |
+
header = ""
|
142 |
+
start_time = time.time()
|
143 |
+
end = time.time()
|
144 |
+
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
145 |
+
data_time = SmoothedValue(fmt="{avg:.4f}")
|
146 |
+
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
147 |
+
if torch.cuda.is_available():
|
148 |
+
log_msg = self.delimiter.join(
|
149 |
+
[
|
150 |
+
header,
|
151 |
+
"[{0" + space_fmt + "}/{1}]",
|
152 |
+
"eta: {eta}",
|
153 |
+
"{meters}",
|
154 |
+
"time: {time}",
|
155 |
+
"data: {data}",
|
156 |
+
"max mem: {memory:.0f}",
|
157 |
+
]
|
158 |
+
)
|
159 |
+
else:
|
160 |
+
log_msg = self.delimiter.join(
|
161 |
+
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
|
162 |
+
)
|
163 |
+
MB = 1024.0 * 1024.0
|
164 |
+
for obj in iterable:
|
165 |
+
data_time.update(time.time() - end)
|
166 |
+
yield obj
|
167 |
+
iter_time.update(time.time() - end)
|
168 |
+
if i % print_freq == 0:
|
169 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
170 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
171 |
+
if torch.cuda.is_available():
|
172 |
+
print(
|
173 |
+
log_msg.format(
|
174 |
+
i,
|
175 |
+
len(iterable),
|
176 |
+
eta=eta_string,
|
177 |
+
meters=str(self),
|
178 |
+
time=str(iter_time),
|
179 |
+
data=str(data_time),
|
180 |
+
memory=torch.cuda.max_memory_allocated() / MB,
|
181 |
+
)
|
182 |
+
)
|
183 |
+
else:
|
184 |
+
print(
|
185 |
+
log_msg.format(
|
186 |
+
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
|
187 |
+
)
|
188 |
+
)
|
189 |
+
i += 1
|
190 |
+
end = time.time()
|
191 |
+
total_time = time.time() - start_time
|
192 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
193 |
+
print(f"{header} Total time: {total_time_str}")
|
194 |
+
|
195 |
+
|
196 |
+
def cat_list(images, fill_value=0):
|
197 |
+
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
|
198 |
+
batch_shape = (len(images),) + max_size
|
199 |
+
batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
|
200 |
+
for img, pad_img in zip(images, batched_imgs):
|
201 |
+
pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)
|
202 |
+
return batched_imgs
|
203 |
+
|
204 |
+
|
205 |
+
def collate_fn(batch):
|
206 |
+
images, targets = list(zip(*batch))
|
207 |
+
batched_imgs = cat_list(images, fill_value=0)
|
208 |
+
batched_targets = cat_list(targets, fill_value=255)
|
209 |
+
return batched_imgs, batched_targets
|
210 |
+
|
211 |
+
|
212 |
+
def mkdir(path):
|
213 |
+
try:
|
214 |
+
os.makedirs(path)
|
215 |
+
except OSError as e:
|
216 |
+
if e.errno != errno.EEXIST:
|
217 |
+
raise
|
218 |
+
|
219 |
+
|
220 |
+
def setup_for_distributed(is_master):
|
221 |
+
"""
|
222 |
+
This function disables printing when not in master process
|
223 |
+
"""
|
224 |
+
import builtins as __builtin__
|
225 |
+
|
226 |
+
builtin_print = __builtin__.print
|
227 |
+
|
228 |
+
def print(*args, **kwargs):
|
229 |
+
force = kwargs.pop("force", False)
|
230 |
+
if is_master or force:
|
231 |
+
builtin_print(*args, **kwargs)
|
232 |
+
|
233 |
+
__builtin__.print = print
|
234 |
+
|
235 |
+
|
236 |
+
def is_dist_avail_and_initialized():
|
237 |
+
if not dist.is_available():
|
238 |
+
return False
|
239 |
+
if not dist.is_initialized():
|
240 |
+
return False
|
241 |
+
return True
|
242 |
+
|
243 |
+
|
244 |
+
def get_world_size():
|
245 |
+
if not is_dist_avail_and_initialized():
|
246 |
+
return 1
|
247 |
+
return dist.get_world_size()
|
248 |
+
|
249 |
+
|
250 |
+
def get_rank():
|
251 |
+
if not is_dist_avail_and_initialized():
|
252 |
+
return 0
|
253 |
+
return dist.get_rank()
|
254 |
+
|
255 |
+
|
256 |
+
def is_main_process():
|
257 |
+
return get_rank() == 0
|
258 |
+
|
259 |
+
|
260 |
+
def save_on_master(*args, **kwargs):
|
261 |
+
if is_main_process():
|
262 |
+
torch.save(*args, **kwargs)
|
263 |
+
|
264 |
+
|
265 |
+
def init_distributed_mode(args):
|
266 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
267 |
+
args.rank = int(os.environ["RANK"])
|
268 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
269 |
+
args.gpu = int(os.environ["LOCAL_RANK"])
|
270 |
+
elif "SLURM_PROCID" in os.environ:
|
271 |
+
args.rank = int(os.environ["SLURM_PROCID"])
|
272 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
273 |
+
elif hasattr(args, "rank"):
|
274 |
+
pass
|
275 |
+
else:
|
276 |
+
print("Not using distributed mode")
|
277 |
+
args.distributed = False
|
278 |
+
return
|
279 |
+
|
280 |
+
args.distributed = True
|
281 |
+
|
282 |
+
torch.cuda.set_device(args.gpu)
|
283 |
+
args.dist_backend = "nccl"
|
284 |
+
print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
|
285 |
+
torch.distributed.init_process_group(
|
286 |
+
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
|
287 |
+
)
|
288 |
+
torch.distributed.barrier()
|
289 |
+
setup_for_distributed(args.rank == 0)
|
290 |
+
|
291 |
+
|
292 |
+
def reduce_across_processes(val):
|
293 |
+
if not is_dist_avail_and_initialized():
|
294 |
+
# nothing to sync, but we still convert to tensor for consistency with the distributed case.
|
295 |
+
return torch.tensor(val)
|
296 |
+
|
297 |
+
t = torch.tensor(val, device="cuda") if isinstance(val, int) else val.clone().detach().to("cuda")
|
298 |
+
dist.barrier()
|
299 |
+
dist.all_reduce(t)
|
300 |
+
return t
|