# Copyright (c) ByteDance, Inc. and its affiliates. | |
# Copyright (c) Chutong Meng | |
# | |
# This source code is licensed under the CC BY-NC license found in the | |
# LICENSE file in the root directory of this source tree. | |
# Based on AudioDec (https://github.com/facebookresearch/AudioDec) | |
import torch.nn as nn | |
from repcodec.layers.vq_module import ResidualVQ | |
class Quantizer(nn.Module): | |
def __init__( | |
self, | |
code_dim: int, | |
codebook_num: int, | |
codebook_size: int, | |
): | |
super().__init__() | |
self.codebook = ResidualVQ( | |
dim=code_dim, | |
num_quantizers=codebook_num, | |
codebook_size=codebook_size | |
) | |
def initial(self): | |
self.codebook.initial() | |
def forward(self, z): | |
zq, vqloss, perplexity = self.codebook(z.transpose(2, 1)) | |
zq = zq.transpose(2, 1) | |
return zq, vqloss, perplexity | |
def inference(self, z): | |
zq, indices = self.codebook.forward_index(z.transpose(2, 1)) | |
zq = zq.transpose(2, 1) | |
return zq, indices | |
def encode(self, z): | |
zq, indices = self.codebook.forward_index(z.transpose(2, 1), flatten_idx=True) | |
return zq, indices | |
def decode(self, indices): | |
z = self.codebook.lookup(indices) | |
return z | |