Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,919 Bytes
1da48bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import random
import torch.nn as nn
from models.vq.encdec import Encoder, Decoder
from models.vq.residual_vq import ResidualVQ
class RVQVAE(nn.Module):
def __init__(self,
args,
input_width=263,
nb_code=1024,
code_dim=512,
output_emb_width=512,
down_t=3,
stride_t=2,
width=512,
depth=3,
dilation_growth_rate=3,
activation='relu',
norm=None):
super().__init__()
assert output_emb_width == code_dim
self.code_dim = code_dim
self.num_code = nb_code
# self.quant = args.quantizer
self.encoder = Encoder(input_width, output_emb_width, down_t, stride_t, width, depth,
dilation_growth_rate, activation=activation, norm=norm)
self.decoder = Decoder(input_width, output_emb_width, down_t, stride_t, width, depth,
dilation_growth_rate, activation=activation, norm=norm)
rvqvae_config = {
'num_quantizers': args.num_quantizers,
'shared_codebook': args.shared_codebook,
'quantize_dropout_prob': args.quantize_dropout_prob,
'quantize_dropout_cutoff_index': 0,
'nb_code': nb_code,
'code_dim':code_dim,
'args': args,
}
self.quantizer = ResidualVQ(**rvqvae_config)
def preprocess(self, x):
# (bs, T, Jx3) -> (bs, Jx3, T)
x = x.permute(0, 2, 1).float()
return x
def postprocess(self, x):
# (bs, Jx3, T) -> (bs, T, Jx3)
x = x.permute(0, 2, 1)
return x
def encode(self, x):
N, T, _ = x.shape
x_in = self.preprocess(x)
x_encoder = self.encoder(x_in)
# print(x_encoder.shape)
code_idx, all_codes = self.quantizer.quantize(x_encoder, return_latent=True)
# print(code_idx.shape)
# code_idx = code_idx.view(N, -1)
# (N, T, Q)
# print()
return code_idx, all_codes
def forward(self, x):
x_in = self.preprocess(x)
# Encode
x_encoder = self.encoder(x_in)
## quantization
# x_quantized, code_idx, commit_loss, perplexity = self.quantizer(x_encoder, sample_codebook_temp=0.5,
# force_dropout_index=0) #TODO hardcode
x_quantized, code_idx, commit_loss, perplexity = self.quantizer(x_encoder, sample_codebook_temp=0.5)
# print(code_idx[0, :, 1])
## decoder
x_out = self.decoder(x_quantized)
# x_out = self.postprocess(x_decoder)
return {
'rec_pose': x_out,
'commit_loss': commit_loss,
'perplexity': perplexity,
}
def forward_decoder(self, x):
x_d = self.quantizer.get_codes_from_indices(x)
# x_d = x_d.view(1, -1, self.code_dim).permute(0, 2, 1).contiguous()
x = x_d.sum(dim=0).permute(0, 2, 1)
# decoder
x_out = self.decoder(x)
# x_out = self.postprocess(x_decoder)
return x_out
def map2latent(self,x):
x_in = self.preprocess(x)
# Encode
x_encoder = self.encoder(x_in)
x_encoder = x_encoder.permute(0,2,1)
return x_encoder
def latent2origin(self,x):
x = x.permute(0,2,1)
x_quantized, code_idx, commit_loss, perplexity = self.quantizer(x, sample_codebook_temp=0.5)
# print(code_idx[0, :, 1])
## decoder
x_out = self.decoder(x_quantized)
# x_out = self.postprocess(x_decoder)
return x_out, commit_loss, perplexity
class LengthEstimator(nn.Module):
def __init__(self, input_size, output_size):
super(LengthEstimator, self).__init__()
nd = 512
self.output = nn.Sequential(
nn.Linear(input_size, nd),
nn.LayerNorm(nd),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.2),
nn.Linear(nd, nd // 2),
nn.LayerNorm(nd // 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.2),
nn.Linear(nd // 2, nd // 4),
nn.LayerNorm(nd // 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(nd // 4, output_size)
)
self.output.apply(self.__init_weights)
def __init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, text_emb):
return self.output(text_emb) |