File size: 5,703 Bytes
eb339cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from typing import Optional

import torch
import torch.nn as nn

from .utils import get_activation_fn


class ResConv1DBlock(nn.Module):
    def __init__(self, n_in: int, n_state: int, dilation: int = 1, activation: str = 'silu', dropout: float = 0.1,
                 norm: Optional[str] = None, norm_groups: int = 32, norm_eps: float = 1e-5) -> None:
        super(ResConv1DBlock, self).__init__()

        self.norm = norm
        if norm == "LN":
            self.norm1 = nn.LayerNorm(n_in, eps=norm_eps)
            self.norm2 = nn.LayerNorm(n_in, eps=norm_eps)
        elif norm == "GN":
            self.norm1 = nn.GroupNorm(num_groups=norm_groups, num_channels=n_in, eps=norm_eps)
            self.norm2 = nn.GroupNorm(num_groups=norm_groups, num_channels=n_in, eps=norm_eps)
        elif norm == "BN":
            self.norm1 = nn.BatchNorm1d(num_features=n_in, eps=norm_eps)
            self.norm2 = nn.BatchNorm1d(num_features=n_in, eps=norm_eps)
        else:
            self.norm1 = nn.Identity()
            self.norm2 = nn.Identity()

        self.activation = get_activation_fn(activation)

        self.conv1 = nn.Conv1d(n_in, n_state, 3, 1, padding=dilation, dilation=dilation)
        self.conv2 = nn.Conv1d(n_state, n_in, 1, 1, 0)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_orig = x
        if self.norm == "LN":
            x = self.norm1(x.transpose(-2, -1))
            x = self.activation(x.transpose(-2, -1))
        else:
            x = self.norm1(x)
            x = self.activation(x)

        x = self.conv1(x)

        if self.norm == "LN":
            x = self.norm2(x.transpose(-2, -1))
            x = self.activation(x.transpose(-2, -1))
        else:
            x = self.norm2(x)
            x = self.activation(x)

        x = self.conv2(x)
        x = self.dropout(x)
        x = x + x_orig
        return x


class Resnet1D(nn.Module):
    def __init__(self, n_in: int, n_state: int, n_depth: int, reverse_dilation: bool = True,
                 dilation_growth_rate: int = 3, activation: str = 'relu', dropout: float = 0.1,
                 norm: Optional[str] = None, norm_groups: int = 32, norm_eps: float = 1e-5) -> None:
        super(Resnet1D, self).__init__()
        blocks = [ResConv1DBlock(n_in, n_state, dilation=dilation_growth_rate ** depth, activation=activation,
                                 dropout=dropout, norm=norm, norm_groups=norm_groups, norm_eps=norm_eps)
                  for depth in range(n_depth)]
        if reverse_dilation:
            blocks = blocks[::-1]
        self.model = nn.Sequential(*blocks)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)


class ResEncoder(nn.Module):
    def __init__(self,
                 in_width: int = 263,
                 mid_width: int = 512,
                 out_width: int = 512,
                 down_t: int = 2,
                 stride_t: int = 2,
                 n_depth: int = 3,
                 dilation_growth_rate: int = 3,
                 activation: str = 'relu',
                 dropout: float = 0.1,
                 norm: Optional[str] = None,
                 norm_groups: int = 32,
                 norm_eps: float = 1e-5,
                 double_z: bool = False) -> None:
        super(ResEncoder, self).__init__()

        blocks = []
        filter_t, pad_t = stride_t * 2, stride_t // 2
        blocks.append(nn.Conv1d(in_width, mid_width, 3, 1, 1))
        blocks.append(get_activation_fn(activation))

        for i in range(down_t):
            block = nn.Sequential(
                nn.Conv1d(mid_width, mid_width, filter_t, stride_t, pad_t),
                Resnet1D(mid_width, mid_width, n_depth, reverse_dilation=True, dilation_growth_rate=dilation_growth_rate,
                         activation=activation, dropout=dropout, norm=norm, norm_groups=norm_groups, norm_eps=norm_eps))
            blocks.append(block)
        blocks.append(nn.Conv1d(mid_width, out_width * 2 if double_z else out_width, 3, 1, 1))
        self.model = nn.Sequential(*blocks)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x.permute(0, 2, 1))  # B x C x T


class ResDecoder(nn.Module):
    def __init__(self,
                 in_width: int = 263,
                 mid_width: int = 512,
                 out_width: int = 512,
                 down_t: int = 2,
                 stride_t: int = 2,
                 n_depth: int = 3,
                 dilation_growth_rate: int = 3,
                 activation: str = 'relu',
                 dropout: float = 0.1,
                 norm: Optional[str] = None,
                 norm_groups: int = 32,
                 norm_eps: float = 1e-5) -> None:
        super(ResDecoder, self).__init__()
        blocks = [nn.Conv1d(out_width, mid_width, 3, 1, 1), get_activation_fn(activation)]

        for i in range(down_t):
            block = nn.Sequential(
                Resnet1D(mid_width, mid_width, n_depth, reverse_dilation=True, dilation_growth_rate=dilation_growth_rate,
                         activation=activation, dropout=dropout, norm=norm, norm_groups=norm_groups, norm_eps=norm_eps),
                nn.Upsample(scale_factor=stride_t, mode='nearest'),
                nn.Conv1d(mid_width, mid_width, 3, 1, 1))
            blocks.append(block)
        blocks.append(nn.Conv1d(mid_width, mid_width, 3, 1, 1))
        blocks.append(get_activation_fn(activation))
        blocks.append(nn.Conv1d(mid_width, in_width, 3, 1, 1))
        self.model = nn.Sequential(*blocks)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x).permute(0, 2, 1)  # B x T x C