UniversalAlgorithmic commited on
Commit
781a945
·
verified ·
1 Parent(s): 9505a51

Delete neural-archicture-search

Browse files
neural-archicture-search/presets.py DELETED
@@ -1,71 +0,0 @@
1
- import torch
2
- from torchvision.transforms import autoaugment, transforms
3
- from torchvision.transforms.functional import InterpolationMode
4
-
5
-
6
- class ClassificationPresetTrain:
7
- def __init__(
8
- self,
9
- *,
10
- crop_size,
11
- mean=(0.485, 0.456, 0.406),
12
- std=(0.229, 0.224, 0.225),
13
- interpolation=InterpolationMode.BILINEAR,
14
- hflip_prob=0.5,
15
- auto_augment_policy=None,
16
- ra_magnitude=9,
17
- augmix_severity=3,
18
- random_erase_prob=0.0,
19
- ):
20
- trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
21
- if hflip_prob > 0:
22
- trans.append(transforms.RandomHorizontalFlip(hflip_prob))
23
- if auto_augment_policy is not None:
24
- if auto_augment_policy == "ra":
25
- trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
26
- elif auto_augment_policy == "ta_wide":
27
- trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
28
- elif auto_augment_policy == "augmix":
29
- trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity))
30
- else:
31
- aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
32
- trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
33
- trans.extend(
34
- [
35
- transforms.PILToTensor(),
36
- transforms.ConvertImageDtype(torch.float),
37
- transforms.Normalize(mean=mean, std=std),
38
- ]
39
- )
40
- if random_erase_prob > 0:
41
- trans.append(transforms.RandomErasing(p=random_erase_prob))
42
-
43
- self.transforms = transforms.Compose(trans)
44
-
45
- def __call__(self, img):
46
- return self.transforms(img)
47
-
48
-
49
- class ClassificationPresetEval:
50
- def __init__(
51
- self,
52
- *,
53
- crop_size,
54
- resize_size=256,
55
- mean=(0.485, 0.456, 0.406),
56
- std=(0.229, 0.224, 0.225),
57
- interpolation=InterpolationMode.BILINEAR,
58
- ):
59
-
60
- self.transforms = transforms.Compose(
61
- [
62
- transforms.Resize(resize_size, interpolation=interpolation),
63
- transforms.CenterCrop(crop_size),
64
- transforms.PILToTensor(),
65
- transforms.ConvertImageDtype(torch.float),
66
- transforms.Normalize(mean=mean, std=std),
67
- ]
68
- )
69
-
70
- def __call__(self, img):
71
- return self.transforms(img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
neural-archicture-search/resnet18/model_3.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e728a634490a078e1a672f464b9baebc04774f83b03fc251ad2437a2731330a0
3
- size 136133334
 
 
 
 
neural-archicture-search/resnet34/model_8.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:179a2d39490980a6cc801c3ef15230bfe08d7e941174d79e2099c8db8b11dfcf
3
- size 202898970
 
 
 
 
neural-archicture-search/resnet50/model_9.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:fecb231a0e220e46dde1025a7403bb1f587d81f6d36d143ba1c510c3b477a122
3
- size 431365452
 
 
 
 
neural-archicture-search/run.sh DELETED
@@ -1,33 +0,0 @@
1
- # # ✅ Test: Acc@1 70.092 Acc@5 89.314
2
- # torchrun --nproc_per_node=4 train.py\
3
- # --data-path /home/cs/Documents/datasets/imagenet\
4
- # --model resnet18 --output-dir resnet18 --weights ResNet18_Weights.IMAGENET1K_V1\
5
- # --batch-size 128 --epochs 10 --lr 0.0004 --lr-step-size 2 --lr-gamma 0.5\
6
- # --lr-warmup-method constant --lr-warmup-epochs 1 --lr-warmup-decay 0.\
7
- # --apply-trp --trp-depths 3 3 3 --trp-planes 256 --trp-lambdas 0.4 0.2 0.1 --print-freq 100
8
- torchrun --nproc_per_node=4 train.py\
9
- --data-path /home/cs/Documents/datasets/imagenet\
10
- --model resnet18 --resume resnet18/model_3.pth --test-only
11
-
12
- # # ✅ Test: Acc@1 73.900 Acc@5 91.536
13
- # torchrun --nproc_per_node=4 train.py\
14
- # --data-path /home/cs/Documents/datasets/imagenet\
15
- # --model resnet34 --output-dir resnet34 --weights ResNet34_Weights.IMAGENET1K_V1\
16
- # --batch-size 96 --epochs 10 --lr 0.0004 --lr-step-size 2 --lr-gamma 0.5\
17
- # --lr-warmup-method constant --lr-warmup-epochs 1 --lr-warmup-decay 0.\
18
- # --apply-trp --trp-depths 2 2 2 --trp-planes 256 --trp-lambdas 0.4 0.2 0.1 --print-freq 100
19
- # torchrun --nproc_per_node=4 train.py\
20
- # --data-path /home/cs/Documents/datasets/imagenet\
21
- # --model resnet34 --resume resnet34/model_8.pth --test-only
22
-
23
-
24
- # # ✅ Test: Acc@1 76.896 Acc@5 93.136
25
- # torchrun --nproc_per_node=4 train.py\
26
- # --data-path /home/cs/Documents/datasets/imagenet\
27
- # --model resnet50 --output-dir resnet50 --weights ResNet50_Weights.IMAGENET1K_V1\
28
- # --batch-size 64 --epochs 10 --lr 0.0004 --lr-step-size 2 --lr-gamma 0.5\
29
- # --lr-warmup-method constant --lr-warmup-epochs 1 --lr-warmup-decay 0.\
30
- # --apply-trp --trp-depths 1 1 1 --trp-planes 1024 --trp-lambdas 0.4 0.2 0.1 --print-freq 100
31
- # torchrun --nproc_per_node=4 train.py\
32
- # --data-path /home/cs/Documents/datasets/imagenet\
33
- # --model resnet50 --resume resnet50/model_9.pth --test-only
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
neural-archicture-search/sampler.py DELETED
@@ -1,62 +0,0 @@
1
- import math
2
-
3
- import torch
4
- import torch.distributed as dist
5
-
6
-
7
- class RASampler(torch.utils.data.Sampler):
8
- """Sampler that restricts data loading to a subset of the dataset for distributed,
9
- with repeated augmentation.
10
- It ensures that different each augmented version of a sample will be visible to a
11
- different process (GPU).
12
- Heavily based on 'torch.utils.data.DistributedSampler'.
13
-
14
- This is borrowed from the DeiT Repo:
15
- https://github.com/facebookresearch/deit/blob/main/samplers.py
16
- """
17
-
18
- def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3):
19
- if num_replicas is None:
20
- if not dist.is_available():
21
- raise RuntimeError("Requires distributed package to be available!")
22
- num_replicas = dist.get_world_size()
23
- if rank is None:
24
- if not dist.is_available():
25
- raise RuntimeError("Requires distributed package to be available!")
26
- rank = dist.get_rank()
27
- self.dataset = dataset
28
- self.num_replicas = num_replicas
29
- self.rank = rank
30
- self.epoch = 0
31
- self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
32
- self.total_size = self.num_samples * self.num_replicas
33
- self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
34
- self.shuffle = shuffle
35
- self.seed = seed
36
- self.repetitions = repetitions
37
-
38
- def __iter__(self):
39
- if self.shuffle:
40
- # Deterministically shuffle based on epoch
41
- g = torch.Generator()
42
- g.manual_seed(self.seed + self.epoch)
43
- indices = torch.randperm(len(self.dataset), generator=g).tolist()
44
- else:
45
- indices = list(range(len(self.dataset)))
46
-
47
- # Add extra samples to make it evenly divisible
48
- indices = [ele for ele in indices for i in range(self.repetitions)]
49
- indices += indices[: (self.total_size - len(indices))]
50
- assert len(indices) == self.total_size
51
-
52
- # Subsample
53
- indices = indices[self.rank : self.total_size : self.num_replicas]
54
- assert len(indices) == self.num_samples
55
-
56
- return iter(indices[: self.num_selected_samples])
57
-
58
- def __len__(self):
59
- return self.num_selected_samples
60
-
61
- def set_epoch(self, epoch):
62
- self.epoch = epoch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
neural-archicture-search/train.py DELETED
@@ -1,524 +0,0 @@
1
- import datetime
2
- import os
3
- import time
4
- import warnings
5
-
6
- import presets
7
- import torch
8
- import torch.utils.data
9
- import torchvision
10
- import transforms
11
- import utils
12
- from sampler import RASampler
13
- from torch import nn
14
- from torch.utils.data.dataloader import default_collate
15
- from torchvision.transforms.functional import InterpolationMode
16
-
17
- from trplib import apply_trp
18
-
19
-
20
- def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
21
- model.train()
22
- metric_logger = utils.MetricLogger(delimiter=" ")
23
- metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
24
- metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
25
-
26
- header = f"Epoch: [{epoch}]"
27
- for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
28
- start_time = time.time()
29
- image, target = image.to(device), target.to(device)
30
- with torch.amp.autocast("cuda", enabled=scaler is not None):
31
- # output = model(image)
32
- # loss = criterion(output, target)
33
- output, loss = model(image, target)
34
-
35
- optimizer.zero_grad()
36
- if scaler is not None:
37
- scaler.scale(loss).backward()
38
- if args.clip_grad_norm is not None:
39
- # we should unscale the gradients of optimizer's assigned params if do gradient clipping
40
- scaler.unscale_(optimizer)
41
- nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
42
- scaler.step(optimizer)
43
- scaler.update()
44
- else:
45
- loss.backward()
46
- if args.clip_grad_norm is not None:
47
- nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
48
- optimizer.step()
49
-
50
- if model_ema and i % args.model_ema_steps == 0:
51
- model_ema.update_parameters(model)
52
- if epoch < args.lr_warmup_epochs:
53
- # Reset ema buffer to keep copying weights during warmup period
54
- model_ema.n_averaged.fill_(0)
55
-
56
- acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
57
- batch_size = image.shape[0]
58
- metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
59
- metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
60
- metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
61
- metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
62
-
63
-
64
- def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
65
- model.eval()
66
- metric_logger = utils.MetricLogger(delimiter=" ")
67
- header = f"Test: {log_suffix}"
68
-
69
- num_processed_samples = 0
70
- with torch.inference_mode():
71
- for image, target in metric_logger.log_every(data_loader, print_freq, header):
72
- image = image.to(device, non_blocking=True)
73
- target = target.to(device, non_blocking=True)
74
- output = model(image)
75
- loss = criterion(output, target)
76
-
77
- acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
78
- # FIXME need to take into account that the datasets
79
- # could have been padded in distributed setup
80
- batch_size = image.shape[0]
81
- metric_logger.update(loss=loss.item())
82
- metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
83
- metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
84
- num_processed_samples += batch_size
85
- # gather the stats from all processes
86
-
87
- num_processed_samples = utils.reduce_across_processes(num_processed_samples)
88
- if (
89
- hasattr(data_loader.dataset, "__len__")
90
- and len(data_loader.dataset) != num_processed_samples
91
- and torch.distributed.get_rank() == 0
92
- ):
93
- # See FIXME above
94
- warnings.warn(
95
- f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
96
- "samples were used for the validation, which might bias the results. "
97
- "Try adjusting the batch size and / or the world size. "
98
- "Setting the world size to 1 is always a safe bet."
99
- )
100
-
101
- metric_logger.synchronize_between_processes()
102
-
103
- print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
104
- return metric_logger.acc1.global_avg
105
-
106
-
107
- def _get_cache_path(filepath):
108
- import hashlib
109
-
110
- h = hashlib.sha1(filepath.encode()).hexdigest()
111
- cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
112
- cache_path = os.path.expanduser(cache_path)
113
- return cache_path
114
-
115
-
116
- def load_data(traindir, valdir, args):
117
- # Data loading code
118
- print("Loading data")
119
- val_resize_size, val_crop_size, train_crop_size = (
120
- args.val_resize_size,
121
- args.val_crop_size,
122
- args.train_crop_size,
123
- )
124
- interpolation = InterpolationMode(args.interpolation)
125
-
126
- print("Loading training data")
127
- st = time.time()
128
- cache_path = _get_cache_path(traindir)
129
- if args.cache_dataset and os.path.exists(cache_path):
130
- # Attention, as the transforms are also cached!
131
- print(f"Loading dataset_train from {cache_path}")
132
- dataset, _ = torch.load(cache_path)
133
- else:
134
- auto_augment_policy = getattr(args, "auto_augment", None)
135
- random_erase_prob = getattr(args, "random_erase", 0.0)
136
- ra_magnitude = args.ra_magnitude
137
- augmix_severity = args.augmix_severity
138
- dataset = torchvision.datasets.ImageFolder(
139
- traindir,
140
- presets.ClassificationPresetTrain(
141
- crop_size=train_crop_size,
142
- interpolation=interpolation,
143
- auto_augment_policy=auto_augment_policy,
144
- random_erase_prob=random_erase_prob,
145
- ra_magnitude=ra_magnitude,
146
- augmix_severity=augmix_severity,
147
- ),
148
- )
149
- if args.cache_dataset:
150
- print(f"Saving dataset_train to {cache_path}")
151
- utils.mkdir(os.path.dirname(cache_path))
152
- utils.save_on_master((dataset, traindir), cache_path)
153
- print("Took", time.time() - st)
154
-
155
- print("Loading validation data")
156
- cache_path = _get_cache_path(valdir)
157
- if args.cache_dataset and os.path.exists(cache_path):
158
- # Attention, as the transforms are also cached!
159
- print(f"Loading dataset_test from {cache_path}")
160
- dataset_test, _ = torch.load(cache_path)
161
- else:
162
- if args.weights and args.test_only:
163
- weights = torchvision.models.get_weight(args.weights)
164
- preprocessing = weights.transforms()
165
- else:
166
- preprocessing = presets.ClassificationPresetEval(
167
- crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
168
- )
169
-
170
- dataset_test = torchvision.datasets.ImageFolder(
171
- valdir,
172
- preprocessing,
173
- )
174
- if args.cache_dataset:
175
- print(f"Saving dataset_test to {cache_path}")
176
- utils.mkdir(os.path.dirname(cache_path))
177
- utils.save_on_master((dataset_test, valdir), cache_path)
178
-
179
- print("Creating data loaders")
180
- if args.distributed:
181
- if hasattr(args, "ra_sampler") and args.ra_sampler:
182
- train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
183
- else:
184
- train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
185
- test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
186
- else:
187
- train_sampler = torch.utils.data.RandomSampler(dataset)
188
- test_sampler = torch.utils.data.SequentialSampler(dataset_test)
189
-
190
- return dataset, dataset_test, train_sampler, test_sampler
191
-
192
-
193
- def main(args):
194
- if args.output_dir:
195
- utils.mkdir(args.output_dir)
196
-
197
- utils.init_distributed_mode(args)
198
- print(args)
199
-
200
- device = torch.device(args.device)
201
-
202
- if args.use_deterministic_algorithms:
203
- torch.backends.cudnn.benchmark = False
204
- torch.use_deterministic_algorithms(True)
205
- else:
206
- torch.backends.cudnn.benchmark = True
207
-
208
- train_dir = os.path.join(args.data_path, "train")
209
- val_dir = os.path.join(args.data_path, "val")
210
- dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
211
-
212
- collate_fn = None
213
- num_classes = len(dataset.classes)
214
- mixup_transforms = []
215
- if args.mixup_alpha > 0.0:
216
- mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
217
- if args.cutmix_alpha > 0.0:
218
- mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
219
- if mixup_transforms:
220
- mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
221
-
222
- def collate_fn(batch):
223
- return mixupcutmix(*default_collate(batch))
224
-
225
- data_loader = torch.utils.data.DataLoader(
226
- dataset,
227
- batch_size=args.batch_size,
228
- sampler=train_sampler,
229
- num_workers=args.workers,
230
- pin_memory=True,
231
- collate_fn=collate_fn,
232
- )
233
- data_loader_test = torch.utils.data.DataLoader(
234
- dataset_test, batch_size=8, sampler=test_sampler, num_workers=args.workers, pin_memory=True
235
- )
236
-
237
- print("Creating model")
238
- model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
239
- if args.apply_trp:
240
- model = apply_trp(model, args.trp_depths, args.trp_planes, args.trp_lambdas, label_smoothing=args.label_smoothing)
241
- model.to(device)
242
-
243
- if args.distributed and args.sync_bn:
244
- model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
245
-
246
- criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
247
-
248
- custom_keys_weight_decay = []
249
- if args.bias_weight_decay is not None:
250
- custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
251
- if args.transformer_embedding_decay is not None:
252
- for key in ["class_token", "position_embedding", "relative_position_bias_table"]:
253
- custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
254
- parameters = utils.set_weight_decay(
255
- model,
256
- args.weight_decay,
257
- norm_weight_decay=args.norm_weight_decay,
258
- custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
259
- )
260
-
261
- opt_name = args.opt.lower()
262
- if opt_name.startswith("sgd"):
263
- optimizer = torch.optim.SGD(
264
- parameters,
265
- lr=args.lr,
266
- momentum=args.momentum,
267
- weight_decay=args.weight_decay,
268
- nesterov="nesterov" in opt_name,
269
- )
270
- elif opt_name == "rmsprop":
271
- optimizer = torch.optim.RMSprop(
272
- parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
273
- )
274
- elif opt_name == "adamw":
275
- optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
276
- else:
277
- raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")
278
-
279
- scaler = torch.amp.GradScaler("cuda") if args.amp else None
280
-
281
- args.lr_scheduler = args.lr_scheduler.lower()
282
- if args.lr_scheduler == "steplr":
283
- main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
284
- elif args.lr_scheduler == "cosineannealinglr":
285
- main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
286
- optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min
287
- )
288
- elif args.lr_scheduler == "exponentiallr":
289
- main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
290
- else:
291
- raise RuntimeError(
292
- f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
293
- "are supported."
294
- )
295
-
296
- if args.lr_warmup_epochs > 0:
297
- if args.lr_warmup_method == "linear":
298
- warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
299
- optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
300
- )
301
- elif args.lr_warmup_method == "constant":
302
- warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
303
- optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
304
- )
305
- else:
306
- raise RuntimeError(
307
- f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
308
- )
309
- lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
310
- optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
311
- )
312
- else:
313
- lr_scheduler = main_lr_scheduler
314
-
315
- model_without_ddp = model
316
- if args.distributed:
317
- model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
318
- model_without_ddp = model.module
319
-
320
- model_ema = None
321
- if args.model_ema:
322
- # Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at:
323
- # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
324
- #
325
- # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
326
- # We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus:
327
- # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
328
- adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
329
- alpha = 1.0 - args.model_ema_decay
330
- alpha = min(1.0, alpha * adjust)
331
- model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
332
-
333
- if args.resume:
334
- checkpoint = torch.load(args.resume, map_location="cpu", weights_only=False)
335
- model_without_ddp.load_state_dict(checkpoint["model"])
336
- if not args.test_only:
337
- optimizer.load_state_dict(checkpoint["optimizer"])
338
- lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
339
- args.start_epoch = checkpoint["epoch"] + 1
340
- if model_ema:
341
- model_ema.load_state_dict(checkpoint["model_ema"])
342
- if scaler:
343
- scaler.load_state_dict(checkpoint["scaler"])
344
-
345
- if args.test_only:
346
- # We disable the cudnn benchmarking because it can noticeably affect the accuracy
347
- torch.backends.cudnn.benchmark = False
348
- torch.backends.cudnn.deterministic = True
349
- if model_ema:
350
- evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
351
- else:
352
- evaluate(model, criterion, data_loader_test, device=device)
353
- return
354
-
355
- print("Start training")
356
- start_time = time.time()
357
- for epoch in range(args.start_epoch, args.epochs):
358
- if args.distributed:
359
- train_sampler.set_epoch(epoch)
360
- train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
361
- lr_scheduler.step()
362
- evaluate(model, criterion, data_loader_test, device=device)
363
- if model_ema:
364
- evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
365
- if args.output_dir:
366
- checkpoint = {
367
- "model": model_without_ddp.state_dict() if not args.apply_trp else {k: v for k, v in model_without_ddp.state_dict().items() if not k.startswith("trp_blocks")}, # NOTE: remove TRP heads
368
- "optimizer": optimizer.state_dict(),
369
- "lr_scheduler": lr_scheduler.state_dict(),
370
- "epoch": epoch,
371
- "args": args,
372
- }
373
- if model_ema:
374
- checkpoint["model_ema"] = model_ema.state_dict() if not args.apply_trp else {k: v for k, v in model_ema.state_dict().items() if not k.startswith("trp_blocks")} # NOTE: remove TRP heads
375
- if scaler:
376
- checkpoint["scaler"] = scaler.state_dict()
377
- utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
378
- utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
379
-
380
- total_time = time.time() - start_time
381
- total_time_str = str(datetime.timedelta(seconds=int(total_time)))
382
- print(f"Training time {total_time_str}")
383
-
384
-
385
- def get_args_parser(add_help=True):
386
- import argparse
387
-
388
- parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
389
-
390
- parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
391
- parser.add_argument("--model", default="resnet18", type=str, help="model name")
392
- parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
393
- parser.add_argument(
394
- "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
395
- )
396
- parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
397
- parser.add_argument(
398
- "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
399
- )
400
- parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
401
- parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
402
- parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
403
- parser.add_argument(
404
- "--wd",
405
- "--weight-decay",
406
- default=1e-4,
407
- type=float,
408
- metavar="W",
409
- help="weight decay (default: 1e-4)",
410
- dest="weight_decay",
411
- )
412
- parser.add_argument(
413
- "--norm-weight-decay",
414
- default=None,
415
- type=float,
416
- help="weight decay for Normalization layers (default: None, same value as --wd)",
417
- )
418
- parser.add_argument(
419
- "--bias-weight-decay",
420
- default=None,
421
- type=float,
422
- help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
423
- )
424
- parser.add_argument(
425
- "--transformer-embedding-decay",
426
- default=None,
427
- type=float,
428
- help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
429
- )
430
- parser.add_argument(
431
- "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
432
- )
433
- parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
434
- parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
435
- parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
436
- parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
437
- parser.add_argument(
438
- "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
439
- )
440
- parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
441
- parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
442
- parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
443
- parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
444
- parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
445
- parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
446
- parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
447
- parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
448
- parser.add_argument(
449
- "--cache-dataset",
450
- dest="cache_dataset",
451
- help="Cache the datasets for quicker initialization. It also serializes the transforms",
452
- action="store_true",
453
- )
454
- parser.add_argument(
455
- "--sync-bn",
456
- dest="sync_bn",
457
- help="Use sync batch norm",
458
- action="store_true",
459
- )
460
- parser.add_argument(
461
- "--test-only",
462
- dest="test_only",
463
- help="Only test the model",
464
- action="store_true",
465
- )
466
- parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
467
- parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy")
468
- parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy")
469
- parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
470
-
471
- # Mixed precision training parameters
472
- parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
473
-
474
- # distributed training parameters
475
- parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
476
- parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
477
- parser.add_argument(
478
- "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
479
- )
480
- parser.add_argument(
481
- "--model-ema-steps",
482
- type=int,
483
- default=32,
484
- help="the number of iterations that controls how often to update the EMA model (default: 32)",
485
- )
486
- parser.add_argument(
487
- "--model-ema-decay",
488
- type=float,
489
- default=0.99998,
490
- help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
491
- )
492
- parser.add_argument(
493
- "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
494
- )
495
- parser.add_argument(
496
- "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
497
- )
498
- parser.add_argument(
499
- "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
500
- )
501
- parser.add_argument(
502
- "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
503
- )
504
- parser.add_argument(
505
- "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
506
- )
507
- parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
508
- parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
509
- parser.add_argument(
510
- "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
511
- )
512
- parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
513
-
514
- parser.add_argument("--apply-trp", action="store_true", help="enable applying trp")
515
- parser.add_argument("--trp-depths", nargs="+", type=int, help="number of layers for trp block")
516
- parser.add_argument("--trp-planes", default=1024, type=int, help="channels of the hidden state")
517
- parser.add_argument("--trp-lambdas", nargs="+", type=float, help="trp lambdas")
518
-
519
- return parser
520
-
521
-
522
- if __name__ == "__main__":
523
- args = get_args_parser().parse_args()
524
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
neural-archicture-search/train_quantization.py DELETED
@@ -1,265 +0,0 @@
1
- import copy
2
- import datetime
3
- import os
4
- import time
5
-
6
- import torch
7
- import torch.ao.quantization
8
- import torch.utils.data
9
- import torchvision
10
- import utils
11
- from torch import nn
12
- from train import evaluate, load_data, train_one_epoch
13
-
14
-
15
- def main(args):
16
- if args.output_dir:
17
- utils.mkdir(args.output_dir)
18
-
19
- utils.init_distributed_mode(args)
20
- print(args)
21
-
22
- if args.post_training_quantize and args.distributed:
23
- raise RuntimeError("Post training quantization example should not be performed on distributed mode")
24
-
25
- # Set backend engine to ensure that quantized model runs on the correct kernels
26
- if args.backend not in torch.backends.quantized.supported_engines:
27
- raise RuntimeError("Quantized backend not supported: " + str(args.backend))
28
- torch.backends.quantized.engine = args.backend
29
-
30
- device = torch.device(args.device)
31
- torch.backends.cudnn.benchmark = True
32
-
33
- # Data loading code
34
- print("Loading data")
35
- train_dir = os.path.join(args.data_path, "train")
36
- val_dir = os.path.join(args.data_path, "val")
37
-
38
- dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
39
- data_loader = torch.utils.data.DataLoader(
40
- dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True
41
- )
42
-
43
- data_loader_test = torch.utils.data.DataLoader(
44
- dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
45
- )
46
-
47
- print("Creating model", args.model)
48
- # when training quantized models, we always start from a pre-trained fp32 reference model
49
- prefix = "quantized_"
50
- model_name = args.model
51
- if not model_name.startswith(prefix):
52
- model_name = prefix + model_name
53
- model = torchvision.models.get_model(model_name, weights=args.weights, quantize=args.test_only)
54
- model.to(device)
55
-
56
- if not (args.test_only or args.post_training_quantize):
57
- model.fuse_model(is_qat=True)
58
- model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
59
- torch.ao.quantization.prepare_qat(model, inplace=True)
60
-
61
- if args.distributed and args.sync_bn:
62
- model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
63
-
64
- optimizer = torch.optim.SGD(
65
- model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
66
- )
67
-
68
- lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
69
-
70
- criterion = nn.CrossEntropyLoss()
71
- model_without_ddp = model
72
- if args.distributed:
73
- model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
74
- model_without_ddp = model.module
75
-
76
- if args.resume:
77
- checkpoint = torch.load(args.resume, map_location="cpu")
78
- model_without_ddp.load_state_dict(checkpoint["model"])
79
- optimizer.load_state_dict(checkpoint["optimizer"])
80
- lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
81
- args.start_epoch = checkpoint["epoch"] + 1
82
-
83
- if args.post_training_quantize:
84
- # perform calibration on a subset of the training dataset
85
- # for that, create a subset of the training dataset
86
- ds = torch.utils.data.Subset(dataset, indices=list(range(args.batch_size * args.num_calibration_batches)))
87
- data_loader_calibration = torch.utils.data.DataLoader(
88
- ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
89
- )
90
- model.eval()
91
- model.fuse_model(is_qat=False)
92
- model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
93
- torch.ao.quantization.prepare(model, inplace=True)
94
- # Calibrate first
95
- print("Calibrating")
96
- evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
97
- torch.ao.quantization.convert(model, inplace=True)
98
- if args.output_dir:
99
- print("Saving quantized model")
100
- if utils.is_main_process():
101
- torch.save(model.state_dict(), os.path.join(args.output_dir, "quantized_post_train_model.pth"))
102
- print("Evaluating post-training quantized model")
103
- evaluate(model, criterion, data_loader_test, device=device)
104
- return
105
-
106
- if args.test_only:
107
- evaluate(model, criterion, data_loader_test, device=device)
108
- return
109
-
110
- model.apply(torch.ao.quantization.enable_observer)
111
- model.apply(torch.ao.quantization.enable_fake_quant)
112
- start_time = time.time()
113
- for epoch in range(args.start_epoch, args.epochs):
114
- if args.distributed:
115
- train_sampler.set_epoch(epoch)
116
- print("Starting training for epoch", epoch)
117
- train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args)
118
- lr_scheduler.step()
119
- with torch.inference_mode():
120
- if epoch >= args.num_observer_update_epochs:
121
- print("Disabling observer for subseq epochs, epoch = ", epoch)
122
- model.apply(torch.ao.quantization.disable_observer)
123
- if epoch >= args.num_batch_norm_update_epochs:
124
- print("Freezing BN for subseq epochs, epoch = ", epoch)
125
- model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
126
- print("Evaluate QAT model")
127
-
128
- evaluate(model, criterion, data_loader_test, device=device, log_suffix="QAT")
129
- quantized_eval_model = copy.deepcopy(model_without_ddp)
130
- quantized_eval_model.eval()
131
- quantized_eval_model.to(torch.device("cpu"))
132
- torch.ao.quantization.convert(quantized_eval_model, inplace=True)
133
-
134
- print("Evaluate Quantized model")
135
- evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
136
-
137
- model.train()
138
-
139
- if args.output_dir:
140
- checkpoint = {
141
- "model": model_without_ddp.state_dict(),
142
- "eval_model": quantized_eval_model.state_dict(),
143
- "optimizer": optimizer.state_dict(),
144
- "lr_scheduler": lr_scheduler.state_dict(),
145
- "epoch": epoch,
146
- "args": args,
147
- }
148
- utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
149
- utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
150
- print("Saving models after epoch ", epoch)
151
-
152
- total_time = time.time() - start_time
153
- total_time_str = str(datetime.timedelta(seconds=int(total_time)))
154
- print(f"Training time {total_time_str}")
155
-
156
-
157
- def get_args_parser(add_help=True):
158
- import argparse
159
-
160
- parser = argparse.ArgumentParser(description="PyTorch Quantized Classification Training", add_help=add_help)
161
-
162
- parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
163
- parser.add_argument("--model", default="mobilenet_v2", type=str, help="model name")
164
- parser.add_argument("--backend", default="qnnpack", type=str, help="fbgemm or qnnpack")
165
- parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
166
-
167
- parser.add_argument(
168
- "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
169
- )
170
- parser.add_argument("--eval-batch-size", default=128, type=int, help="batch size for evaluation")
171
- parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
172
- parser.add_argument(
173
- "--num-observer-update-epochs",
174
- default=4,
175
- type=int,
176
- metavar="N",
177
- help="number of total epochs to update observers",
178
- )
179
- parser.add_argument(
180
- "--num-batch-norm-update-epochs",
181
- default=3,
182
- type=int,
183
- metavar="N",
184
- help="number of total epochs to update batch norm stats",
185
- )
186
- parser.add_argument(
187
- "--num-calibration-batches",
188
- default=32,
189
- type=int,
190
- metavar="N",
191
- help="number of batches of training set for \
192
- observer calibration ",
193
- )
194
-
195
- parser.add_argument(
196
- "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
197
- )
198
- parser.add_argument("--lr", default=0.0001, type=float, help="initial learning rate")
199
- parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
200
- parser.add_argument(
201
- "--wd",
202
- "--weight-decay",
203
- default=1e-4,
204
- type=float,
205
- metavar="W",
206
- help="weight decay (default: 1e-4)",
207
- dest="weight_decay",
208
- )
209
- parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
210
- parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
211
- parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
212
- parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
213
- parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
214
- parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
215
- parser.add_argument(
216
- "--cache-dataset",
217
- dest="cache_dataset",
218
- help="Cache the datasets for quicker initialization. \
219
- It also serializes the transforms",
220
- action="store_true",
221
- )
222
- parser.add_argument(
223
- "--sync-bn",
224
- dest="sync_bn",
225
- help="Use sync batch norm",
226
- action="store_true",
227
- )
228
- parser.add_argument(
229
- "--test-only",
230
- dest="test_only",
231
- help="Only test the model",
232
- action="store_true",
233
- )
234
- parser.add_argument(
235
- "--post-training-quantize",
236
- dest="post_training_quantize",
237
- help="Post training quantize the model",
238
- action="store_true",
239
- )
240
-
241
- # distributed training parameters
242
- parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
243
- parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
244
-
245
- parser.add_argument(
246
- "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
247
- )
248
- parser.add_argument(
249
- "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
250
- )
251
- parser.add_argument(
252
- "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
253
- )
254
- parser.add_argument(
255
- "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
256
- )
257
- parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
258
- parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
259
-
260
- return parser
261
-
262
-
263
- if __name__ == "__main__":
264
- args = get_args_parser().parse_args()
265
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
neural-archicture-search/transforms.py DELETED
@@ -1,183 +0,0 @@
1
- import math
2
- from typing import Tuple
3
-
4
- import torch
5
- from torch import Tensor
6
- from torchvision.transforms import functional as F
7
-
8
-
9
- class RandomMixup(torch.nn.Module):
10
- """Randomly apply Mixup to the provided batch and targets.
11
- The class implements the data augmentations as described in the paper
12
- `"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
13
-
14
- Args:
15
- num_classes (int): number of classes used for one-hot encoding.
16
- p (float): probability of the batch being transformed. Default value is 0.5.
17
- alpha (float): hyperparameter of the Beta distribution used for mixup.
18
- Default value is 1.0.
19
- inplace (bool): boolean to make this transform inplace. Default set to False.
20
- """
21
-
22
- def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
23
- super().__init__()
24
-
25
- if num_classes < 1:
26
- raise ValueError(
27
- f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}"
28
- )
29
-
30
- if alpha <= 0:
31
- raise ValueError("Alpha param can't be zero.")
32
-
33
- self.num_classes = num_classes
34
- self.p = p
35
- self.alpha = alpha
36
- self.inplace = inplace
37
-
38
- def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
39
- """
40
- Args:
41
- batch (Tensor): Float tensor of size (B, C, H, W)
42
- target (Tensor): Integer tensor of size (B, )
43
-
44
- Returns:
45
- Tensor: Randomly transformed batch.
46
- """
47
- if batch.ndim != 4:
48
- raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
49
- if target.ndim != 1:
50
- raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
51
- if not batch.is_floating_point():
52
- raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
53
- if target.dtype != torch.int64:
54
- raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
55
-
56
- if not self.inplace:
57
- batch = batch.clone()
58
- target = target.clone()
59
-
60
- if target.ndim == 1:
61
- target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
62
-
63
- if torch.rand(1).item() >= self.p:
64
- return batch, target
65
-
66
- # It's faster to roll the batch by one instead of shuffling it to create image pairs
67
- batch_rolled = batch.roll(1, 0)
68
- target_rolled = target.roll(1, 0)
69
-
70
- # Implemented as on mixup paper, page 3.
71
- lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
72
- batch_rolled.mul_(1.0 - lambda_param)
73
- batch.mul_(lambda_param).add_(batch_rolled)
74
-
75
- target_rolled.mul_(1.0 - lambda_param)
76
- target.mul_(lambda_param).add_(target_rolled)
77
-
78
- return batch, target
79
-
80
- def __repr__(self) -> str:
81
- s = (
82
- f"{self.__class__.__name__}("
83
- f"num_classes={self.num_classes}"
84
- f", p={self.p}"
85
- f", alpha={self.alpha}"
86
- f", inplace={self.inplace}"
87
- f")"
88
- )
89
- return s
90
-
91
-
92
- class RandomCutmix(torch.nn.Module):
93
- """Randomly apply Cutmix to the provided batch and targets.
94
- The class implements the data augmentations as described in the paper
95
- `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
96
- <https://arxiv.org/abs/1905.04899>`_.
97
-
98
- Args:
99
- num_classes (int): number of classes used for one-hot encoding.
100
- p (float): probability of the batch being transformed. Default value is 0.5.
101
- alpha (float): hyperparameter of the Beta distribution used for cutmix.
102
- Default value is 1.0.
103
- inplace (bool): boolean to make this transform inplace. Default set to False.
104
- """
105
-
106
- def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
107
- super().__init__()
108
- if num_classes < 1:
109
- raise ValueError("Please provide a valid positive value for the num_classes.")
110
- if alpha <= 0:
111
- raise ValueError("Alpha param can't be zero.")
112
-
113
- self.num_classes = num_classes
114
- self.p = p
115
- self.alpha = alpha
116
- self.inplace = inplace
117
-
118
- def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
119
- """
120
- Args:
121
- batch (Tensor): Float tensor of size (B, C, H, W)
122
- target (Tensor): Integer tensor of size (B, )
123
-
124
- Returns:
125
- Tensor: Randomly transformed batch.
126
- """
127
- if batch.ndim != 4:
128
- raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
129
- if target.ndim != 1:
130
- raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
131
- if not batch.is_floating_point():
132
- raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
133
- if target.dtype != torch.int64:
134
- raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
135
-
136
- if not self.inplace:
137
- batch = batch.clone()
138
- target = target.clone()
139
-
140
- if target.ndim == 1:
141
- target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
142
-
143
- if torch.rand(1).item() >= self.p:
144
- return batch, target
145
-
146
- # It's faster to roll the batch by one instead of shuffling it to create image pairs
147
- batch_rolled = batch.roll(1, 0)
148
- target_rolled = target.roll(1, 0)
149
-
150
- # Implemented as on cutmix paper, page 12 (with minor corrections on typos).
151
- lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
152
- _, H, W = F.get_dimensions(batch)
153
-
154
- r_x = torch.randint(W, (1,))
155
- r_y = torch.randint(H, (1,))
156
-
157
- r = 0.5 * math.sqrt(1.0 - lambda_param)
158
- r_w_half = int(r * W)
159
- r_h_half = int(r * H)
160
-
161
- x1 = int(torch.clamp(r_x - r_w_half, min=0))
162
- y1 = int(torch.clamp(r_y - r_h_half, min=0))
163
- x2 = int(torch.clamp(r_x + r_w_half, max=W))
164
- y2 = int(torch.clamp(r_y + r_h_half, max=H))
165
-
166
- batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
167
- lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
168
-
169
- target_rolled.mul_(1.0 - lambda_param)
170
- target.mul_(lambda_param).add_(target_rolled)
171
-
172
- return batch, target
173
-
174
- def __repr__(self) -> str:
175
- s = (
176
- f"{self.__class__.__name__}("
177
- f"num_classes={self.num_classes}"
178
- f", p={self.p}"
179
- f", alpha={self.alpha}"
180
- f", inplace={self.inplace}"
181
- f")"
182
- )
183
- return s
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
neural-archicture-search/trplib.py DELETED
@@ -1,127 +0,0 @@
1
- import types
2
- from typing import List, Callable
3
-
4
- import torch
5
- from torch import nn, Tensor
6
- from torch.nn import functional as F
7
- from torchvision.models.resnet import BasicBlock
8
-
9
-
10
- def trp_criterion(trp_blocks: nn.ModuleList, shared_head: Callable, criterion: Callable, lambdas: List[float], hidden_state: Tensor, logits: Tensor, targets: Tensor, loss_normalization=False):
11
- losses, rewards = criterion(logits, targets)
12
- returns = torch.ones_like(rewards, dtype=torch.float32, device=rewards.device)
13
- if loss_normalization:
14
- coeff = torch.mean(losses).detach()
15
-
16
- embeds = [hidden_state]
17
- predictions = []
18
- for k, w in enumerate(lambdas):
19
- embeds.append(trp_blocks[k](embeds[-1]))
20
- predictions.append(shared_head(embeds[-1]))
21
- returns = returns + w * rewards
22
- replica_losses, rewards = criterion(predictions[-1], targets, rewards)
23
- losses = losses + replica_losses
24
- loss = torch.mean(losses * returns)
25
-
26
- if loss_normalization:
27
- with torch.no_grad():
28
- coeff = torch.exp(coeff) / torch.exp(loss.detach())
29
- loss = coeff * loss
30
-
31
- return loss
32
-
33
-
34
- class TPBlock(nn.Module):
35
- def __init__(self, depths: int, inplanes: int, planes: int):
36
- super(TPBlock, self).__init__()
37
-
38
- blocks = [BasicBlock(inplanes=inplanes, planes=planes) for _ in range(depths)]
39
- self.blocks = nn.Sequential(*blocks)
40
- for name, param in self.blocks.named_parameters():
41
- if 'conv' in name:
42
- nn.init.zeros_(param) # Initialize weights
43
- elif 'downsample' in name:
44
- nn.init.zeros_(param) # Initialize biases
45
-
46
- def forward(self, x):
47
- return self.blocks(x)
48
-
49
-
50
- class ResNetConfig:
51
- @staticmethod
52
- def gen_criterion(label_smoothing=0.0, top_k=1):
53
- def func(input, target, mask=None):
54
- """
55
- Args:
56
- input (Tensor): Input tensor of shape [B, C].
57
- target (Tensor): Target labels of shape [B] or [B, C].
58
-
59
- Returns:
60
- loss (Tensor): Scalar tensor representing the loss.
61
- mask (Tensor): Boolean mask tensor of shape [B].
62
- """
63
- label = torch.argmax(target, dim=1) if label_smoothing > 0.0 else target
64
-
65
- unmasked_loss = F.cross_entropy(input, label, reduction="none", label_smoothing=label_smoothing)
66
- if mask is None:
67
- mask = torch.ones_like(unmasked_loss, dtype=torch.float32, device=target.device)
68
- loss = torch.sum(mask * unmasked_loss) / (torch.sum(mask) + 1e-6)
69
-
70
- with torch.no_grad():
71
- topk_values, topk_indices = torch.topk(input, top_k, dim=-1)
72
- mask = mask * torch.eq(topk_indices, label[:, None]).any(dim=-1).to(input.dtype)
73
-
74
- return loss, mask
75
- return func
76
-
77
- @staticmethod
78
- def gen_shared_head(self):
79
- def func(x):
80
- """
81
- Args:
82
- x (Tensor): Hidden States tensor of shape [B, C, H, Whidden_units].
83
-
84
- Returns:
85
- logits (Tensor): Logits tensor of shape [B, C].
86
- """
87
- x = self.layer4(x)
88
- x = self.avgpool(x)
89
- x = torch.flatten(x, 1)
90
- logits = self.fc(x)
91
- return logits
92
- return func
93
-
94
- @staticmethod
95
- def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
96
- def func(self, x: Tensor, targets=None) -> Tensor:
97
- x = self.conv1(x)
98
- x = self.bn1(x)
99
- x = self.relu(x)
100
- x = self.maxpool(x)
101
-
102
- x = self.layer1(x)
103
- x = self.layer2(x)
104
- hidden_states = self.layer3(x)
105
- x = self.layer4(hidden_states)
106
- x = self.avgpool(x)
107
- x = torch.flatten(x, 1)
108
- logits = self.fc(x)
109
-
110
- if self.training:
111
- shared_head = ResNetConfig.gen_shared_head(self)
112
- criterion = ResNetConfig.gen_criterion(label_smoothing=label_smoothing, top_k=top_k)
113
-
114
- loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, hidden_states, logits, targets, loss_normalization=loss_normalization)
115
-
116
- return logits, loss
117
-
118
- return logits
119
-
120
- return func
121
-
122
-
123
- def apply_trp(model, depths: List[int], planes: int, lambdas: List[float], **kwargs):
124
- print("✅ Applying TRP to ResNet for Image Classification...")
125
- model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, inplanes=planes, planes=planes) for d in depths])
126
- model.forward = types.MethodType(ResNetConfig.gen_forward(lambdas, True, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
127
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
neural-archicture-search/utils.py DELETED
@@ -1,465 +0,0 @@
1
- import copy
2
- import datetime
3
- import errno
4
- import hashlib
5
- import os
6
- import time
7
- from collections import defaultdict, deque, OrderedDict
8
- from typing import List, Optional, Tuple
9
-
10
- import torch
11
- import torch.distributed as dist
12
-
13
-
14
- class SmoothedValue:
15
- """Track a series of values and provide access to smoothed values over a
16
- window or the global series average.
17
- """
18
-
19
- def __init__(self, window_size=20, fmt=None):
20
- if fmt is None:
21
- fmt = "{median:.4f} ({global_avg:.4f})"
22
- self.deque = deque(maxlen=window_size)
23
- self.total = 0.0
24
- self.count = 0
25
- self.fmt = fmt
26
-
27
- def update(self, value, n=1):
28
- self.deque.append(value)
29
- self.count += n
30
- self.total += value * n
31
-
32
- def synchronize_between_processes(self):
33
- """
34
- Warning: does not synchronize the deque!
35
- """
36
- t = reduce_across_processes([self.count, self.total])
37
- t = t.tolist()
38
- self.count = int(t[0])
39
- self.total = t[1]
40
-
41
- @property
42
- def median(self):
43
- d = torch.tensor(list(self.deque))
44
- return d.median().item()
45
-
46
- @property
47
- def avg(self):
48
- d = torch.tensor(list(self.deque), dtype=torch.float32)
49
- return d.mean().item()
50
-
51
- @property
52
- def global_avg(self):
53
- return self.total / self.count
54
-
55
- @property
56
- def max(self):
57
- return max(self.deque)
58
-
59
- @property
60
- def value(self):
61
- return self.deque[-1]
62
-
63
- def __str__(self):
64
- return self.fmt.format(
65
- median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
66
- )
67
-
68
-
69
- class MetricLogger:
70
- def __init__(self, delimiter="\t"):
71
- self.meters = defaultdict(SmoothedValue)
72
- self.delimiter = delimiter
73
-
74
- def update(self, **kwargs):
75
- for k, v in kwargs.items():
76
- if isinstance(v, torch.Tensor):
77
- v = v.item()
78
- assert isinstance(v, (float, int))
79
- self.meters[k].update(v)
80
-
81
- def __getattr__(self, attr):
82
- if attr in self.meters:
83
- return self.meters[attr]
84
- if attr in self.__dict__:
85
- return self.__dict__[attr]
86
- raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
87
-
88
- def __str__(self):
89
- loss_str = []
90
- for name, meter in self.meters.items():
91
- loss_str.append(f"{name}: {str(meter)}")
92
- return self.delimiter.join(loss_str)
93
-
94
- def synchronize_between_processes(self):
95
- for meter in self.meters.values():
96
- meter.synchronize_between_processes()
97
-
98
- def add_meter(self, name, meter):
99
- self.meters[name] = meter
100
-
101
- def log_every(self, iterable, print_freq, header=None):
102
- i = 0
103
- if not header:
104
- header = ""
105
- start_time = time.time()
106
- end = time.time()
107
- iter_time = SmoothedValue(fmt="{avg:.4f}")
108
- data_time = SmoothedValue(fmt="{avg:.4f}")
109
- space_fmt = ":" + str(len(str(len(iterable)))) + "d"
110
- if torch.cuda.is_available():
111
- log_msg = self.delimiter.join(
112
- [
113
- header,
114
- "[{0" + space_fmt + "}/{1}]",
115
- "eta: {eta}",
116
- "{meters}",
117
- "time: {time}",
118
- "data: {data}",
119
- "max mem: {memory:.0f}",
120
- ]
121
- )
122
- else:
123
- log_msg = self.delimiter.join(
124
- [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
125
- )
126
- MB = 1024.0 * 1024.0
127
- for obj in iterable:
128
- data_time.update(time.time() - end)
129
- yield obj
130
- iter_time.update(time.time() - end)
131
- if i % print_freq == 0:
132
- eta_seconds = iter_time.global_avg * (len(iterable) - i)
133
- eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
134
- if torch.cuda.is_available():
135
- print(
136
- log_msg.format(
137
- i,
138
- len(iterable),
139
- eta=eta_string,
140
- meters=str(self),
141
- time=str(iter_time),
142
- data=str(data_time),
143
- memory=torch.cuda.max_memory_allocated() / MB,
144
- )
145
- )
146
- else:
147
- print(
148
- log_msg.format(
149
- i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
150
- )
151
- )
152
- i += 1
153
- end = time.time()
154
- total_time = time.time() - start_time
155
- total_time_str = str(datetime.timedelta(seconds=int(total_time)))
156
- print(f"{header} Total time: {total_time_str}")
157
-
158
-
159
- class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
160
- """Maintains moving averages of model parameters using an exponential decay.
161
- ``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
162
- `torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
163
- is used to compute the EMA.
164
- """
165
-
166
- def __init__(self, model, decay, device="cpu"):
167
- def ema_avg(avg_model_param, model_param, num_averaged):
168
- return decay * avg_model_param + (1 - decay) * model_param
169
-
170
- super().__init__(model, device, ema_avg, use_buffers=True)
171
-
172
-
173
- def accuracy(output, target, topk=(1,)):
174
- """Computes the accuracy over the k top predictions for the specified values of k"""
175
- with torch.inference_mode():
176
- maxk = max(topk)
177
- batch_size = target.size(0)
178
- if target.ndim == 2:
179
- target = target.max(dim=1)[1]
180
-
181
- _, pred = output.topk(maxk, 1, True, True)
182
- pred = pred.t()
183
- correct = pred.eq(target[None])
184
-
185
- res = []
186
- for k in topk:
187
- correct_k = correct[:k].flatten().sum(dtype=torch.float32)
188
- res.append(correct_k * (100.0 / batch_size))
189
- return res
190
-
191
-
192
- def mkdir(path):
193
- try:
194
- os.makedirs(path)
195
- except OSError as e:
196
- if e.errno != errno.EEXIST:
197
- raise
198
-
199
-
200
- def setup_for_distributed(is_master):
201
- """
202
- This function disables printing when not in master process
203
- """
204
- import builtins as __builtin__
205
-
206
- builtin_print = __builtin__.print
207
-
208
- def print(*args, **kwargs):
209
- force = kwargs.pop("force", False)
210
- if is_master or force:
211
- builtin_print(*args, **kwargs)
212
-
213
- __builtin__.print = print
214
-
215
-
216
- def is_dist_avail_and_initialized():
217
- if not dist.is_available():
218
- return False
219
- if not dist.is_initialized():
220
- return False
221
- return True
222
-
223
-
224
- def get_world_size():
225
- if not is_dist_avail_and_initialized():
226
- return 1
227
- return dist.get_world_size()
228
-
229
-
230
- def get_rank():
231
- if not is_dist_avail_and_initialized():
232
- return 0
233
- return dist.get_rank()
234
-
235
-
236
- def is_main_process():
237
- return get_rank() == 0
238
-
239
-
240
- def save_on_master(*args, **kwargs):
241
- if is_main_process():
242
- torch.save(*args, **kwargs)
243
-
244
-
245
- def init_distributed_mode(args):
246
- if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
247
- args.rank = int(os.environ["RANK"])
248
- args.world_size = int(os.environ["WORLD_SIZE"])
249
- args.gpu = int(os.environ["LOCAL_RANK"])
250
- elif "SLURM_PROCID" in os.environ:
251
- args.rank = int(os.environ["SLURM_PROCID"])
252
- args.gpu = args.rank % torch.cuda.device_count()
253
- elif hasattr(args, "rank"):
254
- pass
255
- else:
256
- print("Not using distributed mode")
257
- args.distributed = False
258
- return
259
-
260
- args.distributed = True
261
-
262
- torch.cuda.set_device(args.gpu)
263
- args.dist_backend = "nccl"
264
- print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
265
- torch.distributed.init_process_group(
266
- backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
267
- )
268
- torch.distributed.barrier()
269
- setup_for_distributed(args.rank == 0)
270
-
271
-
272
- def average_checkpoints(inputs):
273
- """Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from:
274
- https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16
275
-
276
- Args:
277
- inputs (List[str]): An iterable of string paths of checkpoints to load from.
278
- Returns:
279
- A dict of string keys mapping to various values. The 'model' key
280
- from the returned dict should correspond to an OrderedDict mapping
281
- string parameter names to torch Tensors.
282
- """
283
- params_dict = OrderedDict()
284
- params_keys = None
285
- new_state = None
286
- num_models = len(inputs)
287
- for fpath in inputs:
288
- with open(fpath, "rb") as f:
289
- state = torch.load(
290
- f,
291
- map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")),
292
- )
293
- # Copies over the settings from the first checkpoint
294
- if new_state is None:
295
- new_state = state
296
- model_params = state["model"]
297
- model_params_keys = list(model_params.keys())
298
- if params_keys is None:
299
- params_keys = model_params_keys
300
- elif params_keys != model_params_keys:
301
- raise KeyError(
302
- f"For checkpoint {f}, expected list of params: {params_keys}, but found: {model_params_keys}"
303
- )
304
- for k in params_keys:
305
- p = model_params[k]
306
- if isinstance(p, torch.HalfTensor):
307
- p = p.float()
308
- if k not in params_dict:
309
- params_dict[k] = p.clone()
310
- # NOTE: clone() is needed in case of p is a shared parameter
311
- else:
312
- params_dict[k] += p
313
- averaged_params = OrderedDict()
314
- for k, v in params_dict.items():
315
- averaged_params[k] = v
316
- if averaged_params[k].is_floating_point():
317
- averaged_params[k].div_(num_models)
318
- else:
319
- averaged_params[k] //= num_models
320
- new_state["model"] = averaged_params
321
- return new_state
322
-
323
-
324
- def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=True):
325
- """
326
- This method can be used to prepare weights files for new models. It receives as
327
- input a model architecture and a checkpoint from the training script and produces
328
- a file with the weights ready for release.
329
-
330
- Examples:
331
- from torchvision import models as M
332
-
333
- # Classification
334
- model = M.mobilenet_v3_large(weights=None)
335
- print(store_model_weights(model, './class.pth'))
336
-
337
- # Quantized Classification
338
- model = M.quantization.mobilenet_v3_large(weights=None, quantize=False)
339
- model.fuse_model(is_qat=True)
340
- model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
341
- _ = torch.ao.quantization.prepare_qat(model, inplace=True)
342
- print(store_model_weights(model, './qat.pth'))
343
-
344
- # Object Detection
345
- model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=None, weights_backbone=None)
346
- print(store_model_weights(model, './obj.pth'))
347
-
348
- # Segmentation
349
- model = M.segmentation.deeplabv3_mobilenet_v3_large(weights=None, weights_backbone=None, aux_loss=True)
350
- print(store_model_weights(model, './segm.pth', strict=False))
351
-
352
- Args:
353
- model (pytorch.nn.Module): The model on which the weights will be loaded for validation purposes.
354
- checkpoint_path (str): The path of the checkpoint we will load.
355
- checkpoint_key (str, optional): The key of the checkpoint where the model weights are stored.
356
- Default: "model".
357
- strict (bool): whether to strictly enforce that the keys
358
- in :attr:`state_dict` match the keys returned by this module's
359
- :meth:`~torch.nn.Module.state_dict` function. Default: ``True``
360
-
361
- Returns:
362
- output_path (str): The location where the weights are saved.
363
- """
364
- # Store the new model next to the checkpoint_path
365
- checkpoint_path = os.path.abspath(checkpoint_path)
366
- output_dir = os.path.dirname(checkpoint_path)
367
-
368
- # Deep copy to avoid side-effects on the model object.
369
- model = copy.deepcopy(model)
370
- checkpoint = torch.load(checkpoint_path, map_location="cpu")
371
-
372
- # Load the weights to the model to validate that everything works
373
- # and remove unnecessary weights (such as auxiliaries, etc)
374
- if checkpoint_key == "model_ema":
375
- del checkpoint[checkpoint_key]["n_averaged"]
376
- torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(checkpoint[checkpoint_key], "module.")
377
- model.load_state_dict(checkpoint[checkpoint_key], strict=strict)
378
-
379
- tmp_path = os.path.join(output_dir, str(model.__hash__()))
380
- torch.save(model.state_dict(), tmp_path)
381
-
382
- sha256_hash = hashlib.sha256()
383
- with open(tmp_path, "rb") as f:
384
- # Read and update hash string value in blocks of 4K
385
- for byte_block in iter(lambda: f.read(4096), b""):
386
- sha256_hash.update(byte_block)
387
- hh = sha256_hash.hexdigest()
388
-
389
- output_path = os.path.join(output_dir, "weights-" + str(hh[:8]) + ".pth")
390
- os.replace(tmp_path, output_path)
391
-
392
- return output_path
393
-
394
-
395
- def reduce_across_processes(val):
396
- if not is_dist_avail_and_initialized():
397
- # nothing to sync, but we still convert to tensor for consistency with the distributed case.
398
- return torch.tensor(val)
399
-
400
- t = torch.tensor(val, device="cuda")
401
- dist.barrier()
402
- dist.all_reduce(t)
403
- return t
404
-
405
-
406
- def set_weight_decay(
407
- model: torch.nn.Module,
408
- weight_decay: float,
409
- norm_weight_decay: Optional[float] = None,
410
- norm_classes: Optional[List[type]] = None,
411
- custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None,
412
- ):
413
- if not norm_classes:
414
- norm_classes = [
415
- torch.nn.modules.batchnorm._BatchNorm,
416
- torch.nn.LayerNorm,
417
- torch.nn.GroupNorm,
418
- torch.nn.modules.instancenorm._InstanceNorm,
419
- torch.nn.LocalResponseNorm,
420
- ]
421
- norm_classes = tuple(norm_classes)
422
-
423
- params = {
424
- "other": [],
425
- "norm": [],
426
- }
427
- params_weight_decay = {
428
- "other": weight_decay,
429
- "norm": norm_weight_decay,
430
- }
431
- custom_keys = []
432
- if custom_keys_weight_decay is not None:
433
- for key, weight_decay in custom_keys_weight_decay:
434
- params[key] = []
435
- params_weight_decay[key] = weight_decay
436
- custom_keys.append(key)
437
-
438
- def _add_params(module, prefix=""):
439
- for name, p in module.named_parameters(recurse=False):
440
- if not p.requires_grad:
441
- continue
442
- is_custom_key = False
443
- for key in custom_keys:
444
- target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name
445
- if key == target_name:
446
- params[key].append(p)
447
- is_custom_key = True
448
- break
449
- if not is_custom_key:
450
- if norm_weight_decay is not None and isinstance(module, norm_classes):
451
- params["norm"].append(p)
452
- else:
453
- params["other"].append(p)
454
-
455
- for child_name, child_module in module.named_children():
456
- child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
457
- _add_params(child_module, prefix=child_prefix)
458
-
459
- _add_params(model)
460
-
461
- param_groups = []
462
- for key in params:
463
- if len(params[key]) > 0:
464
- param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]})
465
- return param_groups