Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from typing import Any,List,Tuple,Dict | |
class Net(nn.Module): | |
def __init__(self,config:Dict): | |
super(Net,self).__init__() | |
DROPOUT= config.get('dropout_rate',0.01) | |
BIAS = config.get('bias',False) | |
self.conv1 = nn.Sequential( | |
nn.Conv2d(in_channels=1,out_channels=8,kernel_size=3,stride=1,padding=1,bias=BIAS), | |
nn.ReLU(), | |
nn.BatchNorm2d(8), | |
nn.Dropout2d(p=DROPOUT), | |
nn.Conv2d(in_channels=8,out_channels=10,kernel_size=3,stride=1,padding=1,bias=BIAS), | |
nn.ReLU(), | |
nn.BatchNorm2d(10), | |
nn.Dropout2d(p=DROPOUT), | |
nn.Conv2d(in_channels=10,out_channels=10,kernel_size=3,stride=1,padding=1,bias=BIAS), | |
nn.ReLU(), | |
nn.BatchNorm2d(10), | |
nn.Dropout2d(p=DROPOUT), | |
) | |
self.trans1 = nn.Sequential( | |
nn.MaxPool2d( kernel_size =2 , stride =2 , padding =1 ), | |
nn.Conv2d(in_channels=10,out_channels=8,kernel_size=1,bias=BIAS,padding=1), | |
) | |
self.conv2 =nn.Sequential( | |
nn.Conv2d(in_channels=8,out_channels=10,kernel_size=3,stride=1,padding=1,bias=BIAS), | |
nn.BatchNorm2d(10), | |
nn.ReLU(), | |
nn.Dropout2d(p=DROPOUT), | |
nn.Conv2d(in_channels=10,out_channels=12,kernel_size=3,stride=1,padding=1,bias=BIAS), | |
nn.BatchNorm2d(12), | |
nn.ReLU(), | |
nn.Dropout2d(p=DROPOUT), | |
nn.Conv2d(in_channels=12,out_channels=12,kernel_size=3,stride=1,padding=1,bias=BIAS), | |
nn.BatchNorm2d(12), | |
nn.ReLU(), | |
nn.Dropout2d(p=DROPOUT), | |
) | |
self.trans2 = nn.Sequential( | |
nn.MaxPool2d( kernel_size =2 , stride =2 , padding =1 ), | |
nn.Conv2d(in_channels=12,out_channels=8,kernel_size=1,bias=BIAS), | |
nn.BatchNorm2d(8), | |
) | |
self.conv3 = nn.Sequential( | |
nn.Conv2d(in_channels=8,out_channels=10,kernel_size=3,stride=1,padding=1,bias=BIAS), | |
nn.BatchNorm2d(10), | |
nn.ReLU(), | |
nn.Dropout2d(p=DROPOUT), | |
nn.Conv2d(in_channels=10,out_channels=12,kernel_size=3,stride=1,padding=1,bias=BIAS), | |
nn.ReLU(), | |
nn.BatchNorm2d(12), | |
nn.Dropout2d(p=DROPOUT), | |
) | |
self.trans3 = nn.Sequential( | |
nn.Conv2d(in_channels=12,out_channels=10,kernel_size=1,bias=BIAS), | |
nn.MaxPool2d( kernel_size =2 , stride =2 , padding =0 ), | |
nn.BatchNorm2d(10), | |
) | |
self.out4 = nn.Sequential( | |
nn.Conv2d(in_channels=10 ,out_channels=10, kernel_size=3,stride=1,padding=1,bias=BIAS), | |
nn.AvgPool2d(kernel_size=3) #(1*1*10) | |
) | |
def forward(self,x): | |
x = self.trans1( self.conv1(x) ) | |
x = self.trans2( self.conv2(x) ) | |
x = self.trans3( self.conv3(x) ) | |
x = self.out4(x) | |
x = x.view(-1,10) | |
return F.log_softmax(x,dim=1) | |