File size: 4,348 Bytes
e72f2a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import numpy as np
import torch.nn as nn


class ConvBlock(nn.Module):
    def __init__(self, f, w, s, d, in_channels):
        super().__init__()
        p1 = d*(w - 1) // 2
        p2 = d*(w - 1) - p1
        self.pad = nn.ZeroPad2d((0, 0, p1, p2))

        self.conv2d = nn.Conv2d(in_channels=in_channels, out_channels=f, kernel_size=(w, 1), stride=(s, 1), dilation=(d, 1))
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm2d(f)
        self.pool = nn.MaxPool2d(kernel_size=(2, 1))
        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        x = self.pad(x)
        x = self.conv2d(x)
        x = self.relu(x)
        x = self.bn(x)
        x = self.pool(x)
        x = self.dropout(x)
        return x


class NoPadConvBlock(nn.Module):
    def __init__(self, f, w, s, d, in_channels):
        super().__init__()

        self.conv2d = nn.Conv2d(in_channels=in_channels, out_channels=f, kernel_size=(w, 1), stride=(s, 1),
                                dilation=(d, 1))
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm2d(f)
        self.pool = nn.MaxPool2d(kernel_size=(2, 1))
        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        x = self.conv2d(x)
        x = self.relu(x)
        x = self.bn(x)
        x = self.pool(x)
        x = self.dropout(x)
        return x


class TinyPathway(nn.Module):
    def __init__(self, dilation=1, hop=256, localize=False,
                 model_capacity="full", n_layers=6, chunk_size=256):
        super().__init__()

        capacity_multiplier = {
            'tiny': 4, 'small': 8, 'medium': 16, 'large': 24, 'full': 32
        }[model_capacity]
        self.layers = [1, 2, 3, 4, 5, 6]
        self.layers = self.layers[:n_layers]
        filters = [n * capacity_multiplier for n in [32, 8, 8, 8, 8, 8]]
        filters = [1] + filters
        widths = [512, 64, 64, 64, 32, 32]
        strides = self.deter_dilations(hop//(4*(2**n_layers)), localize=localize)
        strides[0] = strides[0]*4  # apply 4 times more stride at the first layer
        dilations = self.deter_dilations(dilation)

        for i in range(len(self.layers)):
            f, w, s, d, in_channel = filters[i + 1], widths[i], strides[i], dilations[i], filters[i]
            self.add_module("conv%d" % i, NoPadConvBlock(f, w, s, d, in_channel))
        self.chunk_size = chunk_size
        self.input_window, self.hop = self.find_input_size_for_pathway()
        self.out_dim = filters[n_layers]

    def find_input_size_for_pathway(self):
        def find_input_size(output_size, kernel_size, stride, dilation, padding):
            num = (stride*(output_size-1)) + 1
            input_size = num - 2*padding + dilation*(kernel_size-1)
            return input_size
        conv_calc, n = {}, 0
        for i in self.layers:
            layer = self.__getattr__("conv%d" % (i-1))
            for mm in layer.modules():
                if hasattr(mm, 'kernel_size'):
                    try:
                        d = mm.dilation[0]
                    except TypeError:
                        d = mm.dilation
                    conv_calc[n] = [mm.kernel_size[0], mm.stride[0], 0, d]
                    n += 1
        out = self.chunk_size
        hop = 1
        for n in sorted(conv_calc.keys())[::-1]:
            kernel_size_n, stride_n, padding_n, dilation_n = conv_calc[n]
            out = find_input_size(out, kernel_size_n, stride_n, dilation_n, padding_n)
            hop = hop*stride_n
        return out, hop

    def deter_dilations(self, total_dilation, localize=False):
        n_layers = len(self.layers)
        if localize:  # e.g., 32*1023 window and 3 layers -> [1, 1, 32]
            a = [total_dilation] + [1 for _ in range(n_layers-1)]
        else:  # e.g., 32*1023 window and 3 layers -> [4, 4, 2]
            total_dilation = int(np.log2(total_dilation))
            a = []
            for layer in range(n_layers):
                this_dilation = int(np.ceil(total_dilation/(n_layers-layer)))
                a.append(2**this_dilation)
                total_dilation = total_dilation - this_dilation
        return a[::-1]

    def forward(self, x):
        x = x.view(x.shape[0], 1, -1, 1)
        for i in range(len(self.layers)):
            x = self.__getattr__("conv%d" % i)(x)
        x = x.permute(0, 3, 2, 1)
        return x