KingNish commited on
Commit
108690a
·
verified ·
1 Parent(s): 0fa0976

Upload ./RepCodec/repcodec/RepCodec.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. RepCodec/repcodec/RepCodec.py +84 -0
RepCodec/repcodec/RepCodec.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import torch.nn as nn
9
+
10
+ from repcodec.modules.decoder import Decoder
11
+ from repcodec.modules.encoder import Encoder
12
+ from repcodec.modules.projector import Projector
13
+ from repcodec.modules.quantizer import Quantizer
14
+
15
+
16
+ class RepCodec(nn.Module):
17
+ def __init__(
18
+ self,
19
+ input_channels=768,
20
+ output_channels=768,
21
+ encode_channels=768,
22
+ decode_channels=768,
23
+ code_dim=768,
24
+ codebook_num=1,
25
+ codebook_size=1024,
26
+ bias=True,
27
+ enc_ratios=(1, 1),
28
+ dec_ratios=(1, 1),
29
+ enc_strides=(1, 1),
30
+ dec_strides=(1, 1),
31
+ enc_kernel_size=3,
32
+ dec_kernel_size=3,
33
+ enc_block_dilations=(1, 1),
34
+ enc_block_kernel_size=3,
35
+ dec_block_dilations=(1, 1),
36
+ dec_block_kernel_size=3
37
+ ):
38
+ super().__init__()
39
+
40
+ self.input_channels = input_channels
41
+
42
+ self.encoder = Encoder(
43
+ input_channels=input_channels,
44
+ encode_channels=encode_channels,
45
+ channel_ratios=enc_ratios,
46
+ strides=enc_strides,
47
+ kernel_size=enc_kernel_size,
48
+ bias=bias,
49
+ block_dilations=enc_block_dilations,
50
+ unit_kernel_size=enc_block_kernel_size
51
+ )
52
+
53
+ self.decoder = Decoder(
54
+ code_dim=code_dim,
55
+ output_channels=output_channels,
56
+ decode_channels=decode_channels,
57
+ channel_ratios=dec_ratios,
58
+ strides=dec_strides,
59
+ kernel_size=dec_kernel_size,
60
+ bias=bias,
61
+ block_dilations=dec_block_dilations,
62
+ unit_kernel_size=dec_block_kernel_size
63
+ )
64
+
65
+ self.projector = Projector(
66
+ input_channels=self.encoder.out_channels,
67
+ code_dim=code_dim,
68
+ kernel_size=3,
69
+ stride=1,
70
+ bias=False
71
+ )
72
+
73
+ self.quantizer = Quantizer(
74
+ code_dim=code_dim,
75
+ codebook_num=codebook_num,
76
+ codebook_size=codebook_size
77
+ )
78
+
79
+ def forward(self, x):
80
+ x = self.encoder(x)
81
+ z = self.projector(x)
82
+ zq, vqloss, perplexity = self.quantizer(z)
83
+ y = self.decoder(zq)
84
+ return y, zq, z, vqloss, perplexity