Spaces:
Sleeping
Sleeping
Commit
·
7c9474f
1
Parent(s):
af3a445
lightning code
Browse files- model/__init__.py +2 -0
- model/mnist_model.py +93 -0
- model/model.py +155 -0
model/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .model import LitMNISTModel
|
2 |
+
from .mnist_model import Net
|
model/mnist_model.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
from typing import Any,List,Tuple,Dict
|
6 |
+
|
7 |
+
|
8 |
+
class Net(nn.Module):
|
9 |
+
def __init__(self,config:Dict):
|
10 |
+
super(Net,self).__init__()
|
11 |
+
|
12 |
+
DROPOUT= config.get('dropout_rate',0.01)
|
13 |
+
BIAS = config.get('bias',False)
|
14 |
+
|
15 |
+
self.conv1 = nn.Sequential(
|
16 |
+
nn.Conv2d(in_channels=1,out_channels=8,kernel_size=3,stride=1,padding=1,bias=BIAS),
|
17 |
+
nn.ReLU(),
|
18 |
+
nn.BatchNorm2d(8),
|
19 |
+
nn.Dropout2d(p=DROPOUT),
|
20 |
+
|
21 |
+
nn.Conv2d(in_channels=8,out_channels=10,kernel_size=3,stride=1,padding=1,bias=BIAS),
|
22 |
+
nn.ReLU(),
|
23 |
+
nn.BatchNorm2d(10),
|
24 |
+
nn.Dropout2d(p=DROPOUT),
|
25 |
+
|
26 |
+
nn.Conv2d(in_channels=10,out_channels=10,kernel_size=3,stride=1,padding=1,bias=BIAS),
|
27 |
+
nn.ReLU(),
|
28 |
+
nn.BatchNorm2d(10),
|
29 |
+
nn.Dropout2d(p=DROPOUT),
|
30 |
+
)
|
31 |
+
self.trans1 = nn.Sequential(
|
32 |
+
|
33 |
+
nn.MaxPool2d( kernel_size =2 , stride =2 , padding =1 ),
|
34 |
+
nn.Conv2d(in_channels=10,out_channels=8,kernel_size=1,bias=BIAS,padding=1),
|
35 |
+
)
|
36 |
+
|
37 |
+
self.conv2 =nn.Sequential(
|
38 |
+
nn.Conv2d(in_channels=8,out_channels=10,kernel_size=3,stride=1,padding=1,bias=BIAS),
|
39 |
+
nn.BatchNorm2d(10),
|
40 |
+
nn.ReLU(),
|
41 |
+
nn.Dropout2d(p=DROPOUT),
|
42 |
+
|
43 |
+
|
44 |
+
nn.Conv2d(in_channels=10,out_channels=12,kernel_size=3,stride=1,padding=1,bias=BIAS),
|
45 |
+
nn.BatchNorm2d(12),
|
46 |
+
nn.ReLU(),
|
47 |
+
nn.Dropout2d(p=DROPOUT),
|
48 |
+
|
49 |
+
nn.Conv2d(in_channels=12,out_channels=12,kernel_size=3,stride=1,padding=1,bias=BIAS),
|
50 |
+
nn.BatchNorm2d(12),
|
51 |
+
nn.ReLU(),
|
52 |
+
nn.Dropout2d(p=DROPOUT),
|
53 |
+
)
|
54 |
+
self.trans2 = nn.Sequential(
|
55 |
+
nn.MaxPool2d( kernel_size =2 , stride =2 , padding =1 ),
|
56 |
+
nn.Conv2d(in_channels=12,out_channels=8,kernel_size=1,bias=BIAS),
|
57 |
+
nn.BatchNorm2d(8),
|
58 |
+
)
|
59 |
+
|
60 |
+
self.conv3 = nn.Sequential(
|
61 |
+
nn.Conv2d(in_channels=8,out_channels=10,kernel_size=3,stride=1,padding=1,bias=BIAS),
|
62 |
+
nn.BatchNorm2d(10),
|
63 |
+
nn.ReLU(),
|
64 |
+
nn.Dropout2d(p=DROPOUT),
|
65 |
+
|
66 |
+
nn.Conv2d(in_channels=10,out_channels=12,kernel_size=3,stride=1,padding=1,bias=BIAS),
|
67 |
+
nn.ReLU(),
|
68 |
+
nn.BatchNorm2d(12),
|
69 |
+
nn.Dropout2d(p=DROPOUT),
|
70 |
+
|
71 |
+
)
|
72 |
+
self.trans3 = nn.Sequential(
|
73 |
+
nn.Conv2d(in_channels=12,out_channels=10,kernel_size=1,bias=BIAS),
|
74 |
+
nn.MaxPool2d( kernel_size =2 , stride =2 , padding =0 ),
|
75 |
+
nn.BatchNorm2d(10),
|
76 |
+
)
|
77 |
+
|
78 |
+
self.out4 = nn.Sequential(
|
79 |
+
nn.Conv2d(in_channels=10 ,out_channels=10, kernel_size=3,stride=1,padding=1,bias=BIAS),
|
80 |
+
nn.AvgPool2d(kernel_size=3) #(1*1*10)
|
81 |
+
)
|
82 |
+
|
83 |
+
|
84 |
+
def forward(self,x):
|
85 |
+
x = self.trans1( self.conv1(x) )
|
86 |
+
x = self.trans2( self.conv2(x) )
|
87 |
+
x = self.trans3( self.conv3(x) )
|
88 |
+
x = self.out4(x)
|
89 |
+
x = x.view(-1,10)
|
90 |
+
return F.log_softmax(x,dim=1)
|
91 |
+
|
92 |
+
|
93 |
+
|
model/model.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any,List,Tuple,Dict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from torchvision.utils import make_grid
|
7 |
+
from torch.optim import Optimizer,Adam,SGD
|
8 |
+
from lightning import LightningModule
|
9 |
+
from torchmetrics import Accuracy,F1Score,AUROC,ConfusionMatrix
|
10 |
+
|
11 |
+
|
12 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
13 |
+
torch.set_default_device( device= device )
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
from .mnist_model import Net
|
18 |
+
|
19 |
+
__all__: List[str] = ["LitMNISTModel"]
|
20 |
+
|
21 |
+
class LitMNISTModel(LightningModule):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
learning_rate:float = 3e-4,
|
25 |
+
num_classes:int = 10,
|
26 |
+
dropout_rate:float=0.01,
|
27 |
+
bias:bool=False,
|
28 |
+
momentum:float =.9,
|
29 |
+
*args: Any,
|
30 |
+
**kwargs: Any
|
31 |
+
) -> None:
|
32 |
+
super().__init__()
|
33 |
+
self.save_hyperparameters()
|
34 |
+
|
35 |
+
self.learning_rate:float = learning_rate
|
36 |
+
self.num_class:int = num_classes
|
37 |
+
self.momentum:float = momentum
|
38 |
+
|
39 |
+
# metric
|
40 |
+
## Accuracy
|
41 |
+
self.train_accuracy = Accuracy(task="multiclass", num_classes=num_classes)
|
42 |
+
self.val_accuracy = Accuracy(task="multiclass", num_classes=num_classes)
|
43 |
+
self.test_accuracy = Accuracy(task="multiclass", num_classes=num_classes)
|
44 |
+
|
45 |
+
## F1 Score
|
46 |
+
self.train_f1 = F1Score(task="multiclass", num_classes=num_classes)
|
47 |
+
self.val_f1 = F1Score(task="multiclass", num_classes=num_classes)
|
48 |
+
self.test_f1 = F1Score(task="multiclass", num_classes=num_classes)
|
49 |
+
|
50 |
+
## Model
|
51 |
+
self.model = Net(config={'dropout_rate':dropout_rate, 'bias':bias})
|
52 |
+
|
53 |
+
|
54 |
+
def forward(self, x) -> Any:
|
55 |
+
return self.model(x)
|
56 |
+
|
57 |
+
|
58 |
+
def training_step(self, batch,batch_idx, *args: Any, **kwargs: Any) -> torch.Tensor:
|
59 |
+
x,y = batch
|
60 |
+
logits = self(x)
|
61 |
+
loss = F.nll_loss(logits,y)
|
62 |
+
preds = torch.argmax(logits,dim=1)
|
63 |
+
acc = self.train_accuracy(preds,y)
|
64 |
+
f1 = self.train_f1(preds,y)
|
65 |
+
|
66 |
+
self.log("train/loss",loss,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger)
|
67 |
+
self.log("train/acc",acc,prog_bar=True,on_epoch=False,on_step=True,logger=self.trainer.logger)
|
68 |
+
self.log("train/train_f1",f1,prog_bar=True,on_epoch=False,on_step=True,logger=self.trainer.logger)
|
69 |
+
|
70 |
+
if batch_idx==0:
|
71 |
+
grid = make_grid(x)
|
72 |
+
self.logger.experiment.add_image("train_imgs",grid,self.current_epoch)
|
73 |
+
|
74 |
+
return {
|
75 |
+
'loss':loss,
|
76 |
+
'logits':logits,
|
77 |
+
'preds':preds
|
78 |
+
}
|
79 |
+
|
80 |
+
|
81 |
+
def validation_step(self,batch,batch_idx, *args: Any, **kwargs: Any) -> torch.Tensor :
|
82 |
+
x,y = batch
|
83 |
+
logits = self(x)
|
84 |
+
loss = F.nll_loss(logits,y)
|
85 |
+
preds = torch.argmax(logits,dim=1)
|
86 |
+
acc = self.val_accuracy(preds,y)
|
87 |
+
f1 = self.val_f1(preds,y)
|
88 |
+
|
89 |
+
self.log("val/loss",loss,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger)
|
90 |
+
self.log("val/acc",acc,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger)
|
91 |
+
self.log("val/val_f1",f1,prog_bar=True,on_epoch=True,on_step=False,logger=self.trainer.logger)
|
92 |
+
|
93 |
+
if batch_idx==0:
|
94 |
+
grid = make_grid(x)
|
95 |
+
self.logger.experiment.add_image("val_imgs",grid,self.current_epoch)
|
96 |
+
|
97 |
+
return {
|
98 |
+
'loss':loss,
|
99 |
+
'logits':logits,
|
100 |
+
'preds':preds
|
101 |
+
}
|
102 |
+
|
103 |
+
def predict_step(self,x:torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
104 |
+
with torch.no_grad():
|
105 |
+
logits = self(x)
|
106 |
+
probs,indices = torch.max( F.softmax(logits,dim=1), dim=1)
|
107 |
+
return {
|
108 |
+
'prob':probs,
|
109 |
+
'predict':indices
|
110 |
+
}
|
111 |
+
|
112 |
+
|
113 |
+
def test_step(self,batch):
|
114 |
+
x,y = batch
|
115 |
+
logits = self(x)
|
116 |
+
loss = F.nll_loss(logits,y)
|
117 |
+
preds = torch.argmax(logits,dim=1)
|
118 |
+
acc = self.test_accuracy(preds,y)
|
119 |
+
f1 = self.test_f1(preds,y)
|
120 |
+
|
121 |
+
self.log("test/loss",loss,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger)
|
122 |
+
self.log("test/acc",acc,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger)
|
123 |
+
self.log("test/test_f1",f1,prog_bar=True,on_epoch=True,on_step=False,logger=self.trainer.logger)
|
124 |
+
|
125 |
+
|
126 |
+
return {
|
127 |
+
'loss':loss,
|
128 |
+
'logits':logits,
|
129 |
+
'preds':preds
|
130 |
+
}
|
131 |
+
|
132 |
+
def configure_optimizers(self):
|
133 |
+
# optimizer = SGD(self.parameters(),lr=self.learning_rate,momentum=self.momentum)
|
134 |
+
# Reduce LR ON Plateau
|
135 |
+
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,factor=.1,patience=2,verbose=True)
|
136 |
+
# return {
|
137 |
+
# "optimizer": optimizer,
|
138 |
+
# "lr_scheduler": scheduler,
|
139 |
+
# "monitor": 'val/loss',
|
140 |
+
# 'interval':'step',
|
141 |
+
# "frequency": 15
|
142 |
+
# }
|
143 |
+
optimizer = Adam(self.parameters(),lr=1e3)
|
144 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
145 |
+
optimizer=optimizer,
|
146 |
+
max_lr=1e2*self.learning_rate,
|
147 |
+
total_steps=self.trainer.estimated_stepping_batches,
|
148 |
+
pct_start=.3,
|
149 |
+
cycle_momentum=True,
|
150 |
+
div_factor =100,
|
151 |
+
final_div_factor = 1e10,
|
152 |
+
verbose = False,
|
153 |
+
three_phase=True
|
154 |
+
)
|
155 |
+
return ([optimizer],[scheduler])
|