File size: 1,302 Bytes
b3b0dda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
# Copyright (c) ByteDance, Inc. and its affiliates.
# Copyright (c) Chutong Meng
#
# This source code is licensed under the CC BY-NC license found in the
# LICENSE file in the root directory of this source tree.
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
import torch.nn as nn
from repcodec.layers.vq_module import ResidualVQ
class Quantizer(nn.Module):
def __init__(
self,
code_dim: int,
codebook_num: int,
codebook_size: int,
):
super().__init__()
self.codebook = ResidualVQ(
dim=code_dim,
num_quantizers=codebook_num,
codebook_size=codebook_size
)
def initial(self):
self.codebook.initial()
def forward(self, z):
zq, vqloss, perplexity = self.codebook(z.transpose(2, 1))
zq = zq.transpose(2, 1)
return zq, vqloss, perplexity
def inference(self, z):
zq, indices = self.codebook.forward_index(z.transpose(2, 1))
zq = zq.transpose(2, 1)
return zq, indices
def encode(self, z):
zq, indices = self.codebook.forward_index(z.transpose(2, 1), flatten_idx=True)
return zq, indices
def decode(self, indices):
z = self.codebook.lookup(indices)
return z
|