UniversalAlgorithmic commited on
Commit
97d8aaa
·
verified ·
1 Parent(s): 09a2af4

Upload 13 files

Browse files
nas-examples/image-classification/presets.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.transforms import autoaugment, transforms
3
+ from torchvision.transforms.functional import InterpolationMode
4
+
5
+
6
+ class ClassificationPresetTrain:
7
+ def __init__(
8
+ self,
9
+ *,
10
+ crop_size,
11
+ mean=(0.485, 0.456, 0.406),
12
+ std=(0.229, 0.224, 0.225),
13
+ interpolation=InterpolationMode.BILINEAR,
14
+ hflip_prob=0.5,
15
+ auto_augment_policy=None,
16
+ ra_magnitude=9,
17
+ augmix_severity=3,
18
+ random_erase_prob=0.0,
19
+ ):
20
+ trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
21
+ if hflip_prob > 0:
22
+ trans.append(transforms.RandomHorizontalFlip(hflip_prob))
23
+ if auto_augment_policy is not None:
24
+ if auto_augment_policy == "ra":
25
+ trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
26
+ elif auto_augment_policy == "ta_wide":
27
+ trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
28
+ elif auto_augment_policy == "augmix":
29
+ trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity))
30
+ else:
31
+ aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
32
+ trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
33
+ trans.extend(
34
+ [
35
+ transforms.PILToTensor(),
36
+ transforms.ConvertImageDtype(torch.float),
37
+ transforms.Normalize(mean=mean, std=std),
38
+ ]
39
+ )
40
+ if random_erase_prob > 0:
41
+ trans.append(transforms.RandomErasing(p=random_erase_prob))
42
+
43
+ self.transforms = transforms.Compose(trans)
44
+
45
+ def __call__(self, img):
46
+ return self.transforms(img)
47
+
48
+
49
+ class ClassificationPresetEval:
50
+ def __init__(
51
+ self,
52
+ *,
53
+ crop_size,
54
+ resize_size=256,
55
+ mean=(0.485, 0.456, 0.406),
56
+ std=(0.229, 0.224, 0.225),
57
+ interpolation=InterpolationMode.BILINEAR,
58
+ ):
59
+
60
+ self.transforms = transforms.Compose(
61
+ [
62
+ transforms.Resize(resize_size, interpolation=interpolation),
63
+ transforms.CenterCrop(crop_size),
64
+ transforms.PILToTensor(),
65
+ transforms.ConvertImageDtype(torch.float),
66
+ transforms.Normalize(mean=mean, std=std),
67
+ ]
68
+ )
69
+
70
+ def __call__(self, img):
71
+ return self.transforms(img)
nas-examples/image-classification/sampler.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+
6
+
7
+ class RASampler(torch.utils.data.Sampler):
8
+ """Sampler that restricts data loading to a subset of the dataset for distributed,
9
+ with repeated augmentation.
10
+ It ensures that different each augmented version of a sample will be visible to a
11
+ different process (GPU).
12
+ Heavily based on 'torch.utils.data.DistributedSampler'.
13
+
14
+ This is borrowed from the DeiT Repo:
15
+ https://github.com/facebookresearch/deit/blob/main/samplers.py
16
+ """
17
+
18
+ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3):
19
+ if num_replicas is None:
20
+ if not dist.is_available():
21
+ raise RuntimeError("Requires distributed package to be available!")
22
+ num_replicas = dist.get_world_size()
23
+ if rank is None:
24
+ if not dist.is_available():
25
+ raise RuntimeError("Requires distributed package to be available!")
26
+ rank = dist.get_rank()
27
+ self.dataset = dataset
28
+ self.num_replicas = num_replicas
29
+ self.rank = rank
30
+ self.epoch = 0
31
+ self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
32
+ self.total_size = self.num_samples * self.num_replicas
33
+ self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
34
+ self.shuffle = shuffle
35
+ self.seed = seed
36
+ self.repetitions = repetitions
37
+
38
+ def __iter__(self):
39
+ if self.shuffle:
40
+ # Deterministically shuffle based on epoch
41
+ g = torch.Generator()
42
+ g.manual_seed(self.seed + self.epoch)
43
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
44
+ else:
45
+ indices = list(range(len(self.dataset)))
46
+
47
+ # Add extra samples to make it evenly divisible
48
+ indices = [ele for ele in indices for i in range(self.repetitions)]
49
+ indices += indices[: (self.total_size - len(indices))]
50
+ assert len(indices) == self.total_size
51
+
52
+ # Subsample
53
+ indices = indices[self.rank : self.total_size : self.num_replicas]
54
+ assert len(indices) == self.num_samples
55
+
56
+ return iter(indices[: self.num_selected_samples])
57
+
58
+ def __len__(self):
59
+ return self.num_selected_samples
60
+
61
+ def set_epoch(self, epoch):
62
+ self.epoch = epoch
nas-examples/image-classification/train.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+ import time
4
+ import warnings
5
+
6
+ import presets
7
+ import torch
8
+ import torch.utils.data
9
+ import torchvision
10
+ import transforms
11
+ import utils
12
+ from sampler import RASampler
13
+ from torch import nn
14
+ from torch.utils.data.dataloader import default_collate
15
+ from torchvision.transforms.functional import InterpolationMode
16
+
17
+ from trplib import apply_trp
18
+
19
+
20
+ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
21
+ model.train()
22
+ metric_logger = utils.MetricLogger(delimiter=" ")
23
+ metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
24
+ metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
25
+
26
+ header = f"Epoch: [{epoch}]"
27
+ for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
28
+ start_time = time.time()
29
+ image, target = image.to(device), target.to(device)
30
+ with torch.amp.autocast("cuda", enabled=scaler is not None):
31
+ # output = model(image)
32
+ # loss = criterion(output, target)
33
+ output, loss = model(image, target)
34
+
35
+ optimizer.zero_grad()
36
+ if scaler is not None:
37
+ scaler.scale(loss).backward()
38
+ if args.clip_grad_norm is not None:
39
+ # we should unscale the gradients of optimizer's assigned params if do gradient clipping
40
+ scaler.unscale_(optimizer)
41
+ nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
42
+ scaler.step(optimizer)
43
+ scaler.update()
44
+ else:
45
+ loss.backward()
46
+ if args.clip_grad_norm is not None:
47
+ nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
48
+ optimizer.step()
49
+
50
+ if model_ema and i % args.model_ema_steps == 0:
51
+ model_ema.update_parameters(model)
52
+ if epoch < args.lr_warmup_epochs:
53
+ # Reset ema buffer to keep copying weights during warmup period
54
+ model_ema.n_averaged.fill_(0)
55
+
56
+ acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
57
+ batch_size = image.shape[0]
58
+ metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
59
+ metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
60
+ metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
61
+ metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
62
+
63
+
64
+ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
65
+ model.eval()
66
+ metric_logger = utils.MetricLogger(delimiter=" ")
67
+ header = f"Test: {log_suffix}"
68
+
69
+ num_processed_samples = 0
70
+ with torch.inference_mode():
71
+ for image, target in metric_logger.log_every(data_loader, print_freq, header):
72
+ image = image.to(device, non_blocking=True)
73
+ target = target.to(device, non_blocking=True)
74
+ output = model(image)
75
+ loss = criterion(output, target)
76
+
77
+ acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
78
+ # FIXME need to take into account that the datasets
79
+ # could have been padded in distributed setup
80
+ batch_size = image.shape[0]
81
+ metric_logger.update(loss=loss.item())
82
+ metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
83
+ metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
84
+ num_processed_samples += batch_size
85
+ # gather the stats from all processes
86
+
87
+ num_processed_samples = utils.reduce_across_processes(num_processed_samples)
88
+ if (
89
+ hasattr(data_loader.dataset, "__len__")
90
+ and len(data_loader.dataset) != num_processed_samples
91
+ and torch.distributed.get_rank() == 0
92
+ ):
93
+ # See FIXME above
94
+ warnings.warn(
95
+ f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
96
+ "samples were used for the validation, which might bias the results. "
97
+ "Try adjusting the batch size and / or the world size. "
98
+ "Setting the world size to 1 is always a safe bet."
99
+ )
100
+
101
+ metric_logger.synchronize_between_processes()
102
+
103
+ print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
104
+ return metric_logger.acc1.global_avg
105
+
106
+
107
+ def _get_cache_path(filepath):
108
+ import hashlib
109
+
110
+ h = hashlib.sha1(filepath.encode()).hexdigest()
111
+ cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
112
+ cache_path = os.path.expanduser(cache_path)
113
+ return cache_path
114
+
115
+
116
+ def load_data(traindir, valdir, args):
117
+ # Data loading code
118
+ print("Loading data")
119
+ val_resize_size, val_crop_size, train_crop_size = (
120
+ args.val_resize_size,
121
+ args.val_crop_size,
122
+ args.train_crop_size,
123
+ )
124
+ interpolation = InterpolationMode(args.interpolation)
125
+
126
+ print("Loading training data")
127
+ st = time.time()
128
+ cache_path = _get_cache_path(traindir)
129
+ if args.cache_dataset and os.path.exists(cache_path):
130
+ # Attention, as the transforms are also cached!
131
+ print(f"Loading dataset_train from {cache_path}")
132
+ dataset, _ = torch.load(cache_path)
133
+ else:
134
+ auto_augment_policy = getattr(args, "auto_augment", None)
135
+ random_erase_prob = getattr(args, "random_erase", 0.0)
136
+ ra_magnitude = args.ra_magnitude
137
+ augmix_severity = args.augmix_severity
138
+ dataset = torchvision.datasets.ImageFolder(
139
+ traindir,
140
+ presets.ClassificationPresetTrain(
141
+ crop_size=train_crop_size,
142
+ interpolation=interpolation,
143
+ auto_augment_policy=auto_augment_policy,
144
+ random_erase_prob=random_erase_prob,
145
+ ra_magnitude=ra_magnitude,
146
+ augmix_severity=augmix_severity,
147
+ ),
148
+ )
149
+ if args.cache_dataset:
150
+ print(f"Saving dataset_train to {cache_path}")
151
+ utils.mkdir(os.path.dirname(cache_path))
152
+ utils.save_on_master((dataset, traindir), cache_path)
153
+ print("Took", time.time() - st)
154
+
155
+ print("Loading validation data")
156
+ cache_path = _get_cache_path(valdir)
157
+ if args.cache_dataset and os.path.exists(cache_path):
158
+ # Attention, as the transforms are also cached!
159
+ print(f"Loading dataset_test from {cache_path}")
160
+ dataset_test, _ = torch.load(cache_path)
161
+ else:
162
+ if args.weights and args.test_only:
163
+ weights = torchvision.models.get_weight(args.weights)
164
+ preprocessing = weights.transforms()
165
+ else:
166
+ preprocessing = presets.ClassificationPresetEval(
167
+ crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
168
+ )
169
+
170
+ dataset_test = torchvision.datasets.ImageFolder(
171
+ valdir,
172
+ preprocessing,
173
+ )
174
+ if args.cache_dataset:
175
+ print(f"Saving dataset_test to {cache_path}")
176
+ utils.mkdir(os.path.dirname(cache_path))
177
+ utils.save_on_master((dataset_test, valdir), cache_path)
178
+
179
+ print("Creating data loaders")
180
+ if args.distributed:
181
+ if hasattr(args, "ra_sampler") and args.ra_sampler:
182
+ train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
183
+ else:
184
+ train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
185
+ test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
186
+ else:
187
+ train_sampler = torch.utils.data.RandomSampler(dataset)
188
+ test_sampler = torch.utils.data.SequentialSampler(dataset_test)
189
+
190
+ return dataset, dataset_test, train_sampler, test_sampler
191
+
192
+
193
+ def main(args):
194
+ if args.output_dir:
195
+ utils.mkdir(args.output_dir)
196
+
197
+ utils.init_distributed_mode(args)
198
+ print(args)
199
+
200
+ device = torch.device(args.device)
201
+
202
+ if args.use_deterministic_algorithms:
203
+ torch.backends.cudnn.benchmark = False
204
+ torch.use_deterministic_algorithms(True)
205
+ else:
206
+ torch.backends.cudnn.benchmark = True
207
+
208
+ train_dir = os.path.join(args.data_path, "train")
209
+ val_dir = os.path.join(args.data_path, "val")
210
+ dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
211
+
212
+ collate_fn = None
213
+ num_classes = len(dataset.classes)
214
+ mixup_transforms = []
215
+ if args.mixup_alpha > 0.0:
216
+ mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
217
+ if args.cutmix_alpha > 0.0:
218
+ mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
219
+ if mixup_transforms:
220
+ mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
221
+
222
+ def collate_fn(batch):
223
+ return mixupcutmix(*default_collate(batch))
224
+
225
+ data_loader = torch.utils.data.DataLoader(
226
+ dataset,
227
+ batch_size=args.batch_size,
228
+ sampler=train_sampler,
229
+ num_workers=args.workers,
230
+ pin_memory=True,
231
+ collate_fn=collate_fn,
232
+ )
233
+ data_loader_test = torch.utils.data.DataLoader(
234
+ dataset_test, batch_size=8, sampler=test_sampler, num_workers=args.workers, pin_memory=True
235
+ )
236
+
237
+ print("Creating model")
238
+ model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
239
+ if args.apply_trp:
240
+ model = apply_trp(model, args.trp_depths, args.in_planes, args.out_planes, args.trp_rewards, label_smoothing=args.label_smoothing)
241
+ model.to(device)
242
+
243
+ if args.distributed and args.sync_bn:
244
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
245
+
246
+ criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
247
+
248
+ custom_keys_weight_decay = []
249
+ if args.bias_weight_decay is not None:
250
+ custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
251
+ if args.transformer_embedding_decay is not None:
252
+ for key in ["class_token", "position_embedding", "relative_position_bias_table"]:
253
+ custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
254
+ parameters = utils.set_weight_decay(
255
+ model,
256
+ args.weight_decay,
257
+ norm_weight_decay=args.norm_weight_decay,
258
+ custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
259
+ )
260
+
261
+ opt_name = args.opt.lower()
262
+ if opt_name.startswith("sgd"):
263
+ optimizer = torch.optim.SGD(
264
+ parameters,
265
+ lr=args.lr,
266
+ momentum=args.momentum,
267
+ weight_decay=args.weight_decay,
268
+ nesterov="nesterov" in opt_name,
269
+ )
270
+ elif opt_name == "rmsprop":
271
+ optimizer = torch.optim.RMSprop(
272
+ parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
273
+ )
274
+ elif opt_name == "adamw":
275
+ optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
276
+ else:
277
+ raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")
278
+
279
+ scaler = torch.amp.GradScaler("cuda") if args.amp else None
280
+
281
+ args.lr_scheduler = args.lr_scheduler.lower()
282
+ if args.lr_scheduler == "steplr":
283
+ main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
284
+ elif args.lr_scheduler == "cosineannealinglr":
285
+ main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
286
+ optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min
287
+ )
288
+ elif args.lr_scheduler == "exponentiallr":
289
+ main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
290
+ else:
291
+ raise RuntimeError(
292
+ f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
293
+ "are supported."
294
+ )
295
+
296
+ if args.lr_warmup_epochs > 0:
297
+ if args.lr_warmup_method == "linear":
298
+ warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
299
+ optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
300
+ )
301
+ elif args.lr_warmup_method == "constant":
302
+ warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
303
+ optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
304
+ )
305
+ else:
306
+ raise RuntimeError(
307
+ f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
308
+ )
309
+ lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
310
+ optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
311
+ )
312
+ else:
313
+ lr_scheduler = main_lr_scheduler
314
+
315
+ model_without_ddp = model
316
+ if args.distributed:
317
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
318
+ model_without_ddp = model.module
319
+
320
+ model_ema = None
321
+ if args.model_ema:
322
+ # Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at:
323
+ # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
324
+ #
325
+ # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
326
+ # We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus:
327
+ # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
328
+ adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
329
+ alpha = 1.0 - args.model_ema_decay
330
+ alpha = min(1.0, alpha * adjust)
331
+ model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
332
+
333
+ if args.resume:
334
+ checkpoint = torch.load(args.resume, map_location="cpu", weights_only=False)
335
+ model_without_ddp.load_state_dict(checkpoint["model"])
336
+ if not args.test_only:
337
+ optimizer.load_state_dict(checkpoint["optimizer"])
338
+ lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
339
+ args.start_epoch = checkpoint["epoch"] + 1
340
+ if model_ema:
341
+ model_ema.load_state_dict(checkpoint["model_ema"])
342
+ if scaler:
343
+ scaler.load_state_dict(checkpoint["scaler"])
344
+
345
+ if args.test_only:
346
+ # We disable the cudnn benchmarking because it can noticeably affect the accuracy
347
+ torch.backends.cudnn.benchmark = False
348
+ torch.backends.cudnn.deterministic = True
349
+ if model_ema:
350
+ evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
351
+ else:
352
+ evaluate(model, criterion, data_loader_test, device=device)
353
+ return
354
+
355
+ print("Start training")
356
+ start_time = time.time()
357
+ for epoch in range(args.start_epoch, args.epochs):
358
+ if args.distributed:
359
+ train_sampler.set_epoch(epoch)
360
+ train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
361
+ lr_scheduler.step()
362
+ evaluate(model, criterion, data_loader_test, device=device)
363
+ if model_ema:
364
+ evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
365
+ if args.output_dir:
366
+ checkpoint = {
367
+ "model": model_without_ddp.state_dict() if not args.apply_trp else {k: v for k, v in model_without_ddp.state_dict().items() if not "trp_blocks" in k},
368
+ "optimizer": optimizer.state_dict(),
369
+ "lr_scheduler": lr_scheduler.state_dict(),
370
+ "epoch": epoch,
371
+ "args": args,
372
+ }
373
+ if model_ema:
374
+ checkpoint["model_ema"] = model_ema.state_dict() if not args.apply_trp else {k: v for k, v in model_ema.state_dict().items() if not "trp_blocks" in k}
375
+ if scaler:
376
+ checkpoint["scaler"] = scaler.state_dict()
377
+ utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
378
+ utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
379
+
380
+ total_time = time.time() - start_time
381
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
382
+ print(f"Training time {total_time_str}")
383
+
384
+
385
+ def get_args_parser(add_help=True):
386
+ import argparse
387
+
388
+ parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
389
+
390
+ parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
391
+ parser.add_argument("--model", default="resnet18", type=str, help="model name")
392
+ parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
393
+ parser.add_argument(
394
+ "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
395
+ )
396
+ parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
397
+ parser.add_argument(
398
+ "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
399
+ )
400
+ parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
401
+ parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
402
+ parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
403
+ parser.add_argument(
404
+ "--wd",
405
+ "--weight-decay",
406
+ default=1e-4,
407
+ type=float,
408
+ metavar="W",
409
+ help="weight decay (default: 1e-4)",
410
+ dest="weight_decay",
411
+ )
412
+ parser.add_argument(
413
+ "--norm-weight-decay",
414
+ default=None,
415
+ type=float,
416
+ help="weight decay for Normalization layers (default: None, same value as --wd)",
417
+ )
418
+ parser.add_argument(
419
+ "--bias-weight-decay",
420
+ default=None,
421
+ type=float,
422
+ help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
423
+ )
424
+ parser.add_argument(
425
+ "--transformer-embedding-decay",
426
+ default=None,
427
+ type=float,
428
+ help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
429
+ )
430
+ parser.add_argument(
431
+ "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
432
+ )
433
+ parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
434
+ parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
435
+ parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
436
+ parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
437
+ parser.add_argument(
438
+ "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
439
+ )
440
+ parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
441
+ parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
442
+ parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
443
+ parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
444
+ parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
445
+ parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
446
+ parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
447
+ parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
448
+ parser.add_argument(
449
+ "--cache-dataset",
450
+ dest="cache_dataset",
451
+ help="Cache the datasets for quicker initialization. It also serializes the transforms",
452
+ action="store_true",
453
+ )
454
+ parser.add_argument(
455
+ "--sync-bn",
456
+ dest="sync_bn",
457
+ help="Use sync batch norm",
458
+ action="store_true",
459
+ )
460
+ parser.add_argument(
461
+ "--test-only",
462
+ dest="test_only",
463
+ help="Only test the model",
464
+ action="store_true",
465
+ )
466
+ parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
467
+ parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy")
468
+ parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy")
469
+ parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
470
+
471
+ # Mixed precision training parameters
472
+ parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
473
+
474
+ # distributed training parameters
475
+ parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
476
+ parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
477
+ parser.add_argument(
478
+ "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
479
+ )
480
+ parser.add_argument(
481
+ "--model-ema-steps",
482
+ type=int,
483
+ default=32,
484
+ help="the number of iterations that controls how often to update the EMA model (default: 32)",
485
+ )
486
+ parser.add_argument(
487
+ "--model-ema-decay",
488
+ type=float,
489
+ default=0.99998,
490
+ help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
491
+ )
492
+ parser.add_argument(
493
+ "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
494
+ )
495
+ parser.add_argument(
496
+ "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
497
+ )
498
+ parser.add_argument(
499
+ "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
500
+ )
501
+ parser.add_argument(
502
+ "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
503
+ )
504
+ parser.add_argument(
505
+ "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
506
+ )
507
+ parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
508
+ parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
509
+ parser.add_argument(
510
+ "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
511
+ )
512
+ parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
513
+
514
+ parser.add_argument("--apply-trp", action="store_true", help="enable applying trp")
515
+ parser.add_argument("--trp-depths", nargs="+", default=[2, 2, 2], type=int, help="number of depth for each trp block")
516
+ parser.add_argument("--in-planes", type=int, help="the dimension of the hidden states")
517
+ parser.add_argument("--out-planes", default=8, type=int, help="the dimension of the inner hidden states")
518
+ parser.add_argument("--trp-rewards", nargs="+", default=[1.0, 0.4, 0.2, 0.1], type=float, help="trp rewards")
519
+
520
+ return parser
521
+
522
+
523
+ if __name__ == "__main__":
524
+ args = get_args_parser().parse_args()
525
+ main(args)
nas-examples/image-classification/train_quantization.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import datetime
3
+ import os
4
+ import time
5
+
6
+ import torch
7
+ import torch.ao.quantization
8
+ import torch.utils.data
9
+ import torchvision
10
+ import utils
11
+ from torch import nn
12
+ from train import evaluate, load_data, train_one_epoch
13
+
14
+
15
+ def main(args):
16
+ if args.output_dir:
17
+ utils.mkdir(args.output_dir)
18
+
19
+ utils.init_distributed_mode(args)
20
+ print(args)
21
+
22
+ if args.post_training_quantize and args.distributed:
23
+ raise RuntimeError("Post training quantization example should not be performed on distributed mode")
24
+
25
+ # Set backend engine to ensure that quantized model runs on the correct kernels
26
+ if args.backend not in torch.backends.quantized.supported_engines:
27
+ raise RuntimeError("Quantized backend not supported: " + str(args.backend))
28
+ torch.backends.quantized.engine = args.backend
29
+
30
+ device = torch.device(args.device)
31
+ torch.backends.cudnn.benchmark = True
32
+
33
+ # Data loading code
34
+ print("Loading data")
35
+ train_dir = os.path.join(args.data_path, "train")
36
+ val_dir = os.path.join(args.data_path, "val")
37
+
38
+ dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
39
+ data_loader = torch.utils.data.DataLoader(
40
+ dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True
41
+ )
42
+
43
+ data_loader_test = torch.utils.data.DataLoader(
44
+ dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
45
+ )
46
+
47
+ print("Creating model", args.model)
48
+ # when training quantized models, we always start from a pre-trained fp32 reference model
49
+ prefix = "quantized_"
50
+ model_name = args.model
51
+ if not model_name.startswith(prefix):
52
+ model_name = prefix + model_name
53
+ model = torchvision.models.get_model(model_name, weights=args.weights, quantize=args.test_only)
54
+ model.to(device)
55
+
56
+ if not (args.test_only or args.post_training_quantize):
57
+ model.fuse_model(is_qat=True)
58
+ model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
59
+ torch.ao.quantization.prepare_qat(model, inplace=True)
60
+
61
+ if args.distributed and args.sync_bn:
62
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
63
+
64
+ optimizer = torch.optim.SGD(
65
+ model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
66
+ )
67
+
68
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
69
+
70
+ criterion = nn.CrossEntropyLoss()
71
+ model_without_ddp = model
72
+ if args.distributed:
73
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
74
+ model_without_ddp = model.module
75
+
76
+ if args.resume:
77
+ checkpoint = torch.load(args.resume, map_location="cpu")
78
+ model_without_ddp.load_state_dict(checkpoint["model"])
79
+ optimizer.load_state_dict(checkpoint["optimizer"])
80
+ lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
81
+ args.start_epoch = checkpoint["epoch"] + 1
82
+
83
+ if args.post_training_quantize:
84
+ # perform calibration on a subset of the training dataset
85
+ # for that, create a subset of the training dataset
86
+ ds = torch.utils.data.Subset(dataset, indices=list(range(args.batch_size * args.num_calibration_batches)))
87
+ data_loader_calibration = torch.utils.data.DataLoader(
88
+ ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
89
+ )
90
+ model.eval()
91
+ model.fuse_model(is_qat=False)
92
+ model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
93
+ torch.ao.quantization.prepare(model, inplace=True)
94
+ # Calibrate first
95
+ print("Calibrating")
96
+ evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
97
+ torch.ao.quantization.convert(model, inplace=True)
98
+ if args.output_dir:
99
+ print("Saving quantized model")
100
+ if utils.is_main_process():
101
+ torch.save(model.state_dict(), os.path.join(args.output_dir, "quantized_post_train_model.pth"))
102
+ print("Evaluating post-training quantized model")
103
+ evaluate(model, criterion, data_loader_test, device=device)
104
+ return
105
+
106
+ if args.test_only:
107
+ evaluate(model, criterion, data_loader_test, device=device)
108
+ return
109
+
110
+ model.apply(torch.ao.quantization.enable_observer)
111
+ model.apply(torch.ao.quantization.enable_fake_quant)
112
+ start_time = time.time()
113
+ for epoch in range(args.start_epoch, args.epochs):
114
+ if args.distributed:
115
+ train_sampler.set_epoch(epoch)
116
+ print("Starting training for epoch", epoch)
117
+ train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args)
118
+ lr_scheduler.step()
119
+ with torch.inference_mode():
120
+ if epoch >= args.num_observer_update_epochs:
121
+ print("Disabling observer for subseq epochs, epoch = ", epoch)
122
+ model.apply(torch.ao.quantization.disable_observer)
123
+ if epoch >= args.num_batch_norm_update_epochs:
124
+ print("Freezing BN for subseq epochs, epoch = ", epoch)
125
+ model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
126
+ print("Evaluate QAT model")
127
+
128
+ evaluate(model, criterion, data_loader_test, device=device, log_suffix="QAT")
129
+ quantized_eval_model = copy.deepcopy(model_without_ddp)
130
+ quantized_eval_model.eval()
131
+ quantized_eval_model.to(torch.device("cpu"))
132
+ torch.ao.quantization.convert(quantized_eval_model, inplace=True)
133
+
134
+ print("Evaluate Quantized model")
135
+ evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
136
+
137
+ model.train()
138
+
139
+ if args.output_dir:
140
+ checkpoint = {
141
+ "model": model_without_ddp.state_dict(),
142
+ "eval_model": quantized_eval_model.state_dict(),
143
+ "optimizer": optimizer.state_dict(),
144
+ "lr_scheduler": lr_scheduler.state_dict(),
145
+ "epoch": epoch,
146
+ "args": args,
147
+ }
148
+ utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
149
+ utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
150
+ print("Saving models after epoch ", epoch)
151
+
152
+ total_time = time.time() - start_time
153
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
154
+ print(f"Training time {total_time_str}")
155
+
156
+
157
+ def get_args_parser(add_help=True):
158
+ import argparse
159
+
160
+ parser = argparse.ArgumentParser(description="PyTorch Quantized Classification Training", add_help=add_help)
161
+
162
+ parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
163
+ parser.add_argument("--model", default="mobilenet_v2", type=str, help="model name")
164
+ parser.add_argument("--backend", default="qnnpack", type=str, help="fbgemm or qnnpack")
165
+ parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
166
+
167
+ parser.add_argument(
168
+ "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
169
+ )
170
+ parser.add_argument("--eval-batch-size", default=128, type=int, help="batch size for evaluation")
171
+ parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
172
+ parser.add_argument(
173
+ "--num-observer-update-epochs",
174
+ default=4,
175
+ type=int,
176
+ metavar="N",
177
+ help="number of total epochs to update observers",
178
+ )
179
+ parser.add_argument(
180
+ "--num-batch-norm-update-epochs",
181
+ default=3,
182
+ type=int,
183
+ metavar="N",
184
+ help="number of total epochs to update batch norm stats",
185
+ )
186
+ parser.add_argument(
187
+ "--num-calibration-batches",
188
+ default=32,
189
+ type=int,
190
+ metavar="N",
191
+ help="number of batches of training set for \
192
+ observer calibration ",
193
+ )
194
+
195
+ parser.add_argument(
196
+ "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
197
+ )
198
+ parser.add_argument("--lr", default=0.0001, type=float, help="initial learning rate")
199
+ parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
200
+ parser.add_argument(
201
+ "--wd",
202
+ "--weight-decay",
203
+ default=1e-4,
204
+ type=float,
205
+ metavar="W",
206
+ help="weight decay (default: 1e-4)",
207
+ dest="weight_decay",
208
+ )
209
+ parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
210
+ parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
211
+ parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
212
+ parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
213
+ parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
214
+ parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
215
+ parser.add_argument(
216
+ "--cache-dataset",
217
+ dest="cache_dataset",
218
+ help="Cache the datasets for quicker initialization. \
219
+ It also serializes the transforms",
220
+ action="store_true",
221
+ )
222
+ parser.add_argument(
223
+ "--sync-bn",
224
+ dest="sync_bn",
225
+ help="Use sync batch norm",
226
+ action="store_true",
227
+ )
228
+ parser.add_argument(
229
+ "--test-only",
230
+ dest="test_only",
231
+ help="Only test the model",
232
+ action="store_true",
233
+ )
234
+ parser.add_argument(
235
+ "--post-training-quantize",
236
+ dest="post_training_quantize",
237
+ help="Post training quantize the model",
238
+ action="store_true",
239
+ )
240
+
241
+ # distributed training parameters
242
+ parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
243
+ parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
244
+
245
+ parser.add_argument(
246
+ "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
247
+ )
248
+ parser.add_argument(
249
+ "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
250
+ )
251
+ parser.add_argument(
252
+ "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
253
+ )
254
+ parser.add_argument(
255
+ "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
256
+ )
257
+ parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
258
+ parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
259
+
260
+ return parser
261
+
262
+
263
+ if __name__ == "__main__":
264
+ args = get_args_parser().parse_args()
265
+ main(args)
nas-examples/image-classification/transforms.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torchvision.transforms import functional as F
7
+
8
+
9
+ class RandomMixup(torch.nn.Module):
10
+ """Randomly apply Mixup to the provided batch and targets.
11
+ The class implements the data augmentations as described in the paper
12
+ `"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
13
+
14
+ Args:
15
+ num_classes (int): number of classes used for one-hot encoding.
16
+ p (float): probability of the batch being transformed. Default value is 0.5.
17
+ alpha (float): hyperparameter of the Beta distribution used for mixup.
18
+ Default value is 1.0.
19
+ inplace (bool): boolean to make this transform inplace. Default set to False.
20
+ """
21
+
22
+ def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
23
+ super().__init__()
24
+
25
+ if num_classes < 1:
26
+ raise ValueError(
27
+ f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}"
28
+ )
29
+
30
+ if alpha <= 0:
31
+ raise ValueError("Alpha param can't be zero.")
32
+
33
+ self.num_classes = num_classes
34
+ self.p = p
35
+ self.alpha = alpha
36
+ self.inplace = inplace
37
+
38
+ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
39
+ """
40
+ Args:
41
+ batch (Tensor): Float tensor of size (B, C, H, W)
42
+ target (Tensor): Integer tensor of size (B, )
43
+
44
+ Returns:
45
+ Tensor: Randomly transformed batch.
46
+ """
47
+ if batch.ndim != 4:
48
+ raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
49
+ if target.ndim != 1:
50
+ raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
51
+ if not batch.is_floating_point():
52
+ raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
53
+ if target.dtype != torch.int64:
54
+ raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
55
+
56
+ if not self.inplace:
57
+ batch = batch.clone()
58
+ target = target.clone()
59
+
60
+ if target.ndim == 1:
61
+ target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
62
+
63
+ if torch.rand(1).item() >= self.p:
64
+ return batch, target
65
+
66
+ # It's faster to roll the batch by one instead of shuffling it to create image pairs
67
+ batch_rolled = batch.roll(1, 0)
68
+ target_rolled = target.roll(1, 0)
69
+
70
+ # Implemented as on mixup paper, page 3.
71
+ lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
72
+ batch_rolled.mul_(1.0 - lambda_param)
73
+ batch.mul_(lambda_param).add_(batch_rolled)
74
+
75
+ target_rolled.mul_(1.0 - lambda_param)
76
+ target.mul_(lambda_param).add_(target_rolled)
77
+
78
+ return batch, target
79
+
80
+ def __repr__(self) -> str:
81
+ s = (
82
+ f"{self.__class__.__name__}("
83
+ f"num_classes={self.num_classes}"
84
+ f", p={self.p}"
85
+ f", alpha={self.alpha}"
86
+ f", inplace={self.inplace}"
87
+ f")"
88
+ )
89
+ return s
90
+
91
+
92
+ class RandomCutmix(torch.nn.Module):
93
+ """Randomly apply Cutmix to the provided batch and targets.
94
+ The class implements the data augmentations as described in the paper
95
+ `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
96
+ <https://arxiv.org/abs/1905.04899>`_.
97
+
98
+ Args:
99
+ num_classes (int): number of classes used for one-hot encoding.
100
+ p (float): probability of the batch being transformed. Default value is 0.5.
101
+ alpha (float): hyperparameter of the Beta distribution used for cutmix.
102
+ Default value is 1.0.
103
+ inplace (bool): boolean to make this transform inplace. Default set to False.
104
+ """
105
+
106
+ def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
107
+ super().__init__()
108
+ if num_classes < 1:
109
+ raise ValueError("Please provide a valid positive value for the num_classes.")
110
+ if alpha <= 0:
111
+ raise ValueError("Alpha param can't be zero.")
112
+
113
+ self.num_classes = num_classes
114
+ self.p = p
115
+ self.alpha = alpha
116
+ self.inplace = inplace
117
+
118
+ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
119
+ """
120
+ Args:
121
+ batch (Tensor): Float tensor of size (B, C, H, W)
122
+ target (Tensor): Integer tensor of size (B, )
123
+
124
+ Returns:
125
+ Tensor: Randomly transformed batch.
126
+ """
127
+ if batch.ndim != 4:
128
+ raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
129
+ if target.ndim != 1:
130
+ raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
131
+ if not batch.is_floating_point():
132
+ raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
133
+ if target.dtype != torch.int64:
134
+ raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
135
+
136
+ if not self.inplace:
137
+ batch = batch.clone()
138
+ target = target.clone()
139
+
140
+ if target.ndim == 1:
141
+ target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
142
+
143
+ if torch.rand(1).item() >= self.p:
144
+ return batch, target
145
+
146
+ # It's faster to roll the batch by one instead of shuffling it to create image pairs
147
+ batch_rolled = batch.roll(1, 0)
148
+ target_rolled = target.roll(1, 0)
149
+
150
+ # Implemented as on cutmix paper, page 12 (with minor corrections on typos).
151
+ lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
152
+ _, H, W = F.get_dimensions(batch)
153
+
154
+ r_x = torch.randint(W, (1,))
155
+ r_y = torch.randint(H, (1,))
156
+
157
+ r = 0.5 * math.sqrt(1.0 - lambda_param)
158
+ r_w_half = int(r * W)
159
+ r_h_half = int(r * H)
160
+
161
+ x1 = int(torch.clamp(r_x - r_w_half, min=0))
162
+ y1 = int(torch.clamp(r_y - r_h_half, min=0))
163
+ x2 = int(torch.clamp(r_x + r_w_half, max=W))
164
+ y2 = int(torch.clamp(r_y + r_h_half, max=H))
165
+
166
+ batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
167
+ lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
168
+
169
+ target_rolled.mul_(1.0 - lambda_param)
170
+ target.mul_(lambda_param).add_(target_rolled)
171
+
172
+ return batch, target
173
+
174
+ def __repr__(self) -> str:
175
+ s = (
176
+ f"{self.__class__.__name__}("
177
+ f"num_classes={self.num_classes}"
178
+ f", p={self.p}"
179
+ f", alpha={self.alpha}"
180
+ f", inplace={self.inplace}"
181
+ f")"
182
+ )
183
+ return s
nas-examples/image-classification/trplib.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+ from typing import Optional, List, Union, Callable
3
+
4
+ import torch
5
+ from torch import nn, Tensor
6
+ from torch.nn import functional as F
7
+
8
+ from torchvision.models.mobilenetv2 import MobileNetV2
9
+ from torchvision.models.resnet import ResNet
10
+ from torchvision.models.efficientnet import EfficientNet
11
+ from torchvision.models.vision_transformer import VisionTransformer
12
+
13
+
14
+ def compute_policy_loss(loss_sequence, mask_sequence, rewards):
15
+ losses = sum(mask * padded_loss for mask, padded_loss in zip(mask_sequence, loss_sequence))
16
+ returns = sum(padded_reward * mask for padded_reward, mask in zip(rewards, mask_sequence))
17
+ loss = torch.mean(losses * returns)
18
+
19
+ return loss
20
+
21
+
22
+ class TPBlock(nn.Module):
23
+ def __init__(self, depths: int, in_planes: int, out_planes: int = None, rank=1, shape_dims=3, channel_first=True, dtype=torch.float32) -> None:
24
+ super().__init__()
25
+ out_planes = in_planes if out_planes is None else out_planes
26
+ self.layers = torch.nn.ModuleList([self._make_layer(in_planes, out_planes, rank, shape_dims, channel_first, dtype) for _ in range(depths)])
27
+
28
+ def forward(self, x: Tensor) -> Tensor:
29
+ for layer in self.layers:
30
+ x = x + layer(x)
31
+ return x
32
+
33
+ def _make_layer(self, in_planes: int, out_planes: int = None, rank=1, shape_dims=3, channel_first=True, dtype=torch.float32) -> nn.Sequential:
34
+
35
+ class Permute(nn.Module):
36
+ def __init__(self, *dims):
37
+ super().__init__()
38
+ self.dims = dims
39
+ def forward(self, x):
40
+ return x.permute(*self.dims)
41
+
42
+ class RMSNorm(nn.Module):
43
+ __constants__ = ["eps"]
44
+ eps: float
45
+
46
+ def __init__(self, hidden_size, eps: float = 1e-6, device=None, dtype=None):
47
+ """
48
+ LlamaRMSNorm is equivalent to T5LayerNorm.
49
+ """
50
+ factory_kwargs = {"device": device, "dtype": dtype}
51
+ super().__init__()
52
+ self.eps = eps
53
+ self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs))
54
+
55
+ def forward(self, hidden_states):
56
+ input_dtype = hidden_states.dtype
57
+ hidden_states = hidden_states.to(torch.float32)
58
+ variance = hidden_states.pow(2).mean(dim=1, keepdim=True)
59
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
60
+ weight = self.weight.view(1, -1, *[1] * (hidden_states.ndim - 2))
61
+ return weight * hidden_states.to(input_dtype)
62
+
63
+ def extra_repr(self):
64
+ return f"{self.weight.shape[0]}, eps={self.eps}"
65
+
66
+ conv_map = {
67
+ 2: (nn.Conv1d, (0, 2, 1), (0, 2, 1)),
68
+ 3: (nn.Conv2d, (0, 3, 1, 2), (0, 2, 3, 1)),
69
+ 4: (nn.Conv3d, (0, 4, 1, 2, 3), (0, 2, 3, 4, 1)),
70
+ }
71
+ Conv, pre_dims, post_dims = conv_map[shape_dims]
72
+ kernel_size, dilation, padding = self.generate_hyperparameters(rank)
73
+
74
+ pre_permute = nn.Identity() if channel_first else Permute(*pre_dims)
75
+ post_permute = nn.Identity() if channel_first else Permute(*post_dims)
76
+ conv1 = Conv(in_planes, out_planes, kernel_size, padding=padding, dilation=dilation, bias=False, dtype=dtype, device='cuda')
77
+ nn.init.zeros_(conv1.weight)
78
+ bn1 = RMSNorm(out_planes, dtype=dtype, device="cuda")
79
+ relu = nn.ReLU(inplace=True)
80
+ conv2 = Conv(out_planes, in_planes, kernel_size, padding=padding, dilation=dilation, bias=False, dtype=dtype, device='cuda')
81
+ nn.init.zeros_(conv2.weight)
82
+ bn2 = RMSNorm(in_planes, dtype=dtype, device="cuda")
83
+
84
+ return torch.nn.Sequential(pre_permute, conv1, bn1, relu, conv2, bn2, relu, post_permute)
85
+
86
+ @staticmethod
87
+ def generate_hyperparameters(rank: int):
88
+ """
89
+ Generates kernel size and dilation rate pairs sorted by increasing padded kernel size.
90
+
91
+ Args:
92
+ rank: Number of (kernel_size, dilation) pairs to generate. Must be positive.
93
+
94
+ Returns:
95
+ Tuple[int, int]: A (kernel_size, dilation) tuple where:
96
+ - kernel_size: Always odd and >= 1
97
+ - dilation: Computed to maintain consistent padded kernel size growth
98
+
99
+ Note:
100
+ Padded kernel size is calculated as:
101
+ (kernel_size - 1) * dilation + 1
102
+ Pairs are generated first in order of increasing padded kernel size,
103
+ then by increasing kernel size for equal padded kernel sizes.
104
+ """
105
+ pairs = [(1, 1, 0)] # Start with smallest possible
106
+ padded_kernel_size = 3
107
+
108
+ while len(pairs) < rank:
109
+ for kernel_size in range(3, padded_kernel_size + 1, 2):
110
+ if (padded_kernel_size - 1) % (kernel_size - 1) == 0:
111
+ dilation = (padded_kernel_size - 1) // (kernel_size - 1)
112
+ padding = dilation * (kernel_size - 1) // 2
113
+ pairs.append((kernel_size, dilation, padding))
114
+ if len(pairs) >= rank:
115
+ break
116
+
117
+ # Move to next odd padded kernel size
118
+ padded_kernel_size += 2
119
+
120
+ return pairs[-1]
121
+
122
+
123
+ class ResNetConfig:
124
+ @staticmethod
125
+ def gen_shared_head(self):
126
+ def func(hidden_states):
127
+ """
128
+ Args:
129
+ hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
130
+
131
+ Returns:
132
+ logits (Tensor): Logits tensor of shape [B, C].
133
+ """
134
+ x = self.avgpool(hidden_states)
135
+ x = torch.flatten(x, 1)
136
+ logits = self.fc(x)
137
+ return logits
138
+ return func
139
+
140
+ @staticmethod
141
+ def gen_logits(self, shared_head):
142
+ def func(hidden_states):
143
+ """
144
+ Args:
145
+ hidden_states (Tensor): Hidden States tensor of shape [B, L, hidden_units].
146
+
147
+ Returns:
148
+ logits_seqence (List[Tensor]): List of Logits tensors.
149
+ """
150
+ logits_sequence = [shared_head(hidden_states)]
151
+ for layer in self.trp_blocks:
152
+ logits_sequence.append(shared_head(layer(hidden_states)))
153
+ return logits_sequence
154
+ return func
155
+
156
+ @staticmethod
157
+ def gen_mask(label_smoothing=0.0, top_k=1):
158
+ def func(logits_sequence, labels):
159
+ """
160
+ Args:
161
+ logits_sequence (List[Tensor]): List of Logits tensors.
162
+ labels (Tensor): Target labels of shape [B] or [B, C].
163
+
164
+ Returns:
165
+ mask_sequence (List[Tensor]): List of Mask tensor.
166
+ returns (Tensor): Boolean mask tensor of shape [B*(L-1)].
167
+ """
168
+ labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels
169
+
170
+ mask_sequence = [torch.ones_like(labels, dtype=torch.float32, device=labels.device)]
171
+ for logits in logits_sequence:
172
+ with torch.no_grad():
173
+ topk_values, topk_indices = torch.topk(logits, top_k, dim=-1)
174
+ mask = torch.eq(topk_indices, labels[:, None]).any(dim=-1).to(torch.float32)
175
+ mask_sequence.append(mask_sequence[-1] * mask)
176
+ return mask_sequence
177
+ return func
178
+
179
+ @staticmethod
180
+ def gen_criterion(label_smoothing=0.0):
181
+ def func(logits_sequence, labels):
182
+ """
183
+ Args:
184
+ logits_sequence (List[Tensor]): List of Logits tensor.
185
+ labels (Tensor): labels labels of shape [B] or [B, C].
186
+
187
+ Returns:
188
+ loss (Tensor): Scalar tensor representing the loss.
189
+ mask (Tensor): Boolean mask tensor of shape [B].
190
+ """
191
+ labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels
192
+
193
+ loss_sequence = []
194
+ for logits in logits_sequence:
195
+ loss_sequence.append(F.cross_entropy(logits, labels, reduction="none", label_smoothing=label_smoothing))
196
+
197
+ return loss_sequence
198
+ return func
199
+
200
+ @staticmethod
201
+ def gen_forward(rewards, label_smoothing=0.0, top_k=1):
202
+ def func(self, x: Tensor, targets=None) -> Tensor:
203
+ x = self.conv1(x)
204
+ x = self.bn1(x)
205
+ x = self.relu(x)
206
+ x = self.maxpool(x)
207
+
208
+ x = self.layer1(x)
209
+ x = self.layer2(x)
210
+ x = self.layer3(x)
211
+ hidden_states = self.layer4(x)
212
+ x = self.avgpool(hidden_states)
213
+ x = torch.flatten(x, 1)
214
+ logits = self.fc(x)
215
+
216
+ if self.training:
217
+ shared_head = ResNetConfig.gen_shared_head(self)
218
+ compute_logits = ResNetConfig.gen_logits(self, shared_head)
219
+ compute_mask = ResNetConfig.gen_mask(label_smoothing, top_k)
220
+ compute_loss = ResNetConfig.gen_criterion(label_smoothing)
221
+
222
+ logits_sequence = compute_logits(hidden_states)
223
+ mask_sequence = compute_mask(logits_sequence, targets)
224
+ loss_sequence = compute_loss(logits_sequence, targets)
225
+ loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
226
+
227
+ return logits, loss
228
+
229
+ return logits
230
+
231
+ return func
232
+
233
+
234
+ class MobileNetV2Config(ResNetConfig):
235
+ @staticmethod
236
+ def gen_shared_head(self):
237
+ def func(hidden_states):
238
+ """
239
+ Args:
240
+ hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
241
+
242
+ Returns:
243
+ logits (Tensor): Logits tensor of shape [B, C].
244
+ """
245
+ x = nn.functional.adaptive_avg_pool2d(hidden_states, (1, 1))
246
+ x = torch.flatten(x, 1)
247
+ logits = self.classifier(x)
248
+ return logits
249
+ return func
250
+
251
+ @staticmethod
252
+ def gen_forward(rewards, label_smoothing=0.0, top_k=1):
253
+ def func(self, x: Tensor, targets=None) -> Tensor:
254
+ hidden_states = self.features(x)
255
+ # Cannot use "squeeze" as batch-size can be 1
256
+ x = nn.functional.adaptive_avg_pool2d(hidden_states, (1, 1))
257
+ x = torch.flatten(x, 1)
258
+ logits = self.classifier(x)
259
+
260
+ if self.training:
261
+ shared_head = MobileNetV2Config.gen_shared_head(self)
262
+ compute_logits = MobileNetV2Config.gen_logits(self, shared_head)
263
+ compute_mask = MobileNetV2Config.gen_mask(label_smoothing, top_k)
264
+ compute_loss = MobileNetV2Config.gen_criterion(label_smoothing)
265
+
266
+ logits_sequence = compute_logits(hidden_states)
267
+ mask_sequence = compute_mask(logits_sequence, targets)
268
+ loss_sequence = compute_loss(logits_sequence, targets)
269
+ loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
270
+
271
+ return logits, loss
272
+
273
+ return logits
274
+
275
+ return func
276
+
277
+
278
+ class EfficientNetConfig(ResNetConfig):
279
+ @staticmethod
280
+ def gen_shared_head(self):
281
+ def func(hidden_states):
282
+ """
283
+ Args:
284
+ hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
285
+
286
+ Returns:
287
+ logits (Tensor): Logits tensor of shape [B, C].
288
+ """
289
+ x = self.avgpool(hidden_states)
290
+ x = torch.flatten(x, 1)
291
+ logits = self.classifier(x)
292
+ return logits
293
+ return func
294
+
295
+ @staticmethod
296
+ def gen_forward(rewards, label_smoothing=0.0, top_k=1):
297
+ def func(self, x: Tensor, targets=None) -> Tensor:
298
+ hidden_states = self.features(x)
299
+ x = self.avgpool(hidden_states)
300
+ x = torch.flatten(x, 1)
301
+ logits = self.classifier(x)
302
+
303
+ if self.training:
304
+ shared_head = EfficientNetConfig.gen_shared_head(self)
305
+ compute_logits = EfficientNetConfig.gen_logits(self, shared_head)
306
+ compute_mask = EfficientNetConfig.gen_mask(label_smoothing, top_k)
307
+ compute_loss = EfficientNetConfig.gen_criterion(label_smoothing)
308
+
309
+ logits_sequence = compute_logits(hidden_states)
310
+ mask_sequence = compute_mask(logits_sequence, targets)
311
+ loss_sequence = compute_loss(logits_sequence, targets)
312
+ loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
313
+
314
+ return logits, loss
315
+
316
+ return logits
317
+
318
+ return func
319
+
320
+
321
+ class VisionTransformerConfig(ResNetConfig):
322
+ @staticmethod
323
+ def gen_shared_head(self):
324
+ def func(hidden_states):
325
+ """
326
+ Args:
327
+ hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
328
+
329
+ Returns:
330
+ logits (Tensor): Logits tensor of shape [B, C].
331
+ """
332
+ x = hidden_states[:, 0]
333
+ logits = self.heads(x)
334
+ return logits
335
+ return func
336
+
337
+ @staticmethod
338
+ def gen_forward(rewards, label_smoothing=0.0, top_k=1):
339
+ def func(self, images: Tensor, targets=None):
340
+ x = self._process_input(images)
341
+ n = x.shape[0]
342
+ batch_class_token = self.class_token.expand(n, -1, -1)
343
+ x = torch.cat([batch_class_token, x], dim=1)
344
+ hidden_states = self.encoder(x)
345
+ x = hidden_states[:, 0]
346
+
347
+ logits = self.heads(x)
348
+
349
+
350
+ if self.training:
351
+ shared_head = VisionTransformerConfig.gen_shared_head(self)
352
+ compute_logits = VisionTransformerConfig.gen_logits(self, shared_head)
353
+ compute_mask = VisionTransformerConfig.gen_mask(label_smoothing, top_k)
354
+ compute_loss = VisionTransformerConfig.gen_criterion(label_smoothing)
355
+
356
+ logits_sequence = compute_logits(hidden_states)
357
+ mask_sequence = compute_mask(logits_sequence, targets)
358
+ loss_sequence = compute_loss(logits_sequence, targets)
359
+ loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
360
+
361
+ return logits, loss
362
+ return logits
363
+ return func
364
+
365
+
366
+ def apply_trp(model, depths: List[int], in_planes: int, out_planes: int, rewards, **kwargs):
367
+ if isinstance(model, ResNet):
368
+ print("✅ Applying TRP to ResNet for Image Classification...")
369
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
370
+ model.forward = types.MethodType(ResNetConfig.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
371
+ elif isinstance(model, MobileNetV2):
372
+ print("✅ Applying TRP to MobileNetV2 for Image Classification...")
373
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
374
+ model.forward = types.MethodType(MobileNetV2Config.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
375
+ elif isinstance(model, EfficientNet):
376
+ print("✅ Applying TRP to EfficientNet for Image Classification...")
377
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
378
+ model.forward = types.MethodType(EfficientNetConfig.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
379
+ elif isinstance(model, VisionTransformer):
380
+ print("✅ Applying TRP to VisionTransformer for Image Classification...")
381
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k, shape_dims=2, channel_first=False) for k, d in enumerate(depths)])
382
+ model.forward = types.MethodType(VisionTransformerConfig.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
383
+ return model
nas-examples/image-classification/utils.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import datetime
3
+ import errno
4
+ import hashlib
5
+ import os
6
+ import time
7
+ from collections import defaultdict, deque, OrderedDict
8
+ from typing import List, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.distributed as dist
12
+
13
+
14
+ class SmoothedValue:
15
+ """Track a series of values and provide access to smoothed values over a
16
+ window or the global series average.
17
+ """
18
+
19
+ def __init__(self, window_size=20, fmt=None):
20
+ if fmt is None:
21
+ fmt = "{median:.4f} ({global_avg:.4f})"
22
+ self.deque = deque(maxlen=window_size)
23
+ self.total = 0.0
24
+ self.count = 0
25
+ self.fmt = fmt
26
+
27
+ def update(self, value, n=1):
28
+ self.deque.append(value)
29
+ self.count += n
30
+ self.total += value * n
31
+
32
+ def synchronize_between_processes(self):
33
+ """
34
+ Warning: does not synchronize the deque!
35
+ """
36
+ t = reduce_across_processes([self.count, self.total])
37
+ t = t.tolist()
38
+ self.count = int(t[0])
39
+ self.total = t[1]
40
+
41
+ @property
42
+ def median(self):
43
+ d = torch.tensor(list(self.deque))
44
+ return d.median().item()
45
+
46
+ @property
47
+ def avg(self):
48
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
49
+ return d.mean().item()
50
+
51
+ @property
52
+ def global_avg(self):
53
+ return self.total / self.count
54
+
55
+ @property
56
+ def max(self):
57
+ return max(self.deque)
58
+
59
+ @property
60
+ def value(self):
61
+ return self.deque[-1]
62
+
63
+ def __str__(self):
64
+ return self.fmt.format(
65
+ median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
66
+ )
67
+
68
+
69
+ class MetricLogger:
70
+ def __init__(self, delimiter="\t"):
71
+ self.meters = defaultdict(SmoothedValue)
72
+ self.delimiter = delimiter
73
+
74
+ def update(self, **kwargs):
75
+ for k, v in kwargs.items():
76
+ if isinstance(v, torch.Tensor):
77
+ v = v.item()
78
+ assert isinstance(v, (float, int))
79
+ self.meters[k].update(v)
80
+
81
+ def __getattr__(self, attr):
82
+ if attr in self.meters:
83
+ return self.meters[attr]
84
+ if attr in self.__dict__:
85
+ return self.__dict__[attr]
86
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
87
+
88
+ def __str__(self):
89
+ loss_str = []
90
+ for name, meter in self.meters.items():
91
+ loss_str.append(f"{name}: {str(meter)}")
92
+ return self.delimiter.join(loss_str)
93
+
94
+ def synchronize_between_processes(self):
95
+ for meter in self.meters.values():
96
+ meter.synchronize_between_processes()
97
+
98
+ def add_meter(self, name, meter):
99
+ self.meters[name] = meter
100
+
101
+ def log_every(self, iterable, print_freq, header=None):
102
+ i = 0
103
+ if not header:
104
+ header = ""
105
+ start_time = time.time()
106
+ end = time.time()
107
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
108
+ data_time = SmoothedValue(fmt="{avg:.4f}")
109
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
110
+ if torch.cuda.is_available():
111
+ log_msg = self.delimiter.join(
112
+ [
113
+ header,
114
+ "[{0" + space_fmt + "}/{1}]",
115
+ "eta: {eta}",
116
+ "{meters}",
117
+ "time: {time}",
118
+ "data: {data}",
119
+ "max mem: {memory:.0f}",
120
+ ]
121
+ )
122
+ else:
123
+ log_msg = self.delimiter.join(
124
+ [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
125
+ )
126
+ MB = 1024.0 * 1024.0
127
+ for obj in iterable:
128
+ data_time.update(time.time() - end)
129
+ yield obj
130
+ iter_time.update(time.time() - end)
131
+ if i % print_freq == 0:
132
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
133
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
134
+ if torch.cuda.is_available():
135
+ print(
136
+ log_msg.format(
137
+ i,
138
+ len(iterable),
139
+ eta=eta_string,
140
+ meters=str(self),
141
+ time=str(iter_time),
142
+ data=str(data_time),
143
+ memory=torch.cuda.max_memory_allocated() / MB,
144
+ )
145
+ )
146
+ else:
147
+ print(
148
+ log_msg.format(
149
+ i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
150
+ )
151
+ )
152
+ i += 1
153
+ end = time.time()
154
+ total_time = time.time() - start_time
155
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
156
+ print(f"{header} Total time: {total_time_str}")
157
+
158
+
159
+ class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
160
+ """Maintains moving averages of model parameters using an exponential decay.
161
+ ``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
162
+ `torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
163
+ is used to compute the EMA.
164
+ """
165
+
166
+ def __init__(self, model, decay, device="cpu"):
167
+ def ema_avg(avg_model_param, model_param, num_averaged):
168
+ return decay * avg_model_param + (1 - decay) * model_param
169
+
170
+ super().__init__(model, device, ema_avg, use_buffers=True)
171
+
172
+
173
+ def accuracy(output, target, topk=(1,)):
174
+ """Computes the accuracy over the k top predictions for the specified values of k"""
175
+ with torch.inference_mode():
176
+ maxk = max(topk)
177
+ batch_size = target.size(0)
178
+ if target.ndim == 2:
179
+ target = target.max(dim=1)[1]
180
+
181
+ _, pred = output.topk(maxk, 1, True, True)
182
+ pred = pred.t()
183
+ correct = pred.eq(target[None])
184
+
185
+ res = []
186
+ for k in topk:
187
+ correct_k = correct[:k].flatten().sum(dtype=torch.float32)
188
+ res.append(correct_k * (100.0 / batch_size))
189
+ return res
190
+
191
+
192
+ def mkdir(path):
193
+ try:
194
+ os.makedirs(path)
195
+ except OSError as e:
196
+ if e.errno != errno.EEXIST:
197
+ raise
198
+
199
+
200
+ def setup_for_distributed(is_master):
201
+ """
202
+ This function disables printing when not in master process
203
+ """
204
+ import builtins as __builtin__
205
+
206
+ builtin_print = __builtin__.print
207
+
208
+ def print(*args, **kwargs):
209
+ force = kwargs.pop("force", False)
210
+ if is_master or force:
211
+ builtin_print(*args, **kwargs)
212
+
213
+ __builtin__.print = print
214
+
215
+
216
+ def is_dist_avail_and_initialized():
217
+ if not dist.is_available():
218
+ return False
219
+ if not dist.is_initialized():
220
+ return False
221
+ return True
222
+
223
+
224
+ def get_world_size():
225
+ if not is_dist_avail_and_initialized():
226
+ return 1
227
+ return dist.get_world_size()
228
+
229
+
230
+ def get_rank():
231
+ if not is_dist_avail_and_initialized():
232
+ return 0
233
+ return dist.get_rank()
234
+
235
+
236
+ def is_main_process():
237
+ return get_rank() == 0
238
+
239
+
240
+ def save_on_master(*args, **kwargs):
241
+ if is_main_process():
242
+ torch.save(*args, **kwargs)
243
+
244
+
245
+ def init_distributed_mode(args):
246
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
247
+ args.rank = int(os.environ["RANK"])
248
+ args.world_size = int(os.environ["WORLD_SIZE"])
249
+ args.gpu = int(os.environ["LOCAL_RANK"])
250
+ elif "SLURM_PROCID" in os.environ:
251
+ args.rank = int(os.environ["SLURM_PROCID"])
252
+ args.gpu = args.rank % torch.cuda.device_count()
253
+ elif hasattr(args, "rank"):
254
+ pass
255
+ else:
256
+ print("Not using distributed mode")
257
+ args.distributed = False
258
+ return
259
+
260
+ args.distributed = True
261
+
262
+ torch.cuda.set_device(args.gpu)
263
+ args.dist_backend = "nccl"
264
+ print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
265
+ torch.distributed.init_process_group(
266
+ backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
267
+ )
268
+ torch.distributed.barrier()
269
+ setup_for_distributed(args.rank == 0)
270
+
271
+
272
+ def average_checkpoints(inputs):
273
+ """Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from:
274
+ https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16
275
+
276
+ Args:
277
+ inputs (List[str]): An iterable of string paths of checkpoints to load from.
278
+ Returns:
279
+ A dict of string keys mapping to various values. The 'model' key
280
+ from the returned dict should correspond to an OrderedDict mapping
281
+ string parameter names to torch Tensors.
282
+ """
283
+ params_dict = OrderedDict()
284
+ params_keys = None
285
+ new_state = None
286
+ num_models = len(inputs)
287
+ for fpath in inputs:
288
+ with open(fpath, "rb") as f:
289
+ state = torch.load(
290
+ f,
291
+ map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")),
292
+ )
293
+ # Copies over the settings from the first checkpoint
294
+ if new_state is None:
295
+ new_state = state
296
+ model_params = state["model"]
297
+ model_params_keys = list(model_params.keys())
298
+ if params_keys is None:
299
+ params_keys = model_params_keys
300
+ elif params_keys != model_params_keys:
301
+ raise KeyError(
302
+ f"For checkpoint {f}, expected list of params: {params_keys}, but found: {model_params_keys}"
303
+ )
304
+ for k in params_keys:
305
+ p = model_params[k]
306
+ if isinstance(p, torch.HalfTensor):
307
+ p = p.float()
308
+ if k not in params_dict:
309
+ params_dict[k] = p.clone()
310
+ # NOTE: clone() is needed in case of p is a shared parameter
311
+ else:
312
+ params_dict[k] += p
313
+ averaged_params = OrderedDict()
314
+ for k, v in params_dict.items():
315
+ averaged_params[k] = v
316
+ if averaged_params[k].is_floating_point():
317
+ averaged_params[k].div_(num_models)
318
+ else:
319
+ averaged_params[k] //= num_models
320
+ new_state["model"] = averaged_params
321
+ return new_state
322
+
323
+
324
+ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=True):
325
+ """
326
+ This method can be used to prepare weights files for new models. It receives as
327
+ input a model architecture and a checkpoint from the training script and produces
328
+ a file with the weights ready for release.
329
+
330
+ Examples:
331
+ from torchvision import models as M
332
+
333
+ # Classification
334
+ model = M.mobilenet_v3_large(weights=None)
335
+ print(store_model_weights(model, './class.pth'))
336
+
337
+ # Quantized Classification
338
+ model = M.quantization.mobilenet_v3_large(weights=None, quantize=False)
339
+ model.fuse_model(is_qat=True)
340
+ model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
341
+ _ = torch.ao.quantization.prepare_qat(model, inplace=True)
342
+ print(store_model_weights(model, './qat.pth'))
343
+
344
+ # Object Detection
345
+ model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=None, weights_backbone=None)
346
+ print(store_model_weights(model, './obj.pth'))
347
+
348
+ # Segmentation
349
+ model = M.segmentation.deeplabv3_mobilenet_v3_large(weights=None, weights_backbone=None, aux_loss=True)
350
+ print(store_model_weights(model, './segm.pth', strict=False))
351
+
352
+ Args:
353
+ model (pytorch.nn.Module): The model on which the weights will be loaded for validation purposes.
354
+ checkpoint_path (str): The path of the checkpoint we will load.
355
+ checkpoint_key (str, optional): The key of the checkpoint where the model weights are stored.
356
+ Default: "model".
357
+ strict (bool): whether to strictly enforce that the keys
358
+ in :attr:`state_dict` match the keys returned by this module's
359
+ :meth:`~torch.nn.Module.state_dict` function. Default: ``True``
360
+
361
+ Returns:
362
+ output_path (str): The location where the weights are saved.
363
+ """
364
+ # Store the new model next to the checkpoint_path
365
+ checkpoint_path = os.path.abspath(checkpoint_path)
366
+ output_dir = os.path.dirname(checkpoint_path)
367
+
368
+ # Deep copy to avoid side-effects on the model object.
369
+ model = copy.deepcopy(model)
370
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
371
+
372
+ # Load the weights to the model to validate that everything works
373
+ # and remove unnecessary weights (such as auxiliaries, etc)
374
+ if checkpoint_key == "model_ema":
375
+ del checkpoint[checkpoint_key]["n_averaged"]
376
+ torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(checkpoint[checkpoint_key], "module.")
377
+ model.load_state_dict(checkpoint[checkpoint_key], strict=strict)
378
+
379
+ tmp_path = os.path.join(output_dir, str(model.__hash__()))
380
+ torch.save(model.state_dict(), tmp_path)
381
+
382
+ sha256_hash = hashlib.sha256()
383
+ with open(tmp_path, "rb") as f:
384
+ # Read and update hash string value in blocks of 4K
385
+ for byte_block in iter(lambda: f.read(4096), b""):
386
+ sha256_hash.update(byte_block)
387
+ hh = sha256_hash.hexdigest()
388
+
389
+ output_path = os.path.join(output_dir, "weights-" + str(hh[:8]) + ".pth")
390
+ os.replace(tmp_path, output_path)
391
+
392
+ return output_path
393
+
394
+
395
+ def reduce_across_processes(val):
396
+ if not is_dist_avail_and_initialized():
397
+ # nothing to sync, but we still convert to tensor for consistency with the distributed case.
398
+ return torch.tensor(val)
399
+
400
+ t = torch.tensor(val, device="cuda")
401
+ dist.barrier()
402
+ dist.all_reduce(t)
403
+ return t
404
+
405
+
406
+ def set_weight_decay(
407
+ model: torch.nn.Module,
408
+ weight_decay: float,
409
+ norm_weight_decay: Optional[float] = None,
410
+ norm_classes: Optional[List[type]] = None,
411
+ custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None,
412
+ ):
413
+ if not norm_classes:
414
+ norm_classes = [
415
+ torch.nn.modules.batchnorm._BatchNorm,
416
+ torch.nn.LayerNorm,
417
+ torch.nn.GroupNorm,
418
+ torch.nn.modules.instancenorm._InstanceNorm,
419
+ torch.nn.LocalResponseNorm,
420
+ ]
421
+ norm_classes = tuple(norm_classes)
422
+
423
+ params = {
424
+ "other": [],
425
+ "norm": [],
426
+ }
427
+ params_weight_decay = {
428
+ "other": weight_decay,
429
+ "norm": norm_weight_decay,
430
+ }
431
+ custom_keys = []
432
+ if custom_keys_weight_decay is not None:
433
+ for key, weight_decay in custom_keys_weight_decay:
434
+ params[key] = []
435
+ params_weight_decay[key] = weight_decay
436
+ custom_keys.append(key)
437
+
438
+ def _add_params(module, prefix=""):
439
+ for name, p in module.named_parameters(recurse=False):
440
+ if not p.requires_grad:
441
+ continue
442
+ is_custom_key = False
443
+ for key in custom_keys:
444
+ target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name
445
+ if key == target_name:
446
+ params[key].append(p)
447
+ is_custom_key = True
448
+ break
449
+ if not is_custom_key:
450
+ if norm_weight_decay is not None and isinstance(module, norm_classes):
451
+ params["norm"].append(p)
452
+ else:
453
+ params["other"].append(p)
454
+
455
+ for child_name, child_module in module.named_children():
456
+ child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
457
+ _add_params(child_module, prefix=child_prefix)
458
+
459
+ _add_params(model)
460
+
461
+ param_groups = []
462
+ for key in params:
463
+ if len(params[key]) > 0:
464
+ param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]})
465
+ return param_groups
nas-examples/semantic-segmentation/coco_utils.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+
4
+ import torch
5
+ import torch.utils.data
6
+ import torchvision
7
+ from PIL import Image
8
+ from pycocotools import mask as coco_mask
9
+ from transforms import Compose
10
+
11
+
12
+ class FilterAndRemapCocoCategories:
13
+ def __init__(self, categories, remap=True):
14
+ self.categories = categories
15
+ self.remap = remap
16
+
17
+ def __call__(self, image, anno):
18
+ anno = [obj for obj in anno if obj["category_id"] in self.categories]
19
+ if not self.remap:
20
+ return image, anno
21
+ anno = copy.deepcopy(anno)
22
+ for obj in anno:
23
+ obj["category_id"] = self.categories.index(obj["category_id"])
24
+ return image, anno
25
+
26
+
27
+ def convert_coco_poly_to_mask(segmentations, height, width):
28
+ masks = []
29
+ for polygons in segmentations:
30
+ rles = coco_mask.frPyObjects(polygons, height, width)
31
+ mask = coco_mask.decode(rles)
32
+ if len(mask.shape) < 3:
33
+ mask = mask[..., None]
34
+ mask = torch.as_tensor(mask, dtype=torch.uint8)
35
+ mask = mask.any(dim=2)
36
+ masks.append(mask)
37
+ if masks:
38
+ masks = torch.stack(masks, dim=0)
39
+ else:
40
+ masks = torch.zeros((0, height, width), dtype=torch.uint8)
41
+ return masks
42
+
43
+
44
+ class ConvertCocoPolysToMask:
45
+ def __call__(self, image, anno):
46
+ w, h = image.size
47
+ segmentations = [obj["segmentation"] for obj in anno]
48
+ cats = [obj["category_id"] for obj in anno]
49
+ if segmentations:
50
+ masks = convert_coco_poly_to_mask(segmentations, h, w)
51
+ cats = torch.as_tensor(cats, dtype=masks.dtype)
52
+ # merge all instance masks into a single segmentation map
53
+ # with its corresponding categories
54
+ target, _ = (masks * cats[:, None, None]).max(dim=0)
55
+ # discard overlapping instances
56
+ target[masks.sum(0) > 1] = 255
57
+ else:
58
+ target = torch.zeros((h, w), dtype=torch.uint8)
59
+ target = Image.fromarray(target.numpy())
60
+ return image, target
61
+
62
+
63
+ def _coco_remove_images_without_annotations(dataset, cat_list=None):
64
+ def _has_valid_annotation(anno):
65
+ # if it's empty, there is no annotation
66
+ if len(anno) == 0:
67
+ return False
68
+ # if more than 1k pixels occupied in the image
69
+ return sum(obj["area"] for obj in anno) > 1000
70
+
71
+ if not isinstance(dataset, torchvision.datasets.CocoDetection):
72
+ raise TypeError(
73
+ f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
74
+ )
75
+
76
+ ids = []
77
+ for ds_idx, img_id in enumerate(dataset.ids):
78
+ ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
79
+ anno = dataset.coco.loadAnns(ann_ids)
80
+ if cat_list:
81
+ anno = [obj for obj in anno if obj["category_id"] in cat_list]
82
+ if _has_valid_annotation(anno):
83
+ ids.append(ds_idx)
84
+
85
+ dataset = torch.utils.data.Subset(dataset, ids)
86
+ return dataset
87
+
88
+
89
+ def get_coco(root, image_set, transforms):
90
+ PATHS = {
91
+ "train": ("train2017", os.path.join("annotations", "instances_train2017.json")),
92
+ "val": ("val2017", os.path.join("annotations", "instances_val2017.json")),
93
+ # "train": ("val2017", os.path.join("annotations", "instances_val2017.json"))
94
+ }
95
+ CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72]
96
+
97
+ transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms])
98
+
99
+ img_folder, ann_file = PATHS[image_set]
100
+ img_folder = os.path.join(root, img_folder)
101
+ ann_file = os.path.join(root, ann_file)
102
+
103
+ dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
104
+
105
+ if image_set == "train":
106
+ dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST)
107
+
108
+ return dataset
nas-examples/semantic-segmentation/presets.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transforms as T
3
+
4
+
5
+ class SegmentationPresetTrain:
6
+ def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
7
+ min_size = int(0.5 * base_size)
8
+ max_size = int(2.0 * base_size)
9
+
10
+ trans = [T.RandomResize(min_size, max_size)]
11
+ if hflip_prob > 0:
12
+ trans.append(T.RandomHorizontalFlip(hflip_prob))
13
+ trans.extend(
14
+ [
15
+ T.RandomCrop(crop_size),
16
+ T.PILToTensor(),
17
+ T.ConvertImageDtype(torch.float),
18
+ T.Normalize(mean=mean, std=std),
19
+ ]
20
+ )
21
+ self.transforms = T.Compose(trans)
22
+
23
+ def __call__(self, img, target):
24
+ return self.transforms(img, target)
25
+
26
+
27
+ class SegmentationPresetEval:
28
+ def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
29
+ self.transforms = T.Compose(
30
+ [
31
+ T.RandomResize(base_size, base_size),
32
+ T.PILToTensor(),
33
+ T.ConvertImageDtype(torch.float),
34
+ T.Normalize(mean=mean, std=std),
35
+ ]
36
+ )
37
+
38
+ def __call__(self, img, target):
39
+ return self.transforms(img, target)
nas-examples/semantic-segmentation/train.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+ import time
4
+ import warnings
5
+
6
+ import presets
7
+ import torch
8
+ import torch.utils.data
9
+ import torchvision
10
+ import utils
11
+ from coco_utils import get_coco
12
+ from torch import nn
13
+ from torch.optim.lr_scheduler import PolynomialLR
14
+ from torchvision.transforms import functional as F, InterpolationMode
15
+
16
+ from trplib import apply_trp
17
+
18
+
19
+ def get_dataset(dir_path, name, image_set, transform):
20
+ def sbd(*args, **kwargs):
21
+ return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs)
22
+
23
+ paths = {
24
+ "voc": (dir_path, torchvision.datasets.VOCSegmentation, 21),
25
+ "voc_aug": (dir_path, sbd, 21),
26
+ "coco": (dir_path, get_coco, 21),
27
+ }
28
+ p, ds_fn, num_classes = paths[name]
29
+
30
+ ds = ds_fn(p, image_set=image_set, transforms=transform)
31
+ return ds, num_classes
32
+
33
+
34
+ def get_transform(train, args):
35
+ if train:
36
+ return presets.SegmentationPresetTrain(base_size=520, crop_size=480)
37
+ elif args.weights and args.test_only:
38
+ weights = torchvision.models.get_weight(args.weights)
39
+ trans = weights.transforms()
40
+
41
+ def preprocessing(img, target):
42
+ img = trans(img)
43
+ size = F.get_dimensions(img)[1:]
44
+ target = F.resize(target, size, interpolation=InterpolationMode.NEAREST)
45
+ return img, F.pil_to_tensor(target)
46
+
47
+ return preprocessing
48
+ else:
49
+ return presets.SegmentationPresetEval(base_size=520)
50
+
51
+
52
+ def criterion(inputs, target):
53
+ losses = {}
54
+ for name, x in inputs.items():
55
+ losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)
56
+
57
+ if len(losses) == 1:
58
+ return losses["out"]
59
+
60
+ return losses["out"] + 0.5 * losses["aux"]
61
+
62
+
63
+ def evaluate(model, data_loader, device, num_classes):
64
+ model.eval()
65
+ confmat = utils.ConfusionMatrix(num_classes)
66
+ metric_logger = utils.MetricLogger(delimiter=" ")
67
+ header = "Test:"
68
+ num_processed_samples = 0
69
+ with torch.inference_mode():
70
+ for image, target in metric_logger.log_every(data_loader, 100, header):
71
+ image, target = image.to(device), target.to(device)
72
+ output = model(image)
73
+ output = output["out"]
74
+
75
+ confmat.update(target.flatten(), output.argmax(1).flatten())
76
+ # FIXME need to take into account that the datasets
77
+ # could have been padded in distributed setup
78
+ num_processed_samples += image.shape[0]
79
+
80
+ confmat.reduce_from_all_processes()
81
+
82
+ num_processed_samples = utils.reduce_across_processes(num_processed_samples)
83
+ if (
84
+ hasattr(data_loader.dataset, "__len__")
85
+ and len(data_loader.dataset) != num_processed_samples
86
+ and torch.distributed.get_rank() == 0
87
+ ):
88
+ # See FIXME above
89
+ warnings.warn(
90
+ f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
91
+ "samples were used for the validation, which might bias the results. "
92
+ "Try adjusting the batch size and / or the world size. "
93
+ "Setting the world size to 1 is always a safe bet."
94
+ )
95
+
96
+ return confmat
97
+
98
+
99
+ def train_one_epoch(model, optimizer, data_loader, lr_scheduler, device, epoch, print_freq, scaler=None):
100
+ model.train()
101
+ metric_logger = utils.MetricLogger(delimiter=" ")
102
+ metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
103
+ header = f"Epoch: [{epoch}]"
104
+ for image, target in metric_logger.log_every(data_loader, print_freq, header):
105
+ image, target = image.to(device), target.to(device)
106
+ with torch.amp.autocast(device_type="cuda", enabled=scaler is not None):
107
+ _, loss = model(image, target)
108
+ # output = model(image)
109
+ # loss = criterion(output, target)
110
+
111
+ optimizer.zero_grad()
112
+ if scaler is not None:
113
+ scaler.scale(loss).backward()
114
+ scaler.step(optimizer)
115
+ scaler.update()
116
+ else:
117
+ loss.backward()
118
+ optimizer.step()
119
+
120
+ lr_scheduler.step()
121
+
122
+ metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
123
+
124
+
125
+ def main(args):
126
+ if args.output_dir:
127
+ utils.mkdir(args.output_dir)
128
+
129
+ utils.init_distributed_mode(args)
130
+ print(args)
131
+
132
+ device = torch.device(args.device)
133
+
134
+ if args.use_deterministic_algorithms:
135
+ torch.backends.cudnn.benchmark = False
136
+ torch.use_deterministic_algorithms(True)
137
+ else:
138
+ torch.backends.cudnn.benchmark = True
139
+
140
+ dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(True, args))
141
+ dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(False, args))
142
+
143
+ if args.distributed:
144
+ train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
145
+ test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
146
+ else:
147
+ train_sampler = torch.utils.data.RandomSampler(dataset)
148
+ test_sampler = torch.utils.data.SequentialSampler(dataset_test)
149
+
150
+ data_loader = torch.utils.data.DataLoader(
151
+ dataset,
152
+ batch_size=args.batch_size,
153
+ sampler=train_sampler,
154
+ num_workers=args.workers,
155
+ collate_fn=utils.collate_fn,
156
+ drop_last=True,
157
+ )
158
+
159
+ data_loader_test = torch.utils.data.DataLoader(
160
+ dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
161
+ )
162
+
163
+ model = torchvision.models.get_model(
164
+ args.model,
165
+ weights=args.weights,
166
+ weights_backbone=args.weights_backbone,
167
+ num_classes=num_classes,
168
+ aux_loss=args.aux_loss,
169
+ )
170
+ if args.apply_trp:
171
+ model = apply_trp(model, args.trp_depths, None, args.out_planes, args.trp_rewards)
172
+ model.to(device)
173
+ if args.distributed:
174
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
175
+
176
+ model_without_ddp = model
177
+ if args.distributed:
178
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
179
+ model_without_ddp = model.module
180
+
181
+ params_to_optimize = [
182
+ {"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]},
183
+ {"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]},
184
+ ]
185
+ if args.aux_loss:
186
+ params = [p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad]
187
+ params_to_optimize.append({"params": params, "lr": args.lr * 10})
188
+ optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
189
+
190
+ scaler = torch.amp.GradScaler(device="cuda") if args.amp else None
191
+
192
+ iters_per_epoch = len(data_loader)
193
+ main_lr_scheduler = PolynomialLR(
194
+ optimizer, total_iters=iters_per_epoch * (args.epochs - args.lr_warmup_epochs), power=0.9
195
+ )
196
+
197
+ if args.lr_warmup_epochs > 0:
198
+ warmup_iters = iters_per_epoch * args.lr_warmup_epochs
199
+ args.lr_warmup_method = args.lr_warmup_method.lower()
200
+ if args.lr_warmup_method == "linear":
201
+ warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
202
+ optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters
203
+ )
204
+ elif args.lr_warmup_method == "constant":
205
+ warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
206
+ optimizer, factor=args.lr_warmup_decay, total_iters=warmup_iters
207
+ )
208
+ else:
209
+ raise RuntimeError(
210
+ f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
211
+ )
212
+ lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
213
+ optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters]
214
+ )
215
+ else:
216
+ lr_scheduler = main_lr_scheduler
217
+
218
+ if args.resume:
219
+ checkpoint = torch.load(args.resume, map_location="cpu", weights_only=False)
220
+ model_without_ddp.load_state_dict(checkpoint["model"], strict=not args.test_only)
221
+ if not args.test_only:
222
+ optimizer.load_state_dict(checkpoint["optimizer"])
223
+ lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
224
+ args.start_epoch = checkpoint["epoch"] + 1
225
+ if args.amp:
226
+ scaler.load_state_dict(checkpoint["scaler"])
227
+
228
+ if args.test_only:
229
+ # We disable the cudnn benchmarking because it can noticeably affect the accuracy
230
+ torch.backends.cudnn.benchmark = False
231
+ torch.backends.cudnn.deterministic = True
232
+ confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
233
+ print(confmat)
234
+ return
235
+
236
+ start_time = time.time()
237
+ for epoch in range(args.start_epoch, args.epochs):
238
+ if args.distributed:
239
+ train_sampler.set_epoch(epoch)
240
+ train_one_epoch(model, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq, scaler)
241
+ confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
242
+ print(confmat)
243
+
244
+ if args.output_dir:
245
+ checkpoint = {
246
+ "model": model_without_ddp.state_dict() if not args.apply_trp else {k: v for k, v in model_without_ddp.state_dict().items() if not "trp_blocks" in k},
247
+ "optimizer": optimizer.state_dict(),
248
+ "lr_scheduler": lr_scheduler.state_dict(),
249
+ "epoch": epoch,
250
+ "args": args,
251
+ }
252
+ if args.amp:
253
+ checkpoint["scaler"] = scaler.state_dict()
254
+ utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
255
+ utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
256
+
257
+ total_time = time.time() - start_time
258
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
259
+ print(f"Training time {total_time_str}")
260
+
261
+
262
+ def get_args_parser(add_help=True):
263
+ import argparse
264
+
265
+ parser = argparse.ArgumentParser(description="PyTorch Segmentation Training", add_help=add_help)
266
+
267
+ parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path")
268
+ parser.add_argument("--dataset", default="coco", type=str, help="dataset name")
269
+ parser.add_argument("--model", default="fcn_resnet101", type=str, help="model name")
270
+ parser.add_argument("--aux-loss", action="store_true", help="auxiliar loss")
271
+ parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
272
+ parser.add_argument(
273
+ "-b", "--batch-size", default=8, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
274
+ )
275
+ parser.add_argument("--epochs", default=30, type=int, metavar="N", help="number of total epochs to run")
276
+
277
+ parser.add_argument(
278
+ "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
279
+ )
280
+ parser.add_argument("--lr", default=0.01, type=float, help="initial learning rate")
281
+ parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
282
+ parser.add_argument(
283
+ "--wd",
284
+ "--weight-decay",
285
+ default=1e-4,
286
+ type=float,
287
+ metavar="W",
288
+ help="weight decay (default: 1e-4)",
289
+ dest="weight_decay",
290
+ )
291
+ parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
292
+ parser.add_argument("--lr-warmup-method", default="linear", type=str, help="the warmup method (default: linear)")
293
+ parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
294
+ parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
295
+ parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
296
+ parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
297
+ parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
298
+ parser.add_argument(
299
+ "--test-only",
300
+ dest="test_only",
301
+ help="Only test the model",
302
+ action="store_true",
303
+ )
304
+ parser.add_argument(
305
+ "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
306
+ )
307
+ # distributed training parameters
308
+ parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
309
+ parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
310
+
311
+ parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
312
+ parser.add_argument("--weights-backbone", default=None, type=str, help="the backbone weights enum name to load")
313
+
314
+ # Mixed precision training parameters
315
+ parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
316
+
317
+ parser.add_argument("--apply-trp", action="store_true", help="enable applying trp")
318
+ parser.add_argument("--trp-depths", nargs="+", default=[2, 2, 2], type=int, help="number of depth for each trp block")
319
+ parser.add_argument("--out-planes", default=8, type=int, help="the dimension of the inner hidden states")
320
+ parser.add_argument("--trp-rewards", nargs="+", default=[1.0, 0.4, 0.2, 0.1], type=float, help="trp rewards")
321
+
322
+ return parser
323
+
324
+
325
+ if __name__ == "__main__":
326
+ args = get_args_parser().parse_args()
327
+ main(args)
nas-examples/semantic-segmentation/transforms.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torchvision import transforms as T
6
+ from torchvision.transforms import functional as F
7
+
8
+
9
+ def pad_if_smaller(img, size, fill=0):
10
+ min_size = min(img.size)
11
+ if min_size < size:
12
+ ow, oh = img.size
13
+ padh = size - oh if oh < size else 0
14
+ padw = size - ow if ow < size else 0
15
+ img = F.pad(img, (0, 0, padw, padh), fill=fill)
16
+ return img
17
+
18
+
19
+ class Compose:
20
+ def __init__(self, transforms):
21
+ self.transforms = transforms
22
+
23
+ def __call__(self, image, target):
24
+ for t in self.transforms:
25
+ image, target = t(image, target)
26
+ return image, target
27
+
28
+
29
+ class RandomResize:
30
+ def __init__(self, min_size, max_size=None):
31
+ self.min_size = min_size
32
+ if max_size is None:
33
+ max_size = min_size
34
+ self.max_size = max_size
35
+
36
+ def __call__(self, image, target):
37
+ size = random.randint(self.min_size, self.max_size)
38
+ image = F.resize(image, size)
39
+ target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST)
40
+ return image, target
41
+
42
+
43
+ class RandomHorizontalFlip:
44
+ def __init__(self, flip_prob):
45
+ self.flip_prob = flip_prob
46
+
47
+ def __call__(self, image, target):
48
+ if random.random() < self.flip_prob:
49
+ image = F.hflip(image)
50
+ target = F.hflip(target)
51
+ return image, target
52
+
53
+
54
+ class RandomCrop:
55
+ def __init__(self, size):
56
+ self.size = size
57
+
58
+ def __call__(self, image, target):
59
+ image = pad_if_smaller(image, self.size)
60
+ target = pad_if_smaller(target, self.size, fill=255)
61
+ crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
62
+ image = F.crop(image, *crop_params)
63
+ target = F.crop(target, *crop_params)
64
+ return image, target
65
+
66
+
67
+ class CenterCrop:
68
+ def __init__(self, size):
69
+ self.size = size
70
+
71
+ def __call__(self, image, target):
72
+ image = F.center_crop(image, self.size)
73
+ target = F.center_crop(target, self.size)
74
+ return image, target
75
+
76
+
77
+ class PILToTensor:
78
+ def __call__(self, image, target):
79
+ image = F.pil_to_tensor(image)
80
+ target = torch.as_tensor(np.array(target), dtype=torch.int64)
81
+ return image, target
82
+
83
+
84
+ class ConvertImageDtype:
85
+ def __init__(self, dtype):
86
+ self.dtype = dtype
87
+
88
+ def __call__(self, image, target):
89
+ image = F.convert_image_dtype(image, self.dtype)
90
+ return image, target
91
+
92
+
93
+ class Normalize:
94
+ def __init__(self, mean, std):
95
+ self.mean = mean
96
+ self.std = std
97
+
98
+ def __call__(self, image, target):
99
+ image = F.normalize(image, mean=self.mean, std=self.std)
100
+ return image, target
nas-examples/semantic-segmentation/trplib.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+ from typing import Optional, List, Union, Callable
3
+ from collections import OrderedDict
4
+
5
+ import torch
6
+ from torch import nn, Tensor
7
+ from torch.nn import functional as F
8
+
9
+ from torchvision.models.mobilenetv2 import MobileNetV2
10
+ from torchvision.models.resnet import ResNet
11
+ from torchvision.models.efficientnet import EfficientNet
12
+ from torchvision.models.vision_transformer import VisionTransformer
13
+ from torchvision.models.segmentation.fcn import FCN
14
+ from torchvision.models.segmentation.deeplabv3 import DeepLabV3
15
+
16
+
17
+ def compute_policy_loss(loss_sequence, mask_sequence, rewards):
18
+ losses = sum(mask * padded_loss for mask, padded_loss in zip(mask_sequence, loss_sequence))
19
+ returns = sum(padded_reward * mask for padded_reward, mask in zip(rewards, mask_sequence))
20
+ loss = torch.mean(losses * returns)
21
+
22
+ return loss
23
+
24
+
25
+ class TPBlock(nn.Module):
26
+ def __init__(self, depths: int, in_planes: int, out_planes: int = None, rank=1, shape_dims=3, channel_first=True, dtype=torch.float32) -> None:
27
+ super().__init__()
28
+ out_planes = in_planes if out_planes is None else out_planes
29
+ self.layers = torch.nn.ModuleList([self._make_layer(in_planes, out_planes, rank, shape_dims, channel_first, dtype) for _ in range(depths)])
30
+
31
+ def forward(self, x: Tensor) -> Tensor:
32
+ for layer in self.layers:
33
+ x = x + layer(x)
34
+ return x
35
+
36
+ def _make_layer(self, in_planes: int, out_planes: int = None, rank=1, shape_dims=3, channel_first=True, dtype=torch.float32) -> nn.Sequential:
37
+
38
+ class Permute(nn.Module):
39
+ def __init__(self, *dims):
40
+ super().__init__()
41
+ self.dims = dims
42
+ def forward(self, x):
43
+ return x.permute(*self.dims)
44
+
45
+ class RMSNorm(nn.Module):
46
+ __constants__ = ["eps"]
47
+ eps: float
48
+
49
+ def __init__(self, hidden_size, eps: float = 1e-6, device=None, dtype=None):
50
+ """
51
+ LlamaRMSNorm is equivalent to T5LayerNorm.
52
+ """
53
+ factory_kwargs = {"device": device, "dtype": dtype}
54
+ super().__init__()
55
+ self.eps = eps
56
+ self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs))
57
+
58
+ def forward(self, hidden_states):
59
+ input_dtype = hidden_states.dtype
60
+ hidden_states = hidden_states.to(torch.float32)
61
+ variance = hidden_states.pow(2).mean(dim=1, keepdim=True)
62
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
63
+ weight = self.weight.view(1, -1, *[1] * (hidden_states.ndim - 2))
64
+ return weight * hidden_states.to(input_dtype)
65
+
66
+ def extra_repr(self):
67
+ return f"{self.weight.shape[0]}, eps={self.eps}"
68
+
69
+ conv_map = {
70
+ 2: (nn.Conv1d, (0, 2, 1), (0, 2, 1)),
71
+ 3: (nn.Conv2d, (0, 3, 1, 2), (0, 2, 3, 1)),
72
+ 4: (nn.Conv3d, (0, 4, 1, 2, 3), (0, 2, 3, 4, 1)),
73
+ }
74
+ Conv, pre_dims, post_dims = conv_map[shape_dims]
75
+ kernel_size, dilation, padding = self.generate_hyperparameters(rank)
76
+
77
+ pre_permute = nn.Identity() if channel_first else Permute(*pre_dims)
78
+ post_permute = nn.Identity() if channel_first else Permute(*post_dims)
79
+ conv1 = Conv(in_planes, out_planes, kernel_size, padding=padding, dilation=dilation, bias=False, dtype=dtype, device='cuda')
80
+ nn.init.zeros_(conv1.weight)
81
+ bn1 = RMSNorm(out_planes, dtype=dtype, device="cuda")
82
+ relu = nn.ReLU(inplace=True)
83
+ conv2 = Conv(out_planes, in_planes, kernel_size, padding=padding, dilation=dilation, bias=False, dtype=dtype, device='cuda')
84
+ nn.init.zeros_(conv2.weight)
85
+ bn2 = RMSNorm(in_planes, dtype=dtype, device="cuda")
86
+
87
+ return torch.nn.Sequential(pre_permute, conv1, bn1, relu, conv2, bn2, relu, post_permute)
88
+
89
+ @staticmethod
90
+ def generate_hyperparameters(rank: int):
91
+ """
92
+ Generates kernel size and dilation rate pairs sorted by increasing padded kernel size.
93
+
94
+ Args:
95
+ rank: Number of (kernel_size, dilation) pairs to generate. Must be positive.
96
+
97
+ Returns:
98
+ Tuple[int, int]: A (kernel_size, dilation) tuple where:
99
+ - kernel_size: Always odd and >= 1
100
+ - dilation: Computed to maintain consistent padded kernel size growth
101
+
102
+ Note:
103
+ Padded kernel size is calculated as:
104
+ (kernel_size - 1) * dilation + 1
105
+ Pairs are generated first in order of increasing padded kernel size,
106
+ then by increasing kernel size for equal padded kernel sizes.
107
+ """
108
+ pairs = [(1, 1, 0)] # Start with smallest possible
109
+ padded_kernel_size = 3
110
+
111
+ while len(pairs) < rank:
112
+ for kernel_size in range(3, padded_kernel_size + 1, 2):
113
+ if (padded_kernel_size - 1) % (kernel_size - 1) == 0:
114
+ dilation = (padded_kernel_size - 1) // (kernel_size - 1)
115
+ padding = dilation * (kernel_size - 1) // 2
116
+ pairs.append((kernel_size, dilation, padding))
117
+ if len(pairs) >= rank:
118
+ break
119
+
120
+ # Move to next odd padded kernel size
121
+ padded_kernel_size += 2
122
+
123
+ return pairs[-1]
124
+
125
+
126
+ # ResNet for Image Classification
127
+ class ResNetConfig:
128
+ @staticmethod
129
+ def gen_shared_head(self):
130
+ def func(hidden_states):
131
+ """
132
+ Args:
133
+ hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
134
+
135
+ Returns:
136
+ logits (Tensor): Logits tensor of shape [B, C].
137
+ """
138
+ x = self.avgpool(hidden_states)
139
+ x = torch.flatten(x, 1)
140
+ logits = self.fc(x)
141
+ return logits
142
+ return func
143
+
144
+ @staticmethod
145
+ def gen_logits(self, shared_head):
146
+ def func(hidden_states):
147
+ """
148
+ Args:
149
+ hidden_states (Tensor): Hidden States tensor of shape [B, L, hidden_units].
150
+
151
+ Returns:
152
+ logits_seqence (List[Tensor]): List of Logits tensors.
153
+ """
154
+ logits_sequence = [shared_head(hidden_states)]
155
+ for layer in self.trp_blocks:
156
+ logits_sequence.append(shared_head(layer(hidden_states)))
157
+ return logits_sequence
158
+ return func
159
+
160
+ @staticmethod
161
+ def gen_mask(label_smoothing=0.0, top_k=1):
162
+ def func(logits_sequence, labels):
163
+ """
164
+ Args:
165
+ logits_sequence (List[Tensor]): List of Logits tensors.
166
+ labels (Tensor): Target labels of shape [B] or [B, C].
167
+
168
+ Returns:
169
+ mask_sequence (List[Tensor]): Boolean mask tensor of shape [B*(L-1)].
170
+ """
171
+ labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels
172
+
173
+ mask_sequence = [torch.ones_like(labels, dtype=torch.float32, device=labels.device)]
174
+ for logits in logits_sequence:
175
+ with torch.no_grad():
176
+ topk_values, topk_indices = torch.topk(logits, top_k, dim=-1)
177
+ mask = torch.eq(topk_indices, labels[:, None]).any(dim=-1).to(torch.float32)
178
+ mask_sequence.append(mask_sequence[-1] * mask)
179
+ return mask_sequence
180
+ return func
181
+
182
+ @staticmethod
183
+ def gen_criterion(label_smoothing=0.0):
184
+ def func(logits_sequence, labels):
185
+ """
186
+ Args:
187
+ logits_sequence (List[Tensor]): List of Logits tensor.
188
+ labels (Tensor): labels labels of shape [B] or [B, C].
189
+
190
+ Returns:
191
+ loss (Tensor): Scalar tensor representing the loss.
192
+ mask (Tensor): Boolean mask tensor of shape [B].
193
+ """
194
+ labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels
195
+
196
+ loss_sequence = []
197
+ for logits in logits_sequence:
198
+ loss_sequence.append(F.cross_entropy(logits, labels, reduction="none", label_smoothing=label_smoothing))
199
+
200
+ return loss_sequence
201
+ return func
202
+
203
+ @staticmethod
204
+ def gen_forward(rewards, label_smoothing=0.0, top_k=1):
205
+ def func(self, x: Tensor, targets=None) -> Tensor:
206
+ x = self.conv1(x)
207
+ x = self.bn1(x)
208
+ x = self.relu(x)
209
+ x = self.maxpool(x)
210
+
211
+ x = self.layer1(x)
212
+ x = self.layer2(x)
213
+ x = self.layer3(x)
214
+ hidden_states = self.layer4(x)
215
+ x = self.avgpool(hidden_states)
216
+ x = torch.flatten(x, 1)
217
+ logits = self.fc(x)
218
+
219
+ if self.training:
220
+ shared_head = ResNetConfig.gen_shared_head(self)
221
+ compute_logits = ResNetConfig.gen_logits(self, shared_head)
222
+ compute_mask = ResNetConfig.gen_mask(label_smoothing, top_k)
223
+ compute_loss = ResNetConfig.gen_criterion(label_smoothing)
224
+
225
+ logits_sequence = compute_logits(hidden_states)
226
+ mask_sequence = compute_mask(logits_sequence, targets)
227
+ loss_sequence = compute_loss(logits_sequence, targets)
228
+ loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
229
+
230
+ return logits, loss
231
+
232
+ return logits
233
+
234
+ return func
235
+
236
+
237
+ # MobileNetV2 for Image Classification
238
+ class MobileNetV2Config(ResNetConfig):
239
+ @staticmethod
240
+ def gen_shared_head(self):
241
+ def func(hidden_states):
242
+ """
243
+ Args:
244
+ hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
245
+
246
+ Returns:
247
+ logits (Tensor): Logits tensor of shape [B, C].
248
+ """
249
+ x = nn.functional.adaptive_avg_pool2d(hidden_states, (1, 1))
250
+ x = torch.flatten(x, 1)
251
+ logits = self.classifier(x)
252
+ return logits
253
+ return func
254
+
255
+ @staticmethod
256
+ def gen_forward(rewards, label_smoothing=0.0, top_k=1):
257
+ def func(self, x: Tensor, targets=None) -> Tensor:
258
+ hidden_states = self.features(x)
259
+ # Cannot use "squeeze" as batch-size can be 1
260
+ x = nn.functional.adaptive_avg_pool2d(hidden_states, (1, 1))
261
+ x = torch.flatten(x, 1)
262
+ logits = self.classifier(x)
263
+
264
+ if self.training:
265
+ shared_head = MobileNetV2Config.gen_shared_head(self)
266
+ compute_logits = MobileNetV2Config.gen_logits(self, shared_head)
267
+ compute_mask = MobileNetV2Config.gen_mask(label_smoothing, top_k)
268
+ compute_loss = MobileNetV2Config.gen_criterion(label_smoothing)
269
+
270
+ logits_sequence = compute_logits(hidden_states)
271
+ mask_sequence = compute_mask(logits_sequence, targets)
272
+ loss_sequence = compute_loss(logits_sequence, targets)
273
+ loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
274
+
275
+ return logits, loss
276
+
277
+ return logits
278
+
279
+ return func
280
+
281
+
282
+ # EfficientNet for Image Classification
283
+ class EfficientNetConfig(ResNetConfig):
284
+ @staticmethod
285
+ def gen_shared_head(self):
286
+ def func(hidden_states):
287
+ """
288
+ Args:
289
+ hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
290
+
291
+ Returns:
292
+ logits (Tensor): Logits tensor of shape [B, C].
293
+ """
294
+ x = self.avgpool(hidden_states)
295
+ x = torch.flatten(x, 1)
296
+ logits = self.classifier(x)
297
+ return logits
298
+ return func
299
+
300
+ @staticmethod
301
+ def gen_forward(rewards, label_smoothing=0.0, top_k=1):
302
+ def func(self, x: Tensor, targets=None) -> Tensor:
303
+ hidden_states = self.features(x)
304
+ x = self.avgpool(hidden_states)
305
+ x = torch.flatten(x, 1)
306
+ logits = self.classifier(x)
307
+
308
+ if self.training:
309
+ shared_head = EfficientNetConfig.gen_shared_head(self)
310
+ compute_logits = EfficientNetConfig.gen_logits(self, shared_head)
311
+ compute_mask = EfficientNetConfig.gen_mask(label_smoothing, top_k)
312
+ compute_loss = EfficientNetConfig.gen_criterion(label_smoothing)
313
+
314
+ logits_sequence = compute_logits(hidden_states)
315
+ mask_sequence = compute_mask(logits_sequence, targets)
316
+ loss_sequence = compute_loss(logits_sequence, targets)
317
+ loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
318
+
319
+ return logits, loss
320
+
321
+ return logits
322
+
323
+ return func
324
+
325
+
326
+ # VisionTransformer for Image Classification
327
+ class VisionTransformerConfig(ResNetConfig):
328
+ @staticmethod
329
+ def gen_shared_head(self):
330
+ def func(hidden_states):
331
+ """
332
+ Args:
333
+ hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
334
+
335
+ Returns:
336
+ logits (Tensor): Logits tensor of shape [B, C].
337
+ """
338
+ x = hidden_states[:, 0]
339
+ logits = self.heads(x)
340
+ return logits
341
+ return func
342
+
343
+ @staticmethod
344
+ def gen_forward(rewards, label_smoothing=0.0, top_k=1):
345
+ def func(self, images: Tensor, targets=None):
346
+ x = self._process_input(images)
347
+ n = x.shape[0]
348
+ batch_class_token = self.class_token.expand(n, -1, -1)
349
+ x = torch.cat([batch_class_token, x], dim=1)
350
+ hidden_states = self.encoder(x)
351
+ x = hidden_states[:, 0]
352
+
353
+ logits = self.heads(x)
354
+
355
+
356
+ if self.training:
357
+ shared_head = VisionTransformerConfig.gen_shared_head(self)
358
+ compute_logits = VisionTransformerConfig.gen_logits(self, shared_head)
359
+ compute_mask = VisionTransformerConfig.gen_mask(label_smoothing, top_k)
360
+ compute_loss = VisionTransformerConfig.gen_criterion(label_smoothing)
361
+
362
+ logits_sequence = compute_logits(hidden_states)
363
+ mask_sequence = compute_mask(logits_sequence, targets)
364
+ loss_sequence = compute_loss(logits_sequence, targets)
365
+ loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
366
+
367
+ return logits, loss
368
+ return logits
369
+ return func
370
+
371
+
372
+ # FCN for Semantic Segmentation
373
+ class FCNConfig(ResNetConfig):
374
+ @staticmethod
375
+ def gen_out_shared_head(self, input_shape):
376
+ def func(features):
377
+ """
378
+ Args:
379
+ features (Tensor): features tensor of shape [B, hidden_units, H, W].
380
+
381
+ Returns:
382
+ result (Tensors): result tensor of shape [B, C, H, W].
383
+ """
384
+ x = self.classifier(features)
385
+ result = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
386
+ return result
387
+ return func
388
+
389
+ @staticmethod
390
+ def gen_aux_shared_head(self, input_shape):
391
+ def func(features):
392
+ """
393
+ Args:
394
+ features (Tensor): features tensor of shape [B, hidden_units, H, W].
395
+
396
+ Returns:
397
+ result (Tensors): result tensor of shape [B, C, H, W].
398
+ """
399
+ x = self.aux_classifier(features)
400
+ result = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
401
+ return result
402
+ return func
403
+
404
+ @staticmethod
405
+ def gen_out_logits(self, shared_head):
406
+ def func(hidden_states):
407
+ """
408
+ Args:
409
+ hidden_states (Tensor): Hidden States tensor of shape [B, L, hidden_units].
410
+
411
+ Returns:
412
+ logits_seqence (List[Tensor]): List of Logits tensors.
413
+ """
414
+ logits_sequence = [shared_head(hidden_states)]
415
+ for layer in self.out_trp_blocks:
416
+ logits_sequence.append(shared_head(layer(hidden_states)))
417
+ return logits_sequence
418
+ return func
419
+
420
+ @staticmethod
421
+ def gen_aux_logits(self, shared_head):
422
+ def func(hidden_states):
423
+ """
424
+ Args:
425
+ hidden_states (Tensor): Hidden States tensor of shape [B, L, hidden_units].
426
+
427
+ Returns:
428
+ logits_seqence (List[Tensor]): List of Logits tensors.
429
+ """
430
+ logits_sequence = [shared_head(hidden_states)]
431
+ for layer in self.aux_trp_blocks:
432
+ logits_sequence.append(shared_head(layer(hidden_states)))
433
+ return logits_sequence
434
+ return func
435
+
436
+ @staticmethod
437
+ def gen_mask(label_smoothing=0.0, top_k=1):
438
+ def func(logits_sequence, labels):
439
+ """
440
+ Args:
441
+ logits_sequence (List[Tensor]): List of Logits tensors with shape [B, C, H, W].
442
+ labels (Tensor): Target labels of shape [B, H, W].
443
+
444
+ Returns:
445
+ mask_sequence (List[Tensor]): Boolean mask tensor of shape [B, H, W].
446
+ """
447
+ labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels
448
+
449
+ mask_sequence = [torch.ones_like(labels, dtype=torch.float32, device=labels.device)]
450
+ for logits in logits_sequence:
451
+ with torch.no_grad():
452
+ topk_values, topk_indices = torch.topk(logits, top_k, dim=1)
453
+ mask = torch.eq(topk_indices, labels[:, None, :, :]).any(dim=1).to(torch.float32)
454
+ mask_sequence.append(mask_sequence[-1] * mask)
455
+ return mask_sequence
456
+ return func
457
+
458
+ @staticmethod
459
+ def gen_criterion(label_smoothing=0.0):
460
+ def func(logits_sequence, labels):
461
+ """
462
+ Args:
463
+ logits_sequence (List[Tensor]): List of Logits tensor.
464
+ labels (Tensor): labels labels of shape [B] or [B, C].
465
+
466
+ Returns:
467
+ loss (Tensor): Scalar tensor representing the loss.
468
+ mask (Tensor): Boolean mask tensor of shape [B].
469
+ """
470
+ labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels
471
+
472
+ loss_sequence = []
473
+ for logits in logits_sequence:
474
+ loss_sequence.append(F.cross_entropy(logits, labels, ignore_index=255, reduction="none", label_smoothing=label_smoothing))
475
+
476
+ return loss_sequence
477
+ return func
478
+
479
+ @staticmethod
480
+ def gen_forward(rewards, label_smoothing=0.0, top_k=1):
481
+ def func(self, images: Tensor, targets=None):
482
+ input_shape = images.shape[-2:]
483
+ # contract: features is a dict of tensors
484
+ features = self.backbone(images)
485
+
486
+ result = OrderedDict()
487
+ x = features["out"]
488
+ x = self.classifier(x)
489
+ x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
490
+ result["out"] = x
491
+
492
+ if self.aux_classifier is not None:
493
+ x = features["aux"]
494
+ x = self.aux_classifier(x)
495
+ x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
496
+ result["aux"] = x
497
+
498
+ if self.training:
499
+ torch._assert(targets is not None, "targets should not be none when in training mode")
500
+ out_shared_head = FCNConfig.gen_out_shared_head(self, input_shape)
501
+ aux_shared_head = FCNConfig.gen_aux_shared_head(self, input_shape)
502
+ compute_out_logits = FCNConfig.gen_out_logits(self, out_shared_head)
503
+ compute_aux_logits = FCNConfig.gen_aux_logits(self, aux_shared_head)
504
+ compute_mask = FCNConfig.gen_mask(label_smoothing, top_k)
505
+ compute_loss = FCNConfig.gen_criterion(label_smoothing)
506
+
507
+ out_logits_sequence = compute_out_logits(features["out"])
508
+ out_mask_sequence = compute_mask(out_logits_sequence, targets)
509
+ out_loss_sequence = compute_loss(out_logits_sequence, targets)
510
+ out_loss = compute_policy_loss(out_loss_sequence, out_mask_sequence, rewards)
511
+
512
+ aux_logits_sequence = compute_aux_logits(features["aux"])
513
+ aux_mask_sequence = compute_mask(aux_logits_sequence, targets)
514
+ aux_loss_sequence = compute_loss(aux_logits_sequence, targets)
515
+ aux_loss = compute_policy_loss(aux_loss_sequence, aux_mask_sequence, rewards)
516
+
517
+ loss = out_loss + 0.5 * aux_loss
518
+ return result, loss
519
+ return result
520
+ return func
521
+
522
+
523
+ # DeepLabV3Config for Semantic Segmentation
524
+ class DeepLabV3Config(FCNConfig):
525
+ pass
526
+
527
+
528
+ def apply_trp(model, depths: List[int], in_planes: int, out_planes: int, rewards, **kwargs):
529
+ if isinstance(model, ResNet):
530
+ print("✅ Applying TRP to ResNet for Image Classification...")
531
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
532
+ model.forward = types.MethodType(ResNetConfig.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
533
+ elif isinstance(model, MobileNetV2):
534
+ print("✅ Applying TRP to MobileNetV2 for Image Classification...")
535
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
536
+ model.forward = types.MethodType(MobileNetV2Config.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
537
+ elif isinstance(model, EfficientNet):
538
+ print("✅ Applying TRP to EfficientNet for Image Classification...")
539
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
540
+ model.forward = types.MethodType(EfficientNetConfig.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
541
+ elif isinstance(model, VisionTransformer):
542
+ print("✅ Applying TRP to VisionTransformer for Image Classification...")
543
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k, shape_dims=2, channel_first=False) for k, d in enumerate(depths)])
544
+ model.forward = types.MethodType(VisionTransformerConfig.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
545
+ elif isinstance(model, FCN):
546
+ print("✅ Applying TRP to FCN for Semantic Segmentation...")
547
+ model.out_trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=2048, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
548
+ model.aux_trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=1024, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
549
+ model.forward = types.MethodType(FCNConfig.gen_forward(rewards, label_smoothing=0.0, top_k=1), model)
550
+ elif isinstance(model, DeepLabV3):
551
+ print("✅ Applying TRP to DeepLabV3 for Semantic Segmentation...")
552
+ model.out_trp_blocks = torch.nn.ModuleList([TPBlock(depths, in_planes=2048, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
553
+ model.aux_trp_blocks = torch.nn.ModuleList([TPBlock(depths, in_planes=1024, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
554
+ model.forward = types.MethodType(DeepLabV3Config.gen_forward(rewards, label_smoothing=0.0, top_k=1), model)
555
+ return model
nas-examples/semantic-segmentation/utils.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import errno
3
+ import os
4
+ import time
5
+ from collections import defaultdict, deque
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+
10
+
11
+ class SmoothedValue:
12
+ """Track a series of values and provide access to smoothed values over a
13
+ window or the global series average.
14
+ """
15
+
16
+ def __init__(self, window_size=20, fmt=None):
17
+ if fmt is None:
18
+ fmt = "{median:.4f} ({global_avg:.4f})"
19
+ self.deque = deque(maxlen=window_size)
20
+ self.total = 0.0
21
+ self.count = 0
22
+ self.fmt = fmt
23
+
24
+ def update(self, value, n=1):
25
+ self.deque.append(value)
26
+ self.count += n
27
+ self.total += value * n
28
+
29
+ def synchronize_between_processes(self):
30
+ """
31
+ Warning: does not synchronize the deque!
32
+ """
33
+ t = reduce_across_processes([self.count, self.total])
34
+ t = t.tolist()
35
+ self.count = int(t[0])
36
+ self.total = t[1]
37
+
38
+ @property
39
+ def median(self):
40
+ d = torch.tensor(list(self.deque))
41
+ return d.median().item()
42
+
43
+ @property
44
+ def avg(self):
45
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
46
+ return d.mean().item()
47
+
48
+ @property
49
+ def global_avg(self):
50
+ return self.total / self.count
51
+
52
+ @property
53
+ def max(self):
54
+ return max(self.deque)
55
+
56
+ @property
57
+ def value(self):
58
+ return self.deque[-1]
59
+
60
+ def __str__(self):
61
+ return self.fmt.format(
62
+ median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
63
+ )
64
+
65
+
66
+ class ConfusionMatrix:
67
+ def __init__(self, num_classes):
68
+ self.num_classes = num_classes
69
+ self.mat = None
70
+
71
+ def update(self, a, b):
72
+ n = self.num_classes
73
+ if self.mat is None:
74
+ self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
75
+ with torch.inference_mode():
76
+ k = (a >= 0) & (a < n)
77
+ inds = n * a[k].to(torch.int64) + b[k]
78
+ self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
79
+
80
+ def reset(self):
81
+ self.mat.zero_()
82
+
83
+ def compute(self):
84
+ h = self.mat.float()
85
+ acc_global = torch.diag(h).sum() / h.sum()
86
+ acc = torch.diag(h) / h.sum(1)
87
+ iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
88
+ return acc_global, acc, iu
89
+
90
+ def reduce_from_all_processes(self):
91
+ reduce_across_processes(self.mat)
92
+
93
+ def __str__(self):
94
+ acc_global, acc, iu = self.compute()
95
+ return ("global correct: {:.1f}\naverage row correct: {}\nIoU: {}\nmean IoU: {:.1f}").format(
96
+ acc_global.item() * 100,
97
+ [f"{i:.1f}" for i in (acc * 100).tolist()],
98
+ [f"{i:.1f}" for i in (iu * 100).tolist()],
99
+ iu.mean().item() * 100,
100
+ )
101
+
102
+
103
+ class MetricLogger:
104
+ def __init__(self, delimiter="\t"):
105
+ self.meters = defaultdict(SmoothedValue)
106
+ self.delimiter = delimiter
107
+
108
+ def update(self, **kwargs):
109
+ for k, v in kwargs.items():
110
+ if isinstance(v, torch.Tensor):
111
+ v = v.item()
112
+ if not isinstance(v, (float, int)):
113
+ raise TypeError(
114
+ f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
115
+ )
116
+ self.meters[k].update(v)
117
+
118
+ def __getattr__(self, attr):
119
+ if attr in self.meters:
120
+ return self.meters[attr]
121
+ if attr in self.__dict__:
122
+ return self.__dict__[attr]
123
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
124
+
125
+ def __str__(self):
126
+ loss_str = []
127
+ for name, meter in self.meters.items():
128
+ loss_str.append(f"{name}: {str(meter)}")
129
+ return self.delimiter.join(loss_str)
130
+
131
+ def synchronize_between_processes(self):
132
+ for meter in self.meters.values():
133
+ meter.synchronize_between_processes()
134
+
135
+ def add_meter(self, name, meter):
136
+ self.meters[name] = meter
137
+
138
+ def log_every(self, iterable, print_freq, header=None):
139
+ i = 0
140
+ if not header:
141
+ header = ""
142
+ start_time = time.time()
143
+ end = time.time()
144
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
145
+ data_time = SmoothedValue(fmt="{avg:.4f}")
146
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
147
+ if torch.cuda.is_available():
148
+ log_msg = self.delimiter.join(
149
+ [
150
+ header,
151
+ "[{0" + space_fmt + "}/{1}]",
152
+ "eta: {eta}",
153
+ "{meters}",
154
+ "time: {time}",
155
+ "data: {data}",
156
+ "max mem: {memory:.0f}",
157
+ ]
158
+ )
159
+ else:
160
+ log_msg = self.delimiter.join(
161
+ [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
162
+ )
163
+ MB = 1024.0 * 1024.0
164
+ for obj in iterable:
165
+ data_time.update(time.time() - end)
166
+ yield obj
167
+ iter_time.update(time.time() - end)
168
+ if i % print_freq == 0:
169
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
170
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
171
+ if torch.cuda.is_available():
172
+ print(
173
+ log_msg.format(
174
+ i,
175
+ len(iterable),
176
+ eta=eta_string,
177
+ meters=str(self),
178
+ time=str(iter_time),
179
+ data=str(data_time),
180
+ memory=torch.cuda.max_memory_allocated() / MB,
181
+ )
182
+ )
183
+ else:
184
+ print(
185
+ log_msg.format(
186
+ i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
187
+ )
188
+ )
189
+ i += 1
190
+ end = time.time()
191
+ total_time = time.time() - start_time
192
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
193
+ print(f"{header} Total time: {total_time_str}")
194
+
195
+
196
+ def cat_list(images, fill_value=0):
197
+ max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
198
+ batch_shape = (len(images),) + max_size
199
+ batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
200
+ for img, pad_img in zip(images, batched_imgs):
201
+ pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)
202
+ return batched_imgs
203
+
204
+
205
+ def collate_fn(batch):
206
+ images, targets = list(zip(*batch))
207
+ batched_imgs = cat_list(images, fill_value=0)
208
+ batched_targets = cat_list(targets, fill_value=255)
209
+ return batched_imgs, batched_targets
210
+
211
+
212
+ def mkdir(path):
213
+ try:
214
+ os.makedirs(path)
215
+ except OSError as e:
216
+ if e.errno != errno.EEXIST:
217
+ raise
218
+
219
+
220
+ def setup_for_distributed(is_master):
221
+ """
222
+ This function disables printing when not in master process
223
+ """
224
+ import builtins as __builtin__
225
+
226
+ builtin_print = __builtin__.print
227
+
228
+ def print(*args, **kwargs):
229
+ force = kwargs.pop("force", False)
230
+ if is_master or force:
231
+ builtin_print(*args, **kwargs)
232
+
233
+ __builtin__.print = print
234
+
235
+
236
+ def is_dist_avail_and_initialized():
237
+ if not dist.is_available():
238
+ return False
239
+ if not dist.is_initialized():
240
+ return False
241
+ return True
242
+
243
+
244
+ def get_world_size():
245
+ if not is_dist_avail_and_initialized():
246
+ return 1
247
+ return dist.get_world_size()
248
+
249
+
250
+ def get_rank():
251
+ if not is_dist_avail_and_initialized():
252
+ return 0
253
+ return dist.get_rank()
254
+
255
+
256
+ def is_main_process():
257
+ return get_rank() == 0
258
+
259
+
260
+ def save_on_master(*args, **kwargs):
261
+ if is_main_process():
262
+ torch.save(*args, **kwargs)
263
+
264
+
265
+ def init_distributed_mode(args):
266
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
267
+ args.rank = int(os.environ["RANK"])
268
+ args.world_size = int(os.environ["WORLD_SIZE"])
269
+ args.gpu = int(os.environ["LOCAL_RANK"])
270
+ elif "SLURM_PROCID" in os.environ:
271
+ args.rank = int(os.environ["SLURM_PROCID"])
272
+ args.gpu = args.rank % torch.cuda.device_count()
273
+ elif hasattr(args, "rank"):
274
+ pass
275
+ else:
276
+ print("Not using distributed mode")
277
+ args.distributed = False
278
+ return
279
+
280
+ args.distributed = True
281
+
282
+ torch.cuda.set_device(args.gpu)
283
+ args.dist_backend = "nccl"
284
+ print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
285
+ torch.distributed.init_process_group(
286
+ backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
287
+ )
288
+ torch.distributed.barrier()
289
+ setup_for_distributed(args.rank == 0)
290
+
291
+
292
+ def reduce_across_processes(val):
293
+ if not is_dist_avail_and_initialized():
294
+ # nothing to sync, but we still convert to tensor for consistency with the distributed case.
295
+ return torch.tensor(val)
296
+
297
+ t = torch.tensor(val, device="cuda") if isinstance(val, int) else val.clone().detach().to("cuda")
298
+ dist.barrier()
299
+ dist.all_reduce(t)
300
+ return t