KingNish commited on
Commit
b3b0dda
·
verified ·
1 Parent(s): 6e7d2eb

Upload ./RepCodec/repcodec/modules/quantizer.py with huggingface_hub

Browse files
RepCodec/repcodec/modules/quantizer.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import torch.nn as nn
9
+
10
+ from repcodec.layers.vq_module import ResidualVQ
11
+
12
+
13
+ class Quantizer(nn.Module):
14
+ def __init__(
15
+ self,
16
+ code_dim: int,
17
+ codebook_num: int,
18
+ codebook_size: int,
19
+ ):
20
+ super().__init__()
21
+ self.codebook = ResidualVQ(
22
+ dim=code_dim,
23
+ num_quantizers=codebook_num,
24
+ codebook_size=codebook_size
25
+ )
26
+
27
+ def initial(self):
28
+ self.codebook.initial()
29
+
30
+ def forward(self, z):
31
+ zq, vqloss, perplexity = self.codebook(z.transpose(2, 1))
32
+ zq = zq.transpose(2, 1)
33
+ return zq, vqloss, perplexity
34
+
35
+ def inference(self, z):
36
+ zq, indices = self.codebook.forward_index(z.transpose(2, 1))
37
+ zq = zq.transpose(2, 1)
38
+ return zq, indices
39
+
40
+ def encode(self, z):
41
+ zq, indices = self.codebook.forward_index(z.transpose(2, 1), flatten_idx=True)
42
+ return zq, indices
43
+
44
+ def decode(self, indices):
45
+ z = self.codebook.lookup(indices)
46
+ return z