Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the Apache License, Version 2.0 | |
# found in the LICENSE file in the root directory of this source tree. | |
import torch | |
import torch.nn as nn | |
import torchvision | |
class RNet(nn.Module): | |
def __init__( | |
self, | |
n_channels=3, | |
n_classes=13, | |
n_pix=256, | |
filters=(8, 16, 32, 64, 64, 128), | |
pool=(2, 2), | |
kernel_size=(3, 3), | |
n_meta=0, | |
) -> None: | |
super(RNet, self).__init__() | |
def conv_block(in_filters, out_filters, kernel_size): | |
layers = nn.Sequential( | |
# first conv is across channels, size=1 | |
nn.Conv2d(in_filters, out_filters, kernel_size=(1, 1), padding="same"), | |
nn.BatchNorm2d(out_filters), | |
nn.ReLU(), | |
nn.Conv2d( | |
out_filters, out_filters, kernel_size=kernel_size, padding="same" | |
), | |
) | |
return layers | |
def fc_block(in_features, out_features): | |
layers = nn.Sequential( | |
nn.Linear(in_features=in_features, out_features=out_features), | |
#nn.BatchNorm1d(out_features), | |
#nn.InstanceNorm1d(out_features), | |
nn.ReLU(), | |
) | |
return layers | |
self.pool = nn.MaxPool2d(2, 2) | |
self.input_layer = conv_block(n_channels, filters[0], kernel_size) | |
self.conv_block1 = conv_block(filters[0], filters[1], kernel_size) | |
self.conv_block2 = conv_block(filters[1], filters[2], kernel_size) | |
self.conv_block3 = conv_block(filters[2], filters[3], kernel_size) | |
self.conv_block4 = conv_block(filters[3], filters[4], kernel_size) | |
self.conv_block5 = conv_block(filters[4], filters[5], kernel_size) | |
n_pool = 5 | |
self.fc1 = fc_block(in_features= int(filters[5] * (n_pix / 2**n_pool) ** 2), out_features=64) | |
self.fc2 = fc_block(in_features=64 + n_meta, out_features=64) | |
self.fc3 = fc_block(in_features=64, out_features=32) | |
self.fc4 = nn.Linear(in_features=32, out_features=n_classes) | |
def forward(self, x): | |
x1 = self.pool(self.input_layer(x)) | |
x2 = self.pool(self.conv_block1(x1)) | |
x3 = self.pool(self.conv_block2(x2)) | |
x4 = self.pool(self.conv_block3(x3)) | |
x4b = self.pool(self.conv_block4(x4)) | |
x5 = self.conv_block5(x4b) | |
x6 = torch.flatten(x5, 1) # flatten all dimensions except batch | |
x7 = self.fc1(x6) | |
x9 = self.fc2(x7) | |
x10 = self.fc3(x9) | |
x11 = self.fc4(x10) | |
return x11 |