# Copyright (c) 2025 SparkAudio # 2025 Xinsheng Wang (w.xinshawn@gmail.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Heavily based on https://github.com/lucidrains/vector-quantize-pytorch from typing import Any, Dict import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torch.nn.utils import weight_norm def WNConv1d(*args, **kwargs): return weight_norm(nn.Conv1d(*args, **kwargs)) def ema_inplace(moving_avg, new, decay): moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) class FactorizedVectorQuantize(nn.Module): def __init__( self, input_dim: int, codebook_size: int, codebook_dim: int, commitment: float, codebook_loss_weight: float = 1.0, decay: float = 0.99, threshold_ema_dead_code: float = 2, momentum: float = 0.99, **kwargs, ): super().__init__() self.input_dim = input_dim self.codebook_size = codebook_size self.codebook_dim = codebook_dim self.commitment = commitment self.codebook_loss_weight = codebook_loss_weight self.decay = decay self.threshold_ema_dead_code = threshold_ema_dead_code self.momentum = momentum if input_dim != self.codebook_dim: self.in_project = WNConv1d(input_dim, self.codebook_dim, kernel_size=1) self.out_project = WNConv1d(self.codebook_dim, input_dim, kernel_size=1) else: self.in_project = nn.Identity() self.out_project = nn.Identity() self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim) self.register_buffer("cluster_size", torch.zeros(self.codebook_size)) def forward(self, z: torch.Tensor) -> Dict[str, Any]: """Quantized the input tensor using a fixed codebook and returns the corresponding codebook vectors Parameters ---------- z : Tensor[B x D x T] Returns ------- Tensor[B x D x T] Quantized continuous representation of input Tensor[1] Commitment loss to train encoder to predict vectors closer to codebook entries Tensor[1] Codebook loss to update the codebook Tensor[B x T] Codebook indices (quantized discrete representation of input) Tensor[B x D x T] Projected latents (continuous representation of input before quantization) """ # transpose since we use linear # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim z_e = self.in_project(z) z_q, indices, dists = self.decode_latents(z_e) # statistic the usage of codes embed_onehot = F.one_hot(indices, self.codebook_size).type(z_e.dtype) avg_probs = torch.mean(embed_onehot.reshape(-1, self.codebook_size), dim=0) perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) active_num = (embed_onehot.sum(0).sum(0) > 0).sum() if self.training: # We do the expiry of code at that point as buffers are in sync # and all the workers will take the same decision. ema_inplace(self.cluster_size, embed_onehot.sum(0).sum(0), self.decay) active_num = sum(self.cluster_size > self.threshold_ema_dead_code) if self.training: commit_loss = ( F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) * self.commitment ) codebook_loss = ( F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) * self.codebook_loss_weight ) else: commit_loss = torch.zeros(0, device=z.device) codebook_loss = torch.zeros(0, device=z.device) z_q = ( z_e + (z_q - z_e).detach() ) # noop in forward pass, straight-through gradient estimator in backward pass z_q = self.out_project(z_q) vq_loss = (commit_loss + codebook_loss).mean() return { "z_q": z_q, "indices": indices, "dists": dists, "vq_loss": vq_loss, "perplexity": perplexity, "active_num": active_num.float(), } def vq2emb(self, vq, out_proj=True): emb = self.embed_code(vq) if out_proj: emb = self.out_project(emb) return emb def tokenize(self, z: torch.Tensor) -> torch.Tensor: """tokenize the input tensor""" z_e = self.in_project(z) _, indices, _ = self.decode_latents(z_e) return indices def detokenize(self, indices): """detokenize the input indices""" z_q = self.decode_code(indices) z_q = self.out_project(z_q) return z_q def get_emb(self): return self.codebook.weight def embed_code(self, embed_id): return F.embedding(embed_id, self.codebook.weight) def decode_code(self, embed_id): return self.embed_code(embed_id).transpose(1, 2) def decode_latents(self, latents): encodings = rearrange(latents, "b d t -> (b t) d") codebook = self.codebook.weight # L2 normalize encodings and codebook encodings = F.normalize(encodings) codebook = F.normalize(codebook) # Compute euclidean distance between encodings and codebook, # with L2 normalization, the distance is equal to cosine distance dist = ( encodings.pow(2).sum(1, keepdim=True) - 2 * encodings @ codebook.t() + codebook.pow(2).sum(1, keepdim=True).t() ) indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) z_q = self.decode_code(indices) return z_q, indices, dist