Upload ./RepCodec/repcodec/layers/vq_module.py with huggingface_hub
Browse files
RepCodec/repcodec/layers/vq_module.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the CC BY-NC license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
class VectorQuantize(nn.Module):
|
14 |
+
"""Vector quantization w/ exponential moving averages (EMA)"""
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
dim: int,
|
19 |
+
codebook_size: int,
|
20 |
+
decay=0.8,
|
21 |
+
commitment=1.,
|
22 |
+
eps=1e-5,
|
23 |
+
n_embed=None,
|
24 |
+
):
|
25 |
+
super().__init__()
|
26 |
+
n_embed = self.default(n_embed, codebook_size)
|
27 |
+
|
28 |
+
self.dim = dim
|
29 |
+
self.n_embed = n_embed
|
30 |
+
self.decay = decay
|
31 |
+
self.eps = eps
|
32 |
+
self.commitment = commitment
|
33 |
+
|
34 |
+
embed = torch.randn(dim, n_embed)
|
35 |
+
self.register_buffer('embed', embed)
|
36 |
+
self.register_buffer('cluster_size', torch.zeros(n_embed))
|
37 |
+
self.register_buffer('embed_avg', embed.clone())
|
38 |
+
|
39 |
+
@property
|
40 |
+
def codebook(self):
|
41 |
+
return self.embed.transpose(0, 1)
|
42 |
+
|
43 |
+
def exists(self, val):
|
44 |
+
return val is not None
|
45 |
+
|
46 |
+
def default(self, val, d):
|
47 |
+
return val if self.exists(val) else d
|
48 |
+
|
49 |
+
def ema_inplace(self, moving_avg, new, decay):
|
50 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
51 |
+
|
52 |
+
def laplace_smoothing(self, x, n_categories, eps=1e-5):
|
53 |
+
return (x + eps) / (x.sum() + n_categories * eps)
|
54 |
+
|
55 |
+
def forward(self, input):
|
56 |
+
dtype = input.dtype
|
57 |
+
flatten = input.reshape(-1, self.dim)
|
58 |
+
dist = (
|
59 |
+
flatten.pow(2).sum(1, keepdim=True)
|
60 |
+
- 2 * flatten @ self.embed
|
61 |
+
+ self.embed.pow(2).sum(0, keepdim=True)
|
62 |
+
)
|
63 |
+
_, embed_ind = (-dist).max(1)
|
64 |
+
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype)
|
65 |
+
embed_ind = embed_ind.view(*input.shape[:-1])
|
66 |
+
quantize = F.embedding(embed_ind, self.embed.transpose(0, 1))
|
67 |
+
|
68 |
+
if self.training:
|
69 |
+
self.ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
70 |
+
embed_sum = flatten.transpose(0, 1) @ embed_onehot
|
71 |
+
self.ema_inplace(self.embed_avg, embed_sum, self.decay)
|
72 |
+
cluster_size = self.laplace_smoothing(self.cluster_size, self.n_embed, self.eps) * self.cluster_size.sum()
|
73 |
+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
|
74 |
+
self.embed.data.copy_(embed_normalized)
|
75 |
+
|
76 |
+
loss = F.mse_loss(quantize.detach(), input) * self.commitment
|
77 |
+
quantize = input + (quantize - input).detach()
|
78 |
+
|
79 |
+
avg_probs = torch.mean(embed_onehot, dim=0)
|
80 |
+
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
81 |
+
|
82 |
+
return quantize, loss, perplexity
|
83 |
+
|
84 |
+
def forward_index(self, input):
|
85 |
+
dtype = input.dtype
|
86 |
+
flatten = input.reshape(-1, self.dim)
|
87 |
+
dist = (
|
88 |
+
flatten.pow(2).sum(1, keepdim=True)
|
89 |
+
- 2 * flatten @ self.embed
|
90 |
+
+ self.embed.pow(2).sum(0, keepdim=True)
|
91 |
+
)
|
92 |
+
_, embed_ind = (-dist).max(1)
|
93 |
+
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype)
|
94 |
+
embed_ind = embed_ind.view(*input.shape[:-1])
|
95 |
+
quantize = F.embedding(embed_ind, self.embed.transpose(0, 1))
|
96 |
+
quantize = input + (quantize - input).detach()
|
97 |
+
|
98 |
+
return quantize, embed_ind
|
99 |
+
|
100 |
+
|
101 |
+
class ResidualVQ(nn.Module):
|
102 |
+
""" Residual VQ following algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
|
103 |
+
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
*,
|
107 |
+
num_quantizers,
|
108 |
+
**kwargs
|
109 |
+
):
|
110 |
+
super().__init__()
|
111 |
+
self.layers = nn.ModuleList([VectorQuantize(**kwargs) for _ in range(num_quantizers)])
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
quantized_out = 0.
|
115 |
+
residual = x
|
116 |
+
all_losses = []
|
117 |
+
all_perplexities = []
|
118 |
+
for layer in self.layers:
|
119 |
+
quantized, loss, perplexity = layer(residual)
|
120 |
+
# Issue: https://github.com/lucidrains/vector-quantize-pytorch/issues/33
|
121 |
+
# We found considering only the 1st layer VQ's graident results in better performance
|
122 |
+
# residual = residual - quantized.detach() # considering all layers' graidents
|
123 |
+
residual = residual - quantized # considering only the first layer's graident
|
124 |
+
quantized_out = quantized_out + quantized
|
125 |
+
all_losses.append(loss)
|
126 |
+
all_perplexities.append(perplexity)
|
127 |
+
all_losses, all_perplexities = map(torch.stack, (all_losses, all_perplexities))
|
128 |
+
return quantized_out, all_losses, all_perplexities
|
129 |
+
|
130 |
+
def forward_index(self, x, flatten_idx=False):
|
131 |
+
quantized_out = 0.
|
132 |
+
residual = x
|
133 |
+
all_indices = []
|
134 |
+
for i, layer in enumerate(self.layers):
|
135 |
+
quantized, indices = layer.forward_index(residual)
|
136 |
+
# residual = residual - quantized.detach()
|
137 |
+
residual = residual - quantized
|
138 |
+
quantized_out = quantized_out + quantized
|
139 |
+
if flatten_idx:
|
140 |
+
indices += (self.codebook_size * i)
|
141 |
+
all_indices.append(indices)
|
142 |
+
all_indices = torch.stack(all_indices)
|
143 |
+
return quantized_out, all_indices.squeeze(1)
|
144 |
+
|
145 |
+
def initial(self):
|
146 |
+
self.codebook = []
|
147 |
+
for layer in self.layers:
|
148 |
+
self.codebook.append(layer.codebook)
|
149 |
+
self.codebook_size = self.codebook[0].size(0)
|
150 |
+
self.codebook = torch.stack(self.codebook)
|
151 |
+
self.codebook = self.codebook.reshape(-1, self.codebook.size(-1))
|
152 |
+
|
153 |
+
def lookup(self, indices):
|
154 |
+
quantized_out = F.embedding(indices, self.codebook) # Num x T x C
|
155 |
+
return torch.sum(quantized_out, dim=0, keepdim=True)
|