Muthukamalan commited on
Commit
7c9474f
·
1 Parent(s): af3a445

lightning code

Browse files
Files changed (3) hide show
  1. model/__init__.py +2 -0
  2. model/mnist_model.py +93 -0
  3. model/model.py +155 -0
model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .model import LitMNISTModel
2
+ from .mnist_model import Net
model/mnist_model.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ from typing import Any,List,Tuple,Dict
6
+
7
+
8
+ class Net(nn.Module):
9
+ def __init__(self,config:Dict):
10
+ super(Net,self).__init__()
11
+
12
+ DROPOUT= config.get('dropout_rate',0.01)
13
+ BIAS = config.get('bias',False)
14
+
15
+ self.conv1 = nn.Sequential(
16
+ nn.Conv2d(in_channels=1,out_channels=8,kernel_size=3,stride=1,padding=1,bias=BIAS),
17
+ nn.ReLU(),
18
+ nn.BatchNorm2d(8),
19
+ nn.Dropout2d(p=DROPOUT),
20
+
21
+ nn.Conv2d(in_channels=8,out_channels=10,kernel_size=3,stride=1,padding=1,bias=BIAS),
22
+ nn.ReLU(),
23
+ nn.BatchNorm2d(10),
24
+ nn.Dropout2d(p=DROPOUT),
25
+
26
+ nn.Conv2d(in_channels=10,out_channels=10,kernel_size=3,stride=1,padding=1,bias=BIAS),
27
+ nn.ReLU(),
28
+ nn.BatchNorm2d(10),
29
+ nn.Dropout2d(p=DROPOUT),
30
+ )
31
+ self.trans1 = nn.Sequential(
32
+
33
+ nn.MaxPool2d( kernel_size =2 , stride =2 , padding =1 ),
34
+ nn.Conv2d(in_channels=10,out_channels=8,kernel_size=1,bias=BIAS,padding=1),
35
+ )
36
+
37
+ self.conv2 =nn.Sequential(
38
+ nn.Conv2d(in_channels=8,out_channels=10,kernel_size=3,stride=1,padding=1,bias=BIAS),
39
+ nn.BatchNorm2d(10),
40
+ nn.ReLU(),
41
+ nn.Dropout2d(p=DROPOUT),
42
+
43
+
44
+ nn.Conv2d(in_channels=10,out_channels=12,kernel_size=3,stride=1,padding=1,bias=BIAS),
45
+ nn.BatchNorm2d(12),
46
+ nn.ReLU(),
47
+ nn.Dropout2d(p=DROPOUT),
48
+
49
+ nn.Conv2d(in_channels=12,out_channels=12,kernel_size=3,stride=1,padding=1,bias=BIAS),
50
+ nn.BatchNorm2d(12),
51
+ nn.ReLU(),
52
+ nn.Dropout2d(p=DROPOUT),
53
+ )
54
+ self.trans2 = nn.Sequential(
55
+ nn.MaxPool2d( kernel_size =2 , stride =2 , padding =1 ),
56
+ nn.Conv2d(in_channels=12,out_channels=8,kernel_size=1,bias=BIAS),
57
+ nn.BatchNorm2d(8),
58
+ )
59
+
60
+ self.conv3 = nn.Sequential(
61
+ nn.Conv2d(in_channels=8,out_channels=10,kernel_size=3,stride=1,padding=1,bias=BIAS),
62
+ nn.BatchNorm2d(10),
63
+ nn.ReLU(),
64
+ nn.Dropout2d(p=DROPOUT),
65
+
66
+ nn.Conv2d(in_channels=10,out_channels=12,kernel_size=3,stride=1,padding=1,bias=BIAS),
67
+ nn.ReLU(),
68
+ nn.BatchNorm2d(12),
69
+ nn.Dropout2d(p=DROPOUT),
70
+
71
+ )
72
+ self.trans3 = nn.Sequential(
73
+ nn.Conv2d(in_channels=12,out_channels=10,kernel_size=1,bias=BIAS),
74
+ nn.MaxPool2d( kernel_size =2 , stride =2 , padding =0 ),
75
+ nn.BatchNorm2d(10),
76
+ )
77
+
78
+ self.out4 = nn.Sequential(
79
+ nn.Conv2d(in_channels=10 ,out_channels=10, kernel_size=3,stride=1,padding=1,bias=BIAS),
80
+ nn.AvgPool2d(kernel_size=3) #(1*1*10)
81
+ )
82
+
83
+
84
+ def forward(self,x):
85
+ x = self.trans1( self.conv1(x) )
86
+ x = self.trans2( self.conv2(x) )
87
+ x = self.trans3( self.conv3(x) )
88
+ x = self.out4(x)
89
+ x = x.view(-1,10)
90
+ return F.log_softmax(x,dim=1)
91
+
92
+
93
+
model/model.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any,List,Tuple,Dict
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from torchvision.utils import make_grid
7
+ from torch.optim import Optimizer,Adam,SGD
8
+ from lightning import LightningModule
9
+ from torchmetrics import Accuracy,F1Score,AUROC,ConfusionMatrix
10
+
11
+
12
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
13
+ torch.set_default_device( device= device )
14
+
15
+
16
+
17
+ from .mnist_model import Net
18
+
19
+ __all__: List[str] = ["LitMNISTModel"]
20
+
21
+ class LitMNISTModel(LightningModule):
22
+ def __init__(
23
+ self,
24
+ learning_rate:float = 3e-4,
25
+ num_classes:int = 10,
26
+ dropout_rate:float=0.01,
27
+ bias:bool=False,
28
+ momentum:float =.9,
29
+ *args: Any,
30
+ **kwargs: Any
31
+ ) -> None:
32
+ super().__init__()
33
+ self.save_hyperparameters()
34
+
35
+ self.learning_rate:float = learning_rate
36
+ self.num_class:int = num_classes
37
+ self.momentum:float = momentum
38
+
39
+ # metric
40
+ ## Accuracy
41
+ self.train_accuracy = Accuracy(task="multiclass", num_classes=num_classes)
42
+ self.val_accuracy = Accuracy(task="multiclass", num_classes=num_classes)
43
+ self.test_accuracy = Accuracy(task="multiclass", num_classes=num_classes)
44
+
45
+ ## F1 Score
46
+ self.train_f1 = F1Score(task="multiclass", num_classes=num_classes)
47
+ self.val_f1 = F1Score(task="multiclass", num_classes=num_classes)
48
+ self.test_f1 = F1Score(task="multiclass", num_classes=num_classes)
49
+
50
+ ## Model
51
+ self.model = Net(config={'dropout_rate':dropout_rate, 'bias':bias})
52
+
53
+
54
+ def forward(self, x) -> Any:
55
+ return self.model(x)
56
+
57
+
58
+ def training_step(self, batch,batch_idx, *args: Any, **kwargs: Any) -> torch.Tensor:
59
+ x,y = batch
60
+ logits = self(x)
61
+ loss = F.nll_loss(logits,y)
62
+ preds = torch.argmax(logits,dim=1)
63
+ acc = self.train_accuracy(preds,y)
64
+ f1 = self.train_f1(preds,y)
65
+
66
+ self.log("train/loss",loss,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger)
67
+ self.log("train/acc",acc,prog_bar=True,on_epoch=False,on_step=True,logger=self.trainer.logger)
68
+ self.log("train/train_f1",f1,prog_bar=True,on_epoch=False,on_step=True,logger=self.trainer.logger)
69
+
70
+ if batch_idx==0:
71
+ grid = make_grid(x)
72
+ self.logger.experiment.add_image("train_imgs",grid,self.current_epoch)
73
+
74
+ return {
75
+ 'loss':loss,
76
+ 'logits':logits,
77
+ 'preds':preds
78
+ }
79
+
80
+
81
+ def validation_step(self,batch,batch_idx, *args: Any, **kwargs: Any) -> torch.Tensor :
82
+ x,y = batch
83
+ logits = self(x)
84
+ loss = F.nll_loss(logits,y)
85
+ preds = torch.argmax(logits,dim=1)
86
+ acc = self.val_accuracy(preds,y)
87
+ f1 = self.val_f1(preds,y)
88
+
89
+ self.log("val/loss",loss,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger)
90
+ self.log("val/acc",acc,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger)
91
+ self.log("val/val_f1",f1,prog_bar=True,on_epoch=True,on_step=False,logger=self.trainer.logger)
92
+
93
+ if batch_idx==0:
94
+ grid = make_grid(x)
95
+ self.logger.experiment.add_image("val_imgs",grid,self.current_epoch)
96
+
97
+ return {
98
+ 'loss':loss,
99
+ 'logits':logits,
100
+ 'preds':preds
101
+ }
102
+
103
+ def predict_step(self,x:torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
104
+ with torch.no_grad():
105
+ logits = self(x)
106
+ probs,indices = torch.max( F.softmax(logits,dim=1), dim=1)
107
+ return {
108
+ 'prob':probs,
109
+ 'predict':indices
110
+ }
111
+
112
+
113
+ def test_step(self,batch):
114
+ x,y = batch
115
+ logits = self(x)
116
+ loss = F.nll_loss(logits,y)
117
+ preds = torch.argmax(logits,dim=1)
118
+ acc = self.test_accuracy(preds,y)
119
+ f1 = self.test_f1(preds,y)
120
+
121
+ self.log("test/loss",loss,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger)
122
+ self.log("test/acc",acc,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger)
123
+ self.log("test/test_f1",f1,prog_bar=True,on_epoch=True,on_step=False,logger=self.trainer.logger)
124
+
125
+
126
+ return {
127
+ 'loss':loss,
128
+ 'logits':logits,
129
+ 'preds':preds
130
+ }
131
+
132
+ def configure_optimizers(self):
133
+ # optimizer = SGD(self.parameters(),lr=self.learning_rate,momentum=self.momentum)
134
+ # Reduce LR ON Plateau
135
+ # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,factor=.1,patience=2,verbose=True)
136
+ # return {
137
+ # "optimizer": optimizer,
138
+ # "lr_scheduler": scheduler,
139
+ # "monitor": 'val/loss',
140
+ # 'interval':'step',
141
+ # "frequency": 15
142
+ # }
143
+ optimizer = Adam(self.parameters(),lr=1e3)
144
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
145
+ optimizer=optimizer,
146
+ max_lr=1e2*self.learning_rate,
147
+ total_steps=self.trainer.estimated_stepping_batches,
148
+ pct_start=.3,
149
+ cycle_momentum=True,
150
+ div_factor =100,
151
+ final_div_factor = 1e10,
152
+ verbose = False,
153
+ three_phase=True
154
+ )
155
+ return ([optimizer],[scheduler])