UnsolvedMNIST / model /model.py
Muthukamalan's picture
lightning code
7c9474f
raw
history blame
5.66 kB
from typing import Any,List,Tuple,Dict
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.utils import make_grid
from torch.optim import Optimizer,Adam,SGD
from lightning import LightningModule
from torchmetrics import Accuracy,F1Score,AUROC,ConfusionMatrix
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.set_default_device( device= device )
from .mnist_model import Net
__all__: List[str] = ["LitMNISTModel"]
class LitMNISTModel(LightningModule):
def __init__(
self,
learning_rate:float = 3e-4,
num_classes:int = 10,
dropout_rate:float=0.01,
bias:bool=False,
momentum:float =.9,
*args: Any,
**kwargs: Any
) -> None:
super().__init__()
self.save_hyperparameters()
self.learning_rate:float = learning_rate
self.num_class:int = num_classes
self.momentum:float = momentum
# metric
## Accuracy
self.train_accuracy = Accuracy(task="multiclass", num_classes=num_classes)
self.val_accuracy = Accuracy(task="multiclass", num_classes=num_classes)
self.test_accuracy = Accuracy(task="multiclass", num_classes=num_classes)
## F1 Score
self.train_f1 = F1Score(task="multiclass", num_classes=num_classes)
self.val_f1 = F1Score(task="multiclass", num_classes=num_classes)
self.test_f1 = F1Score(task="multiclass", num_classes=num_classes)
## Model
self.model = Net(config={'dropout_rate':dropout_rate, 'bias':bias})
def forward(self, x) -> Any:
return self.model(x)
def training_step(self, batch,batch_idx, *args: Any, **kwargs: Any) -> torch.Tensor:
x,y = batch
logits = self(x)
loss = F.nll_loss(logits,y)
preds = torch.argmax(logits,dim=1)
acc = self.train_accuracy(preds,y)
f1 = self.train_f1(preds,y)
self.log("train/loss",loss,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger)
self.log("train/acc",acc,prog_bar=True,on_epoch=False,on_step=True,logger=self.trainer.logger)
self.log("train/train_f1",f1,prog_bar=True,on_epoch=False,on_step=True,logger=self.trainer.logger)
if batch_idx==0:
grid = make_grid(x)
self.logger.experiment.add_image("train_imgs",grid,self.current_epoch)
return {
'loss':loss,
'logits':logits,
'preds':preds
}
def validation_step(self,batch,batch_idx, *args: Any, **kwargs: Any) -> torch.Tensor :
x,y = batch
logits = self(x)
loss = F.nll_loss(logits,y)
preds = torch.argmax(logits,dim=1)
acc = self.val_accuracy(preds,y)
f1 = self.val_f1(preds,y)
self.log("val/loss",loss,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger)
self.log("val/acc",acc,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger)
self.log("val/val_f1",f1,prog_bar=True,on_epoch=True,on_step=False,logger=self.trainer.logger)
if batch_idx==0:
grid = make_grid(x)
self.logger.experiment.add_image("val_imgs",grid,self.current_epoch)
return {
'loss':loss,
'logits':logits,
'preds':preds
}
def predict_step(self,x:torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
with torch.no_grad():
logits = self(x)
probs,indices = torch.max( F.softmax(logits,dim=1), dim=1)
return {
'prob':probs,
'predict':indices
}
def test_step(self,batch):
x,y = batch
logits = self(x)
loss = F.nll_loss(logits,y)
preds = torch.argmax(logits,dim=1)
acc = self.test_accuracy(preds,y)
f1 = self.test_f1(preds,y)
self.log("test/loss",loss,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger)
self.log("test/acc",acc,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger)
self.log("test/test_f1",f1,prog_bar=True,on_epoch=True,on_step=False,logger=self.trainer.logger)
return {
'loss':loss,
'logits':logits,
'preds':preds
}
def configure_optimizers(self):
# optimizer = SGD(self.parameters(),lr=self.learning_rate,momentum=self.momentum)
# Reduce LR ON Plateau
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,factor=.1,patience=2,verbose=True)
# return {
# "optimizer": optimizer,
# "lr_scheduler": scheduler,
# "monitor": 'val/loss',
# 'interval':'step',
# "frequency": 15
# }
optimizer = Adam(self.parameters(),lr=1e3)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
max_lr=1e2*self.learning_rate,
total_steps=self.trainer.estimated_stepping_batches,
pct_start=.3,
cycle_momentum=True,
div_factor =100,
final_div_factor = 1e10,
verbose = False,
three_phase=True
)
return ([optimizer],[scheduler])