KingNish's picture
Upload ./RepCodec/repcodec/modules/quantizer.py with huggingface_hub
b3b0dda verified
raw
history blame
1.3 kB
# Copyright (c) ByteDance, Inc. and its affiliates.
# Copyright (c) Chutong Meng
#
# This source code is licensed under the CC BY-NC license found in the
# LICENSE file in the root directory of this source tree.
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
import torch.nn as nn
from repcodec.layers.vq_module import ResidualVQ
class Quantizer(nn.Module):
def __init__(
self,
code_dim: int,
codebook_num: int,
codebook_size: int,
):
super().__init__()
self.codebook = ResidualVQ(
dim=code_dim,
num_quantizers=codebook_num,
codebook_size=codebook_size
)
def initial(self):
self.codebook.initial()
def forward(self, z):
zq, vqloss, perplexity = self.codebook(z.transpose(2, 1))
zq = zq.transpose(2, 1)
return zq, vqloss, perplexity
def inference(self, z):
zq, indices = self.codebook.forward_index(z.transpose(2, 1))
zq = zq.transpose(2, 1)
return zq, indices
def encode(self, z):
zq, indices = self.codebook.forward_index(z.transpose(2, 1), flatten_idx=True)
return zq, indices
def decode(self, indices):
z = self.codebook.lookup(indices)
return z