resnet18_cifar10 / model.py
venkyyuvy's picture
init commit
2efd69c
raw
history blame
6.82 kB
'''
https://github.com/kuangliu/pytorch-cifar
ResNet in PyTorch.
For Pre-activation ResNet, see 'preact_resnet.py'.
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''
import torch
from torch import nn
from torch.nn import functional as F
from torch_lr_finder import LRFinder
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
super(ResNet, self).__init__()
self.in_planes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear = nn.Linear(512*block.expansion, num_classes)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
def ResNet18():
return ResNet(BasicBlock, [2, 2, 2, 2])
import torch.nn as nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from data_loader import CifarAlbumentationsDataset,\
CIFAR_CLASS_LABELS, TRAIN_TRANSFORM, TEST_TRANSFORM
import model
from torch_lr_finder import LRFinder
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning import LightningModule
from torch.optim.lr_scheduler import OneCycleLR
from torchmetrics.functional import accuracy
class LitResnet(LightningModule):
def __init__(self, lr=0.03, batch_size=512):
super().__init__()
self.save_hyperparameters()
self.criterion = nn.CrossEntropyLoss()
self.model = ResNet18()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
output = self.forward(x)
loss = self.criterion(output, y)
self.log("train_loss", loss)
acc = accuracy(torch.argmax(output, dim=1),
y, 'multiclass', num_classes=10)
self.log(f"train_acc", acc, prog_bar=True)
return loss
def evaluate(self, batch, stage=None):
x, y = batch
output = self.forward(x)
loss = self.criterion(output, y)
preds = torch.argmax(output, dim=1)
acc = accuracy(preds, y, 'multiclass', num_classes=10)
if stage:
self.log(f"{stage}_loss", loss, prog_bar=True)
self.log(f"{stage}_acc", acc, prog_bar=True)
def validation_step(self, batch, batch_idx):
self.evaluate(batch, "val")
def test_step(self, batch, batch_idx):
self.evaluate(batch, "test")
# todo
# change the default for num_iter
def lr_finder(self, optimizer, num_iter=200,):
lr_finder = LRFinder(self, optimizer, self.criterion,
device=self.device)
lr_finder.range_test(
self.train_dataloader(), end_lr=1,
num_iter=num_iter, step_mode='exp',
)
ax, suggested_lr = lr_finder.plot(suggest_lr=True)
# todo
# how to log maplotlib images
# self.logger.experiment.add_image('lr_finder', plt.gcf(), 0)
lr_finder.reset()
return suggested_lr
def configure_optimizers(self):
optimizer = torch.optim.SGD(
self.parameters(),
lr=self.hparams.lr,
momentum=0.9,
weight_decay=5e-4,
)
suggested_lr = self.lr_finder(optimizer)
steps_per_epoch = len(self.train_dataloader())
scheduler_dict = {
"scheduler": OneCycleLR(
optimizer, max_lr=suggested_lr,
steps_per_epoch=steps_per_epoch,
epochs=self.trainer.max_epochs,
pct_start=5/self.trainer.max_epochs,
three_phase=False,
div_factor=100,
final_div_factor=100,
anneal_strategy='linear',
),
"interval": "step",
}
return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
####################
# DATA RELATED HOOKS
####################
def prepare_data(self, data_path='../data'):
CifarAlbumentationsDataset(
data_path, train=True, download=True)
CifarAlbumentationsDataset(
data_path, train=False, download=True)
def setup(self, stage=None, data_dir='../data'):
if stage == "fit" or stage is None:
self.train_dataset = CifarAlbumentationsDataset(data_dir, train=True, transform=TRAIN_TRANSFORM)
self.test_dataset = CifarAlbumentationsDataset(data_dir, train=False, transform=TEST_TRANSFORM)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size,
shuffle=True, pin_memory=True) #num_workers=4,
def val_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.hparams.batch_size,
shuffle=False, pin_memory=True)