import pdb import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from .utils.layer import BasicBlock from einops import rearrange import pickle from .timm_transformer.transformer import Block as mytimmBlock class MDM(nn.Module): def __init__(self, args): super().__init__() njoints=768 nfeats=1 latent_dim=512 ff_size=1024 num_layers=8 num_heads=4 dropout=0.1 ablation=None activation="gelu" legacy=False data_rep='rot6d' dataset='amass' audio_feat_dim = 64 emb_trans_dec=False audio_rep='' n_seed=8 cond_mode='' kargs={} if args.vqvae_type == 'rvqvae': njoints = 1536 elif args.vqvae_type == 'novqvae': njoints = 312 self.args= args self.legacy = legacy self.njoints = njoints self.nfeats = nfeats self.data_rep = data_rep self.latent_dim = latent_dim self.ff_size = ff_size self.num_layers = num_layers self.num_heads = num_heads self.dropout = dropout self.ablation = ablation self.activation = activation self.action_emb = kargs.get('action_emb', None) self.input_feats = self.njoints * self.nfeats self.cond_mask_prob = kargs.get('cond_mask_prob', 0.3) self.use_motionclip = args.use_motionclip if args.audio_rep == 'onset+amplitude': self.WavEncoder = WavEncoder(args.audio_f,audio_in=2) self.audio_feat_dim = args.audio_f self.text_encoder_body = nn.Linear(300, args.audio_f) with open(f"{args.data_path}weights/vocab.pkl", 'rb') as f: self.lang_model = pickle.load(f) pre_trained_embedding = self.lang_model.word_embedding_weights self.text_pre_encoder_body = nn.Embedding.from_pretrained(torch.FloatTensor(pre_trained_embedding),freeze=args.t_fix_pre) self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) self.emb_trans_dec = emb_trans_dec self.cond_mode = cond_mode self.num_head = 8 self.mytimmblocks = nn.ModuleList([ mytimmBlock(dim=self.latent_dim,num_heads=self.num_heads,mlp_ratio=self.ff_size//self.latent_dim,drop_path=self.dropout) #hidden是对应于输入x的维度,attn_heads应该是12,这里写1是为了方便调试流程 for _ in range(self.num_layers)]) self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder) self.n_seed = n_seed self.style_dim = 64 self.embed_style = nn.Linear(6, self.style_dim) self.embed_text = nn.Linear(self.input_feats*4, self.latent_dim) self.output_process = OutputProcess(self.data_rep, self.input_feats, self.latent_dim, self.njoints, self.nfeats) self.rel_pos = SinusoidalEmbeddings(self.latent_dim // self.num_head) self.input_process = InputProcess(self.data_rep, self.input_feats , self.latent_dim) self.input_process2 = nn.Linear(self.latent_dim * 2 + self.audio_feat_dim, self.latent_dim) if self.use_motionclip: self.input_process3 = nn.Linear(self.latent_dim + 512, self.latent_dim) self.mix_audio_text = nn.Linear(args.audio_f+args.word_f,256) def mask_cond(self, cond, force_mask=False): bs, d = cond.shape if force_mask: return torch.zeros_like(cond) elif self.training and self.cond_mask_prob > 0.: mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(bs, 1) # 1-> use null_cond, 0-> use real cond return cond * (1. - mask) else: return cond def mask_cond_audio(self, cond, force_mask=False): bs, d = cond.shape if force_mask: return torch.zeros_like(cond) elif self.training and self.cond_mask_prob_audio > 0.: mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob_audio).view(bs, 1) # 1-> use null_cond, 0-> use real cond return cond * (1. - mask) else: return cond def forward(self, x, timesteps, y=None,uncond_info=False): """ x: [batch_size, njoints, nfeats, max_frames], denoted x_t in the paper timesteps: [batch_size] (int) seed: [batch_size, njoints, nfeats] """ _,_,_,noise_length = x.shape y = y.copy() bs, njoints, nfeats, nframes = x.shape # 300 ,1141, 1, 88 emb_t = self.embed_timestep(timesteps) # [1, bs, d], (1, 2, 256) force_mask = y.get('uncond', False) # False #force_mask=uncond_info if self.n_seed != 0: embed_text = self.embed_text(y['seed'].reshape(bs, -1)) # (bs, 256-64) emb_seed = embed_text audio_feat = self.WavEncoder(y['audio']).permute(1, 0, 2) text_feat = self.text_pre_encoder_body(y['word']) text_feat = self.text_encoder_body(text_feat).permute(1, 0, 2) at_feat = torch.cat([audio_feat,text_feat],dim=2) at_feat = self.mix_audio_text(at_feat) at_feat = F.avg_pool1d(at_feat.permute(1,2,0), self.args.vqvae_squeeze_scale).permute(2,0,1) # This part is test for timm transformer blocks x = x.reshape(bs, njoints * nfeats, 1, nframes) # [300, 1141, 1, 88] -> [300, 1141, 1, 88] # self-attention x_ = self.input_process(x) # [300, 1141, 1, 88] -> [88, 300, 256] # local-cross-attention xseq = torch.cat((x_, at_feat), axis=2) # [88, 300, 256], [88, 300, 64] -> [88, 300, 320] # all frames embed_style_2 = (emb_seed + emb_t).repeat(nframes, 1, 1) # [300, 256] ,[1, 300, 256] -> [88, 300, 256] xseq = torch.cat((embed_style_2, xseq), axis=2) # -> [88, 300, 576] xseq = self.input_process2(xseq) #[88, 300, 576] -> [88, 300, 256] if self.use_motionclip: xseq = torch.cat((xseq, self.mask_cond(y['style_feature'],force_mask).unsqueeze(0).repeat(nframes, 1, 1)), axis = 2) xseq = self.input_process3(xseq) # 下面10行都是位置编码,感觉加了会好一点点,不知道是不是错觉 xseq = xseq.permute(1, 0, 2) # [88, 300, 256] -> [300, 88, 256] xseq = xseq.view(bs, nframes, self.num_head, -1) # [300, 88, 256] -> [300, 88, 8, 32] xseq = xseq.permute(0, 2, 1, 3) # [300, 88, 8, 32] -> [300, 8, 88, 32] xseq = xseq.reshape(bs * self.num_head, nframes, -1) # [300, 8, 88, 32] -> [2400, 88, 32] pos_emb = self.rel_pos(xseq) # (88, 32) xseq, _ = apply_rotary_pos_emb(xseq, xseq, pos_emb) # [2400, 88, 32] xseq_rpe = xseq.reshape(bs, self.num_head, nframes, -1) # [300, 8, 88, 32] xseq = xseq_rpe.permute(0, 2, 1, 3) # [300, 8, 88, 32] -> [300, 88, 8, 32] xseq = xseq.view(bs, nframes, -1) # [300, 88, 8, 32] -> [300, 88, 256] for block in self.mytimmblocks: xseq = block(xseq) xseq = xseq.permute(1, 0, 2) # [300, 88, 256] -> [88 ,300, 256] output = xseq output = self.output_process(output) # [88, 300, 256] -> [300, 1141, 1, 88] return output[...,:noise_length] @staticmethod def apply_rotary(x, sinusoidal_pos): sin, cos = sinusoidal_pos x1, x2 = x[..., 0::2], x[..., 1::2] # 如果是旋转query key的话,下面这个直接cat就行,因为要进行矩阵乘法,最终会在这个维度求和。(只要保持query和key的最后一个dim的每一个位置对应上就可以) # torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) # 如果是旋转value的话,下面这个stack后再flatten才可以,因为训练好的模型最后一个dim是两两之间交替的。 return torch.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1).flatten(-2, -1) class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) # (5000, 128) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # (5000, 1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe) def forward(self, x): # not used in the final model x = x + self.pe[:x.shape[0], :] return self.dropout(x) class TimestepEmbedder(nn.Module): def __init__(self, latent_dim, sequence_pos_encoder): super().__init__() self.latent_dim = latent_dim self.sequence_pos_encoder = sequence_pos_encoder time_embed_dim = self.latent_dim self.time_embed = nn.Sequential( nn.Linear(self.latent_dim, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim), ) def forward(self, timesteps): return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2) class InputProcess(nn.Module): def __init__(self, data_rep, input_feats, latent_dim): super().__init__() self.data_rep = data_rep self.input_feats = input_feats self.latent_dim = latent_dim self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim) if self.data_rep == 'rot_vel': self.velEmbedding = nn.Linear(self.input_feats, self.latent_dim) def forward(self, x): bs, njoints, nfeats, nframes = x.shape x = x.permute((3, 0, 1, 2)).reshape(nframes, bs, njoints*nfeats) if self.data_rep in ['rot6d', 'xyz', 'hml_vec']: x = self.poseEmbedding(x) # [seqlen, bs, d] return x elif self.data_rep == 'rot_vel': first_pose = x[[0]] # [1, bs, 150] first_pose = self.poseEmbedding(first_pose) # [1, bs, d] vel = x[1:] # [seqlen-1, bs, 150] vel = self.velEmbedding(vel) # [seqlen-1, bs, d] return torch.cat((first_pose, vel), axis=0) # [seqlen, bs, d] else: raise ValueError class OutputProcess(nn.Module): def __init__(self, data_rep, input_feats, latent_dim, njoints, nfeats): super().__init__() self.data_rep = data_rep self.input_feats = input_feats self.latent_dim = latent_dim self.njoints = njoints self.nfeats = nfeats self.poseFinal = nn.Linear(self.latent_dim, self.input_feats) if self.data_rep == 'rot_vel': self.velFinal = nn.Linear(self.latent_dim, self.input_feats) def forward(self, output): nframes, bs, d = output.shape if self.data_rep in ['rot6d', 'xyz', 'hml_vec']: output = self.poseFinal(output) # [88, 300, 256] -> [88, 300, 1141] elif self.data_rep == 'rot_vel': first_pose = output[[0]] # [1, bs, d] first_pose = self.poseFinal(first_pose) # [1, bs, 150] vel = output[1:] # [seqlen-1, bs, d] vel = self.velFinal(vel) # [seqlen-1, bs, 150] output = torch.cat((first_pose, vel), axis=0) # [seqlen, bs, 150] else: raise ValueError output = output.reshape(nframes, bs, self.njoints, self.nfeats) output = output.permute(1, 2, 3, 0) # [bs, njoints, nfeats, nframes] return output class WavEncoder(nn.Module): def __init__(self, out_dim, audio_in=1): super().__init__() self.out_dim = out_dim self.feat_extractor = nn.Sequential( BasicBlock(audio_in, out_dim//4, 15, 5, first_dilation=1700, downsample=True), BasicBlock(out_dim//4, out_dim//4, 15, 6, first_dilation=0, downsample=True), BasicBlock(out_dim//4, out_dim//4, 15, 1, first_dilation=7, ), BasicBlock(out_dim//4, out_dim//2, 15, 6, first_dilation=0, downsample=True), BasicBlock(out_dim//2, out_dim//2, 15, 1, first_dilation=7), BasicBlock(out_dim//2, out_dim, 15, 3, first_dilation=0,downsample=True), ) def forward(self, wav_data): if wav_data.dim() == 2: wav_data = wav_data.unsqueeze(1) else: wav_data = wav_data.transpose(1, 2) out = self.feat_extractor(wav_data) return out.transpose(1, 2) class SinusoidalEmbeddings(nn.Module): def __init__(self, dim): super().__init__() inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) def forward(self, x): n = x.shape[-2] t = torch.arange(n, device = x.device).type_as(self.inv_freq) freqs = torch.einsum('i , j -> i j', t, self.inv_freq) return torch.cat((freqs, freqs), dim=-1) def rotate_half(x): x = rearrange(x, 'b ... (r d) -> b (...) r d', r = 2) x1, x2 = x.unbind(dim = -2) return torch.cat((-x2, x1), dim = -1) def apply_rotary_pos_emb(q, k, freqs): q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k)) return q, k if __name__ == '__main__': ''' cd ./main/model python mdm.py ''' n_frames = 240 n_seed = 8 model = MDM(modeltype='', njoints=1140, nfeats=1, cond_mode = 'cross_local_attention5_style1', action_emb='tensor', audio_rep='mfcc', arch='mytrans_enc', latent_dim=256, n_seed=n_seed, cond_mask_prob=0.1) x = torch.randn(2, 1140, 1, 88) t = torch.tensor([12, 85]) model_kwargs_ = {'y': {}} model_kwargs_['y']['mask'] = (torch.zeros([1, 1, 1, n_frames]) < 1) # [..., n_seed:] model_kwargs_['y']['audio'] = torch.randn(2, 88, 13).permute(1, 0, 2) # [n_seed:, ...] model_kwargs_['y']['style'] = torch.randn(2, 6) model_kwargs_['y']['mask_local'] = torch.ones(2, 88).bool() model_kwargs_['y']['seed'] = x[..., 0:n_seed] y = model(x, t, model_kwargs_['y']) print(y.shape)