Upload ./RepCodec/repcodec/modules/quantizer.py with huggingface_hub
Browse files
RepCodec/repcodec/modules/quantizer.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the CC BY-NC license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
|
7 |
+
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from repcodec.layers.vq_module import ResidualVQ
|
11 |
+
|
12 |
+
|
13 |
+
class Quantizer(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
code_dim: int,
|
17 |
+
codebook_num: int,
|
18 |
+
codebook_size: int,
|
19 |
+
):
|
20 |
+
super().__init__()
|
21 |
+
self.codebook = ResidualVQ(
|
22 |
+
dim=code_dim,
|
23 |
+
num_quantizers=codebook_num,
|
24 |
+
codebook_size=codebook_size
|
25 |
+
)
|
26 |
+
|
27 |
+
def initial(self):
|
28 |
+
self.codebook.initial()
|
29 |
+
|
30 |
+
def forward(self, z):
|
31 |
+
zq, vqloss, perplexity = self.codebook(z.transpose(2, 1))
|
32 |
+
zq = zq.transpose(2, 1)
|
33 |
+
return zq, vqloss, perplexity
|
34 |
+
|
35 |
+
def inference(self, z):
|
36 |
+
zq, indices = self.codebook.forward_index(z.transpose(2, 1))
|
37 |
+
zq = zq.transpose(2, 1)
|
38 |
+
return zq, indices
|
39 |
+
|
40 |
+
def encode(self, z):
|
41 |
+
zq, indices = self.codebook.forward_index(z.transpose(2, 1), flatten_idx=True)
|
42 |
+
return zq, indices
|
43 |
+
|
44 |
+
def decode(self, indices):
|
45 |
+
z = self.codebook.lookup(indices)
|
46 |
+
return z
|