KingNish commited on
Commit
5873fc1
·
verified ·
1 Parent(s): 108690a

Upload ./RepCodec/repcodec/layers/vq_module.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. RepCodec/repcodec/layers/vq_module.py +155 -0
RepCodec/repcodec/layers/vq_module.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class VectorQuantize(nn.Module):
14
+ """Vector quantization w/ exponential moving averages (EMA)"""
15
+
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ codebook_size: int,
20
+ decay=0.8,
21
+ commitment=1.,
22
+ eps=1e-5,
23
+ n_embed=None,
24
+ ):
25
+ super().__init__()
26
+ n_embed = self.default(n_embed, codebook_size)
27
+
28
+ self.dim = dim
29
+ self.n_embed = n_embed
30
+ self.decay = decay
31
+ self.eps = eps
32
+ self.commitment = commitment
33
+
34
+ embed = torch.randn(dim, n_embed)
35
+ self.register_buffer('embed', embed)
36
+ self.register_buffer('cluster_size', torch.zeros(n_embed))
37
+ self.register_buffer('embed_avg', embed.clone())
38
+
39
+ @property
40
+ def codebook(self):
41
+ return self.embed.transpose(0, 1)
42
+
43
+ def exists(self, val):
44
+ return val is not None
45
+
46
+ def default(self, val, d):
47
+ return val if self.exists(val) else d
48
+
49
+ def ema_inplace(self, moving_avg, new, decay):
50
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
51
+
52
+ def laplace_smoothing(self, x, n_categories, eps=1e-5):
53
+ return (x + eps) / (x.sum() + n_categories * eps)
54
+
55
+ def forward(self, input):
56
+ dtype = input.dtype
57
+ flatten = input.reshape(-1, self.dim)
58
+ dist = (
59
+ flatten.pow(2).sum(1, keepdim=True)
60
+ - 2 * flatten @ self.embed
61
+ + self.embed.pow(2).sum(0, keepdim=True)
62
+ )
63
+ _, embed_ind = (-dist).max(1)
64
+ embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype)
65
+ embed_ind = embed_ind.view(*input.shape[:-1])
66
+ quantize = F.embedding(embed_ind, self.embed.transpose(0, 1))
67
+
68
+ if self.training:
69
+ self.ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
70
+ embed_sum = flatten.transpose(0, 1) @ embed_onehot
71
+ self.ema_inplace(self.embed_avg, embed_sum, self.decay)
72
+ cluster_size = self.laplace_smoothing(self.cluster_size, self.n_embed, self.eps) * self.cluster_size.sum()
73
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
74
+ self.embed.data.copy_(embed_normalized)
75
+
76
+ loss = F.mse_loss(quantize.detach(), input) * self.commitment
77
+ quantize = input + (quantize - input).detach()
78
+
79
+ avg_probs = torch.mean(embed_onehot, dim=0)
80
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
81
+
82
+ return quantize, loss, perplexity
83
+
84
+ def forward_index(self, input):
85
+ dtype = input.dtype
86
+ flatten = input.reshape(-1, self.dim)
87
+ dist = (
88
+ flatten.pow(2).sum(1, keepdim=True)
89
+ - 2 * flatten @ self.embed
90
+ + self.embed.pow(2).sum(0, keepdim=True)
91
+ )
92
+ _, embed_ind = (-dist).max(1)
93
+ embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype)
94
+ embed_ind = embed_ind.view(*input.shape[:-1])
95
+ quantize = F.embedding(embed_ind, self.embed.transpose(0, 1))
96
+ quantize = input + (quantize - input).detach()
97
+
98
+ return quantize, embed_ind
99
+
100
+
101
+ class ResidualVQ(nn.Module):
102
+ """ Residual VQ following algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
103
+
104
+ def __init__(
105
+ self,
106
+ *,
107
+ num_quantizers,
108
+ **kwargs
109
+ ):
110
+ super().__init__()
111
+ self.layers = nn.ModuleList([VectorQuantize(**kwargs) for _ in range(num_quantizers)])
112
+
113
+ def forward(self, x):
114
+ quantized_out = 0.
115
+ residual = x
116
+ all_losses = []
117
+ all_perplexities = []
118
+ for layer in self.layers:
119
+ quantized, loss, perplexity = layer(residual)
120
+ # Issue: https://github.com/lucidrains/vector-quantize-pytorch/issues/33
121
+ # We found considering only the 1st layer VQ's graident results in better performance
122
+ # residual = residual - quantized.detach() # considering all layers' graidents
123
+ residual = residual - quantized # considering only the first layer's graident
124
+ quantized_out = quantized_out + quantized
125
+ all_losses.append(loss)
126
+ all_perplexities.append(perplexity)
127
+ all_losses, all_perplexities = map(torch.stack, (all_losses, all_perplexities))
128
+ return quantized_out, all_losses, all_perplexities
129
+
130
+ def forward_index(self, x, flatten_idx=False):
131
+ quantized_out = 0.
132
+ residual = x
133
+ all_indices = []
134
+ for i, layer in enumerate(self.layers):
135
+ quantized, indices = layer.forward_index(residual)
136
+ # residual = residual - quantized.detach()
137
+ residual = residual - quantized
138
+ quantized_out = quantized_out + quantized
139
+ if flatten_idx:
140
+ indices += (self.codebook_size * i)
141
+ all_indices.append(indices)
142
+ all_indices = torch.stack(all_indices)
143
+ return quantized_out, all_indices.squeeze(1)
144
+
145
+ def initial(self):
146
+ self.codebook = []
147
+ for layer in self.layers:
148
+ self.codebook.append(layer.codebook)
149
+ self.codebook_size = self.codebook[0].size(0)
150
+ self.codebook = torch.stack(self.codebook)
151
+ self.codebook = self.codebook.reshape(-1, self.codebook.size(-1))
152
+
153
+ def lookup(self, indices):
154
+ quantized_out = F.embedding(indices, self.codebook) # Num x T x C
155
+ return torch.sum(quantized_out, dim=0, keepdim=True)