Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# coding: utf-8 | |
### config | |
total_epochs = 100 | |
batch_size = 256 | |
num_processes = 2 | |
image_size = 224 | |
drop_path = 0.05 | |
## Loss Function - CE (but try BCE) | |
# Always choose "SGD" for CNNs and AdamW for ViTs - SGD is Difficult to Converge || We should use LAMB with Cosine LR | |
## Multi-label --> Mixup and CutMix | |
LR = 5e-3 | |
weight_decay = 0.05 | |
warmup_epoch = 5 | |
dropout = 0 | |
drop_path = 0.05 | |
# In[5]: | |
import wandb | |
wandb_token = "e653df8526c77d083379de033d13342620583fdf" | |
wandb.login(key=wandb_token) | |
# In[7]: | |
import torch | |
import torch.nn as nn | |
from PIL import Image | |
import numpy as np | |
import pandas as pd | |
import albumentations | |
train_aug = albumentations.Compose( | |
[ | |
albumentations.Resize(image_size, image_size, p=1), | |
albumentations.ShiftScaleRotate( | |
shift_limit=0.0625, scale_limit=0.1, rotate_limit=10, p=0.8 | |
), | |
albumentations.OneOf( | |
[ | |
albumentations.RandomGamma(gamma_limit=(90, 110)), | |
albumentations.RandomBrightnessContrast( | |
brightness_limit=0.1, contrast_limit=0.1 | |
), | |
], | |
p=0.5, | |
), | |
albumentations.HorizontalFlip(), | |
albumentations.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225], | |
max_pixel_value=255.0, | |
p=1.0, | |
), | |
], | |
p=1.0, | |
) | |
valid_aug = albumentations.Compose( | |
[ | |
albumentations.Resize(image_size, image_size, p=1), | |
albumentations.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225], | |
max_pixel_value=255.0, | |
p=1.0, | |
), | |
], | |
p=1.0, | |
) | |
class ImageNetDataset(torch.utils.data.Dataset): | |
def __init__(self, image_path, augmentations=None, train=True): | |
self.image_path = image_path | |
self.augmentations = augmentations | |
self.df = pd.read_csv( | |
"/home/ubuntu/training/training/imagenet_class_labels.csv" | |
) | |
self.valid_df = pd.read_csv( | |
"/home/ubuntu/training/training/validation_classes.csv" | |
) | |
self.train = train | |
def __len__(self): | |
return len(self.image_path) | |
def __getitem__(self, item): | |
image_path = self.image_path[item] | |
with Image.open(image_path) as img: | |
image = img.convert("RGB") | |
image = np.asarray(image) | |
## center crop 95% area | |
H, W, C = image.shape | |
image = image[int(0.04 * H) : int(0.96 * H), int(0.04 * W) : int(0.96 * W), :] | |
if self.train: | |
class_id = str(self.image_path[item].split("/")[-2]) | |
targets = self.df[self.df["Index"] == class_id]["ID"].values[0] - 1 | |
else: | |
class_id = str(self.image_path[item].split("/")[-1][:-5]) | |
targets = ( | |
self.valid_df[self.valid_df["ImageId"] == class_id]["LabelId"].values[0] | |
- 1 | |
) | |
if self.augmentations is not None: | |
augmented = self.augmentations(image=image) | |
image = augmented["image"] | |
image = np.transpose(image, (2, 0, 1)).astype(np.float32) | |
return { | |
"image": torch.tensor(image, dtype=torch.float), | |
"targets": torch.tensor(targets, dtype=torch.long), | |
} | |
from timm.data.mixup import Mixup | |
mixup_args = { | |
"mixup_alpha": 0.1, | |
"cutmix_alpha": 1.0, | |
"cutmix_minmax": None, | |
"prob": 0.7, | |
"switch_prob": 0, | |
"mode": "batch", | |
"label_smoothing": 0.1, | |
"num_classes": 1000, | |
} | |
mixup_fn = Mixup(**mixup_args) | |
import glob | |
import random | |
train_paths = glob.glob( | |
"/home/ubuntu/training/Imagenet/ILSVRC/Data/ImageNet/train/*/*.JPEG" | |
) | |
valid_paths = glob.glob( | |
"/home/ubuntu/training/Imagenet/ILSVRC/Data/ImageNet/val/*.JPEG" | |
) | |
import pytorch_lightning as pl | |
from pytorch_lightning.loggers import WandbLogger | |
import torch | |
from timm import create_model | |
from torchvision import transforms, datasets | |
import pytorch_lightning as L | |
# from timm.scheduler.cosine_lr import CosineLRScheduler | |
class LitClassification(L.LightningModule): | |
def __init__(self): | |
super().__init__() | |
self.model = create_model( | |
"resnet50", pretrained=False, drop_path_rate=drop_path | |
) | |
# model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) | |
self.loss_fn = torch.nn.CrossEntropyLoss() | |
def forward(self, x): | |
return self.model(x) | |
def training_step(self, batch): | |
images, targets = batch["image"], batch["targets"] | |
outputs = self.model(images) | |
loss = self.loss_fn(outputs, targets) | |
acc1, acc5 = self.__accuracy(outputs, targets, topk=(1, 5)) | |
self.log("train_loss", loss) | |
self.log( | |
"train_acc1", acc1, on_step=True, prog_bar=True, on_epoch=True, logger=True | |
) | |
self.log("train_acc5", acc5, on_step=True, on_epoch=True, logger=True) | |
return loss | |
def validation_step(self, batch): | |
images, targets = batch["image"], batch["targets"] | |
outputs = self(images) | |
loss = self.loss_fn(outputs, targets) | |
acc1, acc5 = self.__accuracy(outputs, targets, topk=(1, 5)) | |
self.log("valid_loss", loss) | |
self.log("val_acc1", acc1, on_step=True, prog_bar=True, on_epoch=True) | |
self.log("val_acc5", acc5, on_step=True, on_epoch=True) | |
def __accuracy(output, target, topk=(1,)): | |
"""Computes the accuracy over the k top predictions for the specified values of k.""" | |
with torch.no_grad(): | |
maxk = max(topk) | |
batch_size = target.size(0) | |
_, pred = output.topk(maxk, 1, True, True) | |
pred = pred.t() | |
correct = pred.eq(target.view(1, -1).expand_as(pred)) | |
res = [] | |
for k in topk: | |
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) | |
res.append(correct_k.mul_(100.0 / batch_size)) | |
return res | |
def configure_optimizers(self): | |
optimizer = torch.optim.AdamW( | |
self.parameters(), lr=LR, weight_decay=weight_decay | |
) | |
scheduler = torch.optim.lr_scheduler.OneCycleLR( | |
optimizer, | |
max_lr=LR, | |
total_steps=self.trainer.estimated_stepping_batches, | |
epochs=warmup_epoch, | |
steps_per_epoch=None, | |
pct_start=0.3, | |
anneal_strategy="cos", | |
cycle_momentum=True, | |
base_momentum=0.85, | |
max_momentum=0.95, | |
div_factor=25.0, | |
final_div_factor=10000.0, | |
three_phase=False, | |
last_epoch=-1, | |
verbose="deprecated", | |
) | |
return [optimizer], [scheduler] | |
def train_dataloader(self): | |
train_dataset = ImageNetDataset(train_paths, train_aug, train=True) | |
train_loader = torch.utils.data.DataLoader( | |
train_dataset, | |
batch_size=batch_size, | |
shuffle=True, | |
num_workers=num_processes, | |
pin_memory=True, | |
) | |
return train_loader | |
def val_dataloader(self): | |
valid_dataset = ImageNetDataset(valid_paths, valid_aug, train=False) | |
valid_loader = torch.utils.data.DataLoader( | |
valid_dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
) | |
return valid_loader | |
L.seed_everything(879246) | |
wandb_logger = WandbLogger(log_model="all", project="ImageNet_Lightning") | |
# Initialize a trainer | |
best_checkpoint_callback = L.callbacks.ModelCheckpoint( | |
filename="bestmodel-{epoch}-monitor-{val_acc1}", mode="max" | |
) | |
every_epoch_checkpoint_callback = L.callbacks.ModelCheckpoint( | |
filename="{epoch}_{val_acc1}", every_n_epochs=10 | |
) | |
trainer = L.Trainer( | |
max_epochs=total_epochs, | |
devices=torch.cuda.device_count(), | |
accelerator="gpu", | |
logger=wandb_logger, | |
# callbacks=[early_stop_callback], | |
precision=16, | |
callbacks=[best_checkpoint_callback, every_epoch_checkpoint_callback], | |
) | |
model = LitClassification() | |
trainer.fit( | |
model, | |
ckpt_path="/home/ubuntu/training/training/ImageNet_Lightning/h94dnl2b/checkpoints/bestmodel-epoch=32-monitor-val_acc1=62.54399871826172.ckpt", | |
) | |