|
import torch |
|
import torch.nn as nn |
|
import torchvision.models as models |
|
|
|
|
|
|
|
class Net_BINARY(nn.Module): |
|
def __init__(self, n_classes: int) -> None: |
|
super(Net_BINARY, self).__init__() |
|
|
|
self.cnn_layers = nn.Sequential( |
|
|
|
nn.Conv2d(1, 32, kernel_size=4, stride=1), |
|
nn.PReLU(), |
|
nn.BatchNorm2d(32), |
|
nn.ReLU6(inplace=True), |
|
nn.AvgPool2d(kernel_size=3), |
|
torch.nn.Dropout(p=0.5, inplace=True), |
|
|
|
nn.Conv2d(32, 64, kernel_size=4, stride=1), |
|
nn.PReLU(), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU6(inplace=True), |
|
nn.AvgPool2d(kernel_size=3), |
|
torch.nn.Dropout(p=0.25, inplace=True), |
|
|
|
nn.Conv2d(64, 128, kernel_size=3, stride=1), |
|
nn.PReLU(), |
|
nn.BatchNorm2d(128), |
|
nn.Sigmoid(), |
|
nn.AvgPool2d(kernel_size=3), |
|
torch.nn.Dropout(p=0.125, inplace=True), |
|
) |
|
|
|
self.linear_layers = nn.Sequential( |
|
nn.Linear(1152, 312), |
|
nn.Linear(312, n_classes) |
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.cnn_layers(x) |
|
|
|
|
|
x = x.view(x.size(0), -1) |
|
x = self.linear_layers(x) |
|
return x |
|
|
|
|