Spaces:
Sleeping
Sleeping
from typing import Any,List,Tuple,Dict | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from torchvision.utils import make_grid | |
from torch.optim import Optimizer,Adam,SGD | |
from lightning import LightningModule | |
from torchmetrics import Accuracy,F1Score,AUROC,ConfusionMatrix | |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
torch.set_default_device( device= device ) | |
from .mnist_model import Net | |
__all__: List[str] = ["LitMNISTModel"] | |
class LitMNISTModel(LightningModule): | |
def __init__( | |
self, | |
learning_rate:float = 3e-4, | |
num_classes:int = 10, | |
dropout_rate:float=0.01, | |
bias:bool=False, | |
momentum:float =.9, | |
*args: Any, | |
**kwargs: Any | |
) -> None: | |
super().__init__() | |
self.save_hyperparameters() | |
self.learning_rate:float = learning_rate | |
self.num_class:int = num_classes | |
self.momentum:float = momentum | |
# metric | |
## Accuracy | |
self.train_accuracy = Accuracy(task="multiclass", num_classes=num_classes) | |
self.val_accuracy = Accuracy(task="multiclass", num_classes=num_classes) | |
self.test_accuracy = Accuracy(task="multiclass", num_classes=num_classes) | |
## F1 Score | |
self.train_f1 = F1Score(task="multiclass", num_classes=num_classes) | |
self.val_f1 = F1Score(task="multiclass", num_classes=num_classes) | |
self.test_f1 = F1Score(task="multiclass", num_classes=num_classes) | |
## Model | |
self.model = Net(config={'dropout_rate':dropout_rate, 'bias':bias}) | |
def forward(self, x) -> Any: | |
return self.model(x) | |
def training_step(self, batch,batch_idx, *args: Any, **kwargs: Any) -> torch.Tensor: | |
x,y = batch | |
logits = self(x) | |
loss = F.nll_loss(logits,y) | |
preds = torch.argmax(logits,dim=1) | |
acc = self.train_accuracy(preds,y) | |
f1 = self.train_f1(preds,y) | |
self.log("train/loss",loss,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger) | |
self.log("train/acc",acc,prog_bar=True,on_epoch=False,on_step=True,logger=self.trainer.logger) | |
self.log("train/train_f1",f1,prog_bar=True,on_epoch=False,on_step=True,logger=self.trainer.logger) | |
if batch_idx==0: | |
grid = make_grid(x) | |
self.logger.experiment.add_image("train_imgs",grid,self.current_epoch) | |
return { | |
'loss':loss, | |
'logits':logits, | |
'preds':preds | |
} | |
def validation_step(self,batch,batch_idx, *args: Any, **kwargs: Any) -> torch.Tensor : | |
x,y = batch | |
logits = self(x) | |
loss = F.nll_loss(logits,y) | |
preds = torch.argmax(logits,dim=1) | |
acc = self.val_accuracy(preds,y) | |
f1 = self.val_f1(preds,y) | |
self.log("val/loss",loss,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger) | |
self.log("val/acc",acc,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger) | |
self.log("val/val_f1",f1,prog_bar=True,on_epoch=True,on_step=False,logger=self.trainer.logger) | |
if batch_idx==0: | |
grid = make_grid(x) | |
self.logger.experiment.add_image("val_imgs",grid,self.current_epoch) | |
return { | |
'loss':loss, | |
'logits':logits, | |
'preds':preds | |
} | |
def predict_step(self,x:torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: | |
with torch.no_grad(): | |
logits = self(x) | |
probs,indices = torch.max( F.softmax(logits,dim=1), dim=1) | |
return { | |
'prob':probs, | |
'predict':indices | |
} | |
def test_step(self,batch): | |
x,y = batch | |
logits = self(x) | |
loss = F.nll_loss(logits,y) | |
preds = torch.argmax(logits,dim=1) | |
acc = self.test_accuracy(preds,y) | |
f1 = self.test_f1(preds,y) | |
self.log("test/loss",loss,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger) | |
self.log("test/acc",acc,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger) | |
self.log("test/test_f1",f1,prog_bar=True,on_epoch=True,on_step=False,logger=self.trainer.logger) | |
return { | |
'loss':loss, | |
'logits':logits, | |
'preds':preds | |
} | |
def configure_optimizers(self): | |
# optimizer = SGD(self.parameters(),lr=self.learning_rate,momentum=self.momentum) | |
# Reduce LR ON Plateau | |
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,factor=.1,patience=2,verbose=True) | |
# return { | |
# "optimizer": optimizer, | |
# "lr_scheduler": scheduler, | |
# "monitor": 'val/loss', | |
# 'interval':'step', | |
# "frequency": 15 | |
# } | |
optimizer = Adam(self.parameters(),lr=1e3) | |
scheduler = torch.optim.lr_scheduler.OneCycleLR( | |
optimizer=optimizer, | |
max_lr=1e2*self.learning_rate, | |
total_steps=self.trainer.estimated_stepping_batches, | |
pct_start=.3, | |
cycle_momentum=True, | |
div_factor =100, | |
final_div_factor = 1e10, | |
verbose = False, | |
three_phase=True | |
) | |
return ([optimizer],[scheduler]) |