from typing import Any,List,Tuple,Dict import torch from torch import nn from torch.nn import functional as F from torchvision.utils import make_grid from torch.optim import Optimizer,Adam,SGD from lightning import LightningModule from torchmetrics import Accuracy,F1Score,AUROC,ConfusionMatrix device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') torch.set_default_device( device= device ) from .mnist_model import Net __all__: List[str] = ["LitMNISTModel"] class LitMNISTModel(LightningModule): def __init__( self, learning_rate:float = 3e-4, num_classes:int = 10, dropout_rate:float=0.01, bias:bool=False, momentum:float =.9, *args: Any, **kwargs: Any ) -> None: super().__init__() self.save_hyperparameters() self.learning_rate:float = learning_rate self.num_class:int = num_classes self.momentum:float = momentum # metric ## Accuracy self.train_accuracy = Accuracy(task="multiclass", num_classes=num_classes) self.val_accuracy = Accuracy(task="multiclass", num_classes=num_classes) self.test_accuracy = Accuracy(task="multiclass", num_classes=num_classes) ## F1 Score self.train_f1 = F1Score(task="multiclass", num_classes=num_classes) self.val_f1 = F1Score(task="multiclass", num_classes=num_classes) self.test_f1 = F1Score(task="multiclass", num_classes=num_classes) ## Model self.model = Net(config={'dropout_rate':dropout_rate, 'bias':bias}) def forward(self, x) -> Any: return self.model(x) def training_step(self, batch,batch_idx, *args: Any, **kwargs: Any) -> torch.Tensor: x,y = batch logits = self(x) loss = F.nll_loss(logits,y) preds = torch.argmax(logits,dim=1) acc = self.train_accuracy(preds,y) f1 = self.train_f1(preds,y) self.log("train/loss",loss,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger) self.log("train/acc",acc,prog_bar=True,on_epoch=False,on_step=True,logger=self.trainer.logger) self.log("train/train_f1",f1,prog_bar=True,on_epoch=False,on_step=True,logger=self.trainer.logger) if batch_idx==0: grid = make_grid(x) self.logger.experiment.add_image("train_imgs",grid,self.current_epoch) return { 'loss':loss, 'logits':logits, 'preds':preds } def validation_step(self,batch,batch_idx, *args: Any, **kwargs: Any) -> torch.Tensor : x,y = batch logits = self(x) loss = F.nll_loss(logits,y) preds = torch.argmax(logits,dim=1) acc = self.val_accuracy(preds,y) f1 = self.val_f1(preds,y) self.log("val/loss",loss,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger) self.log("val/acc",acc,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger) self.log("val/val_f1",f1,prog_bar=True,on_epoch=True,on_step=False,logger=self.trainer.logger) if batch_idx==0: grid = make_grid(x) self.logger.experiment.add_image("val_imgs",grid,self.current_epoch) return { 'loss':loss, 'logits':logits, 'preds':preds } def predict_step(self,x:torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: with torch.no_grad(): logits = self(x) probs,indices = torch.max( F.softmax(logits,dim=1), dim=1) return { 'prob':probs, 'predict':indices } def test_step(self,batch): x,y = batch logits = self(x) loss = F.nll_loss(logits,y) preds = torch.argmax(logits,dim=1) acc = self.test_accuracy(preds,y) f1 = self.test_f1(preds,y) self.log("test/loss",loss,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger) self.log("test/acc",acc,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger) self.log("test/test_f1",f1,prog_bar=True,on_epoch=True,on_step=False,logger=self.trainer.logger) return { 'loss':loss, 'logits':logits, 'preds':preds } def configure_optimizers(self): # optimizer = SGD(self.parameters(),lr=self.learning_rate,momentum=self.momentum) # Reduce LR ON Plateau # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,factor=.1,patience=2,verbose=True) # return { # "optimizer": optimizer, # "lr_scheduler": scheduler, # "monitor": 'val/loss', # 'interval':'step', # "frequency": 15 # } optimizer = Adam(self.parameters(),lr=1e3) scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer=optimizer, max_lr=1e2*self.learning_rate, total_steps=self.trainer.estimated_stepping_batches, pct_start=.3, cycle_momentum=True, div_factor =100, final_div_factor = 1e10, verbose = False, three_phase=True ) return ([optimizer],[scheduler])