Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,773 Bytes
7f2690b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
import torch.nn as nn
class VectorQuantizer(nn.Module):
"""
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
____________________________________________
Discretization bottleneck part of the VQ-VAE.
Inputs:
- n_e : number of embeddings
- e_dim : dimension of embedding
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
_____________________________________________
"""
def __init__(self, n_e, e_dim, beta):
super(VectorQuantizer, self).__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
# better inheritence properties (so that when VectorQuantizer1d() inherits it, only these will be
# changed)
self.permute_order_in = [0, 2, 3, 1]
self.permute_order_out = [0, 3, 1, 2]
def forward(self, z):
"""
Inputs the output of the encoder network z and maps it to a discrete
one-hot vector that is the index of the closest embedding vector e_j
z (continuous) -> z_q (discrete)
2d: z.shape = (batch, channel, height, width)
1d: z.shape = (batch, channel, time)
quantization pipeline:
1. get encoder input 2d: (B,C,H,W) or 1d: (B, C, T)
2. flatten input to 2d: (B*H*W,C) or 1d: (B*T, C)
"""
# reshape z -> (batch, height, width, channel) or (batch, time, channel) and flatten
z = z.permute(self.permute_order_in).contiguous()
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
torch.matmul(z_flattened, self.embedding.weight.t())
## could possible replace this here
# #\start...
# find closest encodings
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.n_e).to(z)
min_encodings.scatter_(1, min_encoding_indices, 1)
# dtype min encodings: torch.float32
# min_encodings shape: torch.Size([2048, 512])
# min_encoding_indices.shape: torch.Size([2048, 1])
# get quantized latent vectors
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
#.........\end
# with:
# .........\start
#min_encoding_indices = torch.argmin(d, dim=1)
#z_q = self.embedding(min_encoding_indices)
# ......\end......... (TODO)
# compute loss for embedding
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q = z + (z_q - z).detach()
# perplexity
e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
# reshape back to match original input shape
z_q = z_q.permute(self.permute_order_out).contiguous()
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
def get_codebook_entry(self, indices, shape):
# shape specifying (batch, height, width, channel)
# TODO: check for more easy handling with nn.Embedding
min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
min_encodings.scatter_(1, indices[:, None], 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
if shape is not None:
z_q = z_q.view(shape)
# reshape back to match original input shape
z_q = z_q.permute(self.permute_order_out).contiguous()
return z_q
class VectorQuantizer1d(VectorQuantizer):
def __init__(self, n_embed, embed_dim, beta=0.25):
super().__init__(n_embed, embed_dim, beta)
self.permute_order_in = [0, 2, 1]
self.permute_order_out = [0, 2, 1]
if __name__ == '__main__':
quantize = VectorQuantizer1d(n_embed=1024, embed_dim=256, beta=0.25)
# 1d Input (features)
enc_outputs = torch.rand(6, 256, 53)
quant, emb_loss, info = quantize(enc_outputs)
print(quant.shape)
quantize = VectorQuantizer(n_e=1024, e_dim=256, beta=0.25)
# Audio
enc_outputs = torch.rand(4, 256, 5, 53)
quant, emb_loss, info = quantize(enc_outputs)
print(quant.shape)
# Image
enc_outputs = torch.rand(4, 256, 16, 16)
quant, emb_loss, info = quantize(enc_outputs)
print(quant.shape)
|