File size: 2,642 Bytes
c5d0550
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
import torchvision

class RNet(nn.Module):
    def __init__(
        self,
        n_channels=3,
        n_classes=13,
        n_pix=256,
        filters=(8, 16, 32, 64, 64, 128),
        pool=(2, 2),
        kernel_size=(3, 3),
        n_meta=0,
    ) -> None:
        super(RNet, self).__init__()

        def conv_block(in_filters, out_filters, kernel_size):
            layers = nn.Sequential(
                # first conv is across channels, size=1
                nn.Conv2d(in_filters, out_filters, kernel_size=(1, 1), padding="same"),
                nn.BatchNorm2d(out_filters),
                nn.ReLU(),
                nn.Conv2d(
                    out_filters, out_filters, kernel_size=kernel_size, padding="same"
                ),
            )
            return layers

        def fc_block(in_features, out_features):
            layers = nn.Sequential(
                nn.Linear(in_features=in_features, out_features=out_features),
                #nn.BatchNorm1d(out_features),
                #nn.InstanceNorm1d(out_features),
                nn.ReLU(),
            )
            return layers

        self.pool = nn.MaxPool2d(2, 2)
        self.input_layer = conv_block(n_channels, filters[0], kernel_size)
        self.conv_block1 = conv_block(filters[0], filters[1], kernel_size)
        self.conv_block2 = conv_block(filters[1], filters[2], kernel_size)
        self.conv_block3 = conv_block(filters[2], filters[3], kernel_size)
        self.conv_block4 = conv_block(filters[3], filters[4], kernel_size)
        self.conv_block5 = conv_block(filters[4], filters[5], kernel_size)
        n_pool = 5
        self.fc1 = fc_block(in_features= int(filters[5] * (n_pix / 2**n_pool) ** 2), out_features=64)
        self.fc2 = fc_block(in_features=64 + n_meta, out_features=64)
        self.fc3 = fc_block(in_features=64, out_features=32)
        self.fc4 = nn.Linear(in_features=32, out_features=n_classes)

    def forward(self, x):
        x1 = self.pool(self.input_layer(x))
        x2 = self.pool(self.conv_block1(x1))
        x3 = self.pool(self.conv_block2(x2))
        x4 = self.pool(self.conv_block3(x3))
        x4b = self.pool(self.conv_block4(x4))
        x5 = self.conv_block5(x4b)
        x6 = torch.flatten(x5, 1)  # flatten all dimensions except batch
        x7 = self.fc1(x6)
        x9 = self.fc2(x7)
        x10 = self.fc3(x9)
        x11 = self.fc4(x10)
        return x11