|
import random |
|
import torch |
|
import torch.nn.functional as F |
|
import torch.distributed as dist |
|
|
|
from typing import List |
|
from torch import nn |
|
from torch.nn import Module |
|
from torch.amp import autocast |
|
from einx import get_at |
|
from einops import rearrange, reduce, pack, unpack |
|
|
|
from sparktts.modules.fsq.finite_scalar_quantization import FSQ |
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
|
|
def first(l): |
|
return l[0] |
|
|
|
|
|
def default(val, d): |
|
return val if exists(val) else d |
|
|
|
|
|
def round_up_multiple(num, mult): |
|
return ceil(num / mult) * mult |
|
|
|
|
|
|
|
|
|
|
|
def is_distributed(): |
|
return dist.is_initialized() and dist.get_world_size() > 1 |
|
|
|
|
|
def get_maybe_sync_seed(device, max_size=10_000): |
|
rand_int = torch.randint(0, max_size, (), device=device) |
|
|
|
if is_distributed(): |
|
dist.all_reduce(rand_int) |
|
|
|
return rand_int.item() |
|
|
|
|
|
class ResidualFSQ(Module): |
|
"""Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf""" |
|
|
|
def __init__( |
|
self, |
|
*, |
|
levels: List[int], |
|
num_quantizers, |
|
dim=None, |
|
is_channel_first=False, |
|
quantize_dropout=False, |
|
quantize_dropout_cutoff_index=0, |
|
quantize_dropout_multiple_of=1, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
codebook_dim = len(levels) |
|
dim = default(dim, codebook_dim) |
|
|
|
requires_projection = codebook_dim != dim |
|
self.project_in = ( |
|
nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity() |
|
) |
|
self.project_out = ( |
|
nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity() |
|
) |
|
self.has_projections = requires_projection |
|
|
|
self.is_channel_first = is_channel_first |
|
self.num_quantizers = num_quantizers |
|
|
|
self.levels = levels |
|
self.layers = nn.ModuleList([]) |
|
|
|
levels_tensor = torch.Tensor(levels) |
|
|
|
scales = [] |
|
|
|
for ind in range(num_quantizers): |
|
scales.append((levels_tensor - 1) ** -ind) |
|
|
|
fsq = FSQ(levels=levels, dim=codebook_dim, **kwargs) |
|
|
|
self.layers.append(fsq) |
|
|
|
assert all([not fsq.has_projections for fsq in self.layers]) |
|
|
|
self.codebook_size = self.layers[0].codebook_size |
|
|
|
self.register_buffer("scales", torch.stack(scales), persistent=False) |
|
|
|
self.quantize_dropout = quantize_dropout and num_quantizers > 1 |
|
|
|
assert quantize_dropout_cutoff_index >= 0 |
|
|
|
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index |
|
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of |
|
|
|
@property |
|
def codebooks(self): |
|
codebooks = [layer.implicit_codebook for layer in self.layers] |
|
codebooks = torch.stack(codebooks, dim=0) |
|
return codebooks |
|
|
|
def get_codes_from_indices(self, indices): |
|
|
|
batch, quantize_dim = indices.shape[0], indices.shape[-1] |
|
|
|
|
|
|
|
indices, ps = pack([indices], "b * q") |
|
|
|
|
|
|
|
|
|
if quantize_dim < self.num_quantizers: |
|
assert ( |
|
self.quantize_dropout > 0.0 |
|
), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations" |
|
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1) |
|
|
|
|
|
|
|
mask = indices == -1 |
|
indices = indices.masked_fill( |
|
mask, 0 |
|
) |
|
|
|
all_codes = get_at("q [c] d, b n q -> q b n d", self.codebooks, indices) |
|
|
|
|
|
|
|
all_codes = all_codes.masked_fill(rearrange(mask, "b n q -> q b n 1"), 0.0) |
|
|
|
|
|
|
|
scales = rearrange(self.scales, "q d -> q 1 1 d") |
|
all_codes = all_codes * scales |
|
|
|
|
|
|
|
(all_codes,) = unpack(all_codes, ps, "q b * d") |
|
|
|
return all_codes |
|
|
|
def get_output_from_indices(self, indices): |
|
codes = self.get_codes_from_indices(indices) |
|
codes_summed = reduce(codes, "q ... -> ...", "sum") |
|
return self.project_out(codes_summed) |
|
|
|
def forward(self, x, return_all_codes=False, rand_quantize_dropout_fixed_seed=None): |
|
num_quant, quant_dropout_multiple_of, device = ( |
|
self.num_quantizers, |
|
self.quantize_dropout_multiple_of, |
|
x.device, |
|
) |
|
|
|
|
|
|
|
if self.is_channel_first: |
|
x = rearrange(x, "b d ... -> b ... d") |
|
x, ps = pack([x], "b * d") |
|
|
|
|
|
|
|
x = self.project_in(x) |
|
|
|
quantized_out = 0.0 |
|
residual = x |
|
|
|
all_indices = [] |
|
|
|
should_quantize_dropout = self.training and self.quantize_dropout |
|
|
|
|
|
|
|
|
|
if should_quantize_dropout: |
|
|
|
|
|
|
|
if not exists(rand_quantize_dropout_fixed_seed): |
|
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device) |
|
|
|
rand = random.Random(rand_quantize_dropout_fixed_seed) |
|
|
|
rand_quantize_dropout_index = rand.randrange( |
|
self.quantize_dropout_cutoff_index, num_quant |
|
) |
|
|
|
if quant_dropout_multiple_of != 1: |
|
rand_quantize_dropout_index = ( |
|
round_up_multiple( |
|
rand_quantize_dropout_index + 1, quant_dropout_multiple_of |
|
) |
|
- 1 |
|
) |
|
|
|
null_indices = torch.full( |
|
x.shape[:2], -1.0, device=device, dtype=torch.long |
|
) |
|
|
|
|
|
|
|
with autocast("cuda", enabled=False): |
|
for quantizer_index, (layer, scale) in enumerate( |
|
zip(self.layers, self.scales) |
|
): |
|
|
|
if ( |
|
should_quantize_dropout |
|
and quantizer_index > rand_quantize_dropout_index |
|
): |
|
all_indices.append(null_indices) |
|
continue |
|
|
|
quantized, indices = layer(residual / scale) |
|
|
|
quantized = quantized * scale |
|
|
|
residual = residual - quantized.detach() |
|
quantized_out = quantized_out + quantized |
|
|
|
all_indices.append(indices) |
|
|
|
|
|
|
|
quantized_out = self.project_out(quantized_out) |
|
|
|
|
|
|
|
all_indices = torch.stack(all_indices, dim=-1) |
|
|
|
|
|
|
|
if self.is_channel_first: |
|
(quantized_out,) = unpack(quantized_out, ps, "b * d") |
|
(all_indices,) = unpack(all_indices, ps, "b * d") |
|
|
|
quantized_out = rearrange(quantized_out, "b ... d -> b d ...") |
|
all_indices = rearrange(all_indices, "b ... d -> b d ...") |
|
|
|
|
|
|
|
ret = (quantized_out, all_indices) |
|
|
|
if not return_all_codes: |
|
return ret |
|
|
|
|
|
|
|
all_codes = self.get_codes_from_indices(all_indices) |
|
|
|
|
|
|
|
return (*ret, all_codes) |
|
|
|
|
|
|
|
|
|
|
|
class GroupedResidualFSQ(Module): |
|
def __init__(self, *, dim, groups=1, accept_image_fmap=False, **kwargs): |
|
super().__init__() |
|
self.dim = dim |
|
self.groups = groups |
|
assert (dim % groups) == 0 |
|
dim_per_group = dim // groups |
|
|
|
self.accept_image_fmap = accept_image_fmap |
|
|
|
self.rvqs = nn.ModuleList([]) |
|
|
|
for _ in range(groups): |
|
self.rvqs.append(ResidualFSQ(dim=dim_per_group, **kwargs)) |
|
|
|
self.codebook_size = self.rvqs[0].codebook_size |
|
|
|
@property |
|
def codebooks(self): |
|
return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs)) |
|
|
|
@property |
|
def split_dim(self): |
|
return 1 if self.accept_image_fmap else -1 |
|
|
|
def get_codes_from_indices(self, indices): |
|
codes = tuple( |
|
rvq.get_codes_from_indices(chunk_indices) |
|
for rvq, chunk_indices in zip(self.rvqs, indices) |
|
) |
|
return torch.stack(codes) |
|
|
|
def get_output_from_indices(self, indices): |
|
outputs = tuple( |
|
rvq.get_output_from_indices(chunk_indices) |
|
for rvq, chunk_indices in zip(self.rvqs, indices) |
|
) |
|
return torch.cat(outputs, dim=self.split_dim) |
|
|
|
def forward(self, x, return_all_codes=False): |
|
shape, split_dim, device = x.shape, self.split_dim, x.device |
|
assert shape[split_dim] == self.dim |
|
|
|
|
|
|
|
x = x.chunk(self.groups, dim=split_dim) |
|
|
|
forward_kwargs = dict( |
|
return_all_codes=return_all_codes, |
|
rand_quantize_dropout_fixed_seed=( |
|
get_maybe_sync_seed(device) if self.training else None |
|
), |
|
) |
|
|
|
|
|
|
|
out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x)) |
|
out = tuple(zip(*out)) |
|
|
|
|
|
|
|
quantized, all_indices, *maybe_all_codes = out |
|
|
|
quantized = torch.cat(quantized, dim=split_dim) |
|
all_indices = torch.stack(all_indices) |
|
|
|
ret = (quantized, all_indices, *maybe_all_codes) |
|
return ret |
|
|
|
|
|
if __name__ == "__main__": |
|
model = ResidualFSQ( |
|
levels=[4, 4, 4, 4, 4, 4], |
|
num_quantizers=1, |
|
dim=30, |
|
is_channel_first=True, |
|
quantize_dropout=False, |
|
) |
|
x = torch.randn(2, 30, 10) |
|
quantize, embed_ind = model(x) |
|
|
|
emb_from_ind = model.get_output_from_indices(embed_ind.transpose(1, 2)) |
|
|
|
print(quantize == emb_from_ind.transpose(1, 2)) |
|
|
|
print("quantize shape", quantize.shape) |
|
print("embed_ind", embed_ind) |
|
|