biomass / models /regressor.py
juansensio's picture
Upload 5 files
c5d0550 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torchvision
class RNet(nn.Module):
def __init__(
self,
n_channels=3,
n_classes=13,
n_pix=256,
filters=(8, 16, 32, 64, 64, 128),
pool=(2, 2),
kernel_size=(3, 3),
n_meta=0,
) -> None:
super(RNet, self).__init__()
def conv_block(in_filters, out_filters, kernel_size):
layers = nn.Sequential(
# first conv is across channels, size=1
nn.Conv2d(in_filters, out_filters, kernel_size=(1, 1), padding="same"),
nn.BatchNorm2d(out_filters),
nn.ReLU(),
nn.Conv2d(
out_filters, out_filters, kernel_size=kernel_size, padding="same"
),
)
return layers
def fc_block(in_features, out_features):
layers = nn.Sequential(
nn.Linear(in_features=in_features, out_features=out_features),
#nn.BatchNorm1d(out_features),
#nn.InstanceNorm1d(out_features),
nn.ReLU(),
)
return layers
self.pool = nn.MaxPool2d(2, 2)
self.input_layer = conv_block(n_channels, filters[0], kernel_size)
self.conv_block1 = conv_block(filters[0], filters[1], kernel_size)
self.conv_block2 = conv_block(filters[1], filters[2], kernel_size)
self.conv_block3 = conv_block(filters[2], filters[3], kernel_size)
self.conv_block4 = conv_block(filters[3], filters[4], kernel_size)
self.conv_block5 = conv_block(filters[4], filters[5], kernel_size)
n_pool = 5
self.fc1 = fc_block(in_features= int(filters[5] * (n_pix / 2**n_pool) ** 2), out_features=64)
self.fc2 = fc_block(in_features=64 + n_meta, out_features=64)
self.fc3 = fc_block(in_features=64, out_features=32)
self.fc4 = nn.Linear(in_features=32, out_features=n_classes)
def forward(self, x):
x1 = self.pool(self.input_layer(x))
x2 = self.pool(self.conv_block1(x1))
x3 = self.pool(self.conv_block2(x2))
x4 = self.pool(self.conv_block3(x3))
x4b = self.pool(self.conv_block4(x4))
x5 = self.conv_block5(x4b)
x6 = torch.flatten(x5, 1) # flatten all dimensions except batch
x7 = self.fc1(x6)
x9 = self.fc2(x7)
x10 = self.fc3(x9)
x11 = self.fc4(x10)
return x11