Spaces:
Paused
Paused
| """ | |
| Binary Spherical Quantization | |
| Proposed in https://arxiv.org/abs/2406.07548 | |
| In the simplest setup, each dimension is quantized into {-1, 1}. | |
| An entropy penalty is used to encourage utilization. | |
| """ | |
| import random | |
| from math import log2, ceil | |
| from functools import partial, cache | |
| from collections import namedtuple | |
| from contextlib import nullcontext | |
| import torch.distributed as dist | |
| from torch.distributed import nn as dist_nn | |
| import torch | |
| from torch import nn, einsum | |
| import torch.nn.functional as F | |
| from torch.nn import Module | |
| from torch.amp import autocast | |
| import numpy as np | |
| from einops import rearrange, reduce, pack, unpack | |
| # from einx import get_at | |
| from .dynamic_resolution import predefined_HW_Scales_dynamic | |
| # constants | |
| Return = namedtuple('Return', ['quantized', 'indices', 'bit_indices', 'entropy_aux_loss']) | |
| LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment']) | |
| # distributed helpers | |
| def is_distributed(): | |
| return dist.is_initialized() and dist.get_world_size() > 1 | |
| def maybe_distributed_mean(t): | |
| if not is_distributed(): | |
| return t | |
| dist_nn.all_reduce(t) | |
| t = t / dist.get_world_size() | |
| return t | |
| # helper functions | |
| def exists(v): | |
| return v is not None | |
| def identity(t): | |
| return t | |
| def default(*args): | |
| for arg in args: | |
| if exists(arg): | |
| return arg() if callable(arg) else arg | |
| return None | |
| def round_up_multiple(num, mult): | |
| return ceil(num / mult) * mult | |
| def pack_one(t, pattern): | |
| return pack([t], pattern) | |
| def unpack_one(t, ps, pattern): | |
| return unpack(t, ps, pattern)[0] | |
| def l2norm(t): | |
| return F.normalize(t, dim = -1) | |
| # entropy | |
| def log(t, eps = 1e-5): | |
| return t.clamp(min = eps).log() | |
| def entropy(prob): | |
| return (-prob * log(prob)).sum(dim=-1) | |
| # cosine sim linear | |
| class CosineSimLinear(Module): | |
| def __init__( | |
| self, | |
| dim_in, | |
| dim_out, | |
| scale = 1. | |
| ): | |
| super().__init__() | |
| self.scale = scale | |
| self.weight = nn.Parameter(torch.randn(dim_in, dim_out)) | |
| def forward(self, x): | |
| x = F.normalize(x, dim = -1) | |
| w = F.normalize(self.weight, dim = 0) | |
| return (x @ w) * self.scale | |
| def get_latent2scale_schedule(T: int, H: int, W: int, mode="original"): | |
| assert mode in ["original", "dynamic", "dense", "same1", "same2", "same3"] | |
| predefined_HW_Scales = { | |
| # 256 * 256 | |
| (32, 32): [(1, 1), (2, 2), (3, 3), (4, 4), (6, 6), (9, 9), (13, 13), (18, 18), (24, 24), (32, 32)], | |
| (16, 16): [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (8, 8), (10, 10), (13, 13), (16, 16)], | |
| # 1024x1024 | |
| (64, 64): [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (7, 7), (9, 9), (12, 12), (16, 16), (21, 21), (27, 27), (36, 36), (48, 48), (64, 64)], | |
| (36, 64): [(1, 1), (2, 2), (3, 3), (4, 4), (6, 6), (9, 12), (13, 16), (18, 24), (24, 32), (32, 48), (36, 64)], | |
| } | |
| if mode == "dynamic": | |
| predefined_HW_Scales.update(predefined_HW_Scales_dynamic) | |
| elif mode == "dense": | |
| predefined_HW_Scales[(16, 16)] = [(x, x) for x in range(1, 16+1)] | |
| predefined_HW_Scales[(32, 32)] = predefined_HW_Scales[(16, 16)] + [(20, 20), (24, 24), (28, 28), (32, 32)] | |
| predefined_HW_Scales[(64, 64)] = predefined_HW_Scales[(32, 32)] + [(40, 40), (48, 48), (56, 56), (64, 64)] | |
| elif mode.startswith("same"): | |
| num_quant = int(mode[len("same"):]) | |
| predefined_HW_Scales[(16, 16)] = [(16, 16) for _ in range(num_quant)] | |
| predefined_HW_Scales[(32, 32)] = [(32, 32) for _ in range(num_quant)] | |
| predefined_HW_Scales[(64, 64)] = [(64, 64) for _ in range(num_quant)] | |
| predefined_T_Scales = [1, 2, 3, 4, 5, 6, 7, 9, 11, 13, 15, 17, 17, 17, 17, 17] | |
| patch_THW_shape_per_scale = predefined_HW_Scales[(H, W)] | |
| if len(predefined_T_Scales) < len(patch_THW_shape_per_scale): | |
| # print("warning: the length of predefined_T_Scales is less than the length of patch_THW_shape_per_scale!") | |
| predefined_T_Scales += [predefined_T_Scales[-1]] * (len(patch_THW_shape_per_scale) - len(predefined_T_Scales)) | |
| patch_THW_shape_per_scale = [(min(T, t), h, w ) for (h, w), t in zip(patch_THW_shape_per_scale, predefined_T_Scales[:len(patch_THW_shape_per_scale)])] | |
| return patch_THW_shape_per_scale | |
| class LayerNorm(nn.Module): | |
| r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. | |
| The ordering of the dimensions in the inputs. channels_last corresponds to inputs with | |
| shape (batch_size, height, width, channels) while channels_first corresponds to inputs | |
| with shape (batch_size, channels, height, width). | |
| normalized_shape: int | |
| """ | |
| def __init__(self, normalized_shape, norm_weight=False, eps=1e-6, data_format="channels_first"): | |
| super().__init__() | |
| if norm_weight: | |
| self.weight = nn.Parameter(torch.ones(normalized_shape)/(normalized_shape**0.5)) | |
| else: | |
| self.weight = nn.Parameter(torch.ones(normalized_shape)) | |
| self.bias = nn.Parameter(torch.zeros(normalized_shape)) | |
| self.eps = eps | |
| self.data_format = data_format | |
| if self.data_format not in ["channels_last", "channels_first"]: | |
| raise NotImplementedError | |
| self.normalized_shape = (normalized_shape, ) | |
| def forward(self, x): | |
| if self.data_format == "channels_last": | |
| return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) | |
| elif self.data_format == "channels_first": | |
| u = x.mean(1, keepdim=True) | |
| s = (x - u).pow(2).mean(1, keepdim=True) | |
| x = (x - u) / torch.sqrt(s + self.eps) | |
| if x.ndim == 4: # (b, c, h, w) | |
| x = self.weight[:, None, None] * x + self.bias[:, None, None] | |
| elif x.ndim == 5: # (b, c, t, h, w) | |
| x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None] | |
| else: | |
| raise ValueError("the number of dimensions of the input should be 4 or 5") | |
| return x | |
| class MultiScaleBSQ(Module): | |
| """ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """ | |
| def __init__( | |
| self, | |
| *, | |
| dim, | |
| codebook_size, | |
| soft_clamp_input_value = None, | |
| aux_loss = False, # intermediate auxiliary loss | |
| ln_before_quant=False, # add a LN before multi-scale RQ | |
| ln_init_by_sqrt=False, # weight init by 1/sqrt(d) | |
| use_decay_factor=False, | |
| use_stochastic_depth=False, | |
| drop_rate=0., | |
| schedule_mode="original", # ["original", "dynamic", "dense"] | |
| keep_first_quant=False, | |
| keep_last_quant=False, | |
| remove_residual_detach=False, | |
| random_flip = False, | |
| flip_prob = 0.5, | |
| flip_mode = "stochastic", # "stochastic", "deterministic" | |
| max_flip_lvl = 1, | |
| random_flip_1lvl = False, # random flip one level each time | |
| flip_lvl_idx = None, | |
| drop_when_test=False, | |
| drop_lvl_idx=None, | |
| drop_lvl_num=0, | |
| **kwargs | |
| ): | |
| super().__init__() | |
| codebook_dim = int(log2(codebook_size)) | |
| requires_projection = codebook_dim != dim | |
| self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity() | |
| self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity() | |
| self.has_projections = requires_projection | |
| self.layernorm = LayerNorm(codebook_dim, norm_weight=ln_init_by_sqrt) if ln_before_quant else nn.Identity() | |
| self.use_stochastic_depth = use_stochastic_depth | |
| self.drop_rate = drop_rate | |
| self.remove_residual_detach = remove_residual_detach | |
| self.random_flip = random_flip | |
| self.flip_prob = flip_prob | |
| self.flip_mode = flip_mode | |
| self.max_flip_lvl = max_flip_lvl | |
| self.random_flip_1lvl = random_flip_1lvl | |
| self.flip_lvl_idx = flip_lvl_idx | |
| assert (random_flip and random_flip_1lvl) == False | |
| self.drop_when_test = drop_when_test | |
| self.drop_lvl_idx = drop_lvl_idx | |
| self.drop_lvl_num = drop_lvl_num | |
| if self.drop_when_test: | |
| assert drop_lvl_idx is not None | |
| assert drop_lvl_num > 0 | |
| self.lfq = BSQ( | |
| dim = codebook_dim, | |
| codebook_scale = 1/np.sqrt(codebook_dim), | |
| soft_clamp_input_value = soft_clamp_input_value, | |
| # experimental_softplus_entropy_loss=True, | |
| # entropy_loss_offset=2, | |
| **kwargs | |
| ) | |
| self.z_interplote_up = 'trilinear' | |
| self.z_interplote_down = 'area' | |
| self.use_decay_factor = use_decay_factor | |
| self.schedule_mode = schedule_mode | |
| self.keep_first_quant = keep_first_quant | |
| self.keep_last_quant = keep_last_quant | |
| if self.use_stochastic_depth and self.drop_rate > 0: | |
| assert self.keep_first_quant or self.keep_last_quant | |
| def codebooks(self): | |
| return self.lfq.codebook | |
| def get_codes_from_indices(self, indices_list): | |
| all_codes = [] | |
| for indices in indices_list: | |
| codes = self.lfq.indices_to_codes(indices) | |
| all_codes.append(codes) | |
| _, _, T, H, W = all_codes[-1].size() | |
| summed_codes = 0 | |
| for code in all_codes: | |
| summed_codes += F.interpolate(code, size=(T, H, W), mode=self.z_interplote_up) | |
| return summed_codes | |
| def get_output_from_indices(self, indices): | |
| codes = self.get_codes_from_indices(indices) | |
| codes_summed = reduce(codes, 'q ... -> ...', 'sum') | |
| return self.project_out(codes_summed) | |
| def flip_quant(self, x): | |
| assert self.flip_mode == 'stochastic' | |
| flip_mask = torch.rand_like(x) < self.flip_prob | |
| x = x.clone() | |
| x[flip_mask] = -x[flip_mask] | |
| return x | |
| def forward( | |
| self, | |
| x, | |
| scale_schedule=None, | |
| mask = None, | |
| return_all_codes = False, | |
| return_residual_norm_per_scale = False | |
| ): | |
| if x.ndim == 4: | |
| x = x.unsqueeze(2) | |
| B, C, T, H, W = x.size() | |
| if scale_schedule is None: | |
| if self.schedule_mode.startswith("same"): | |
| scale_num = int(self.schedule_mode[len("same"):]) | |
| assert T == 1 | |
| scale_schedule = [(1, H, W)] * scale_num | |
| else: | |
| scale_schedule = get_latent2scale_schedule(T, H, W, mode=self.schedule_mode) | |
| scale_num = len(scale_schedule) | |
| # x = self.project_in(x) | |
| x = x.permute(0, 2, 3, 4, 1).contiguous() # (b, c, t, h, w) => (b, t, h, w, c) | |
| x = self.project_in(x) | |
| x = x.permute(0, 4, 1, 2, 3).contiguous() # (b, t, h, w, c) => (b, c, t, h, w) | |
| x = self.layernorm(x) | |
| quantized_out = 0. | |
| residual = x | |
| all_losses = [] | |
| all_indices = [] | |
| all_bit_indices = [] | |
| var_inputs = [] | |
| residual_norm_per_scale = [] | |
| # go through the layers | |
| out_fact = init_out_fact = 1.0 | |
| # residual_list = [] | |
| # interpolate_residual_list = [] | |
| # quantized_list = [] | |
| if self.drop_when_test: | |
| drop_lvl_start = self.drop_lvl_idx | |
| drop_lvl_end = self.drop_lvl_idx + self.drop_lvl_num | |
| scale_num = len(scale_schedule) | |
| with autocast('cuda', enabled = False): | |
| for si, (pt, ph, pw) in enumerate(scale_schedule): | |
| out_fact = max(0.1, out_fact) if self.use_decay_factor else init_out_fact | |
| if (pt, ph, pw) != (T, H, W): | |
| interpolate_residual = F.interpolate(residual, size=(pt, ph, pw), mode=self.z_interplote_down) | |
| else: | |
| interpolate_residual = residual | |
| if return_residual_norm_per_scale: | |
| residual_norm_per_scale.append((torch.abs(interpolate_residual) < 0.05 * self.lfq.codebook_scale).sum() / interpolate_residual.numel()) | |
| # residual_list.append(torch.norm(residual.detach(), dim=1).mean()) | |
| # interpolate_residual_list.append(torch.norm(interpolate_residual.detach(), dim=1).mean()) | |
| if self.training and self.use_stochastic_depth and random.random() < self.drop_rate: | |
| if (si == 0 and self.keep_first_quant) or (si == scale_num - 1 and self.keep_last_quant): | |
| quantized, indices, _, loss = self.lfq(interpolate_residual) | |
| quantized = quantized * out_fact | |
| all_indices.append(indices) | |
| all_losses.append(loss) | |
| else: | |
| quantized = torch.zeros_like(interpolate_residual) | |
| elif self.drop_when_test and drop_lvl_start <= si < drop_lvl_end: | |
| continue | |
| else: | |
| # residual_norm = torch.norm(interpolate_residual.detach(), dim=1) # (b, t, h, w) | |
| # print(si, residual_norm.min(), residual_norm.max(), residual_norm.mean()) | |
| quantized, indices, bit_indices, loss = self.lfq(interpolate_residual) | |
| if self.random_flip and si < self.max_flip_lvl: | |
| quantized = self.flip_quant(quantized) | |
| if self.random_flip_1lvl and si == self.flip_lvl_idx: | |
| quantized = self.flip_quant(quantized) | |
| quantized = quantized * out_fact | |
| all_indices.append(indices) | |
| # quantized_list.append(torch.norm(quantized.detach(), dim=1).mean()) | |
| if (pt, ph, pw) != (T, H, W): | |
| quantized = F.interpolate(quantized, size=(T, H, W), mode=self.z_interplote_up).contiguous() | |
| if self.remove_residual_detach: | |
| residual = residual - quantized | |
| else: | |
| residual = residual - quantized.detach() | |
| quantized_out = quantized_out + quantized | |
| all_bit_indices.append(bit_indices) | |
| all_losses.append(loss) | |
| if si != scale_num - 1: | |
| var_inputs.append(F.interpolate(quantized_out, size=scale_schedule[si+1], mode=self.z_interplote_down).contiguous()) | |
| if self.use_decay_factor: | |
| out_fact -= 0.1 | |
| # print("residual_list:", residual_list) | |
| # print("interpolate_residual_list:", interpolate_residual_list) | |
| # print("quantized_list:", quantized_list) | |
| # import ipdb; ipdb.set_trace() | |
| # project out, if needed | |
| quantized_out = quantized_out.permute(0, 2, 3, 4, 1).contiguous() # (b, c, t, h, w) => (b, t, h, w, c) | |
| quantized_out = self.project_out(quantized_out) | |
| quantized_out = quantized_out.permute(0, 4, 1, 2, 3).contiguous() # (b, t, h, w, c) => (b, c, t, h, w) | |
| # image | |
| if quantized_out.size(2) == 1: | |
| quantized_out = quantized_out.squeeze(2) | |
| # stack all losses and indices | |
| all_losses = torch.stack(all_losses, dim = -1) | |
| ret = (quantized_out, all_indices, all_bit_indices, residual_norm_per_scale, all_losses, var_inputs) | |
| if not return_all_codes: | |
| return ret | |
| # whether to return all codes from all codebooks across layers | |
| all_codes = self.get_codes_from_indices(all_indices) | |
| # will return all codes in shape (quantizer, batch, sequence length, codebook dimension) | |
| return (*ret, all_codes) | |
| class BSQ(Module): | |
| def __init__( | |
| self, | |
| *, | |
| dim = None, | |
| codebook_size = None, | |
| entropy_loss_weight = 0.1, | |
| commitment_loss_weight = 0.25, | |
| diversity_gamma = 1., | |
| straight_through_activation = nn.Identity(), | |
| num_codebooks = 1, | |
| keep_num_codebooks_dim = None, | |
| codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer | |
| frac_per_sample_entropy = 1., # make less than 1. to only use a random fraction of the probs for per sample entropy | |
| has_projections = None, | |
| projection_has_bias = True, | |
| soft_clamp_input_value = None, | |
| cosine_sim_project_in = False, | |
| cosine_sim_project_in_scale = None, | |
| channel_first = None, | |
| experimental_softplus_entropy_loss = False, | |
| entropy_loss_offset = 5., # how much to shift the loss before softplus | |
| spherical = True, # from https://arxiv.org/abs/2406.07548 | |
| force_quantization_f32 = True, # will force the quantization step to be full precision | |
| inv_temperature = 100.0, | |
| gamma0=1.0, gamma=1.0, zeta=1.0, | |
| preserve_norm = False, # whether to preserve the original norm info | |
| new_quant = False, # new quant function, | |
| mask_out = False, # mask the output as 0 in some conditions | |
| use_out_phi = False, # use output phi network | |
| use_out_phi_res = False, # residual out phi | |
| ): | |
| super().__init__() | |
| # some assert validations | |
| assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ' | |
| assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})' | |
| codebook_size = default(codebook_size, lambda: 2 ** dim) | |
| self.codebook_size = codebook_size | |
| codebook_dim = int(log2(codebook_size)) | |
| codebook_dims = codebook_dim * num_codebooks | |
| dim = default(dim, codebook_dims) | |
| self.codebook_dims = codebook_dims | |
| has_projections = default(has_projections, dim != codebook_dims) | |
| if cosine_sim_project_in: | |
| cosine_sim_project_in = default(cosine_sim_project_in_scale, codebook_scale) | |
| project_in_klass = partial(CosineSimLinear, scale = cosine_sim_project_in) | |
| else: | |
| project_in_klass = partial(nn.Linear, bias = projection_has_bias) | |
| self.project_in = project_in_klass(dim, codebook_dims) if has_projections else nn.Identity() # nn.Identity() | |
| self.project_out = nn.Linear(codebook_dims, dim, bias = projection_has_bias) if has_projections else nn.Identity() # nn.Identity() | |
| self.has_projections = has_projections | |
| self.out_phi = nn.Linear(codebook_dims, codebook_dims) if use_out_phi else nn.Identity() | |
| self.use_out_phi_res = use_out_phi_res | |
| if self.use_out_phi_res: | |
| self.out_phi_scale = nn.Parameter(torch.zeros(codebook_dims), requires_grad=True) # init as zero | |
| self.dim = dim | |
| self.codebook_dim = codebook_dim | |
| self.num_codebooks = num_codebooks | |
| keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) | |
| assert not (num_codebooks > 1 and not keep_num_codebooks_dim) | |
| self.keep_num_codebooks_dim = keep_num_codebooks_dim | |
| # channel first | |
| self.channel_first = channel_first | |
| # straight through activation | |
| self.activation = straight_through_activation | |
| # For BSQ (binary spherical quantization) | |
| if not spherical: | |
| raise ValueError("For BSQ, spherical must be True.") | |
| self.persample_entropy_compute = 'analytical' | |
| self.inv_temperature = inv_temperature | |
| self.gamma0 = gamma0 # loss weight for entropy penalty | |
| self.gamma = gamma # loss weight for entropy penalty | |
| self.zeta = zeta # loss weight for entire entropy penalty | |
| self.preserve_norm = preserve_norm | |
| self.new_quant = new_quant | |
| self.mask_out = mask_out | |
| # entropy aux loss related weights | |
| assert 0 < frac_per_sample_entropy <= 1. | |
| self.frac_per_sample_entropy = frac_per_sample_entropy | |
| self.diversity_gamma = diversity_gamma | |
| self.entropy_loss_weight = entropy_loss_weight | |
| # codebook scale | |
| self.codebook_scale = codebook_scale | |
| # commitment loss | |
| self.commitment_loss_weight = commitment_loss_weight | |
| # whether to soft clamp the input value from -value to value | |
| self.soft_clamp_input_value = soft_clamp_input_value | |
| assert not exists(soft_clamp_input_value) or soft_clamp_input_value >= codebook_scale | |
| # whether to make the entropy loss positive through a softplus (experimental, please report if this worked or not in discussions) | |
| self.entropy_loss_offset = entropy_loss_offset | |
| self.experimental_softplus_entropy_loss = experimental_softplus_entropy_loss | |
| # for no auxiliary loss, during inference | |
| self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1)) | |
| self.register_buffer('zero', torch.tensor(0.), persistent = False) | |
| # whether to force quantization step to be f32 | |
| self.force_quantization_f32 = force_quantization_f32 | |
| # codes | |
| # all_codes = torch.arange(codebook_size) | |
| # bits = ((all_codes[..., None].int() & self.mask) != 0).float() | |
| # codebook = self.bits_to_codes(bits) | |
| # self.register_buffer('codebook', codebook.float(), persistent = False) | |
| def bits_to_codes(self, bits): | |
| return bits * self.codebook_scale * 2 - self.codebook_scale | |
| # @property | |
| # def dtype(self): | |
| # return self.codebook.dtype | |
| def indices_to_codes( | |
| self, | |
| indices, | |
| label_type = 'int_label', | |
| project_out = True | |
| ): | |
| assert label_type in ['int_label', 'bit_label'] | |
| is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) | |
| should_transpose = default(self.channel_first, is_img_or_video) | |
| if not self.keep_num_codebooks_dim: | |
| if label_type == 'int_label': | |
| indices = rearrange(indices, '... -> ... 1') | |
| else: | |
| indices = indices.unsqueeze(-2) | |
| # indices to codes, which are bits of either -1 or 1 | |
| if label_type == 'int_label': | |
| assert indices[..., None].int().min() > 0 | |
| bits = ((indices[..., None].int() & self.mask) != 0).float() # .to(self.dtype) | |
| else: | |
| bits = indices | |
| codes = self.bits_to_codes(bits) | |
| codes = l2norm(codes) # must normalize when using BSQ | |
| codes = rearrange(codes, '... c d -> ... (c d)') | |
| # whether to project codes out to original dimensions | |
| # if the input feature dimensions were not log2(codebook size) | |
| if project_out: | |
| codes = self.project_out(codes) | |
| # rearrange codes back to original shape | |
| if should_transpose: | |
| codes = rearrange(codes, 'b ... d -> b d ...') | |
| return codes | |
| def quantize(self, z): | |
| assert z.shape[-1] == self.codebook_dims, f"Expected {self.codebook_dims} dimensions, got {z.shape[-1]}" | |
| zhat = torch.where(z > 0, | |
| torch.tensor(1, dtype=z.dtype, device=z.device), | |
| torch.tensor(-1, dtype=z.dtype, device=z.device)) | |
| return z + (zhat - z).detach() | |
| def quantize_new(self, z): | |
| assert z.shape[-1] == self.codebook_dims, f"Expected {self.codebook_dims} dimensions, got {z.shape[-1]}" | |
| zhat = torch.where(z > 0, | |
| torch.tensor(1, dtype=z.dtype, device=z.device), | |
| torch.tensor(-1, dtype=z.dtype, device=z.device)) | |
| q_scale = 1. / (self.codebook_dims ** 0.5) | |
| zhat = q_scale * zhat # on unit sphere | |
| return z + (zhat - z).detach() | |
| def soft_entropy_loss(self, z): | |
| if self.persample_entropy_compute == 'analytical': | |
| # if self.l2_norm: | |
| p = torch.sigmoid(-4 * z / (self.codebook_dims ** 0.5) * self.inv_temperature) | |
| # else: | |
| # p = torch.sigmoid(-4 * z * self.inv_temperature) | |
| prob = torch.stack([p, 1-p], dim=-1) # (b, h, w, 18, 2) | |
| per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() # (b,h,w,18)->(b,h,w)->scalar | |
| else: | |
| per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() | |
| # macro average of the probability of each subgroup | |
| avg_prob = reduce(prob, '... g d ->g d', 'mean') # (18, 2) | |
| codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False) | |
| # the approximation of the entropy is the sum of the entropy of each subgroup | |
| return per_sample_entropy, codebook_entropy.sum(), avg_prob | |
| def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True): | |
| if normalize: # False | |
| probs = (count + eps) / (count + eps).sum(dim=dim, keepdim =True) | |
| else: # True | |
| probs = count | |
| H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim) | |
| return H | |
| def forward( | |
| self, | |
| x, | |
| return_loss_breakdown = False, | |
| mask = None, | |
| entropy_weight=0.1 | |
| ): | |
| """ | |
| einstein notation | |
| b - batch | |
| n - sequence (or flattened spatial dimensions) | |
| d - feature dimension, which is also log2(codebook size) | |
| c - number of codebook dim | |
| """ | |
| is_img_or_video = x.ndim >= 4 | |
| should_transpose = default(self.channel_first, is_img_or_video) | |
| # standardize image or video into (batch, seq, dimension) | |
| if should_transpose: | |
| x = rearrange(x, 'b d ... -> b ... d') | |
| x, ps = pack_one(x, 'b * d') # x.shape [b, hwt, c] | |
| assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}' | |
| x = self.project_in(x) | |
| # split out number of codebooks | |
| x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks) | |
| x = l2norm(x) | |
| # whether to force quantization step to be full precision or not | |
| force_f32 = self.force_quantization_f32 | |
| quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext | |
| indices = None | |
| with quantization_context(): | |
| if force_f32: | |
| orig_dtype = x.dtype | |
| x = x.float() | |
| # use straight-through gradients (optionally with custom activation fn) if training | |
| if self.new_quant: | |
| quantized = self.quantize_new(x) | |
| # calculate indices | |
| bit_indices = (quantized > 0).int() | |
| entropy_penalty = persample_entropy = cb_entropy = self.zero | |
| commit_loss = self.zero | |
| # input back to original dtype if needed | |
| if force_f32: | |
| x = x.type(orig_dtype) | |
| # merge back codebook dim | |
| x = quantized # rename quantized to x for output | |
| x = rearrange(x, 'b n c d -> b n (c d)') | |
| # project out to feature dimension if needed | |
| x = self.project_out(x) | |
| # reconstitute image or video dimensions | |
| if should_transpose: | |
| x = unpack_one(x, ps, 'b * d') | |
| x = rearrange(x, 'b ... d -> b d ...') | |
| bit_indices = unpack_one(bit_indices, ps, 'b * c d') | |
| # whether to remove single codebook dim | |
| if not self.keep_num_codebooks_dim: | |
| bit_indices = rearrange(bit_indices, '... 1 d -> ... d') | |
| # complete aux loss | |
| aux_loss = commit_loss * self.commitment_loss_weight + (self.zeta * entropy_penalty / self.inv_temperature)*entropy_weight | |
| # returns | |
| ret = Return(x, indices, bit_indices, aux_loss) | |
| if not return_loss_breakdown: | |
| return ret | |
| return ret, LossBreakdown(persample_entropy, cb_entropy, commit_loss) | |