Delete neural-archicture-search
Browse files- neural-archicture-search/presets.py +0 -71
- neural-archicture-search/resnet18/model_3.pth +0 -3
- neural-archicture-search/resnet34/model_8.pth +0 -3
- neural-archicture-search/resnet50/model_9.pth +0 -3
- neural-archicture-search/run.sh +0 -33
- neural-archicture-search/sampler.py +0 -62
- neural-archicture-search/train.py +0 -524
- neural-archicture-search/train_quantization.py +0 -265
- neural-archicture-search/transforms.py +0 -183
- neural-archicture-search/trplib.py +0 -127
- neural-archicture-search/utils.py +0 -465
neural-archicture-search/presets.py
DELETED
@@ -1,71 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torchvision.transforms import autoaugment, transforms
|
3 |
-
from torchvision.transforms.functional import InterpolationMode
|
4 |
-
|
5 |
-
|
6 |
-
class ClassificationPresetTrain:
|
7 |
-
def __init__(
|
8 |
-
self,
|
9 |
-
*,
|
10 |
-
crop_size,
|
11 |
-
mean=(0.485, 0.456, 0.406),
|
12 |
-
std=(0.229, 0.224, 0.225),
|
13 |
-
interpolation=InterpolationMode.BILINEAR,
|
14 |
-
hflip_prob=0.5,
|
15 |
-
auto_augment_policy=None,
|
16 |
-
ra_magnitude=9,
|
17 |
-
augmix_severity=3,
|
18 |
-
random_erase_prob=0.0,
|
19 |
-
):
|
20 |
-
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
|
21 |
-
if hflip_prob > 0:
|
22 |
-
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
|
23 |
-
if auto_augment_policy is not None:
|
24 |
-
if auto_augment_policy == "ra":
|
25 |
-
trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
|
26 |
-
elif auto_augment_policy == "ta_wide":
|
27 |
-
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
|
28 |
-
elif auto_augment_policy == "augmix":
|
29 |
-
trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity))
|
30 |
-
else:
|
31 |
-
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
|
32 |
-
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
|
33 |
-
trans.extend(
|
34 |
-
[
|
35 |
-
transforms.PILToTensor(),
|
36 |
-
transforms.ConvertImageDtype(torch.float),
|
37 |
-
transforms.Normalize(mean=mean, std=std),
|
38 |
-
]
|
39 |
-
)
|
40 |
-
if random_erase_prob > 0:
|
41 |
-
trans.append(transforms.RandomErasing(p=random_erase_prob))
|
42 |
-
|
43 |
-
self.transforms = transforms.Compose(trans)
|
44 |
-
|
45 |
-
def __call__(self, img):
|
46 |
-
return self.transforms(img)
|
47 |
-
|
48 |
-
|
49 |
-
class ClassificationPresetEval:
|
50 |
-
def __init__(
|
51 |
-
self,
|
52 |
-
*,
|
53 |
-
crop_size,
|
54 |
-
resize_size=256,
|
55 |
-
mean=(0.485, 0.456, 0.406),
|
56 |
-
std=(0.229, 0.224, 0.225),
|
57 |
-
interpolation=InterpolationMode.BILINEAR,
|
58 |
-
):
|
59 |
-
|
60 |
-
self.transforms = transforms.Compose(
|
61 |
-
[
|
62 |
-
transforms.Resize(resize_size, interpolation=interpolation),
|
63 |
-
transforms.CenterCrop(crop_size),
|
64 |
-
transforms.PILToTensor(),
|
65 |
-
transforms.ConvertImageDtype(torch.float),
|
66 |
-
transforms.Normalize(mean=mean, std=std),
|
67 |
-
]
|
68 |
-
)
|
69 |
-
|
70 |
-
def __call__(self, img):
|
71 |
-
return self.transforms(img)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
neural-archicture-search/resnet18/model_3.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:e728a634490a078e1a672f464b9baebc04774f83b03fc251ad2437a2731330a0
|
3 |
-
size 136133334
|
|
|
|
|
|
|
|
neural-archicture-search/resnet34/model_8.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:179a2d39490980a6cc801c3ef15230bfe08d7e941174d79e2099c8db8b11dfcf
|
3 |
-
size 202898970
|
|
|
|
|
|
|
|
neural-archicture-search/resnet50/model_9.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:fecb231a0e220e46dde1025a7403bb1f587d81f6d36d143ba1c510c3b477a122
|
3 |
-
size 431365452
|
|
|
|
|
|
|
|
neural-archicture-search/run.sh
DELETED
@@ -1,33 +0,0 @@
|
|
1 |
-
# # ✅ Test: Acc@1 70.092 Acc@5 89.314
|
2 |
-
# torchrun --nproc_per_node=4 train.py\
|
3 |
-
# --data-path /home/cs/Documents/datasets/imagenet\
|
4 |
-
# --model resnet18 --output-dir resnet18 --weights ResNet18_Weights.IMAGENET1K_V1\
|
5 |
-
# --batch-size 128 --epochs 10 --lr 0.0004 --lr-step-size 2 --lr-gamma 0.5\
|
6 |
-
# --lr-warmup-method constant --lr-warmup-epochs 1 --lr-warmup-decay 0.\
|
7 |
-
# --apply-trp --trp-depths 3 3 3 --trp-planes 256 --trp-lambdas 0.4 0.2 0.1 --print-freq 100
|
8 |
-
torchrun --nproc_per_node=4 train.py\
|
9 |
-
--data-path /home/cs/Documents/datasets/imagenet\
|
10 |
-
--model resnet18 --resume resnet18/model_3.pth --test-only
|
11 |
-
|
12 |
-
# # ✅ Test: Acc@1 73.900 Acc@5 91.536
|
13 |
-
# torchrun --nproc_per_node=4 train.py\
|
14 |
-
# --data-path /home/cs/Documents/datasets/imagenet\
|
15 |
-
# --model resnet34 --output-dir resnet34 --weights ResNet34_Weights.IMAGENET1K_V1\
|
16 |
-
# --batch-size 96 --epochs 10 --lr 0.0004 --lr-step-size 2 --lr-gamma 0.5\
|
17 |
-
# --lr-warmup-method constant --lr-warmup-epochs 1 --lr-warmup-decay 0.\
|
18 |
-
# --apply-trp --trp-depths 2 2 2 --trp-planes 256 --trp-lambdas 0.4 0.2 0.1 --print-freq 100
|
19 |
-
# torchrun --nproc_per_node=4 train.py\
|
20 |
-
# --data-path /home/cs/Documents/datasets/imagenet\
|
21 |
-
# --model resnet34 --resume resnet34/model_8.pth --test-only
|
22 |
-
|
23 |
-
|
24 |
-
# # ✅ Test: Acc@1 76.896 Acc@5 93.136
|
25 |
-
# torchrun --nproc_per_node=4 train.py\
|
26 |
-
# --data-path /home/cs/Documents/datasets/imagenet\
|
27 |
-
# --model resnet50 --output-dir resnet50 --weights ResNet50_Weights.IMAGENET1K_V1\
|
28 |
-
# --batch-size 64 --epochs 10 --lr 0.0004 --lr-step-size 2 --lr-gamma 0.5\
|
29 |
-
# --lr-warmup-method constant --lr-warmup-epochs 1 --lr-warmup-decay 0.\
|
30 |
-
# --apply-trp --trp-depths 1 1 1 --trp-planes 1024 --trp-lambdas 0.4 0.2 0.1 --print-freq 100
|
31 |
-
# torchrun --nproc_per_node=4 train.py\
|
32 |
-
# --data-path /home/cs/Documents/datasets/imagenet\
|
33 |
-
# --model resnet50 --resume resnet50/model_9.pth --test-only
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
neural-archicture-search/sampler.py
DELETED
@@ -1,62 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import torch.distributed as dist
|
5 |
-
|
6 |
-
|
7 |
-
class RASampler(torch.utils.data.Sampler):
|
8 |
-
"""Sampler that restricts data loading to a subset of the dataset for distributed,
|
9 |
-
with repeated augmentation.
|
10 |
-
It ensures that different each augmented version of a sample will be visible to a
|
11 |
-
different process (GPU).
|
12 |
-
Heavily based on 'torch.utils.data.DistributedSampler'.
|
13 |
-
|
14 |
-
This is borrowed from the DeiT Repo:
|
15 |
-
https://github.com/facebookresearch/deit/blob/main/samplers.py
|
16 |
-
"""
|
17 |
-
|
18 |
-
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3):
|
19 |
-
if num_replicas is None:
|
20 |
-
if not dist.is_available():
|
21 |
-
raise RuntimeError("Requires distributed package to be available!")
|
22 |
-
num_replicas = dist.get_world_size()
|
23 |
-
if rank is None:
|
24 |
-
if not dist.is_available():
|
25 |
-
raise RuntimeError("Requires distributed package to be available!")
|
26 |
-
rank = dist.get_rank()
|
27 |
-
self.dataset = dataset
|
28 |
-
self.num_replicas = num_replicas
|
29 |
-
self.rank = rank
|
30 |
-
self.epoch = 0
|
31 |
-
self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
|
32 |
-
self.total_size = self.num_samples * self.num_replicas
|
33 |
-
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
|
34 |
-
self.shuffle = shuffle
|
35 |
-
self.seed = seed
|
36 |
-
self.repetitions = repetitions
|
37 |
-
|
38 |
-
def __iter__(self):
|
39 |
-
if self.shuffle:
|
40 |
-
# Deterministically shuffle based on epoch
|
41 |
-
g = torch.Generator()
|
42 |
-
g.manual_seed(self.seed + self.epoch)
|
43 |
-
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
44 |
-
else:
|
45 |
-
indices = list(range(len(self.dataset)))
|
46 |
-
|
47 |
-
# Add extra samples to make it evenly divisible
|
48 |
-
indices = [ele for ele in indices for i in range(self.repetitions)]
|
49 |
-
indices += indices[: (self.total_size - len(indices))]
|
50 |
-
assert len(indices) == self.total_size
|
51 |
-
|
52 |
-
# Subsample
|
53 |
-
indices = indices[self.rank : self.total_size : self.num_replicas]
|
54 |
-
assert len(indices) == self.num_samples
|
55 |
-
|
56 |
-
return iter(indices[: self.num_selected_samples])
|
57 |
-
|
58 |
-
def __len__(self):
|
59 |
-
return self.num_selected_samples
|
60 |
-
|
61 |
-
def set_epoch(self, epoch):
|
62 |
-
self.epoch = epoch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
neural-archicture-search/train.py
DELETED
@@ -1,524 +0,0 @@
|
|
1 |
-
import datetime
|
2 |
-
import os
|
3 |
-
import time
|
4 |
-
import warnings
|
5 |
-
|
6 |
-
import presets
|
7 |
-
import torch
|
8 |
-
import torch.utils.data
|
9 |
-
import torchvision
|
10 |
-
import transforms
|
11 |
-
import utils
|
12 |
-
from sampler import RASampler
|
13 |
-
from torch import nn
|
14 |
-
from torch.utils.data.dataloader import default_collate
|
15 |
-
from torchvision.transforms.functional import InterpolationMode
|
16 |
-
|
17 |
-
from trplib import apply_trp
|
18 |
-
|
19 |
-
|
20 |
-
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
|
21 |
-
model.train()
|
22 |
-
metric_logger = utils.MetricLogger(delimiter=" ")
|
23 |
-
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
|
24 |
-
metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
|
25 |
-
|
26 |
-
header = f"Epoch: [{epoch}]"
|
27 |
-
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
|
28 |
-
start_time = time.time()
|
29 |
-
image, target = image.to(device), target.to(device)
|
30 |
-
with torch.amp.autocast("cuda", enabled=scaler is not None):
|
31 |
-
# output = model(image)
|
32 |
-
# loss = criterion(output, target)
|
33 |
-
output, loss = model(image, target)
|
34 |
-
|
35 |
-
optimizer.zero_grad()
|
36 |
-
if scaler is not None:
|
37 |
-
scaler.scale(loss).backward()
|
38 |
-
if args.clip_grad_norm is not None:
|
39 |
-
# we should unscale the gradients of optimizer's assigned params if do gradient clipping
|
40 |
-
scaler.unscale_(optimizer)
|
41 |
-
nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
|
42 |
-
scaler.step(optimizer)
|
43 |
-
scaler.update()
|
44 |
-
else:
|
45 |
-
loss.backward()
|
46 |
-
if args.clip_grad_norm is not None:
|
47 |
-
nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
|
48 |
-
optimizer.step()
|
49 |
-
|
50 |
-
if model_ema and i % args.model_ema_steps == 0:
|
51 |
-
model_ema.update_parameters(model)
|
52 |
-
if epoch < args.lr_warmup_epochs:
|
53 |
-
# Reset ema buffer to keep copying weights during warmup period
|
54 |
-
model_ema.n_averaged.fill_(0)
|
55 |
-
|
56 |
-
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
|
57 |
-
batch_size = image.shape[0]
|
58 |
-
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
|
59 |
-
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
|
60 |
-
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
|
61 |
-
metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
|
62 |
-
|
63 |
-
|
64 |
-
def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
|
65 |
-
model.eval()
|
66 |
-
metric_logger = utils.MetricLogger(delimiter=" ")
|
67 |
-
header = f"Test: {log_suffix}"
|
68 |
-
|
69 |
-
num_processed_samples = 0
|
70 |
-
with torch.inference_mode():
|
71 |
-
for image, target in metric_logger.log_every(data_loader, print_freq, header):
|
72 |
-
image = image.to(device, non_blocking=True)
|
73 |
-
target = target.to(device, non_blocking=True)
|
74 |
-
output = model(image)
|
75 |
-
loss = criterion(output, target)
|
76 |
-
|
77 |
-
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
|
78 |
-
# FIXME need to take into account that the datasets
|
79 |
-
# could have been padded in distributed setup
|
80 |
-
batch_size = image.shape[0]
|
81 |
-
metric_logger.update(loss=loss.item())
|
82 |
-
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
|
83 |
-
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
|
84 |
-
num_processed_samples += batch_size
|
85 |
-
# gather the stats from all processes
|
86 |
-
|
87 |
-
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
|
88 |
-
if (
|
89 |
-
hasattr(data_loader.dataset, "__len__")
|
90 |
-
and len(data_loader.dataset) != num_processed_samples
|
91 |
-
and torch.distributed.get_rank() == 0
|
92 |
-
):
|
93 |
-
# See FIXME above
|
94 |
-
warnings.warn(
|
95 |
-
f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
|
96 |
-
"samples were used for the validation, which might bias the results. "
|
97 |
-
"Try adjusting the batch size and / or the world size. "
|
98 |
-
"Setting the world size to 1 is always a safe bet."
|
99 |
-
)
|
100 |
-
|
101 |
-
metric_logger.synchronize_between_processes()
|
102 |
-
|
103 |
-
print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
|
104 |
-
return metric_logger.acc1.global_avg
|
105 |
-
|
106 |
-
|
107 |
-
def _get_cache_path(filepath):
|
108 |
-
import hashlib
|
109 |
-
|
110 |
-
h = hashlib.sha1(filepath.encode()).hexdigest()
|
111 |
-
cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
|
112 |
-
cache_path = os.path.expanduser(cache_path)
|
113 |
-
return cache_path
|
114 |
-
|
115 |
-
|
116 |
-
def load_data(traindir, valdir, args):
|
117 |
-
# Data loading code
|
118 |
-
print("Loading data")
|
119 |
-
val_resize_size, val_crop_size, train_crop_size = (
|
120 |
-
args.val_resize_size,
|
121 |
-
args.val_crop_size,
|
122 |
-
args.train_crop_size,
|
123 |
-
)
|
124 |
-
interpolation = InterpolationMode(args.interpolation)
|
125 |
-
|
126 |
-
print("Loading training data")
|
127 |
-
st = time.time()
|
128 |
-
cache_path = _get_cache_path(traindir)
|
129 |
-
if args.cache_dataset and os.path.exists(cache_path):
|
130 |
-
# Attention, as the transforms are also cached!
|
131 |
-
print(f"Loading dataset_train from {cache_path}")
|
132 |
-
dataset, _ = torch.load(cache_path)
|
133 |
-
else:
|
134 |
-
auto_augment_policy = getattr(args, "auto_augment", None)
|
135 |
-
random_erase_prob = getattr(args, "random_erase", 0.0)
|
136 |
-
ra_magnitude = args.ra_magnitude
|
137 |
-
augmix_severity = args.augmix_severity
|
138 |
-
dataset = torchvision.datasets.ImageFolder(
|
139 |
-
traindir,
|
140 |
-
presets.ClassificationPresetTrain(
|
141 |
-
crop_size=train_crop_size,
|
142 |
-
interpolation=interpolation,
|
143 |
-
auto_augment_policy=auto_augment_policy,
|
144 |
-
random_erase_prob=random_erase_prob,
|
145 |
-
ra_magnitude=ra_magnitude,
|
146 |
-
augmix_severity=augmix_severity,
|
147 |
-
),
|
148 |
-
)
|
149 |
-
if args.cache_dataset:
|
150 |
-
print(f"Saving dataset_train to {cache_path}")
|
151 |
-
utils.mkdir(os.path.dirname(cache_path))
|
152 |
-
utils.save_on_master((dataset, traindir), cache_path)
|
153 |
-
print("Took", time.time() - st)
|
154 |
-
|
155 |
-
print("Loading validation data")
|
156 |
-
cache_path = _get_cache_path(valdir)
|
157 |
-
if args.cache_dataset and os.path.exists(cache_path):
|
158 |
-
# Attention, as the transforms are also cached!
|
159 |
-
print(f"Loading dataset_test from {cache_path}")
|
160 |
-
dataset_test, _ = torch.load(cache_path)
|
161 |
-
else:
|
162 |
-
if args.weights and args.test_only:
|
163 |
-
weights = torchvision.models.get_weight(args.weights)
|
164 |
-
preprocessing = weights.transforms()
|
165 |
-
else:
|
166 |
-
preprocessing = presets.ClassificationPresetEval(
|
167 |
-
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
|
168 |
-
)
|
169 |
-
|
170 |
-
dataset_test = torchvision.datasets.ImageFolder(
|
171 |
-
valdir,
|
172 |
-
preprocessing,
|
173 |
-
)
|
174 |
-
if args.cache_dataset:
|
175 |
-
print(f"Saving dataset_test to {cache_path}")
|
176 |
-
utils.mkdir(os.path.dirname(cache_path))
|
177 |
-
utils.save_on_master((dataset_test, valdir), cache_path)
|
178 |
-
|
179 |
-
print("Creating data loaders")
|
180 |
-
if args.distributed:
|
181 |
-
if hasattr(args, "ra_sampler") and args.ra_sampler:
|
182 |
-
train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
|
183 |
-
else:
|
184 |
-
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
185 |
-
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
|
186 |
-
else:
|
187 |
-
train_sampler = torch.utils.data.RandomSampler(dataset)
|
188 |
-
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
189 |
-
|
190 |
-
return dataset, dataset_test, train_sampler, test_sampler
|
191 |
-
|
192 |
-
|
193 |
-
def main(args):
|
194 |
-
if args.output_dir:
|
195 |
-
utils.mkdir(args.output_dir)
|
196 |
-
|
197 |
-
utils.init_distributed_mode(args)
|
198 |
-
print(args)
|
199 |
-
|
200 |
-
device = torch.device(args.device)
|
201 |
-
|
202 |
-
if args.use_deterministic_algorithms:
|
203 |
-
torch.backends.cudnn.benchmark = False
|
204 |
-
torch.use_deterministic_algorithms(True)
|
205 |
-
else:
|
206 |
-
torch.backends.cudnn.benchmark = True
|
207 |
-
|
208 |
-
train_dir = os.path.join(args.data_path, "train")
|
209 |
-
val_dir = os.path.join(args.data_path, "val")
|
210 |
-
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
|
211 |
-
|
212 |
-
collate_fn = None
|
213 |
-
num_classes = len(dataset.classes)
|
214 |
-
mixup_transforms = []
|
215 |
-
if args.mixup_alpha > 0.0:
|
216 |
-
mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
|
217 |
-
if args.cutmix_alpha > 0.0:
|
218 |
-
mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
|
219 |
-
if mixup_transforms:
|
220 |
-
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
|
221 |
-
|
222 |
-
def collate_fn(batch):
|
223 |
-
return mixupcutmix(*default_collate(batch))
|
224 |
-
|
225 |
-
data_loader = torch.utils.data.DataLoader(
|
226 |
-
dataset,
|
227 |
-
batch_size=args.batch_size,
|
228 |
-
sampler=train_sampler,
|
229 |
-
num_workers=args.workers,
|
230 |
-
pin_memory=True,
|
231 |
-
collate_fn=collate_fn,
|
232 |
-
)
|
233 |
-
data_loader_test = torch.utils.data.DataLoader(
|
234 |
-
dataset_test, batch_size=8, sampler=test_sampler, num_workers=args.workers, pin_memory=True
|
235 |
-
)
|
236 |
-
|
237 |
-
print("Creating model")
|
238 |
-
model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
|
239 |
-
if args.apply_trp:
|
240 |
-
model = apply_trp(model, args.trp_depths, args.trp_planes, args.trp_lambdas, label_smoothing=args.label_smoothing)
|
241 |
-
model.to(device)
|
242 |
-
|
243 |
-
if args.distributed and args.sync_bn:
|
244 |
-
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
245 |
-
|
246 |
-
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
247 |
-
|
248 |
-
custom_keys_weight_decay = []
|
249 |
-
if args.bias_weight_decay is not None:
|
250 |
-
custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
|
251 |
-
if args.transformer_embedding_decay is not None:
|
252 |
-
for key in ["class_token", "position_embedding", "relative_position_bias_table"]:
|
253 |
-
custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
|
254 |
-
parameters = utils.set_weight_decay(
|
255 |
-
model,
|
256 |
-
args.weight_decay,
|
257 |
-
norm_weight_decay=args.norm_weight_decay,
|
258 |
-
custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
|
259 |
-
)
|
260 |
-
|
261 |
-
opt_name = args.opt.lower()
|
262 |
-
if opt_name.startswith("sgd"):
|
263 |
-
optimizer = torch.optim.SGD(
|
264 |
-
parameters,
|
265 |
-
lr=args.lr,
|
266 |
-
momentum=args.momentum,
|
267 |
-
weight_decay=args.weight_decay,
|
268 |
-
nesterov="nesterov" in opt_name,
|
269 |
-
)
|
270 |
-
elif opt_name == "rmsprop":
|
271 |
-
optimizer = torch.optim.RMSprop(
|
272 |
-
parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
|
273 |
-
)
|
274 |
-
elif opt_name == "adamw":
|
275 |
-
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
|
276 |
-
else:
|
277 |
-
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")
|
278 |
-
|
279 |
-
scaler = torch.amp.GradScaler("cuda") if args.amp else None
|
280 |
-
|
281 |
-
args.lr_scheduler = args.lr_scheduler.lower()
|
282 |
-
if args.lr_scheduler == "steplr":
|
283 |
-
main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
|
284 |
-
elif args.lr_scheduler == "cosineannealinglr":
|
285 |
-
main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
286 |
-
optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min
|
287 |
-
)
|
288 |
-
elif args.lr_scheduler == "exponentiallr":
|
289 |
-
main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
|
290 |
-
else:
|
291 |
-
raise RuntimeError(
|
292 |
-
f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
|
293 |
-
"are supported."
|
294 |
-
)
|
295 |
-
|
296 |
-
if args.lr_warmup_epochs > 0:
|
297 |
-
if args.lr_warmup_method == "linear":
|
298 |
-
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
|
299 |
-
optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
|
300 |
-
)
|
301 |
-
elif args.lr_warmup_method == "constant":
|
302 |
-
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
|
303 |
-
optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
|
304 |
-
)
|
305 |
-
else:
|
306 |
-
raise RuntimeError(
|
307 |
-
f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
|
308 |
-
)
|
309 |
-
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
|
310 |
-
optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
|
311 |
-
)
|
312 |
-
else:
|
313 |
-
lr_scheduler = main_lr_scheduler
|
314 |
-
|
315 |
-
model_without_ddp = model
|
316 |
-
if args.distributed:
|
317 |
-
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
318 |
-
model_without_ddp = model.module
|
319 |
-
|
320 |
-
model_ema = None
|
321 |
-
if args.model_ema:
|
322 |
-
# Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at:
|
323 |
-
# https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
|
324 |
-
#
|
325 |
-
# total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
|
326 |
-
# We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus:
|
327 |
-
# adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
|
328 |
-
adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
|
329 |
-
alpha = 1.0 - args.model_ema_decay
|
330 |
-
alpha = min(1.0, alpha * adjust)
|
331 |
-
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
|
332 |
-
|
333 |
-
if args.resume:
|
334 |
-
checkpoint = torch.load(args.resume, map_location="cpu", weights_only=False)
|
335 |
-
model_without_ddp.load_state_dict(checkpoint["model"])
|
336 |
-
if not args.test_only:
|
337 |
-
optimizer.load_state_dict(checkpoint["optimizer"])
|
338 |
-
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
|
339 |
-
args.start_epoch = checkpoint["epoch"] + 1
|
340 |
-
if model_ema:
|
341 |
-
model_ema.load_state_dict(checkpoint["model_ema"])
|
342 |
-
if scaler:
|
343 |
-
scaler.load_state_dict(checkpoint["scaler"])
|
344 |
-
|
345 |
-
if args.test_only:
|
346 |
-
# We disable the cudnn benchmarking because it can noticeably affect the accuracy
|
347 |
-
torch.backends.cudnn.benchmark = False
|
348 |
-
torch.backends.cudnn.deterministic = True
|
349 |
-
if model_ema:
|
350 |
-
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
|
351 |
-
else:
|
352 |
-
evaluate(model, criterion, data_loader_test, device=device)
|
353 |
-
return
|
354 |
-
|
355 |
-
print("Start training")
|
356 |
-
start_time = time.time()
|
357 |
-
for epoch in range(args.start_epoch, args.epochs):
|
358 |
-
if args.distributed:
|
359 |
-
train_sampler.set_epoch(epoch)
|
360 |
-
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
|
361 |
-
lr_scheduler.step()
|
362 |
-
evaluate(model, criterion, data_loader_test, device=device)
|
363 |
-
if model_ema:
|
364 |
-
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
|
365 |
-
if args.output_dir:
|
366 |
-
checkpoint = {
|
367 |
-
"model": model_without_ddp.state_dict() if not args.apply_trp else {k: v for k, v in model_without_ddp.state_dict().items() if not k.startswith("trp_blocks")}, # NOTE: remove TRP heads
|
368 |
-
"optimizer": optimizer.state_dict(),
|
369 |
-
"lr_scheduler": lr_scheduler.state_dict(),
|
370 |
-
"epoch": epoch,
|
371 |
-
"args": args,
|
372 |
-
}
|
373 |
-
if model_ema:
|
374 |
-
checkpoint["model_ema"] = model_ema.state_dict() if not args.apply_trp else {k: v for k, v in model_ema.state_dict().items() if not k.startswith("trp_blocks")} # NOTE: remove TRP heads
|
375 |
-
if scaler:
|
376 |
-
checkpoint["scaler"] = scaler.state_dict()
|
377 |
-
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
|
378 |
-
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
|
379 |
-
|
380 |
-
total_time = time.time() - start_time
|
381 |
-
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
382 |
-
print(f"Training time {total_time_str}")
|
383 |
-
|
384 |
-
|
385 |
-
def get_args_parser(add_help=True):
|
386 |
-
import argparse
|
387 |
-
|
388 |
-
parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
|
389 |
-
|
390 |
-
parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
|
391 |
-
parser.add_argument("--model", default="resnet18", type=str, help="model name")
|
392 |
-
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
|
393 |
-
parser.add_argument(
|
394 |
-
"-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
|
395 |
-
)
|
396 |
-
parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
|
397 |
-
parser.add_argument(
|
398 |
-
"-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
|
399 |
-
)
|
400 |
-
parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
|
401 |
-
parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
|
402 |
-
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
|
403 |
-
parser.add_argument(
|
404 |
-
"--wd",
|
405 |
-
"--weight-decay",
|
406 |
-
default=1e-4,
|
407 |
-
type=float,
|
408 |
-
metavar="W",
|
409 |
-
help="weight decay (default: 1e-4)",
|
410 |
-
dest="weight_decay",
|
411 |
-
)
|
412 |
-
parser.add_argument(
|
413 |
-
"--norm-weight-decay",
|
414 |
-
default=None,
|
415 |
-
type=float,
|
416 |
-
help="weight decay for Normalization layers (default: None, same value as --wd)",
|
417 |
-
)
|
418 |
-
parser.add_argument(
|
419 |
-
"--bias-weight-decay",
|
420 |
-
default=None,
|
421 |
-
type=float,
|
422 |
-
help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
|
423 |
-
)
|
424 |
-
parser.add_argument(
|
425 |
-
"--transformer-embedding-decay",
|
426 |
-
default=None,
|
427 |
-
type=float,
|
428 |
-
help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
|
429 |
-
)
|
430 |
-
parser.add_argument(
|
431 |
-
"--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
|
432 |
-
)
|
433 |
-
parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
|
434 |
-
parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
|
435 |
-
parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
|
436 |
-
parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
|
437 |
-
parser.add_argument(
|
438 |
-
"--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
|
439 |
-
)
|
440 |
-
parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
|
441 |
-
parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
|
442 |
-
parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
|
443 |
-
parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
|
444 |
-
parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
|
445 |
-
parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
|
446 |
-
parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
|
447 |
-
parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
|
448 |
-
parser.add_argument(
|
449 |
-
"--cache-dataset",
|
450 |
-
dest="cache_dataset",
|
451 |
-
help="Cache the datasets for quicker initialization. It also serializes the transforms",
|
452 |
-
action="store_true",
|
453 |
-
)
|
454 |
-
parser.add_argument(
|
455 |
-
"--sync-bn",
|
456 |
-
dest="sync_bn",
|
457 |
-
help="Use sync batch norm",
|
458 |
-
action="store_true",
|
459 |
-
)
|
460 |
-
parser.add_argument(
|
461 |
-
"--test-only",
|
462 |
-
dest="test_only",
|
463 |
-
help="Only test the model",
|
464 |
-
action="store_true",
|
465 |
-
)
|
466 |
-
parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
|
467 |
-
parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy")
|
468 |
-
parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy")
|
469 |
-
parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
|
470 |
-
|
471 |
-
# Mixed precision training parameters
|
472 |
-
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
|
473 |
-
|
474 |
-
# distributed training parameters
|
475 |
-
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
|
476 |
-
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
|
477 |
-
parser.add_argument(
|
478 |
-
"--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
|
479 |
-
)
|
480 |
-
parser.add_argument(
|
481 |
-
"--model-ema-steps",
|
482 |
-
type=int,
|
483 |
-
default=32,
|
484 |
-
help="the number of iterations that controls how often to update the EMA model (default: 32)",
|
485 |
-
)
|
486 |
-
parser.add_argument(
|
487 |
-
"--model-ema-decay",
|
488 |
-
type=float,
|
489 |
-
default=0.99998,
|
490 |
-
help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
|
491 |
-
)
|
492 |
-
parser.add_argument(
|
493 |
-
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
|
494 |
-
)
|
495 |
-
parser.add_argument(
|
496 |
-
"--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
|
497 |
-
)
|
498 |
-
parser.add_argument(
|
499 |
-
"--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
|
500 |
-
)
|
501 |
-
parser.add_argument(
|
502 |
-
"--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
|
503 |
-
)
|
504 |
-
parser.add_argument(
|
505 |
-
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
|
506 |
-
)
|
507 |
-
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
|
508 |
-
parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
|
509 |
-
parser.add_argument(
|
510 |
-
"--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
|
511 |
-
)
|
512 |
-
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
|
513 |
-
|
514 |
-
parser.add_argument("--apply-trp", action="store_true", help="enable applying trp")
|
515 |
-
parser.add_argument("--trp-depths", nargs="+", type=int, help="number of layers for trp block")
|
516 |
-
parser.add_argument("--trp-planes", default=1024, type=int, help="channels of the hidden state")
|
517 |
-
parser.add_argument("--trp-lambdas", nargs="+", type=float, help="trp lambdas")
|
518 |
-
|
519 |
-
return parser
|
520 |
-
|
521 |
-
|
522 |
-
if __name__ == "__main__":
|
523 |
-
args = get_args_parser().parse_args()
|
524 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
neural-archicture-search/train_quantization.py
DELETED
@@ -1,265 +0,0 @@
|
|
1 |
-
import copy
|
2 |
-
import datetime
|
3 |
-
import os
|
4 |
-
import time
|
5 |
-
|
6 |
-
import torch
|
7 |
-
import torch.ao.quantization
|
8 |
-
import torch.utils.data
|
9 |
-
import torchvision
|
10 |
-
import utils
|
11 |
-
from torch import nn
|
12 |
-
from train import evaluate, load_data, train_one_epoch
|
13 |
-
|
14 |
-
|
15 |
-
def main(args):
|
16 |
-
if args.output_dir:
|
17 |
-
utils.mkdir(args.output_dir)
|
18 |
-
|
19 |
-
utils.init_distributed_mode(args)
|
20 |
-
print(args)
|
21 |
-
|
22 |
-
if args.post_training_quantize and args.distributed:
|
23 |
-
raise RuntimeError("Post training quantization example should not be performed on distributed mode")
|
24 |
-
|
25 |
-
# Set backend engine to ensure that quantized model runs on the correct kernels
|
26 |
-
if args.backend not in torch.backends.quantized.supported_engines:
|
27 |
-
raise RuntimeError("Quantized backend not supported: " + str(args.backend))
|
28 |
-
torch.backends.quantized.engine = args.backend
|
29 |
-
|
30 |
-
device = torch.device(args.device)
|
31 |
-
torch.backends.cudnn.benchmark = True
|
32 |
-
|
33 |
-
# Data loading code
|
34 |
-
print("Loading data")
|
35 |
-
train_dir = os.path.join(args.data_path, "train")
|
36 |
-
val_dir = os.path.join(args.data_path, "val")
|
37 |
-
|
38 |
-
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
|
39 |
-
data_loader = torch.utils.data.DataLoader(
|
40 |
-
dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True
|
41 |
-
)
|
42 |
-
|
43 |
-
data_loader_test = torch.utils.data.DataLoader(
|
44 |
-
dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
|
45 |
-
)
|
46 |
-
|
47 |
-
print("Creating model", args.model)
|
48 |
-
# when training quantized models, we always start from a pre-trained fp32 reference model
|
49 |
-
prefix = "quantized_"
|
50 |
-
model_name = args.model
|
51 |
-
if not model_name.startswith(prefix):
|
52 |
-
model_name = prefix + model_name
|
53 |
-
model = torchvision.models.get_model(model_name, weights=args.weights, quantize=args.test_only)
|
54 |
-
model.to(device)
|
55 |
-
|
56 |
-
if not (args.test_only or args.post_training_quantize):
|
57 |
-
model.fuse_model(is_qat=True)
|
58 |
-
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
|
59 |
-
torch.ao.quantization.prepare_qat(model, inplace=True)
|
60 |
-
|
61 |
-
if args.distributed and args.sync_bn:
|
62 |
-
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
63 |
-
|
64 |
-
optimizer = torch.optim.SGD(
|
65 |
-
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
|
66 |
-
)
|
67 |
-
|
68 |
-
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
|
69 |
-
|
70 |
-
criterion = nn.CrossEntropyLoss()
|
71 |
-
model_without_ddp = model
|
72 |
-
if args.distributed:
|
73 |
-
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
74 |
-
model_without_ddp = model.module
|
75 |
-
|
76 |
-
if args.resume:
|
77 |
-
checkpoint = torch.load(args.resume, map_location="cpu")
|
78 |
-
model_without_ddp.load_state_dict(checkpoint["model"])
|
79 |
-
optimizer.load_state_dict(checkpoint["optimizer"])
|
80 |
-
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
|
81 |
-
args.start_epoch = checkpoint["epoch"] + 1
|
82 |
-
|
83 |
-
if args.post_training_quantize:
|
84 |
-
# perform calibration on a subset of the training dataset
|
85 |
-
# for that, create a subset of the training dataset
|
86 |
-
ds = torch.utils.data.Subset(dataset, indices=list(range(args.batch_size * args.num_calibration_batches)))
|
87 |
-
data_loader_calibration = torch.utils.data.DataLoader(
|
88 |
-
ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
|
89 |
-
)
|
90 |
-
model.eval()
|
91 |
-
model.fuse_model(is_qat=False)
|
92 |
-
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
|
93 |
-
torch.ao.quantization.prepare(model, inplace=True)
|
94 |
-
# Calibrate first
|
95 |
-
print("Calibrating")
|
96 |
-
evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
|
97 |
-
torch.ao.quantization.convert(model, inplace=True)
|
98 |
-
if args.output_dir:
|
99 |
-
print("Saving quantized model")
|
100 |
-
if utils.is_main_process():
|
101 |
-
torch.save(model.state_dict(), os.path.join(args.output_dir, "quantized_post_train_model.pth"))
|
102 |
-
print("Evaluating post-training quantized model")
|
103 |
-
evaluate(model, criterion, data_loader_test, device=device)
|
104 |
-
return
|
105 |
-
|
106 |
-
if args.test_only:
|
107 |
-
evaluate(model, criterion, data_loader_test, device=device)
|
108 |
-
return
|
109 |
-
|
110 |
-
model.apply(torch.ao.quantization.enable_observer)
|
111 |
-
model.apply(torch.ao.quantization.enable_fake_quant)
|
112 |
-
start_time = time.time()
|
113 |
-
for epoch in range(args.start_epoch, args.epochs):
|
114 |
-
if args.distributed:
|
115 |
-
train_sampler.set_epoch(epoch)
|
116 |
-
print("Starting training for epoch", epoch)
|
117 |
-
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args)
|
118 |
-
lr_scheduler.step()
|
119 |
-
with torch.inference_mode():
|
120 |
-
if epoch >= args.num_observer_update_epochs:
|
121 |
-
print("Disabling observer for subseq epochs, epoch = ", epoch)
|
122 |
-
model.apply(torch.ao.quantization.disable_observer)
|
123 |
-
if epoch >= args.num_batch_norm_update_epochs:
|
124 |
-
print("Freezing BN for subseq epochs, epoch = ", epoch)
|
125 |
-
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
|
126 |
-
print("Evaluate QAT model")
|
127 |
-
|
128 |
-
evaluate(model, criterion, data_loader_test, device=device, log_suffix="QAT")
|
129 |
-
quantized_eval_model = copy.deepcopy(model_without_ddp)
|
130 |
-
quantized_eval_model.eval()
|
131 |
-
quantized_eval_model.to(torch.device("cpu"))
|
132 |
-
torch.ao.quantization.convert(quantized_eval_model, inplace=True)
|
133 |
-
|
134 |
-
print("Evaluate Quantized model")
|
135 |
-
evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
|
136 |
-
|
137 |
-
model.train()
|
138 |
-
|
139 |
-
if args.output_dir:
|
140 |
-
checkpoint = {
|
141 |
-
"model": model_without_ddp.state_dict(),
|
142 |
-
"eval_model": quantized_eval_model.state_dict(),
|
143 |
-
"optimizer": optimizer.state_dict(),
|
144 |
-
"lr_scheduler": lr_scheduler.state_dict(),
|
145 |
-
"epoch": epoch,
|
146 |
-
"args": args,
|
147 |
-
}
|
148 |
-
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
|
149 |
-
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
|
150 |
-
print("Saving models after epoch ", epoch)
|
151 |
-
|
152 |
-
total_time = time.time() - start_time
|
153 |
-
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
154 |
-
print(f"Training time {total_time_str}")
|
155 |
-
|
156 |
-
|
157 |
-
def get_args_parser(add_help=True):
|
158 |
-
import argparse
|
159 |
-
|
160 |
-
parser = argparse.ArgumentParser(description="PyTorch Quantized Classification Training", add_help=add_help)
|
161 |
-
|
162 |
-
parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
|
163 |
-
parser.add_argument("--model", default="mobilenet_v2", type=str, help="model name")
|
164 |
-
parser.add_argument("--backend", default="qnnpack", type=str, help="fbgemm or qnnpack")
|
165 |
-
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
|
166 |
-
|
167 |
-
parser.add_argument(
|
168 |
-
"-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
|
169 |
-
)
|
170 |
-
parser.add_argument("--eval-batch-size", default=128, type=int, help="batch size for evaluation")
|
171 |
-
parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
|
172 |
-
parser.add_argument(
|
173 |
-
"--num-observer-update-epochs",
|
174 |
-
default=4,
|
175 |
-
type=int,
|
176 |
-
metavar="N",
|
177 |
-
help="number of total epochs to update observers",
|
178 |
-
)
|
179 |
-
parser.add_argument(
|
180 |
-
"--num-batch-norm-update-epochs",
|
181 |
-
default=3,
|
182 |
-
type=int,
|
183 |
-
metavar="N",
|
184 |
-
help="number of total epochs to update batch norm stats",
|
185 |
-
)
|
186 |
-
parser.add_argument(
|
187 |
-
"--num-calibration-batches",
|
188 |
-
default=32,
|
189 |
-
type=int,
|
190 |
-
metavar="N",
|
191 |
-
help="number of batches of training set for \
|
192 |
-
observer calibration ",
|
193 |
-
)
|
194 |
-
|
195 |
-
parser.add_argument(
|
196 |
-
"-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
|
197 |
-
)
|
198 |
-
parser.add_argument("--lr", default=0.0001, type=float, help="initial learning rate")
|
199 |
-
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
|
200 |
-
parser.add_argument(
|
201 |
-
"--wd",
|
202 |
-
"--weight-decay",
|
203 |
-
default=1e-4,
|
204 |
-
type=float,
|
205 |
-
metavar="W",
|
206 |
-
help="weight decay (default: 1e-4)",
|
207 |
-
dest="weight_decay",
|
208 |
-
)
|
209 |
-
parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
|
210 |
-
parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
|
211 |
-
parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
|
212 |
-
parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
|
213 |
-
parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
|
214 |
-
parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
|
215 |
-
parser.add_argument(
|
216 |
-
"--cache-dataset",
|
217 |
-
dest="cache_dataset",
|
218 |
-
help="Cache the datasets for quicker initialization. \
|
219 |
-
It also serializes the transforms",
|
220 |
-
action="store_true",
|
221 |
-
)
|
222 |
-
parser.add_argument(
|
223 |
-
"--sync-bn",
|
224 |
-
dest="sync_bn",
|
225 |
-
help="Use sync batch norm",
|
226 |
-
action="store_true",
|
227 |
-
)
|
228 |
-
parser.add_argument(
|
229 |
-
"--test-only",
|
230 |
-
dest="test_only",
|
231 |
-
help="Only test the model",
|
232 |
-
action="store_true",
|
233 |
-
)
|
234 |
-
parser.add_argument(
|
235 |
-
"--post-training-quantize",
|
236 |
-
dest="post_training_quantize",
|
237 |
-
help="Post training quantize the model",
|
238 |
-
action="store_true",
|
239 |
-
)
|
240 |
-
|
241 |
-
# distributed training parameters
|
242 |
-
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
|
243 |
-
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
|
244 |
-
|
245 |
-
parser.add_argument(
|
246 |
-
"--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
|
247 |
-
)
|
248 |
-
parser.add_argument(
|
249 |
-
"--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
|
250 |
-
)
|
251 |
-
parser.add_argument(
|
252 |
-
"--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
|
253 |
-
)
|
254 |
-
parser.add_argument(
|
255 |
-
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
|
256 |
-
)
|
257 |
-
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
|
258 |
-
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
|
259 |
-
|
260 |
-
return parser
|
261 |
-
|
262 |
-
|
263 |
-
if __name__ == "__main__":
|
264 |
-
args = get_args_parser().parse_args()
|
265 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
neural-archicture-search/transforms.py
DELETED
@@ -1,183 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
from typing import Tuple
|
3 |
-
|
4 |
-
import torch
|
5 |
-
from torch import Tensor
|
6 |
-
from torchvision.transforms import functional as F
|
7 |
-
|
8 |
-
|
9 |
-
class RandomMixup(torch.nn.Module):
|
10 |
-
"""Randomly apply Mixup to the provided batch and targets.
|
11 |
-
The class implements the data augmentations as described in the paper
|
12 |
-
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
|
13 |
-
|
14 |
-
Args:
|
15 |
-
num_classes (int): number of classes used for one-hot encoding.
|
16 |
-
p (float): probability of the batch being transformed. Default value is 0.5.
|
17 |
-
alpha (float): hyperparameter of the Beta distribution used for mixup.
|
18 |
-
Default value is 1.0.
|
19 |
-
inplace (bool): boolean to make this transform inplace. Default set to False.
|
20 |
-
"""
|
21 |
-
|
22 |
-
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
|
23 |
-
super().__init__()
|
24 |
-
|
25 |
-
if num_classes < 1:
|
26 |
-
raise ValueError(
|
27 |
-
f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}"
|
28 |
-
)
|
29 |
-
|
30 |
-
if alpha <= 0:
|
31 |
-
raise ValueError("Alpha param can't be zero.")
|
32 |
-
|
33 |
-
self.num_classes = num_classes
|
34 |
-
self.p = p
|
35 |
-
self.alpha = alpha
|
36 |
-
self.inplace = inplace
|
37 |
-
|
38 |
-
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
|
39 |
-
"""
|
40 |
-
Args:
|
41 |
-
batch (Tensor): Float tensor of size (B, C, H, W)
|
42 |
-
target (Tensor): Integer tensor of size (B, )
|
43 |
-
|
44 |
-
Returns:
|
45 |
-
Tensor: Randomly transformed batch.
|
46 |
-
"""
|
47 |
-
if batch.ndim != 4:
|
48 |
-
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
|
49 |
-
if target.ndim != 1:
|
50 |
-
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
|
51 |
-
if not batch.is_floating_point():
|
52 |
-
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
|
53 |
-
if target.dtype != torch.int64:
|
54 |
-
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
|
55 |
-
|
56 |
-
if not self.inplace:
|
57 |
-
batch = batch.clone()
|
58 |
-
target = target.clone()
|
59 |
-
|
60 |
-
if target.ndim == 1:
|
61 |
-
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
|
62 |
-
|
63 |
-
if torch.rand(1).item() >= self.p:
|
64 |
-
return batch, target
|
65 |
-
|
66 |
-
# It's faster to roll the batch by one instead of shuffling it to create image pairs
|
67 |
-
batch_rolled = batch.roll(1, 0)
|
68 |
-
target_rolled = target.roll(1, 0)
|
69 |
-
|
70 |
-
# Implemented as on mixup paper, page 3.
|
71 |
-
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
|
72 |
-
batch_rolled.mul_(1.0 - lambda_param)
|
73 |
-
batch.mul_(lambda_param).add_(batch_rolled)
|
74 |
-
|
75 |
-
target_rolled.mul_(1.0 - lambda_param)
|
76 |
-
target.mul_(lambda_param).add_(target_rolled)
|
77 |
-
|
78 |
-
return batch, target
|
79 |
-
|
80 |
-
def __repr__(self) -> str:
|
81 |
-
s = (
|
82 |
-
f"{self.__class__.__name__}("
|
83 |
-
f"num_classes={self.num_classes}"
|
84 |
-
f", p={self.p}"
|
85 |
-
f", alpha={self.alpha}"
|
86 |
-
f", inplace={self.inplace}"
|
87 |
-
f")"
|
88 |
-
)
|
89 |
-
return s
|
90 |
-
|
91 |
-
|
92 |
-
class RandomCutmix(torch.nn.Module):
|
93 |
-
"""Randomly apply Cutmix to the provided batch and targets.
|
94 |
-
The class implements the data augmentations as described in the paper
|
95 |
-
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
|
96 |
-
<https://arxiv.org/abs/1905.04899>`_.
|
97 |
-
|
98 |
-
Args:
|
99 |
-
num_classes (int): number of classes used for one-hot encoding.
|
100 |
-
p (float): probability of the batch being transformed. Default value is 0.5.
|
101 |
-
alpha (float): hyperparameter of the Beta distribution used for cutmix.
|
102 |
-
Default value is 1.0.
|
103 |
-
inplace (bool): boolean to make this transform inplace. Default set to False.
|
104 |
-
"""
|
105 |
-
|
106 |
-
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
|
107 |
-
super().__init__()
|
108 |
-
if num_classes < 1:
|
109 |
-
raise ValueError("Please provide a valid positive value for the num_classes.")
|
110 |
-
if alpha <= 0:
|
111 |
-
raise ValueError("Alpha param can't be zero.")
|
112 |
-
|
113 |
-
self.num_classes = num_classes
|
114 |
-
self.p = p
|
115 |
-
self.alpha = alpha
|
116 |
-
self.inplace = inplace
|
117 |
-
|
118 |
-
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
|
119 |
-
"""
|
120 |
-
Args:
|
121 |
-
batch (Tensor): Float tensor of size (B, C, H, W)
|
122 |
-
target (Tensor): Integer tensor of size (B, )
|
123 |
-
|
124 |
-
Returns:
|
125 |
-
Tensor: Randomly transformed batch.
|
126 |
-
"""
|
127 |
-
if batch.ndim != 4:
|
128 |
-
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
|
129 |
-
if target.ndim != 1:
|
130 |
-
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
|
131 |
-
if not batch.is_floating_point():
|
132 |
-
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
|
133 |
-
if target.dtype != torch.int64:
|
134 |
-
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
|
135 |
-
|
136 |
-
if not self.inplace:
|
137 |
-
batch = batch.clone()
|
138 |
-
target = target.clone()
|
139 |
-
|
140 |
-
if target.ndim == 1:
|
141 |
-
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
|
142 |
-
|
143 |
-
if torch.rand(1).item() >= self.p:
|
144 |
-
return batch, target
|
145 |
-
|
146 |
-
# It's faster to roll the batch by one instead of shuffling it to create image pairs
|
147 |
-
batch_rolled = batch.roll(1, 0)
|
148 |
-
target_rolled = target.roll(1, 0)
|
149 |
-
|
150 |
-
# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
|
151 |
-
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
|
152 |
-
_, H, W = F.get_dimensions(batch)
|
153 |
-
|
154 |
-
r_x = torch.randint(W, (1,))
|
155 |
-
r_y = torch.randint(H, (1,))
|
156 |
-
|
157 |
-
r = 0.5 * math.sqrt(1.0 - lambda_param)
|
158 |
-
r_w_half = int(r * W)
|
159 |
-
r_h_half = int(r * H)
|
160 |
-
|
161 |
-
x1 = int(torch.clamp(r_x - r_w_half, min=0))
|
162 |
-
y1 = int(torch.clamp(r_y - r_h_half, min=0))
|
163 |
-
x2 = int(torch.clamp(r_x + r_w_half, max=W))
|
164 |
-
y2 = int(torch.clamp(r_y + r_h_half, max=H))
|
165 |
-
|
166 |
-
batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
|
167 |
-
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
|
168 |
-
|
169 |
-
target_rolled.mul_(1.0 - lambda_param)
|
170 |
-
target.mul_(lambda_param).add_(target_rolled)
|
171 |
-
|
172 |
-
return batch, target
|
173 |
-
|
174 |
-
def __repr__(self) -> str:
|
175 |
-
s = (
|
176 |
-
f"{self.__class__.__name__}("
|
177 |
-
f"num_classes={self.num_classes}"
|
178 |
-
f", p={self.p}"
|
179 |
-
f", alpha={self.alpha}"
|
180 |
-
f", inplace={self.inplace}"
|
181 |
-
f")"
|
182 |
-
)
|
183 |
-
return s
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
neural-archicture-search/trplib.py
DELETED
@@ -1,127 +0,0 @@
|
|
1 |
-
import types
|
2 |
-
from typing import List, Callable
|
3 |
-
|
4 |
-
import torch
|
5 |
-
from torch import nn, Tensor
|
6 |
-
from torch.nn import functional as F
|
7 |
-
from torchvision.models.resnet import BasicBlock
|
8 |
-
|
9 |
-
|
10 |
-
def trp_criterion(trp_blocks: nn.ModuleList, shared_head: Callable, criterion: Callable, lambdas: List[float], hidden_state: Tensor, logits: Tensor, targets: Tensor, loss_normalization=False):
|
11 |
-
losses, rewards = criterion(logits, targets)
|
12 |
-
returns = torch.ones_like(rewards, dtype=torch.float32, device=rewards.device)
|
13 |
-
if loss_normalization:
|
14 |
-
coeff = torch.mean(losses).detach()
|
15 |
-
|
16 |
-
embeds = [hidden_state]
|
17 |
-
predictions = []
|
18 |
-
for k, w in enumerate(lambdas):
|
19 |
-
embeds.append(trp_blocks[k](embeds[-1]))
|
20 |
-
predictions.append(shared_head(embeds[-1]))
|
21 |
-
returns = returns + w * rewards
|
22 |
-
replica_losses, rewards = criterion(predictions[-1], targets, rewards)
|
23 |
-
losses = losses + replica_losses
|
24 |
-
loss = torch.mean(losses * returns)
|
25 |
-
|
26 |
-
if loss_normalization:
|
27 |
-
with torch.no_grad():
|
28 |
-
coeff = torch.exp(coeff) / torch.exp(loss.detach())
|
29 |
-
loss = coeff * loss
|
30 |
-
|
31 |
-
return loss
|
32 |
-
|
33 |
-
|
34 |
-
class TPBlock(nn.Module):
|
35 |
-
def __init__(self, depths: int, inplanes: int, planes: int):
|
36 |
-
super(TPBlock, self).__init__()
|
37 |
-
|
38 |
-
blocks = [BasicBlock(inplanes=inplanes, planes=planes) for _ in range(depths)]
|
39 |
-
self.blocks = nn.Sequential(*blocks)
|
40 |
-
for name, param in self.blocks.named_parameters():
|
41 |
-
if 'conv' in name:
|
42 |
-
nn.init.zeros_(param) # Initialize weights
|
43 |
-
elif 'downsample' in name:
|
44 |
-
nn.init.zeros_(param) # Initialize biases
|
45 |
-
|
46 |
-
def forward(self, x):
|
47 |
-
return self.blocks(x)
|
48 |
-
|
49 |
-
|
50 |
-
class ResNetConfig:
|
51 |
-
@staticmethod
|
52 |
-
def gen_criterion(label_smoothing=0.0, top_k=1):
|
53 |
-
def func(input, target, mask=None):
|
54 |
-
"""
|
55 |
-
Args:
|
56 |
-
input (Tensor): Input tensor of shape [B, C].
|
57 |
-
target (Tensor): Target labels of shape [B] or [B, C].
|
58 |
-
|
59 |
-
Returns:
|
60 |
-
loss (Tensor): Scalar tensor representing the loss.
|
61 |
-
mask (Tensor): Boolean mask tensor of shape [B].
|
62 |
-
"""
|
63 |
-
label = torch.argmax(target, dim=1) if label_smoothing > 0.0 else target
|
64 |
-
|
65 |
-
unmasked_loss = F.cross_entropy(input, label, reduction="none", label_smoothing=label_smoothing)
|
66 |
-
if mask is None:
|
67 |
-
mask = torch.ones_like(unmasked_loss, dtype=torch.float32, device=target.device)
|
68 |
-
loss = torch.sum(mask * unmasked_loss) / (torch.sum(mask) + 1e-6)
|
69 |
-
|
70 |
-
with torch.no_grad():
|
71 |
-
topk_values, topk_indices = torch.topk(input, top_k, dim=-1)
|
72 |
-
mask = mask * torch.eq(topk_indices, label[:, None]).any(dim=-1).to(input.dtype)
|
73 |
-
|
74 |
-
return loss, mask
|
75 |
-
return func
|
76 |
-
|
77 |
-
@staticmethod
|
78 |
-
def gen_shared_head(self):
|
79 |
-
def func(x):
|
80 |
-
"""
|
81 |
-
Args:
|
82 |
-
x (Tensor): Hidden States tensor of shape [B, C, H, Whidden_units].
|
83 |
-
|
84 |
-
Returns:
|
85 |
-
logits (Tensor): Logits tensor of shape [B, C].
|
86 |
-
"""
|
87 |
-
x = self.layer4(x)
|
88 |
-
x = self.avgpool(x)
|
89 |
-
x = torch.flatten(x, 1)
|
90 |
-
logits = self.fc(x)
|
91 |
-
return logits
|
92 |
-
return func
|
93 |
-
|
94 |
-
@staticmethod
|
95 |
-
def gen_forward(lambdas, loss_normalization=True, label_smoothing=0.0, top_k=1):
|
96 |
-
def func(self, x: Tensor, targets=None) -> Tensor:
|
97 |
-
x = self.conv1(x)
|
98 |
-
x = self.bn1(x)
|
99 |
-
x = self.relu(x)
|
100 |
-
x = self.maxpool(x)
|
101 |
-
|
102 |
-
x = self.layer1(x)
|
103 |
-
x = self.layer2(x)
|
104 |
-
hidden_states = self.layer3(x)
|
105 |
-
x = self.layer4(hidden_states)
|
106 |
-
x = self.avgpool(x)
|
107 |
-
x = torch.flatten(x, 1)
|
108 |
-
logits = self.fc(x)
|
109 |
-
|
110 |
-
if self.training:
|
111 |
-
shared_head = ResNetConfig.gen_shared_head(self)
|
112 |
-
criterion = ResNetConfig.gen_criterion(label_smoothing=label_smoothing, top_k=top_k)
|
113 |
-
|
114 |
-
loss = trp_criterion(self.trp_blocks, shared_head, criterion, lambdas, hidden_states, logits, targets, loss_normalization=loss_normalization)
|
115 |
-
|
116 |
-
return logits, loss
|
117 |
-
|
118 |
-
return logits
|
119 |
-
|
120 |
-
return func
|
121 |
-
|
122 |
-
|
123 |
-
def apply_trp(model, depths: List[int], planes: int, lambdas: List[float], **kwargs):
|
124 |
-
print("✅ Applying TRP to ResNet for Image Classification...")
|
125 |
-
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, inplanes=planes, planes=planes) for d in depths])
|
126 |
-
model.forward = types.MethodType(ResNetConfig.gen_forward(lambdas, True, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
|
127 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
neural-archicture-search/utils.py
DELETED
@@ -1,465 +0,0 @@
|
|
1 |
-
import copy
|
2 |
-
import datetime
|
3 |
-
import errno
|
4 |
-
import hashlib
|
5 |
-
import os
|
6 |
-
import time
|
7 |
-
from collections import defaultdict, deque, OrderedDict
|
8 |
-
from typing import List, Optional, Tuple
|
9 |
-
|
10 |
-
import torch
|
11 |
-
import torch.distributed as dist
|
12 |
-
|
13 |
-
|
14 |
-
class SmoothedValue:
|
15 |
-
"""Track a series of values and provide access to smoothed values over a
|
16 |
-
window or the global series average.
|
17 |
-
"""
|
18 |
-
|
19 |
-
def __init__(self, window_size=20, fmt=None):
|
20 |
-
if fmt is None:
|
21 |
-
fmt = "{median:.4f} ({global_avg:.4f})"
|
22 |
-
self.deque = deque(maxlen=window_size)
|
23 |
-
self.total = 0.0
|
24 |
-
self.count = 0
|
25 |
-
self.fmt = fmt
|
26 |
-
|
27 |
-
def update(self, value, n=1):
|
28 |
-
self.deque.append(value)
|
29 |
-
self.count += n
|
30 |
-
self.total += value * n
|
31 |
-
|
32 |
-
def synchronize_between_processes(self):
|
33 |
-
"""
|
34 |
-
Warning: does not synchronize the deque!
|
35 |
-
"""
|
36 |
-
t = reduce_across_processes([self.count, self.total])
|
37 |
-
t = t.tolist()
|
38 |
-
self.count = int(t[0])
|
39 |
-
self.total = t[1]
|
40 |
-
|
41 |
-
@property
|
42 |
-
def median(self):
|
43 |
-
d = torch.tensor(list(self.deque))
|
44 |
-
return d.median().item()
|
45 |
-
|
46 |
-
@property
|
47 |
-
def avg(self):
|
48 |
-
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
49 |
-
return d.mean().item()
|
50 |
-
|
51 |
-
@property
|
52 |
-
def global_avg(self):
|
53 |
-
return self.total / self.count
|
54 |
-
|
55 |
-
@property
|
56 |
-
def max(self):
|
57 |
-
return max(self.deque)
|
58 |
-
|
59 |
-
@property
|
60 |
-
def value(self):
|
61 |
-
return self.deque[-1]
|
62 |
-
|
63 |
-
def __str__(self):
|
64 |
-
return self.fmt.format(
|
65 |
-
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
|
66 |
-
)
|
67 |
-
|
68 |
-
|
69 |
-
class MetricLogger:
|
70 |
-
def __init__(self, delimiter="\t"):
|
71 |
-
self.meters = defaultdict(SmoothedValue)
|
72 |
-
self.delimiter = delimiter
|
73 |
-
|
74 |
-
def update(self, **kwargs):
|
75 |
-
for k, v in kwargs.items():
|
76 |
-
if isinstance(v, torch.Tensor):
|
77 |
-
v = v.item()
|
78 |
-
assert isinstance(v, (float, int))
|
79 |
-
self.meters[k].update(v)
|
80 |
-
|
81 |
-
def __getattr__(self, attr):
|
82 |
-
if attr in self.meters:
|
83 |
-
return self.meters[attr]
|
84 |
-
if attr in self.__dict__:
|
85 |
-
return self.__dict__[attr]
|
86 |
-
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
|
87 |
-
|
88 |
-
def __str__(self):
|
89 |
-
loss_str = []
|
90 |
-
for name, meter in self.meters.items():
|
91 |
-
loss_str.append(f"{name}: {str(meter)}")
|
92 |
-
return self.delimiter.join(loss_str)
|
93 |
-
|
94 |
-
def synchronize_between_processes(self):
|
95 |
-
for meter in self.meters.values():
|
96 |
-
meter.synchronize_between_processes()
|
97 |
-
|
98 |
-
def add_meter(self, name, meter):
|
99 |
-
self.meters[name] = meter
|
100 |
-
|
101 |
-
def log_every(self, iterable, print_freq, header=None):
|
102 |
-
i = 0
|
103 |
-
if not header:
|
104 |
-
header = ""
|
105 |
-
start_time = time.time()
|
106 |
-
end = time.time()
|
107 |
-
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
108 |
-
data_time = SmoothedValue(fmt="{avg:.4f}")
|
109 |
-
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
110 |
-
if torch.cuda.is_available():
|
111 |
-
log_msg = self.delimiter.join(
|
112 |
-
[
|
113 |
-
header,
|
114 |
-
"[{0" + space_fmt + "}/{1}]",
|
115 |
-
"eta: {eta}",
|
116 |
-
"{meters}",
|
117 |
-
"time: {time}",
|
118 |
-
"data: {data}",
|
119 |
-
"max mem: {memory:.0f}",
|
120 |
-
]
|
121 |
-
)
|
122 |
-
else:
|
123 |
-
log_msg = self.delimiter.join(
|
124 |
-
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
|
125 |
-
)
|
126 |
-
MB = 1024.0 * 1024.0
|
127 |
-
for obj in iterable:
|
128 |
-
data_time.update(time.time() - end)
|
129 |
-
yield obj
|
130 |
-
iter_time.update(time.time() - end)
|
131 |
-
if i % print_freq == 0:
|
132 |
-
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
133 |
-
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
134 |
-
if torch.cuda.is_available():
|
135 |
-
print(
|
136 |
-
log_msg.format(
|
137 |
-
i,
|
138 |
-
len(iterable),
|
139 |
-
eta=eta_string,
|
140 |
-
meters=str(self),
|
141 |
-
time=str(iter_time),
|
142 |
-
data=str(data_time),
|
143 |
-
memory=torch.cuda.max_memory_allocated() / MB,
|
144 |
-
)
|
145 |
-
)
|
146 |
-
else:
|
147 |
-
print(
|
148 |
-
log_msg.format(
|
149 |
-
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
|
150 |
-
)
|
151 |
-
)
|
152 |
-
i += 1
|
153 |
-
end = time.time()
|
154 |
-
total_time = time.time() - start_time
|
155 |
-
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
156 |
-
print(f"{header} Total time: {total_time_str}")
|
157 |
-
|
158 |
-
|
159 |
-
class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
|
160 |
-
"""Maintains moving averages of model parameters using an exponential decay.
|
161 |
-
``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
|
162 |
-
`torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
|
163 |
-
is used to compute the EMA.
|
164 |
-
"""
|
165 |
-
|
166 |
-
def __init__(self, model, decay, device="cpu"):
|
167 |
-
def ema_avg(avg_model_param, model_param, num_averaged):
|
168 |
-
return decay * avg_model_param + (1 - decay) * model_param
|
169 |
-
|
170 |
-
super().__init__(model, device, ema_avg, use_buffers=True)
|
171 |
-
|
172 |
-
|
173 |
-
def accuracy(output, target, topk=(1,)):
|
174 |
-
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
175 |
-
with torch.inference_mode():
|
176 |
-
maxk = max(topk)
|
177 |
-
batch_size = target.size(0)
|
178 |
-
if target.ndim == 2:
|
179 |
-
target = target.max(dim=1)[1]
|
180 |
-
|
181 |
-
_, pred = output.topk(maxk, 1, True, True)
|
182 |
-
pred = pred.t()
|
183 |
-
correct = pred.eq(target[None])
|
184 |
-
|
185 |
-
res = []
|
186 |
-
for k in topk:
|
187 |
-
correct_k = correct[:k].flatten().sum(dtype=torch.float32)
|
188 |
-
res.append(correct_k * (100.0 / batch_size))
|
189 |
-
return res
|
190 |
-
|
191 |
-
|
192 |
-
def mkdir(path):
|
193 |
-
try:
|
194 |
-
os.makedirs(path)
|
195 |
-
except OSError as e:
|
196 |
-
if e.errno != errno.EEXIST:
|
197 |
-
raise
|
198 |
-
|
199 |
-
|
200 |
-
def setup_for_distributed(is_master):
|
201 |
-
"""
|
202 |
-
This function disables printing when not in master process
|
203 |
-
"""
|
204 |
-
import builtins as __builtin__
|
205 |
-
|
206 |
-
builtin_print = __builtin__.print
|
207 |
-
|
208 |
-
def print(*args, **kwargs):
|
209 |
-
force = kwargs.pop("force", False)
|
210 |
-
if is_master or force:
|
211 |
-
builtin_print(*args, **kwargs)
|
212 |
-
|
213 |
-
__builtin__.print = print
|
214 |
-
|
215 |
-
|
216 |
-
def is_dist_avail_and_initialized():
|
217 |
-
if not dist.is_available():
|
218 |
-
return False
|
219 |
-
if not dist.is_initialized():
|
220 |
-
return False
|
221 |
-
return True
|
222 |
-
|
223 |
-
|
224 |
-
def get_world_size():
|
225 |
-
if not is_dist_avail_and_initialized():
|
226 |
-
return 1
|
227 |
-
return dist.get_world_size()
|
228 |
-
|
229 |
-
|
230 |
-
def get_rank():
|
231 |
-
if not is_dist_avail_and_initialized():
|
232 |
-
return 0
|
233 |
-
return dist.get_rank()
|
234 |
-
|
235 |
-
|
236 |
-
def is_main_process():
|
237 |
-
return get_rank() == 0
|
238 |
-
|
239 |
-
|
240 |
-
def save_on_master(*args, **kwargs):
|
241 |
-
if is_main_process():
|
242 |
-
torch.save(*args, **kwargs)
|
243 |
-
|
244 |
-
|
245 |
-
def init_distributed_mode(args):
|
246 |
-
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
247 |
-
args.rank = int(os.environ["RANK"])
|
248 |
-
args.world_size = int(os.environ["WORLD_SIZE"])
|
249 |
-
args.gpu = int(os.environ["LOCAL_RANK"])
|
250 |
-
elif "SLURM_PROCID" in os.environ:
|
251 |
-
args.rank = int(os.environ["SLURM_PROCID"])
|
252 |
-
args.gpu = args.rank % torch.cuda.device_count()
|
253 |
-
elif hasattr(args, "rank"):
|
254 |
-
pass
|
255 |
-
else:
|
256 |
-
print("Not using distributed mode")
|
257 |
-
args.distributed = False
|
258 |
-
return
|
259 |
-
|
260 |
-
args.distributed = True
|
261 |
-
|
262 |
-
torch.cuda.set_device(args.gpu)
|
263 |
-
args.dist_backend = "nccl"
|
264 |
-
print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
|
265 |
-
torch.distributed.init_process_group(
|
266 |
-
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
|
267 |
-
)
|
268 |
-
torch.distributed.barrier()
|
269 |
-
setup_for_distributed(args.rank == 0)
|
270 |
-
|
271 |
-
|
272 |
-
def average_checkpoints(inputs):
|
273 |
-
"""Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from:
|
274 |
-
https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16
|
275 |
-
|
276 |
-
Args:
|
277 |
-
inputs (List[str]): An iterable of string paths of checkpoints to load from.
|
278 |
-
Returns:
|
279 |
-
A dict of string keys mapping to various values. The 'model' key
|
280 |
-
from the returned dict should correspond to an OrderedDict mapping
|
281 |
-
string parameter names to torch Tensors.
|
282 |
-
"""
|
283 |
-
params_dict = OrderedDict()
|
284 |
-
params_keys = None
|
285 |
-
new_state = None
|
286 |
-
num_models = len(inputs)
|
287 |
-
for fpath in inputs:
|
288 |
-
with open(fpath, "rb") as f:
|
289 |
-
state = torch.load(
|
290 |
-
f,
|
291 |
-
map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")),
|
292 |
-
)
|
293 |
-
# Copies over the settings from the first checkpoint
|
294 |
-
if new_state is None:
|
295 |
-
new_state = state
|
296 |
-
model_params = state["model"]
|
297 |
-
model_params_keys = list(model_params.keys())
|
298 |
-
if params_keys is None:
|
299 |
-
params_keys = model_params_keys
|
300 |
-
elif params_keys != model_params_keys:
|
301 |
-
raise KeyError(
|
302 |
-
f"For checkpoint {f}, expected list of params: {params_keys}, but found: {model_params_keys}"
|
303 |
-
)
|
304 |
-
for k in params_keys:
|
305 |
-
p = model_params[k]
|
306 |
-
if isinstance(p, torch.HalfTensor):
|
307 |
-
p = p.float()
|
308 |
-
if k not in params_dict:
|
309 |
-
params_dict[k] = p.clone()
|
310 |
-
# NOTE: clone() is needed in case of p is a shared parameter
|
311 |
-
else:
|
312 |
-
params_dict[k] += p
|
313 |
-
averaged_params = OrderedDict()
|
314 |
-
for k, v in params_dict.items():
|
315 |
-
averaged_params[k] = v
|
316 |
-
if averaged_params[k].is_floating_point():
|
317 |
-
averaged_params[k].div_(num_models)
|
318 |
-
else:
|
319 |
-
averaged_params[k] //= num_models
|
320 |
-
new_state["model"] = averaged_params
|
321 |
-
return new_state
|
322 |
-
|
323 |
-
|
324 |
-
def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=True):
|
325 |
-
"""
|
326 |
-
This method can be used to prepare weights files for new models. It receives as
|
327 |
-
input a model architecture and a checkpoint from the training script and produces
|
328 |
-
a file with the weights ready for release.
|
329 |
-
|
330 |
-
Examples:
|
331 |
-
from torchvision import models as M
|
332 |
-
|
333 |
-
# Classification
|
334 |
-
model = M.mobilenet_v3_large(weights=None)
|
335 |
-
print(store_model_weights(model, './class.pth'))
|
336 |
-
|
337 |
-
# Quantized Classification
|
338 |
-
model = M.quantization.mobilenet_v3_large(weights=None, quantize=False)
|
339 |
-
model.fuse_model(is_qat=True)
|
340 |
-
model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
|
341 |
-
_ = torch.ao.quantization.prepare_qat(model, inplace=True)
|
342 |
-
print(store_model_weights(model, './qat.pth'))
|
343 |
-
|
344 |
-
# Object Detection
|
345 |
-
model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=None, weights_backbone=None)
|
346 |
-
print(store_model_weights(model, './obj.pth'))
|
347 |
-
|
348 |
-
# Segmentation
|
349 |
-
model = M.segmentation.deeplabv3_mobilenet_v3_large(weights=None, weights_backbone=None, aux_loss=True)
|
350 |
-
print(store_model_weights(model, './segm.pth', strict=False))
|
351 |
-
|
352 |
-
Args:
|
353 |
-
model (pytorch.nn.Module): The model on which the weights will be loaded for validation purposes.
|
354 |
-
checkpoint_path (str): The path of the checkpoint we will load.
|
355 |
-
checkpoint_key (str, optional): The key of the checkpoint where the model weights are stored.
|
356 |
-
Default: "model".
|
357 |
-
strict (bool): whether to strictly enforce that the keys
|
358 |
-
in :attr:`state_dict` match the keys returned by this module's
|
359 |
-
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
|
360 |
-
|
361 |
-
Returns:
|
362 |
-
output_path (str): The location where the weights are saved.
|
363 |
-
"""
|
364 |
-
# Store the new model next to the checkpoint_path
|
365 |
-
checkpoint_path = os.path.abspath(checkpoint_path)
|
366 |
-
output_dir = os.path.dirname(checkpoint_path)
|
367 |
-
|
368 |
-
# Deep copy to avoid side-effects on the model object.
|
369 |
-
model = copy.deepcopy(model)
|
370 |
-
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
371 |
-
|
372 |
-
# Load the weights to the model to validate that everything works
|
373 |
-
# and remove unnecessary weights (such as auxiliaries, etc)
|
374 |
-
if checkpoint_key == "model_ema":
|
375 |
-
del checkpoint[checkpoint_key]["n_averaged"]
|
376 |
-
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(checkpoint[checkpoint_key], "module.")
|
377 |
-
model.load_state_dict(checkpoint[checkpoint_key], strict=strict)
|
378 |
-
|
379 |
-
tmp_path = os.path.join(output_dir, str(model.__hash__()))
|
380 |
-
torch.save(model.state_dict(), tmp_path)
|
381 |
-
|
382 |
-
sha256_hash = hashlib.sha256()
|
383 |
-
with open(tmp_path, "rb") as f:
|
384 |
-
# Read and update hash string value in blocks of 4K
|
385 |
-
for byte_block in iter(lambda: f.read(4096), b""):
|
386 |
-
sha256_hash.update(byte_block)
|
387 |
-
hh = sha256_hash.hexdigest()
|
388 |
-
|
389 |
-
output_path = os.path.join(output_dir, "weights-" + str(hh[:8]) + ".pth")
|
390 |
-
os.replace(tmp_path, output_path)
|
391 |
-
|
392 |
-
return output_path
|
393 |
-
|
394 |
-
|
395 |
-
def reduce_across_processes(val):
|
396 |
-
if not is_dist_avail_and_initialized():
|
397 |
-
# nothing to sync, but we still convert to tensor for consistency with the distributed case.
|
398 |
-
return torch.tensor(val)
|
399 |
-
|
400 |
-
t = torch.tensor(val, device="cuda")
|
401 |
-
dist.barrier()
|
402 |
-
dist.all_reduce(t)
|
403 |
-
return t
|
404 |
-
|
405 |
-
|
406 |
-
def set_weight_decay(
|
407 |
-
model: torch.nn.Module,
|
408 |
-
weight_decay: float,
|
409 |
-
norm_weight_decay: Optional[float] = None,
|
410 |
-
norm_classes: Optional[List[type]] = None,
|
411 |
-
custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None,
|
412 |
-
):
|
413 |
-
if not norm_classes:
|
414 |
-
norm_classes = [
|
415 |
-
torch.nn.modules.batchnorm._BatchNorm,
|
416 |
-
torch.nn.LayerNorm,
|
417 |
-
torch.nn.GroupNorm,
|
418 |
-
torch.nn.modules.instancenorm._InstanceNorm,
|
419 |
-
torch.nn.LocalResponseNorm,
|
420 |
-
]
|
421 |
-
norm_classes = tuple(norm_classes)
|
422 |
-
|
423 |
-
params = {
|
424 |
-
"other": [],
|
425 |
-
"norm": [],
|
426 |
-
}
|
427 |
-
params_weight_decay = {
|
428 |
-
"other": weight_decay,
|
429 |
-
"norm": norm_weight_decay,
|
430 |
-
}
|
431 |
-
custom_keys = []
|
432 |
-
if custom_keys_weight_decay is not None:
|
433 |
-
for key, weight_decay in custom_keys_weight_decay:
|
434 |
-
params[key] = []
|
435 |
-
params_weight_decay[key] = weight_decay
|
436 |
-
custom_keys.append(key)
|
437 |
-
|
438 |
-
def _add_params(module, prefix=""):
|
439 |
-
for name, p in module.named_parameters(recurse=False):
|
440 |
-
if not p.requires_grad:
|
441 |
-
continue
|
442 |
-
is_custom_key = False
|
443 |
-
for key in custom_keys:
|
444 |
-
target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name
|
445 |
-
if key == target_name:
|
446 |
-
params[key].append(p)
|
447 |
-
is_custom_key = True
|
448 |
-
break
|
449 |
-
if not is_custom_key:
|
450 |
-
if norm_weight_decay is not None and isinstance(module, norm_classes):
|
451 |
-
params["norm"].append(p)
|
452 |
-
else:
|
453 |
-
params["other"].append(p)
|
454 |
-
|
455 |
-
for child_name, child_module in module.named_children():
|
456 |
-
child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
|
457 |
-
_add_params(child_module, prefix=child_prefix)
|
458 |
-
|
459 |
-
_add_params(model)
|
460 |
-
|
461 |
-
param_groups = []
|
462 |
-
for key in params:
|
463 |
-
if len(params[key]) > 0:
|
464 |
-
param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]})
|
465 |
-
return param_groups
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|