UniversalAlgorithmic commited on
Commit
7213cec
·
verified ·
1 Parent(s): b4bbcfb

Upload 11 files

Browse files
neural-archicture-search/presets.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.transforms import autoaugment, transforms
3
+ from torchvision.transforms.functional import InterpolationMode
4
+
5
+
6
+ class ClassificationPresetTrain:
7
+ def __init__(
8
+ self,
9
+ *,
10
+ crop_size,
11
+ mean=(0.485, 0.456, 0.406),
12
+ std=(0.229, 0.224, 0.225),
13
+ interpolation=InterpolationMode.BILINEAR,
14
+ hflip_prob=0.5,
15
+ auto_augment_policy=None,
16
+ ra_magnitude=9,
17
+ augmix_severity=3,
18
+ random_erase_prob=0.0,
19
+ ):
20
+ trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
21
+ if hflip_prob > 0:
22
+ trans.append(transforms.RandomHorizontalFlip(hflip_prob))
23
+ if auto_augment_policy is not None:
24
+ if auto_augment_policy == "ra":
25
+ trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
26
+ elif auto_augment_policy == "ta_wide":
27
+ trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
28
+ elif auto_augment_policy == "augmix":
29
+ trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity))
30
+ else:
31
+ aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
32
+ trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
33
+ trans.extend(
34
+ [
35
+ transforms.PILToTensor(),
36
+ transforms.ConvertImageDtype(torch.float),
37
+ transforms.Normalize(mean=mean, std=std),
38
+ ]
39
+ )
40
+ if random_erase_prob > 0:
41
+ trans.append(transforms.RandomErasing(p=random_erase_prob))
42
+
43
+ self.transforms = transforms.Compose(trans)
44
+
45
+ def __call__(self, img):
46
+ return self.transforms(img)
47
+
48
+
49
+ class ClassificationPresetEval:
50
+ def __init__(
51
+ self,
52
+ *,
53
+ crop_size,
54
+ resize_size=256,
55
+ mean=(0.485, 0.456, 0.406),
56
+ std=(0.229, 0.224, 0.225),
57
+ interpolation=InterpolationMode.BILINEAR,
58
+ ):
59
+
60
+ self.transforms = transforms.Compose(
61
+ [
62
+ transforms.Resize(resize_size, interpolation=interpolation),
63
+ transforms.CenterCrop(crop_size),
64
+ transforms.PILToTensor(),
65
+ transforms.ConvertImageDtype(torch.float),
66
+ transforms.Normalize(mean=mean, std=std),
67
+ ]
68
+ )
69
+
70
+ def __call__(self, img):
71
+ return self.transforms(img)
neural-archicture-search/resnet18/model_3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e728a634490a078e1a672f464b9baebc04774f83b03fc251ad2437a2731330a0
3
+ size 136133334
neural-archicture-search/resnet34/model_8.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:179a2d39490980a6cc801c3ef15230bfe08d7e941174d79e2099c8db8b11dfcf
3
+ size 202898970
neural-archicture-search/resnet50/model_9.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fecb231a0e220e46dde1025a7403bb1f587d81f6d36d143ba1c510c3b477a122
3
+ size 431365452
neural-archicture-search/run.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # # ✅ Test: Acc@1 70.092 Acc@5 89.314
2
+ # torchrun --nproc_per_node=4 train.py\
3
+ # --data-path /home/cs/Documents/datasets/imagenet\
4
+ # --model resnet18 --output-dir resnet18 --weights ResNet18_Weights.IMAGENET1K_V1\
5
+ # --batch-size 128 --epochs 10 --lr 0.0004 --lr-step-size 2 --lr-gamma 0.5\
6
+ # --lr-warmup-method constant --lr-warmup-epochs 1 --lr-warmup-decay 0.\
7
+ # --apply-trp --trp-depths 3 3 3 --trp-planes 256 --trp-lambdas 0.4 0.2 0.1 --print-freq 100
8
+ torchrun --nproc_per_node=4 train.py\
9
+ --data-path /home/cs/Documents/datasets/imagenet\
10
+ --model resnet18 --resume resnet18/model_3.pth --test-only
11
+
12
+ # # ✅ Test: Acc@1 73.900 Acc@5 91.536
13
+ # torchrun --nproc_per_node=4 train.py\
14
+ # --data-path /home/cs/Documents/datasets/imagenet\
15
+ # --model resnet34 --output-dir resnet34 --weights ResNet34_Weights.IMAGENET1K_V1\
16
+ # --batch-size 96 --epochs 10 --lr 0.0004 --lr-step-size 2 --lr-gamma 0.5\
17
+ # --lr-warmup-method constant --lr-warmup-epochs 1 --lr-warmup-decay 0.\
18
+ # --apply-trp --trp-depths 2 2 2 --trp-planes 256 --trp-lambdas 0.4 0.2 0.1 --print-freq 100
19
+ # torchrun --nproc_per_node=4 train.py\
20
+ # --data-path /home/cs/Documents/datasets/imagenet\
21
+ # --model resnet34 --resume resnet34/model_8.pth --test-only
22
+
23
+
24
+ # # ✅ Test: Acc@1 76.896 Acc@5 93.136
25
+ # torchrun --nproc_per_node=4 train.py\
26
+ # --data-path /home/cs/Documents/datasets/imagenet\
27
+ # --model resnet50 --output-dir resnet50 --weights ResNet50_Weights.IMAGENET1K_V1\
28
+ # --batch-size 64 --epochs 10 --lr 0.0004 --lr-step-size 2 --lr-gamma 0.5\
29
+ # --lr-warmup-method constant --lr-warmup-epochs 1 --lr-warmup-decay 0.\
30
+ # --apply-trp --trp-depths 1 1 1 --trp-planes 1024 --trp-lambdas 0.4 0.2 0.1 --print-freq 100
31
+ # torchrun --nproc_per_node=4 train.py\
32
+ # --data-path /home/cs/Documents/datasets/imagenet\
33
+ # --model resnet50 --resume resnet50/model_9.pth --test-only
neural-archicture-search/sampler.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+
6
+
7
+ class RASampler(torch.utils.data.Sampler):
8
+ """Sampler that restricts data loading to a subset of the dataset for distributed,
9
+ with repeated augmentation.
10
+ It ensures that different each augmented version of a sample will be visible to a
11
+ different process (GPU).
12
+ Heavily based on 'torch.utils.data.DistributedSampler'.
13
+
14
+ This is borrowed from the DeiT Repo:
15
+ https://github.com/facebookresearch/deit/blob/main/samplers.py
16
+ """
17
+
18
+ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3):
19
+ if num_replicas is None:
20
+ if not dist.is_available():
21
+ raise RuntimeError("Requires distributed package to be available!")
22
+ num_replicas = dist.get_world_size()
23
+ if rank is None:
24
+ if not dist.is_available():
25
+ raise RuntimeError("Requires distributed package to be available!")
26
+ rank = dist.get_rank()
27
+ self.dataset = dataset
28
+ self.num_replicas = num_replicas
29
+ self.rank = rank
30
+ self.epoch = 0
31
+ self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
32
+ self.total_size = self.num_samples * self.num_replicas
33
+ self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
34
+ self.shuffle = shuffle
35
+ self.seed = seed
36
+ self.repetitions = repetitions
37
+
38
+ def __iter__(self):
39
+ if self.shuffle:
40
+ # Deterministically shuffle based on epoch
41
+ g = torch.Generator()
42
+ g.manual_seed(self.seed + self.epoch)
43
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
44
+ else:
45
+ indices = list(range(len(self.dataset)))
46
+
47
+ # Add extra samples to make it evenly divisible
48
+ indices = [ele for ele in indices for i in range(self.repetitions)]
49
+ indices += indices[: (self.total_size - len(indices))]
50
+ assert len(indices) == self.total_size
51
+
52
+ # Subsample
53
+ indices = indices[self.rank : self.total_size : self.num_replicas]
54
+ assert len(indices) == self.num_samples
55
+
56
+ return iter(indices[: self.num_selected_samples])
57
+
58
+ def __len__(self):
59
+ return self.num_selected_samples
60
+
61
+ def set_epoch(self, epoch):
62
+ self.epoch = epoch
neural-archicture-search/train.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+ import time
4
+ import warnings
5
+
6
+ import presets
7
+ import torch
8
+ import torch.utils.data
9
+ import torchvision
10
+ import transforms
11
+ import utils
12
+ from sampler import RASampler
13
+ from torch import nn
14
+ from torch.utils.data.dataloader import default_collate
15
+ from torchvision.transforms.functional import InterpolationMode
16
+
17
+ from trplib import apply_trp
18
+
19
+
20
+ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
21
+ model.train()
22
+ metric_logger = utils.MetricLogger(delimiter=" ")
23
+ metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
24
+ metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
25
+
26
+ header = f"Epoch: [{epoch}]"
27
+ for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
28
+ start_time = time.time()
29
+ image, target = image.to(device), target.to(device)
30
+ with torch.amp.autocast("cuda", enabled=scaler is not None):
31
+ # output = model(image)
32
+ # loss = criterion(output, target)
33
+ output, loss = model(image, target)
34
+
35
+ optimizer.zero_grad()
36
+ if scaler is not None:
37
+ scaler.scale(loss).backward()
38
+ if args.clip_grad_norm is not None:
39
+ # we should unscale the gradients of optimizer's assigned params if do gradient clipping
40
+ scaler.unscale_(optimizer)
41
+ nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
42
+ scaler.step(optimizer)
43
+ scaler.update()
44
+ else:
45
+ loss.backward()
46
+ if args.clip_grad_norm is not None:
47
+ nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
48
+ optimizer.step()
49
+
50
+ if model_ema and i % args.model_ema_steps == 0:
51
+ model_ema.update_parameters(model)
52
+ if epoch < args.lr_warmup_epochs:
53
+ # Reset ema buffer to keep copying weights during warmup period
54
+ model_ema.n_averaged.fill_(0)
55
+
56
+ acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
57
+ batch_size = image.shape[0]
58
+ metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
59
+ metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
60
+ metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
61
+ metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
62
+
63
+
64
+ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
65
+ model.eval()
66
+ metric_logger = utils.MetricLogger(delimiter=" ")
67
+ header = f"Test: {log_suffix}"
68
+
69
+ num_processed_samples = 0
70
+ with torch.inference_mode():
71
+ for image, target in metric_logger.log_every(data_loader, print_freq, header):
72
+ image = image.to(device, non_blocking=True)
73
+ target = target.to(device, non_blocking=True)
74
+ output = model(image)
75
+ loss = criterion(output, target)
76
+
77
+ acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
78
+ # FIXME need to take into account that the datasets
79
+ # could have been padded in distributed setup
80
+ batch_size = image.shape[0]
81
+ metric_logger.update(loss=loss.item())
82
+ metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
83
+ metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
84
+ num_processed_samples += batch_size
85
+ # gather the stats from all processes
86
+
87
+ num_processed_samples = utils.reduce_across_processes(num_processed_samples)
88
+ if (
89
+ hasattr(data_loader.dataset, "__len__")
90
+ and len(data_loader.dataset) != num_processed_samples
91
+ and torch.distributed.get_rank() == 0
92
+ ):
93
+ # See FIXME above
94
+ warnings.warn(
95
+ f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
96
+ "samples were used for the validation, which might bias the results. "
97
+ "Try adjusting the batch size and / or the world size. "
98
+ "Setting the world size to 1 is always a safe bet."
99
+ )
100
+
101
+ metric_logger.synchronize_between_processes()
102
+
103
+ print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
104
+ return metric_logger.acc1.global_avg
105
+
106
+
107
+ def _get_cache_path(filepath):
108
+ import hashlib
109
+
110
+ h = hashlib.sha1(filepath.encode()).hexdigest()
111
+ cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
112
+ cache_path = os.path.expanduser(cache_path)
113
+ return cache_path
114
+
115
+
116
+ def load_data(traindir, valdir, args):
117
+ # Data loading code
118
+ print("Loading data")
119
+ val_resize_size, val_crop_size, train_crop_size = (
120
+ args.val_resize_size,
121
+ args.val_crop_size,
122
+ args.train_crop_size,
123
+ )
124
+ interpolation = InterpolationMode(args.interpolation)
125
+
126
+ print("Loading training data")
127
+ st = time.time()
128
+ cache_path = _get_cache_path(traindir)
129
+ if args.cache_dataset and os.path.exists(cache_path):
130
+ # Attention, as the transforms are also cached!
131
+ print(f"Loading dataset_train from {cache_path}")
132
+ dataset, _ = torch.load(cache_path)
133
+ else:
134
+ auto_augment_policy = getattr(args, "auto_augment", None)
135
+ random_erase_prob = getattr(args, "random_erase", 0.0)
136
+ ra_magnitude = args.ra_magnitude
137
+ augmix_severity = args.augmix_severity
138
+ dataset = torchvision.datasets.ImageFolder(
139
+ traindir,
140
+ presets.ClassificationPresetTrain(
141
+ crop_size=train_crop_size,
142
+ interpolation=interpolation,
143
+ auto_augment_policy=auto_augment_policy,
144
+ random_erase_prob=random_erase_prob,
145
+ ra_magnitude=ra_magnitude,
146
+ augmix_severity=augmix_severity,
147
+ ),
148
+ )
149
+ if args.cache_dataset:
150
+ print(f"Saving dataset_train to {cache_path}")
151
+ utils.mkdir(os.path.dirname(cache_path))
152
+ utils.save_on_master((dataset, traindir), cache_path)
153
+ print("Took", time.time() - st)
154
+
155
+ print("Loading validation data")
156
+ cache_path = _get_cache_path(valdir)
157
+ if args.cache_dataset and os.path.exists(cache_path):
158
+ # Attention, as the transforms are also cached!
159
+ print(f"Loading dataset_test from {cache_path}")
160
+ dataset_test, _ = torch.load(cache_path)
161
+ else:
162
+ if args.weights and args.test_only:
163
+ weights = torchvision.models.get_weight(args.weights)
164
+ preprocessing = weights.transforms()
165
+ else:
166
+ preprocessing = presets.ClassificationPresetEval(
167
+ crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
168
+ )
169
+
170
+ dataset_test = torchvision.datasets.ImageFolder(
171
+ valdir,
172
+ preprocessing,
173
+ )
174
+ if args.cache_dataset:
175
+ print(f"Saving dataset_test to {cache_path}")
176
+ utils.mkdir(os.path.dirname(cache_path))
177
+ utils.save_on_master((dataset_test, valdir), cache_path)
178
+
179
+ print("Creating data loaders")
180
+ if args.distributed:
181
+ if hasattr(args, "ra_sampler") and args.ra_sampler:
182
+ train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
183
+ else:
184
+ train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
185
+ test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
186
+ else:
187
+ train_sampler = torch.utils.data.RandomSampler(dataset)
188
+ test_sampler = torch.utils.data.SequentialSampler(dataset_test)
189
+
190
+ return dataset, dataset_test, train_sampler, test_sampler
191
+
192
+
193
+ def main(args):
194
+ if args.output_dir:
195
+ utils.mkdir(args.output_dir)
196
+
197
+ utils.init_distributed_mode(args)
198
+ print(args)
199
+
200
+ device = torch.device(args.device)
201
+
202
+ if args.use_deterministic_algorithms:
203
+ torch.backends.cudnn.benchmark = False
204
+ torch.use_deterministic_algorithms(True)
205
+ else:
206
+ torch.backends.cudnn.benchmark = True
207
+
208
+ train_dir = os.path.join(args.data_path, "train")
209
+ val_dir = os.path.join(args.data_path, "val")
210
+ dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
211
+
212
+ collate_fn = None
213
+ num_classes = len(dataset.classes)
214
+ mixup_transforms = []
215
+ if args.mixup_alpha > 0.0:
216
+ mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
217
+ if args.cutmix_alpha > 0.0:
218
+ mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
219
+ if mixup_transforms:
220
+ mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
221
+
222
+ def collate_fn(batch):
223
+ return mixupcutmix(*default_collate(batch))
224
+
225
+ data_loader = torch.utils.data.DataLoader(
226
+ dataset,
227
+ batch_size=args.batch_size,
228
+ sampler=train_sampler,
229
+ num_workers=args.workers,
230
+ pin_memory=True,
231
+ collate_fn=collate_fn,
232
+ )
233
+ data_loader_test = torch.utils.data.DataLoader(
234
+ dataset_test, batch_size=8, sampler=test_sampler, num_workers=args.workers, pin_memory=True
235
+ )
236
+
237
+ print("Creating model")
238
+ model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
239
+ if args.apply_trp:
240
+ model = apply_trp(model, args.trp_depths, args.trp_planes, args.trp_lambdas, label_smoothing=args.label_smoothing)
241
+ model.to(device)
242
+
243
+ if args.distributed and args.sync_bn:
244
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
245
+
246
+ criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
247
+
248
+ custom_keys_weight_decay = []
249
+ if args.bias_weight_decay is not None:
250
+ custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
251
+ if args.transformer_embedding_decay is not None:
252
+ for key in ["class_token", "position_embedding", "relative_position_bias_table"]:
253
+ custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
254
+ parameters = utils.set_weight_decay(
255
+ model,
256
+ args.weight_decay,
257
+ norm_weight_decay=args.norm_weight_decay,
258
+ custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
259
+ )
260
+
261
+ opt_name = args.opt.lower()
262
+ if opt_name.startswith("sgd"):
263
+ optimizer = torch.optim.SGD(
264
+ parameters,
265
+ lr=args.lr,
266
+ momentum=args.momentum,
267
+ weight_decay=args.weight_decay,
268
+ nesterov="nesterov" in opt_name,
269
+ )
270
+ elif opt_name == "rmsprop":
271
+ optimizer = torch.optim.RMSprop(
272
+ parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
273
+ )
274
+ elif opt_name == "adamw":
275
+ optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
276
+ else:
277
+ raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")
278
+
279
+ scaler = torch.amp.GradScaler("cuda") if args.amp else None
280
+
281
+ args.lr_scheduler = args.lr_scheduler.lower()
282
+ if args.lr_scheduler == "steplr":
283
+ main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
284
+ elif args.lr_scheduler == "cosineannealinglr":
285
+ main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
286
+ optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min
287
+ )
288
+ elif args.lr_scheduler == "exponentiallr":
289
+ main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
290
+ else:
291
+ raise RuntimeError(
292
+ f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
293
+ "are supported."
294
+ )
295
+
296
+ if args.lr_warmup_epochs > 0:
297
+ if args.lr_warmup_method == "linear":
298
+ warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
299
+ optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
300
+ )
301
+ elif args.lr_warmup_method == "constant":
302
+ warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
303
+ optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
304
+ )
305
+ else:
306
+ raise RuntimeError(
307
+ f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
308
+ )
309
+ lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
310
+ optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
311
+ )
312
+ else:
313
+ lr_scheduler = main_lr_scheduler
314
+
315
+ model_without_ddp = model
316
+ if args.distributed:
317
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
318
+ model_without_ddp = model.module
319
+
320
+ model_ema = None
321
+ if args.model_ema:
322
+ # Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at:
323
+ # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
324
+ #
325
+ # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
326
+ # We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus:
327
+ # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
328
+ adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
329
+ alpha = 1.0 - args.model_ema_decay
330
+ alpha = min(1.0, alpha * adjust)
331
+ model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
332
+
333
+ if args.resume:
334
+ checkpoint = torch.load(args.resume, map_location="cpu", weights_only=False)
335
+ model_without_ddp.load_state_dict(checkpoint["model"])
336
+ if not args.test_only:
337
+ optimizer.load_state_dict(checkpoint["optimizer"])
338
+ lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
339
+ args.start_epoch = checkpoint["epoch"] + 1
340
+ if model_ema:
341
+ model_ema.load_state_dict(checkpoint["model_ema"])
342
+ if scaler:
343
+ scaler.load_state_dict(checkpoint["scaler"])
344
+
345
+ if args.test_only:
346
+ # We disable the cudnn benchmarking because it can noticeably affect the accuracy
347
+ torch.backends.cudnn.benchmark = False
348
+ torch.backends.cudnn.deterministic = True
349
+ if model_ema:
350
+ evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
351
+ else:
352
+ evaluate(model, criterion, data_loader_test, device=device)
353
+ return
354
+
355
+ print("Start training")
356
+ start_time = time.time()
357
+ for epoch in range(args.start_epoch, args.epochs):
358
+ if args.distributed:
359
+ train_sampler.set_epoch(epoch)
360
+ train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
361
+ lr_scheduler.step()
362
+ evaluate(model, criterion, data_loader_test, device=device)
363
+ if model_ema:
364
+ evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
365
+ if args.output_dir:
366
+ checkpoint = {
367
+ "model": model_without_ddp.state_dict() if not args.apply_trp else {k: v for k, v in model_without_ddp.state_dict().items() if not k.startswith("trp_blocks")}, # NOTE: remove TRP heads
368
+ "optimizer": optimizer.state_dict(),
369
+ "lr_scheduler": lr_scheduler.state_dict(),
370
+ "epoch": epoch,
371
+ "args": args,
372
+ }
373
+ if model_ema:
374
+ checkpoint["model_ema"] = model_ema.state_dict() if not args.apply_trp else {k: v for k, v in model_ema.state_dict().items() if not k.startswith("trp_blocks")} # NOTE: remove TRP heads
375
+ if scaler:
376
+ checkpoint["scaler"] = scaler.state_dict()
377
+ utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
378
+ utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
379
+
380
+ total_time = time.time() - start_time
381
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
382
+ print(f"Training time {total_time_str}")
383
+
384
+
385
+ def get_args_parser(add_help=True):
386
+ import argparse
387
+
388
+ parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
389
+
390
+ parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
391
+ parser.add_argument("--model", default="resnet18", type=str, help="model name")
392
+ parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
393
+ parser.add_argument(
394
+ "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
395
+ )
396
+ parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
397
+ parser.add_argument(
398
+ "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
399
+ )
400
+ parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
401
+ parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
402
+ parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
403
+ parser.add_argument(
404
+ "--wd",
405
+ "--weight-decay",
406
+ default=1e-4,
407
+ type=float,
408
+ metavar="W",
409
+ help="weight decay (default: 1e-4)",
410
+ dest="weight_decay",
411
+ )
412
+ parser.add_argument(
413
+ "--norm-weight-decay",
414
+ default=None,
415
+ type=float,
416
+ help="weight decay for Normalization layers (default: None, same value as --wd)",
417
+ )
418
+ parser.add_argument(
419
+ "--bias-weight-decay",
420
+ default=None,
421
+ type=float,
422
+ help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
423
+ )
424
+ parser.add_argument(
425
+ "--transformer-embedding-decay",
426
+ default=None,
427
+ type=float,
428
+ help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
429
+ )
430
+ parser.add_argument(
431
+ "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
432
+ )
433
+ parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
434
+ parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
435
+ parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
436
+ parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
437
+ parser.add_argument(
438
+ "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
439
+ )
440
+ parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
441
+ parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
442
+ parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
443
+ parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
444
+ parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
445
+ parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
446
+ parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
447
+ parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
448
+ parser.add_argument(
449
+ "--cache-dataset",
450
+ dest="cache_dataset",
451
+ help="Cache the datasets for quicker initialization. It also serializes the transforms",
452
+ action="store_true",
453
+ )
454
+ parser.add_argument(
455
+ "--sync-bn",
456
+ dest="sync_bn",
457
+ help="Use sync batch norm",
458
+ action="store_true",
459
+ )
460
+ parser.add_argument(
461
+ "--test-only",
462
+ dest="test_only",
463
+ help="Only test the model",
464
+ action="store_true",
465
+ )
466
+ parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
467
+ parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy")
468
+ parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy")
469
+ parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
470
+
471
+ # Mixed precision training parameters
472
+ parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
473
+
474
+ # distributed training parameters
475
+ parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
476
+ parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
477
+ parser.add_argument(
478
+ "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
479
+ )
480
+ parser.add_argument(
481
+ "--model-ema-steps",
482
+ type=int,
483
+ default=32,
484
+ help="the number of iterations that controls how often to update the EMA model (default: 32)",
485
+ )
486
+ parser.add_argument(
487
+ "--model-ema-decay",
488
+ type=float,
489
+ default=0.99998,
490
+ help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
491
+ )
492
+ parser.add_argument(
493
+ "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
494
+ )
495
+ parser.add_argument(
496
+ "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
497
+ )
498
+ parser.add_argument(
499
+ "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
500
+ )
501
+ parser.add_argument(
502
+ "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
503
+ )
504
+ parser.add_argument(
505
+ "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
506
+ )
507
+ parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
508
+ parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
509
+ parser.add_argument(
510
+ "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
511
+ )
512
+ parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
513
+
514
+ parser.add_argument("--apply-trp", action="store_true", help="enable applying trp")
515
+ parser.add_argument("--trp-depths", nargs="+", type=int, help="number of layers for trp block")
516
+ parser.add_argument("--trp-planes", default=1024, type=int, help="channels of the hidden state")
517
+ parser.add_argument("--trp-lambdas", nargs="+", type=float, help="trp lambdas")
518
+
519
+ return parser
520
+
521
+
522
+ if __name__ == "__main__":
523
+ args = get_args_parser().parse_args()
524
+ main(args)
neural-archicture-search/train_quantization.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import datetime
3
+ import os
4
+ import time
5
+
6
+ import torch
7
+ import torch.ao.quantization
8
+ import torch.utils.data
9
+ import torchvision
10
+ import utils
11
+ from torch import nn
12
+ from train import evaluate, load_data, train_one_epoch
13
+
14
+
15
+ def main(args):
16
+ if args.output_dir:
17
+ utils.mkdir(args.output_dir)
18
+
19
+ utils.init_distributed_mode(args)
20
+ print(args)
21
+
22
+ if args.post_training_quantize and args.distributed:
23
+ raise RuntimeError("Post training quantization example should not be performed on distributed mode")
24
+
25
+ # Set backend engine to ensure that quantized model runs on the correct kernels
26
+ if args.backend not in torch.backends.quantized.supported_engines:
27
+ raise RuntimeError("Quantized backend not supported: " + str(args.backend))
28
+ torch.backends.quantized.engine = args.backend
29
+
30
+ device = torch.device(args.device)
31
+ torch.backends.cudnn.benchmark = True
32
+
33
+ # Data loading code
34
+ print("Loading data")
35
+ train_dir = os.path.join(args.data_path, "train")
36
+ val_dir = os.path.join(args.data_path, "val")
37
+
38
+ dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
39
+ data_loader = torch.utils.data.DataLoader(
40
+ dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True
41
+ )
42
+
43
+ data_loader_test = torch.utils.data.DataLoader(
44
+ dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
45
+ )
46
+
47
+ print("Creating model", args.model)
48
+ # when training quantized models, we always start from a pre-trained fp32 reference model
49
+ prefix = "quantized_"
50
+ model_name = args.model
51
+ if not model_name.startswith(prefix):
52
+ model_name = prefix + model_name
53
+ model = torchvision.models.get_model(model_name, weights=args.weights, quantize=args.test_only)
54
+ model.to(device)
55
+
56
+ if not (args.test_only or args.post_training_quantize):
57
+ model.fuse_model(is_qat=True)
58
+ model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
59
+ torch.ao.quantization.prepare_qat(model, inplace=True)
60
+
61
+ if args.distributed and args.sync_bn:
62
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
63
+
64
+ optimizer = torch.optim.SGD(
65
+ model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
66
+ )
67
+
68
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
69
+
70
+ criterion = nn.CrossEntropyLoss()
71
+ model_without_ddp = model
72
+ if args.distributed:
73
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
74
+ model_without_ddp = model.module
75
+
76
+ if args.resume:
77
+ checkpoint = torch.load(args.resume, map_location="cpu")
78
+ model_without_ddp.load_state_dict(checkpoint["model"])
79
+ optimizer.load_state_dict(checkpoint["optimizer"])
80
+ lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
81
+ args.start_epoch = checkpoint["epoch"] + 1
82
+
83
+ if args.post_training_quantize:
84
+ # perform calibration on a subset of the training dataset
85
+ # for that, create a subset of the training dataset
86
+ ds = torch.utils.data.Subset(dataset, indices=list(range(args.batch_size * args.num_calibration_batches)))
87
+ data_loader_calibration = torch.utils.data.DataLoader(
88
+ ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
89
+ )
90
+ model.eval()
91
+ model.fuse_model(is_qat=False)
92
+ model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
93
+ torch.ao.quantization.prepare(model, inplace=True)
94
+ # Calibrate first
95
+ print("Calibrating")
96
+ evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
97
+ torch.ao.quantization.convert(model, inplace=True)
98
+ if args.output_dir:
99
+ print("Saving quantized model")
100
+ if utils.is_main_process():
101
+ torch.save(model.state_dict(), os.path.join(args.output_dir, "quantized_post_train_model.pth"))
102
+ print("Evaluating post-training quantized model")
103
+ evaluate(model, criterion, data_loader_test, device=device)
104
+ return
105
+
106
+ if args.test_only:
107
+ evaluate(model, criterion, data_loader_test, device=device)
108
+ return
109
+
110
+ model.apply(torch.ao.quantization.enable_observer)
111
+ model.apply(torch.ao.quantization.enable_fake_quant)
112
+ start_time = time.time()
113
+ for epoch in range(args.start_epoch, args.epochs):
114
+ if args.distributed:
115
+ train_sampler.set_epoch(epoch)
116
+ print("Starting training for epoch", epoch)
117
+ train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args)
118
+ lr_scheduler.step()
119
+ with torch.inference_mode():
120
+ if epoch >= args.num_observer_update_epochs:
121
+ print("Disabling observer for subseq epochs, epoch = ", epoch)
122
+ model.apply(torch.ao.quantization.disable_observer)
123
+ if epoch >= args.num_batch_norm_update_epochs:
124
+ print("Freezing BN for subseq epochs, epoch = ", epoch)
125
+ model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
126
+ print("Evaluate QAT model")
127
+
128
+ evaluate(model, criterion, data_loader_test, device=device, log_suffix="QAT")
129
+ quantized_eval_model = copy.deepcopy(model_without_ddp)
130
+ quantized_eval_model.eval()
131
+ quantized_eval_model.to(torch.device("cpu"))
132
+ torch.ao.quantization.convert(quantized_eval_model, inplace=True)
133
+
134
+ print("Evaluate Quantized model")
135
+ evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
136
+
137
+ model.train()
138
+
139
+ if args.output_dir:
140
+ checkpoint = {
141
+ "model": model_without_ddp.state_dict(),
142
+ "eval_model": quantized_eval_model.state_dict(),
143
+ "optimizer": optimizer.state_dict(),
144
+ "lr_scheduler": lr_scheduler.state_dict(),
145
+ "epoch": epoch,
146
+ "args": args,
147
+ }
148
+ utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
149
+ utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
150
+ print("Saving models after epoch ", epoch)
151
+
152
+ total_time = time.time() - start_time
153
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
154
+ print(f"Training time {total_time_str}")
155
+
156
+
157
+ def get_args_parser(add_help=True):
158
+ import argparse
159
+
160
+ parser = argparse.ArgumentParser(description="PyTorch Quantized Classification Training", add_help=add_help)
161
+
162
+ parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
163
+ parser.add_argument("--model", default="mobilenet_v2", type=str, help="model name")
164
+ parser.add_argument("--backend", default="qnnpack", type=str, help="fbgemm or qnnpack")
165
+ parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
166
+
167
+ parser.add_argument(
168
+ "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
169
+ )
170
+ parser.add_argument("--eval-batch-size", default=128, type=int, help="batch size for evaluation")
171
+ parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
172
+ parser.add_argument(
173
+ "--num-observer-update-epochs",
174
+ default=4,
175
+ type=int,
176
+ metavar="N",
177
+ help="number of total epochs to update observers",
178
+ )
179
+ parser.add_argument(
180
+ "--num-batch-norm-update-epochs",
181
+ default=3,
182
+ type=int,
183
+ metavar="N",
184
+ help="number of total epochs to update batch norm stats",
185
+ )
186
+ parser.add_argument(
187
+ "--num-calibration-batches",
188
+ default=32,
189
+ type=int,
190
+ metavar="N",
191
+ help="number of batches of training set for \
192
+ observer calibration ",
193
+ )
194
+
195
+ parser.add_argument(
196
+ "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
197
+ )
198
+ parser.add_argument("--lr", default=0.0001, type=float, help="initial learning rate")
199
+ parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
200
+ parser.add_argument(
201
+ "--wd",
202
+ "--weight-decay",
203
+ default=1e-4,
204
+ type=float,
205
+ metavar="W",
206
+ help="weight decay (default: 1e-4)",
207
+ dest="weight_decay",
208
+ )
209
+ parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
210
+ parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
211
+ parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
212
+ parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
213
+ parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
214
+ parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
215
+ parser.add_argument(
216
+ "--cache-dataset",
217
+ dest="cache_dataset",
218
+ help="Cache the datasets for quicker initialization. \
219
+ It also serializes the transforms",
220
+ action="store_true",
221
+ )
222
+ parser.add_argument(
223
+ "--sync-bn",
224
+ dest="sync_bn",
225
+ help="Use sync batch norm",
226
+ action="store_true",
227
+ )
228
+ parser.add_argument(
229
+ "--test-only",
230
+ dest="test_only",
231
+ help="Only test the model",
232
+ action="store_true",
233
+ )
234
+ parser.add_argument(
235
+ "--post-training-quantize",
236
+ dest="post_training_quantize",
237
+ help="Post training quantize the model",
238
+ action="store_true",
239
+ )
240
+
241
+ # distributed training parameters
242
+ parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
243
+ parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
244
+
245
+ parser.add_argument(
246
+ "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
247
+ )
248
+ parser.add_argument(
249
+ "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
250
+ )
251
+ parser.add_argument(
252
+ "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
253
+ )
254
+ parser.add_argument(
255
+ "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
256
+ )
257
+ parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
258
+ parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
259
+
260
+ return parser
261
+
262
+
263
+ if __name__ == "__main__":
264
+ args = get_args_parser().parse_args()
265
+ main(args)
neural-archicture-search/transforms.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torchvision.transforms import functional as F
7
+
8
+
9
+ class RandomMixup(torch.nn.Module):
10
+ """Randomly apply Mixup to the provided batch and targets.
11
+ The class implements the data augmentations as described in the paper
12
+ `"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
13
+
14
+ Args:
15
+ num_classes (int): number of classes used for one-hot encoding.
16
+ p (float): probability of the batch being transformed. Default value is 0.5.
17
+ alpha (float): hyperparameter of the Beta distribution used for mixup.
18
+ Default value is 1.0.
19
+ inplace (bool): boolean to make this transform inplace. Default set to False.
20
+ """
21
+
22
+ def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
23
+ super().__init__()
24
+
25
+ if num_classes < 1:
26
+ raise ValueError(
27
+ f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}"
28
+ )
29
+
30
+ if alpha <= 0:
31
+ raise ValueError("Alpha param can't be zero.")
32
+
33
+ self.num_classes = num_classes
34
+ self.p = p
35
+ self.alpha = alpha
36
+ self.inplace = inplace
37
+
38
+ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
39
+ """
40
+ Args:
41
+ batch (Tensor): Float tensor of size (B, C, H, W)
42
+ target (Tensor): Integer tensor of size (B, )
43
+
44
+ Returns:
45
+ Tensor: Randomly transformed batch.
46
+ """
47
+ if batch.ndim != 4:
48
+ raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
49
+ if target.ndim != 1:
50
+ raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
51
+ if not batch.is_floating_point():
52
+ raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
53
+ if target.dtype != torch.int64:
54
+ raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
55
+
56
+ if not self.inplace:
57
+ batch = batch.clone()
58
+ target = target.clone()
59
+
60
+ if target.ndim == 1:
61
+ target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
62
+
63
+ if torch.rand(1).item() >= self.p:
64
+ return batch, target
65
+
66
+ # It's faster to roll the batch by one instead of shuffling it to create image pairs
67
+ batch_rolled = batch.roll(1, 0)
68
+ target_rolled = target.roll(1, 0)
69
+
70
+ # Implemented as on mixup paper, page 3.
71
+ lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
72
+ batch_rolled.mul_(1.0 - lambda_param)
73
+ batch.mul_(lambda_param).add_(batch_rolled)
74
+
75
+ target_rolled.mul_(1.0 - lambda_param)
76
+ target.mul_(lambda_param).add_(target_rolled)
77
+
78
+ return batch, target
79
+
80
+ def __repr__(self) -> str:
81
+ s = (
82
+ f"{self.__class__.__name__}("
83
+ f"num_classes={self.num_classes}"
84
+ f", p={self.p}"
85
+ f", alpha={self.alpha}"
86
+ f", inplace={self.inplace}"
87
+ f")"
88
+ )
89
+ return s
90
+
91
+
92
+ class RandomCutmix(torch.nn.Module):
93
+ """Randomly apply Cutmix to the provided batch and targets.
94
+ The class implements the data augmentations as described in the paper
95
+ `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
96
+ <https://arxiv.org/abs/1905.04899>`_.
97
+
98
+ Args:
99
+ num_classes (int): number of classes used for one-hot encoding.
100
+ p (float): probability of the batch being transformed. Default value is 0.5.
101
+ alpha (float): hyperparameter of the Beta distribution used for cutmix.
102
+ Default value is 1.0.
103
+ inplace (bool): boolean to make this transform inplace. Default set to False.
104
+ """
105
+
106
+ def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
107
+ super().__init__()
108
+ if num_classes < 1:
109
+ raise ValueError("Please provide a valid positive value for the num_classes.")
110
+ if alpha <= 0:
111
+ raise ValueError("Alpha param can't be zero.")
112
+
113
+ self.num_classes = num_classes
114
+ self.p = p
115
+ self.alpha = alpha
116
+ self.inplace = inplace
117
+
118
+ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
119
+ """
120
+ Args:
121
+ batch (Tensor): Float tensor of size (B, C, H, W)
122
+ target (Tensor): Integer tensor of size (B, )
123
+
124
+ Returns:
125
+ Tensor: Randomly transformed batch.
126
+ """
127
+ if batch.ndim != 4:
128
+ raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
129
+ if target.ndim != 1:
130
+ raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
131
+ if not batch.is_floating_point():
132
+ raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
133
+ if target.dtype != torch.int64:
134
+ raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
135
+
136
+ if not self.inplace:
137
+ batch = batch.clone()
138
+ target = target.clone()
139
+
140
+ if target.ndim == 1:
141
+ target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
142
+
143
+ if torch.rand(1).item() >= self.p:
144
+ return batch, target
145
+
146
+ # It's faster to roll the batch by one instead of shuffling it to create image pairs
147
+ batch_rolled = batch.roll(1, 0)
148
+ target_rolled = target.roll(1, 0)
149
+
150
+ # Implemented as on cutmix paper, page 12 (with minor corrections on typos).
151
+ lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
152
+ _, H, W = F.get_dimensions(batch)
153
+
154
+ r_x = torch.randint(W, (1,))
155
+ r_y = torch.randint(H, (1,))
156
+
157
+ r = 0.5 * math.sqrt(1.0 - lambda_param)
158
+ r_w_half = int(r * W)
159
+ r_h_half = int(r * H)
160
+
161
+ x1 = int(torch.clamp(r_x - r_w_half, min=0))
162
+ y1 = int(torch.clamp(r_y - r_h_half, min=0))
163
+ x2 = int(torch.clamp(r_x + r_w_half, max=W))
164
+ y2 = int(torch.clamp(r_y + r_h_half, max=H))
165
+
166
+ batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
167
+ lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
168
+
169
+ target_rolled.mul_(1.0 - lambda_param)
170
+ target.mul_(lambda_param).add_(target_rolled)
171
+
172
+ return batch, target
173
+
174
+ def __repr__(self) -> str:
175
+ s = (
176
+ f"{self.__class__.__name__}("
177
+ f"num_classes={self.num_classes}"
178
+ f", p={self.p}"
179
+ f", alpha={self.alpha}"
180
+ f", inplace={self.inplace}"
181
+ f")"
182
+ )
183
+ return s
neural-archicture-search/trplib.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+ from typing import List, Callable
3
+
4
+ import torch
5
+ from torch import nn, Tensor
6
+ from torch.nn import functional as F
7
+ from torchvision.models.resnet import BasicBlock
8
+
9
+
10
+ def trp_criterion(trp_blocks: nn.ModuleList, shared_head: Callable, criterion: Callable, lambdas: List[float], hidden_state: Tensor, logits: Tensor, targets: Tensor, loss_normalization=False):
11
+ losses, rewards = criterion(logits, targets)
12
+ returns = torch.ones_like(rewards, dtype=torch.float32, device=rewards.device)
13
+ if loss_normalization:
14
+ coeff = torch.mean(losses).detach()
15
+
16
+ embeds = [hidden_state]
17
+ predictions = []
18
+ for k, w in enumerate(lambdas):
19
+ embeds.append(trp_blocks[k](embeds[-1]))
20
+ predictions.append(shared_head(embeds[-1]))
21
+ returns = returns + w * rewards
22
+ replica_losses, rewards = criterion(predictions[-1], targets, rewards)
23
+ losses = losses + replica_losses
24
+ loss = torch.mean(losses * returns)
25
+
26
+ if loss_normalization:
27
+ with torch.no_grad():
28
+ coeff = torch.exp(coeff) / torch.exp(loss.detach())
29
+ loss = coeff * loss
30
+
31
+ return loss
32
+
33
+
34
+ class TPBlock(nn.Module):
35
+ def __init__(self, depths: int, inplanes: int, planes: int):
36
+ super(TPBlock, self).__init__()
37
+
38
+ blocks = [BasicBlock(inplanes=inplanes, planes=planes) for _ in range(depths)]
39
+ self.blocks = nn.Sequential(*blocks)
40
+ for name, param in self.blocks.named_parameters():
41
+ if 'conv' in name:
42
+ nn.init.zeros_(param) # Initialize weights
43
+ elif 'downsample' in name:
44
+ nn.init.zeros_(param) # Initialize biases
45
+
46
+ def forward(self, x):
47
+ return self.blocks(x)
48
+
49
+
50
+ class ResNetConfig:
51
+ @staticmethod
52
+ def gen_criterion(label_smoothing=0.0, top_k=1):
53
+ def func(input, target, mask=None):
54
+ """
55
+ Args:
56
+ input (Tensor): Input tensor of shape [B, C].
57
+ target (Tensor): Target labels of shape [B] or [B, C].
58
+
59
+ Returns:
60
+ loss (Tensor): Scalar tensor representing the loss.
61
+ mask (Tensor): Boolean mask tensor of shape [B].
62
+ """
63
+ label = torch.argmax(target, dim=1) if label_smoothing > 0.0 else target
64
+
65
+ unmasked_loss = F.cross_entropy(input, label, reduction="none", label_smoothing=label_smoothing)
66
+ if mask is None:
67
+ mask = torch.ones_like(unmasked_loss, dtype=torch.float32, device=target.device)
68
+ loss = torch.sum(mask * unmasked_loss) / (torch.sum(mask) + 1e-6)
69
+
70
+ with torch.no_grad():
71
+ topk_values, topk_indices = torch.topk(input, top_k, dim=-1)
72
+ mask = mask * torch.eq(topk_indices, label[:, None]).any(dim=-1).to(input.dtype)
73
+
74
+ return loss, mask
75
+ return func
76
+
77
+ @staticmethod
78
+ def gen_shared_head(self):
79
+ def func(x):
80
+ """
81
+ Args:
82
+ x (Tensor): Hidden States tensor of shape [B, C, H, Whidden_units].
83
+
84
+ Returns:
85
+ logits (Tensor): Logits tensor of shape [B, C].
86
+ """
87
+ x = self.layer4(x)
88
+ x = self.avgpool(x)
89
+ x = torch.flatten(x, 1)
90
+ logits = self.fc(x)
91
+ return logits
92
+ return func
93
+
94
+ @staticmethod
95
+ def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
96
+ def func(self, x: Tensor, targets=None) -> Tensor:
97
+ x = self.conv1(x)
98
+ x = self.bn1(x)
99
+ x = self.relu(x)
100
+ x = self.maxpool(x)
101
+
102
+ x = self.layer1(x)
103
+ x = self.layer2(x)
104
+ hidden_states = self.layer3(x)
105
+ x = self.layer4(hidden_states)
106
+ x = self.avgpool(x)
107
+ x = torch.flatten(x, 1)
108
+ logits = self.fc(x)
109
+
110
+ if self.training:
111
+ shared_head = ResNetConfig.gen_shared_head(self)
112
+ criterion = ResNetConfig.gen_criterion(label_smoothing=label_smoothing, top_k=top_k)
113
+
114
+ loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, hidden_states, logits, targets, loss_normalization=loss_normalization)
115
+
116
+ return logits, loss
117
+
118
+ return logits
119
+
120
+ return func
121
+
122
+
123
+ def apply_trp(model, depths: List[int], planes: int, lambdas: List[float], **kwargs):
124
+ print("✅ Applying TRP to ResNet for Image Classification...")
125
+ model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, inplanes=planes, planes=planes) for d in depths])
126
+ model.forward = types.MethodType(ResNetConfig.gen_forward(lambdas, True, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
127
+ return model
neural-archicture-search/utils.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import datetime
3
+ import errno
4
+ import hashlib
5
+ import os
6
+ import time
7
+ from collections import defaultdict, deque, OrderedDict
8
+ from typing import List, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.distributed as dist
12
+
13
+
14
+ class SmoothedValue:
15
+ """Track a series of values and provide access to smoothed values over a
16
+ window or the global series average.
17
+ """
18
+
19
+ def __init__(self, window_size=20, fmt=None):
20
+ if fmt is None:
21
+ fmt = "{median:.4f} ({global_avg:.4f})"
22
+ self.deque = deque(maxlen=window_size)
23
+ self.total = 0.0
24
+ self.count = 0
25
+ self.fmt = fmt
26
+
27
+ def update(self, value, n=1):
28
+ self.deque.append(value)
29
+ self.count += n
30
+ self.total += value * n
31
+
32
+ def synchronize_between_processes(self):
33
+ """
34
+ Warning: does not synchronize the deque!
35
+ """
36
+ t = reduce_across_processes([self.count, self.total])
37
+ t = t.tolist()
38
+ self.count = int(t[0])
39
+ self.total = t[1]
40
+
41
+ @property
42
+ def median(self):
43
+ d = torch.tensor(list(self.deque))
44
+ return d.median().item()
45
+
46
+ @property
47
+ def avg(self):
48
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
49
+ return d.mean().item()
50
+
51
+ @property
52
+ def global_avg(self):
53
+ return self.total / self.count
54
+
55
+ @property
56
+ def max(self):
57
+ return max(self.deque)
58
+
59
+ @property
60
+ def value(self):
61
+ return self.deque[-1]
62
+
63
+ def __str__(self):
64
+ return self.fmt.format(
65
+ median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
66
+ )
67
+
68
+
69
+ class MetricLogger:
70
+ def __init__(self, delimiter="\t"):
71
+ self.meters = defaultdict(SmoothedValue)
72
+ self.delimiter = delimiter
73
+
74
+ def update(self, **kwargs):
75
+ for k, v in kwargs.items():
76
+ if isinstance(v, torch.Tensor):
77
+ v = v.item()
78
+ assert isinstance(v, (float, int))
79
+ self.meters[k].update(v)
80
+
81
+ def __getattr__(self, attr):
82
+ if attr in self.meters:
83
+ return self.meters[attr]
84
+ if attr in self.__dict__:
85
+ return self.__dict__[attr]
86
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
87
+
88
+ def __str__(self):
89
+ loss_str = []
90
+ for name, meter in self.meters.items():
91
+ loss_str.append(f"{name}: {str(meter)}")
92
+ return self.delimiter.join(loss_str)
93
+
94
+ def synchronize_between_processes(self):
95
+ for meter in self.meters.values():
96
+ meter.synchronize_between_processes()
97
+
98
+ def add_meter(self, name, meter):
99
+ self.meters[name] = meter
100
+
101
+ def log_every(self, iterable, print_freq, header=None):
102
+ i = 0
103
+ if not header:
104
+ header = ""
105
+ start_time = time.time()
106
+ end = time.time()
107
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
108
+ data_time = SmoothedValue(fmt="{avg:.4f}")
109
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
110
+ if torch.cuda.is_available():
111
+ log_msg = self.delimiter.join(
112
+ [
113
+ header,
114
+ "[{0" + space_fmt + "}/{1}]",
115
+ "eta: {eta}",
116
+ "{meters}",
117
+ "time: {time}",
118
+ "data: {data}",
119
+ "max mem: {memory:.0f}",
120
+ ]
121
+ )
122
+ else:
123
+ log_msg = self.delimiter.join(
124
+ [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
125
+ )
126
+ MB = 1024.0 * 1024.0
127
+ for obj in iterable:
128
+ data_time.update(time.time() - end)
129
+ yield obj
130
+ iter_time.update(time.time() - end)
131
+ if i % print_freq == 0:
132
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
133
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
134
+ if torch.cuda.is_available():
135
+ print(
136
+ log_msg.format(
137
+ i,
138
+ len(iterable),
139
+ eta=eta_string,
140
+ meters=str(self),
141
+ time=str(iter_time),
142
+ data=str(data_time),
143
+ memory=torch.cuda.max_memory_allocated() / MB,
144
+ )
145
+ )
146
+ else:
147
+ print(
148
+ log_msg.format(
149
+ i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
150
+ )
151
+ )
152
+ i += 1
153
+ end = time.time()
154
+ total_time = time.time() - start_time
155
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
156
+ print(f"{header} Total time: {total_time_str}")
157
+
158
+
159
+ class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
160
+ """Maintains moving averages of model parameters using an exponential decay.
161
+ ``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
162
+ `torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
163
+ is used to compute the EMA.
164
+ """
165
+
166
+ def __init__(self, model, decay, device="cpu"):
167
+ def ema_avg(avg_model_param, model_param, num_averaged):
168
+ return decay * avg_model_param + (1 - decay) * model_param
169
+
170
+ super().__init__(model, device, ema_avg, use_buffers=True)
171
+
172
+
173
+ def accuracy(output, target, topk=(1,)):
174
+ """Computes the accuracy over the k top predictions for the specified values of k"""
175
+ with torch.inference_mode():
176
+ maxk = max(topk)
177
+ batch_size = target.size(0)
178
+ if target.ndim == 2:
179
+ target = target.max(dim=1)[1]
180
+
181
+ _, pred = output.topk(maxk, 1, True, True)
182
+ pred = pred.t()
183
+ correct = pred.eq(target[None])
184
+
185
+ res = []
186
+ for k in topk:
187
+ correct_k = correct[:k].flatten().sum(dtype=torch.float32)
188
+ res.append(correct_k * (100.0 / batch_size))
189
+ return res
190
+
191
+
192
+ def mkdir(path):
193
+ try:
194
+ os.makedirs(path)
195
+ except OSError as e:
196
+ if e.errno != errno.EEXIST:
197
+ raise
198
+
199
+
200
+ def setup_for_distributed(is_master):
201
+ """
202
+ This function disables printing when not in master process
203
+ """
204
+ import builtins as __builtin__
205
+
206
+ builtin_print = __builtin__.print
207
+
208
+ def print(*args, **kwargs):
209
+ force = kwargs.pop("force", False)
210
+ if is_master or force:
211
+ builtin_print(*args, **kwargs)
212
+
213
+ __builtin__.print = print
214
+
215
+
216
+ def is_dist_avail_and_initialized():
217
+ if not dist.is_available():
218
+ return False
219
+ if not dist.is_initialized():
220
+ return False
221
+ return True
222
+
223
+
224
+ def get_world_size():
225
+ if not is_dist_avail_and_initialized():
226
+ return 1
227
+ return dist.get_world_size()
228
+
229
+
230
+ def get_rank():
231
+ if not is_dist_avail_and_initialized():
232
+ return 0
233
+ return dist.get_rank()
234
+
235
+
236
+ def is_main_process():
237
+ return get_rank() == 0
238
+
239
+
240
+ def save_on_master(*args, **kwargs):
241
+ if is_main_process():
242
+ torch.save(*args, **kwargs)
243
+
244
+
245
+ def init_distributed_mode(args):
246
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
247
+ args.rank = int(os.environ["RANK"])
248
+ args.world_size = int(os.environ["WORLD_SIZE"])
249
+ args.gpu = int(os.environ["LOCAL_RANK"])
250
+ elif "SLURM_PROCID" in os.environ:
251
+ args.rank = int(os.environ["SLURM_PROCID"])
252
+ args.gpu = args.rank % torch.cuda.device_count()
253
+ elif hasattr(args, "rank"):
254
+ pass
255
+ else:
256
+ print("Not using distributed mode")
257
+ args.distributed = False
258
+ return
259
+
260
+ args.distributed = True
261
+
262
+ torch.cuda.set_device(args.gpu)
263
+ args.dist_backend = "nccl"
264
+ print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
265
+ torch.distributed.init_process_group(
266
+ backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
267
+ )
268
+ torch.distributed.barrier()
269
+ setup_for_distributed(args.rank == 0)
270
+
271
+
272
+ def average_checkpoints(inputs):
273
+ """Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from:
274
+ https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16
275
+
276
+ Args:
277
+ inputs (List[str]): An iterable of string paths of checkpoints to load from.
278
+ Returns:
279
+ A dict of string keys mapping to various values. The 'model' key
280
+ from the returned dict should correspond to an OrderedDict mapping
281
+ string parameter names to torch Tensors.
282
+ """
283
+ params_dict = OrderedDict()
284
+ params_keys = None
285
+ new_state = None
286
+ num_models = len(inputs)
287
+ for fpath in inputs:
288
+ with open(fpath, "rb") as f:
289
+ state = torch.load(
290
+ f,
291
+ map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")),
292
+ )
293
+ # Copies over the settings from the first checkpoint
294
+ if new_state is None:
295
+ new_state = state
296
+ model_params = state["model"]
297
+ model_params_keys = list(model_params.keys())
298
+ if params_keys is None:
299
+ params_keys = model_params_keys
300
+ elif params_keys != model_params_keys:
301
+ raise KeyError(
302
+ f"For checkpoint {f}, expected list of params: {params_keys}, but found: {model_params_keys}"
303
+ )
304
+ for k in params_keys:
305
+ p = model_params[k]
306
+ if isinstance(p, torch.HalfTensor):
307
+ p = p.float()
308
+ if k not in params_dict:
309
+ params_dict[k] = p.clone()
310
+ # NOTE: clone() is needed in case of p is a shared parameter
311
+ else:
312
+ params_dict[k] += p
313
+ averaged_params = OrderedDict()
314
+ for k, v in params_dict.items():
315
+ averaged_params[k] = v
316
+ if averaged_params[k].is_floating_point():
317
+ averaged_params[k].div_(num_models)
318
+ else:
319
+ averaged_params[k] //= num_models
320
+ new_state["model"] = averaged_params
321
+ return new_state
322
+
323
+
324
+ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=True):
325
+ """
326
+ This method can be used to prepare weights files for new models. It receives as
327
+ input a model architecture and a checkpoint from the training script and produces
328
+ a file with the weights ready for release.
329
+
330
+ Examples:
331
+ from torchvision import models as M
332
+
333
+ # Classification
334
+ model = M.mobilenet_v3_large(weights=None)
335
+ print(store_model_weights(model, './class.pth'))
336
+
337
+ # Quantized Classification
338
+ model = M.quantization.mobilenet_v3_large(weights=None, quantize=False)
339
+ model.fuse_model(is_qat=True)
340
+ model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
341
+ _ = torch.ao.quantization.prepare_qat(model, inplace=True)
342
+ print(store_model_weights(model, './qat.pth'))
343
+
344
+ # Object Detection
345
+ model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=None, weights_backbone=None)
346
+ print(store_model_weights(model, './obj.pth'))
347
+
348
+ # Segmentation
349
+ model = M.segmentation.deeplabv3_mobilenet_v3_large(weights=None, weights_backbone=None, aux_loss=True)
350
+ print(store_model_weights(model, './segm.pth', strict=False))
351
+
352
+ Args:
353
+ model (pytorch.nn.Module): The model on which the weights will be loaded for validation purposes.
354
+ checkpoint_path (str): The path of the checkpoint we will load.
355
+ checkpoint_key (str, optional): The key of the checkpoint where the model weights are stored.
356
+ Default: "model".
357
+ strict (bool): whether to strictly enforce that the keys
358
+ in :attr:`state_dict` match the keys returned by this module's
359
+ :meth:`~torch.nn.Module.state_dict` function. Default: ``True``
360
+
361
+ Returns:
362
+ output_path (str): The location where the weights are saved.
363
+ """
364
+ # Store the new model next to the checkpoint_path
365
+ checkpoint_path = os.path.abspath(checkpoint_path)
366
+ output_dir = os.path.dirname(checkpoint_path)
367
+
368
+ # Deep copy to avoid side-effects on the model object.
369
+ model = copy.deepcopy(model)
370
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
371
+
372
+ # Load the weights to the model to validate that everything works
373
+ # and remove unnecessary weights (such as auxiliaries, etc)
374
+ if checkpoint_key == "model_ema":
375
+ del checkpoint[checkpoint_key]["n_averaged"]
376
+ torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(checkpoint[checkpoint_key], "module.")
377
+ model.load_state_dict(checkpoint[checkpoint_key], strict=strict)
378
+
379
+ tmp_path = os.path.join(output_dir, str(model.__hash__()))
380
+ torch.save(model.state_dict(), tmp_path)
381
+
382
+ sha256_hash = hashlib.sha256()
383
+ with open(tmp_path, "rb") as f:
384
+ # Read and update hash string value in blocks of 4K
385
+ for byte_block in iter(lambda: f.read(4096), b""):
386
+ sha256_hash.update(byte_block)
387
+ hh = sha256_hash.hexdigest()
388
+
389
+ output_path = os.path.join(output_dir, "weights-" + str(hh[:8]) + ".pth")
390
+ os.replace(tmp_path, output_path)
391
+
392
+ return output_path
393
+
394
+
395
+ def reduce_across_processes(val):
396
+ if not is_dist_avail_and_initialized():
397
+ # nothing to sync, but we still convert to tensor for consistency with the distributed case.
398
+ return torch.tensor(val)
399
+
400
+ t = torch.tensor(val, device="cuda")
401
+ dist.barrier()
402
+ dist.all_reduce(t)
403
+ return t
404
+
405
+
406
+ def set_weight_decay(
407
+ model: torch.nn.Module,
408
+ weight_decay: float,
409
+ norm_weight_decay: Optional[float] = None,
410
+ norm_classes: Optional[List[type]] = None,
411
+ custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None,
412
+ ):
413
+ if not norm_classes:
414
+ norm_classes = [
415
+ torch.nn.modules.batchnorm._BatchNorm,
416
+ torch.nn.LayerNorm,
417
+ torch.nn.GroupNorm,
418
+ torch.nn.modules.instancenorm._InstanceNorm,
419
+ torch.nn.LocalResponseNorm,
420
+ ]
421
+ norm_classes = tuple(norm_classes)
422
+
423
+ params = {
424
+ "other": [],
425
+ "norm": [],
426
+ }
427
+ params_weight_decay = {
428
+ "other": weight_decay,
429
+ "norm": norm_weight_decay,
430
+ }
431
+ custom_keys = []
432
+ if custom_keys_weight_decay is not None:
433
+ for key, weight_decay in custom_keys_weight_decay:
434
+ params[key] = []
435
+ params_weight_decay[key] = weight_decay
436
+ custom_keys.append(key)
437
+
438
+ def _add_params(module, prefix=""):
439
+ for name, p in module.named_parameters(recurse=False):
440
+ if not p.requires_grad:
441
+ continue
442
+ is_custom_key = False
443
+ for key in custom_keys:
444
+ target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name
445
+ if key == target_name:
446
+ params[key].append(p)
447
+ is_custom_key = True
448
+ break
449
+ if not is_custom_key:
450
+ if norm_weight_decay is not None and isinstance(module, norm_classes):
451
+ params["norm"].append(p)
452
+ else:
453
+ params["other"].append(p)
454
+
455
+ for child_name, child_module in module.named_children():
456
+ child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
457
+ _add_params(child_module, prefix=child_prefix)
458
+
459
+ _add_params(model)
460
+
461
+ param_groups = []
462
+ for key in params:
463
+ if len(params[key]) > 0:
464
+ param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]})
465
+ return param_groups