KingNish commited on
Commit
6e7d2eb
·
verified ·
1 Parent(s): 81df686

Upload ./RepCodec/repcodec/modules/encoder.py with huggingface_hub

Browse files
RepCodec/repcodec/modules/encoder.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from RepCodec.repcodec.layers.conv_layer import Conv1d
12
+ from RepCodec.repcodec.modules.residual_unit import ResidualUnit
13
+
14
+
15
+ class EncoderBlock(nn.Module):
16
+ def __init__(
17
+ self,
18
+ in_channels: int,
19
+ out_channels: int,
20
+ stride: int,
21
+ dilations=(1, 1),
22
+ unit_kernel_size=3,
23
+ bias=True
24
+ ):
25
+ super().__init__()
26
+ self.res_units = torch.nn.ModuleList()
27
+ for dilation in dilations:
28
+ self.res_units += [
29
+ ResidualUnit(in_channels, in_channels,
30
+ kernel_size=unit_kernel_size,
31
+ dilation=dilation)
32
+ ]
33
+ self.num_res = len(self.res_units)
34
+
35
+ self.conv = Conv1d(
36
+ in_channels=in_channels,
37
+ out_channels=out_channels,
38
+ kernel_size=3 if stride == 1 else (2 * stride), # special case: stride=1, do not use kernel=2
39
+ stride=stride,
40
+ bias=bias,
41
+ )
42
+
43
+ def forward(self, x):
44
+ for idx in range(self.num_res):
45
+ x = self.res_units[idx](x)
46
+ x = self.conv(x)
47
+ return x
48
+
49
+
50
+ class Encoder(nn.Module):
51
+ def __init__(
52
+ self,
53
+ input_channels: int,
54
+ encode_channels: int,
55
+ channel_ratios=(1, 1),
56
+ strides=(1, 1),
57
+ kernel_size=3,
58
+ bias=True,
59
+ block_dilations=(1, 1),
60
+ unit_kernel_size=3
61
+ ):
62
+ super().__init__()
63
+ assert len(channel_ratios) == len(strides)
64
+
65
+ self.conv = Conv1d(
66
+ in_channels=input_channels,
67
+ out_channels=encode_channels,
68
+ kernel_size=kernel_size,
69
+ stride=1,
70
+ bias=False
71
+ )
72
+ self.conv_blocks = torch.nn.ModuleList()
73
+ in_channels = encode_channels
74
+ for idx, stride in enumerate(strides):
75
+ out_channels = int(encode_channels * channel_ratios[idx]) # could be float
76
+ self.conv_blocks += [
77
+ EncoderBlock(in_channels, out_channels, stride,
78
+ dilations=block_dilations, unit_kernel_size=unit_kernel_size,
79
+ bias=bias)
80
+ ]
81
+ in_channels = out_channels
82
+ self.num_blocks = len(self.conv_blocks)
83
+ self.out_channels = out_channels
84
+
85
+ def forward(self, x):
86
+ x = self.conv(x)
87
+ for i in range(self.num_blocks):
88
+ x = self.conv_blocks[i](x)
89
+ return x