import inspect from typing import Optional from einops import rearrange import torch import torch.nn.functional as F from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.schedulers.scheduling_ddim import DDIMScheduler from diffusers.schedulers.scheduling_pndm import PNDMScheduler from torch import Tensor from tqdm import tqdm from diffusers import ModelMixin from .model_utils import get_custom_betas from .point_model import PointModel import copy class ConditionalPointCloudDiffusionModel(ModelMixin): def __init__( self, beta_start: float = 1e-5, beta_end: float = 8e-3, beta_schedule: str = 'linear', point_cloud_model: str = 'simple', point_cloud_model_embed_dim: int = 64, ): super().__init__() self.in_channels = 70 # 3 for 3D point positions self.out_channels = 70 # Checks # Create diffusion model schedulers which define the sampling timesteps scheduler_kwargs = {} if beta_schedule == 'custom': scheduler_kwargs.update(dict(trained_betas=get_custom_betas(beta_start=beta_start, beta_end=beta_end))) else: scheduler_kwargs.update(dict(beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule)) self.schedulers_map = { 'ddpm': DDPMScheduler(**scheduler_kwargs, clip_sample=False), 'ddim': DDIMScheduler(**scheduler_kwargs, clip_sample=False), 'pndm': PNDMScheduler(**scheduler_kwargs), } self.scheduler = self.schedulers_map['ddim'] # this can be changed for inference # Create point cloud model for processing point cloud at each diffusion step self.point_model = PointModel( model_type=point_cloud_model, embed_dim=point_cloud_model_embed_dim, in_channels=self.in_channels, out_channels=self.out_channels, ) def forward_train( self, pc: Optional[Tensor], ref_kps: Optional[Tensor], ori_kps: Optional[Tensor], aud_feat: Optional[Tensor], mode: str = 'train', return_intermediate_steps: bool = False ): # Normalize colors and convert to tensor x_0 = pc B, Nf, Np, D = x_0.shape# batch, nums of frames, nums of points, 3 x_0=x_0[:,:,:,0]# batch, nums of frames, 70 # Sample random noise noise = torch.randn_like(x_0) # Sample random timesteps for each point_cloud timestep = torch.randint(0, self.scheduler.num_train_timesteps, (B,), device=self.device, dtype=torch.long) # Add noise to points x_t = self.scheduler.add_noise(x_0, noise, timestep) # Conditioning ref_kps = ref_kps[:, :, 0] x_t_input = torch.cat([ori_kps.unsqueeze(1), ref_kps.unsqueeze(1), x_t], dim=1) # x_t_input = torch.cat([ref_kps.unsqueeze(1), x_t], dim=1) # ori_kps_repeat = torch.repeat_interleave(ori_kps.unsqueeze(1), repeats=Nf+1, dim=1) # x_t_input = torch.cat([x_t_input, ori_kps_repeat], dim=-1) #B, 32+1, 51+45 aud_feat = torch.cat([torch.zeros(B, 2, 512).cuda(), aud_feat], 1) # Augmentation for audio feature if mode in 'train': if torch.rand(1) > 0.3: mean = torch.mean(aud_feat) std = torch.std(aud_feat) sample = torch.normal(mean=torch.full(aud_feat.shape, mean), std=torch.full(aud_feat.shape, std)).cuda() aud_feat = sample + aud_feat else: pass else: pass # Forward noise_pred = self.point_model(x_t_input, timestep, context=aud_feat) #torch.cat([mel_feat,style_embed],-1)) noise_pred = noise_pred[:, 2:] # # Check if not noise_pred.shape == noise.shape: raise ValueError(f'{noise_pred.shape=} and {noise.shape=}') # Loss loss = F.mse_loss(noise_pred, noise) loss_pose = F.mse_loss(noise_pred[:, :, :6], noise[:, :, :6]) loss_exp = F.mse_loss(noise_pred[:, :, 6:], noise[:, :, 6:]) # Whether to return intermediate steps if return_intermediate_steps: return loss, (x_0, x_t, noise, noise_pred) return loss, loss_exp, loss_pose # def forward_train( # self, # pc: Optional[Tensor], # ref_kps: Optional[Tensor], # ori_kps: Optional[Tensor], # aud_feat: Optional[Tensor], # mode: str = 'train', # return_intermediate_steps: bool = False # ): # # # Normalize colors and convert to tensor # x_0 = pc # B, Nf, Np, D = x_0.shape# batch, nums of frames, nums of points, 3 # # # ori_kps = torch.repeat_interleave(ori_kps.unsqueeze(1), Nf, dim=1) # B, Nf, 45 # # # # ref_kps = ref_kps[:, :, 0] # # ref_kps = torch.repeat_interleave(ref_kps.unsqueeze(1), Nf, dim=1) # B, Nf, 91 # # x_0 = x_0[:,:,:,0] # # # Sample random noise # noise = torch.randn_like(x_0) # # # Sample random timesteps for each point_cloud # timestep = torch.randint(0, self.scheduler.num_train_timesteps, (B,), # device=self.device, dtype=torch.long) # # # Add noise to points # x_t = self.scheduler.add_noise(x_0, noise, timestep) # # # Conditioning # ref_kps = ref_kps[:,:,0] # # # x_t_input = torch.cat([ref_kps.unsqueeze(1), x_t], dim=1) # # # x_0 = torch.cat([x_0, ref_kps, ori_kps], dim=2) # B, Nf, 91+91+45 # # x_t_input = torch.cat([ref_kps.unsqueeze(1), x_t], dim=1) # # x_t_input = torch.cat([ori_kps.unsqueeze(1), ref_kps.unsqueeze(1), x_t], dim=1) # # aud_feat = torch.cat([torch.zeros(B, 1, 512).cuda(), aud_feat], 1) # # # Augmentation for audio feature # if mode in 'train': # if torch.rand(1) > 0.3: # mean = torch.mean(aud_feat) # std = torch.std(aud_feat) # sample = torch.normal(mean=torch.full(aud_feat.shape, mean), std=torch.full(aud_feat.shape, std)).cuda() # aud_feat = sample + aud_feat # else: # pass # else: # pass # # # Forward # noise_pred = self.point_model(x_t_input, timestep, context=aud_feat) # noise_pred = noise_pred[:, 1:] # # # Check # # if not noise_pred.shape == noise.shape: # # raise ValueError(f'{noise_pred.shape=} and {noise.shape=}') # # # Loss # loss = F.mse_loss(noise_pred, noise) # # # loss_kp = F.mse_loss(noise_pred[:, :, :45], noise[:, :, :45]) # # # Whether to return intermediate steps # if return_intermediate_steps: # return loss, (x_0, x_t, noise, noise_pred) # # return loss # @torch.no_grad() # def forward_sample( # self, # num_points: int, # ref_kps: Optional[Tensor], # ori_kps: Optional[Tensor], # aud_feat: Optional[Tensor], # # Optional overrides # scheduler: Optional[str] = 'ddpm', # # Inference parameters # num_inference_steps: Optional[int] = 1000, # eta: Optional[float] = 0.0, # for DDIM # # Whether to return all the intermediate steps in generation # return_sample_every_n_steps: int = -1, # # Whether to disable tqdm # disable_tqdm: bool = False, # ): # # # Get scheduler from mapping, or use self.scheduler if None # scheduler = self.scheduler if scheduler is None else self.schedulers_map[scheduler] # # # Get the size of the noise # Np = num_points # Nf = aud_feat.size(1) # B = 1 # D = 3 # device = self.device # # # Sample noise # x_t = torch.randn(B, Nf, Np, D, device=device) # # x_t = x_t[:, :, :, 0] # # # ori_kps = torch.repeat_interleave(ori_kps.unsqueeze(1), Nf, dim=1) # B, Nf, 45 # # ref_kps = ref_kps[:, :, 0] # # ref_kps = torch.repeat_interleave(ref_kps.unsqueeze(1), Nf, dim=1) # B, Nf, 91 # # # Set timesteps # accepts_offset = "offset" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) # extra_set_kwargs = {"offset": 1} if accepts_offset else {} # scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) # # # Prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # # and should be between [0, 1] # accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys()) # extra_step_kwargs = {"eta": eta} if accepts_eta else {} # # # Loop over timesteps # all_outputs = [] # return_all_outputs = (return_sample_every_n_steps > 0) # progress_bar = tqdm(scheduler.timesteps.to(device), desc=f'Sampling ({x_t.shape})', disable=disable_tqdm) # # # ori_kps = torch.repeat_interleave(ori_kps[:, 6:].unsqueeze(1), Nf + 1, dim=1) # aud_feat = torch.cat([torch.zeros(B, 1, 512).cuda(), aud_feat], 1) # # aud_feat = torch.cat([ori_kps, aud_feat], -1) # # # aud_feat = torch.cat([torch.zeros(B, 2, 512).cuda(), aud_feat], 1) # # for i, t in enumerate(progress_bar): # # # Conditioning # x_t_input = torch.cat([ref_kps.unsqueeze(1).detach(), x_t], dim=1) # # x_t_input = torch.cat([ori_kps.unsqueeze(1).detach(), ref_kps.unsqueeze(1).detach(), x_t], dim=1) # # x_t_input = torch.cat([x_t, ref_kps, ori_kps], dim=2) # B, Nf, 91+91+45 # # # Forward # # noise_pred = self.point_model(x_t_input, t.reshape(1).expand(B), context=aud_feat)[:, 1:] # noise_pred = self.point_model(x_t_input, t.reshape(1).expand(B), context=aud_feat)[:, 1:] # # # noise_pred = noise_pred[:, :, :51] # # # Step # # x_t = x_t[:, :, :51] # x_t = scheduler.step(noise_pred, t, x_t, **extra_step_kwargs).prev_sample # # # Append to output list if desired # if (return_all_outputs and (i % return_sample_every_n_steps == 0 or i == len(scheduler.timesteps) - 1)): # all_outputs.append(x_t) # # # Convert output back into a point cloud, undoing normalization and scaling # output = x_t # output = torch.stack([output, output, output], -1) # if return_all_outputs: # all_outputs = torch.stack(all_outputs, dim=1) # (B, sample_steps, N, D) # return (output, all_outputs) if return_all_outputs else output @torch.no_grad() def forward_sample( self, num_points: int, ref_kps: Optional[Tensor], ori_kps: Optional[Tensor], aud_feat: Optional[Tensor], # Optional overrides scheduler: Optional[str] = 'ddpm', # Inference parameters num_inference_steps: Optional[int] = 1000, eta: Optional[float] = 0.0, # for DDIM # Whether to return all the intermediate steps in generation return_sample_every_n_steps: int = -1, # Whether to disable tqdm disable_tqdm: bool = False, ): # Get scheduler from mapping, or use self.scheduler if None scheduler = self.scheduler if scheduler is None else self.schedulers_map[scheduler] # Get the size of the noise Np = num_points Nf = aud_feat.size(1) B = 1 D = 3 device = self.device # Sample noise x_t = torch.randn(B, Nf, Np, D, device=device) x_t = x_t[:, :, :, 0] ref_kps = ref_kps[:,:,0] # Set timesteps accepts_offset = "offset" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) extra_set_kwargs = {"offset": 1} if accepts_offset else {} scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) # Prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys()) extra_step_kwargs = {"eta": eta} if accepts_eta else {} # Loop over timesteps all_outputs = [] return_all_outputs = (return_sample_every_n_steps > 0) progress_bar = tqdm(scheduler.timesteps.to(device), desc=f'Sampling ({x_t.shape})', disable=disable_tqdm) # ori_kps = torch.repeat_interleave(ori_kps[:, 6:].unsqueeze(1), Nf + 1, dim=1) # aud_feat = torch.cat([torch.zeros(B, 1, 512).cuda(), aud_feat], 1) # aud_feat = torch.cat([ori_kps, aud_feat], -1) aud_feat = torch.cat([torch.zeros(B, 2, 512).cuda(), aud_feat], 1) for i, t in enumerate(progress_bar): # Conditioning # x_t_input = torch.cat([ref_kps.unsqueeze(1), x_t], dim=1) # # ori_kps_repeat = torch.repeat_interleave(ori_kps.unsqueeze(1), repeats=Nf + 1, dim=1) # # x_t_input = torch.cat([x_t_input.detach(), ori_kps_repeat.detach()], dim=-1) # B, 32+1, 51+45 x_t_input = torch.cat([ori_kps.unsqueeze(1).detach(),ref_kps.unsqueeze(1).detach(), x_t], dim=1) # x_t_input = torch.cat([ref_kps.unsqueeze(1).detach(), x_t], dim=1) # Forward # noise_pred = self.point_model(x_t_input, t.reshape(1).expand(B), context=aud_feat)[:, 1:] noise_pred = self.point_model(x_t_input, t.reshape(1).expand(B), context=aud_feat)[:, 2:] # Step x_t = scheduler.step(noise_pred, t, x_t, **extra_step_kwargs).prev_sample # Append to output list if desired if (return_all_outputs and (i % return_sample_every_n_steps == 0 or i == len(scheduler.timesteps) - 1)): all_outputs.append(x_t) # Convert output back into a point cloud, undoing normalization and scaling output = x_t output = torch.stack([output,output,output],-1) if return_all_outputs: all_outputs = torch.stack(all_outputs, dim=1) # (B, sample_steps, N, D) return (output, all_outputs) if return_all_outputs else output def forward(self, batch: dict, mode: str = 'train', **kwargs): """A wrapper around the forward method for training and inference""" if mode == 'train': return self.forward_train( pc=batch['sequence_keypoints'], ref_kps=batch['ref_keypoint'], ori_kps=batch['ori_keypoint'], aud_feat=batch['aud_feat'], mode='train', **kwargs) elif mode == 'val': return self.forward_train( pc=batch['sequence_keypoints'], ref_kps=batch['ref_keypoint'], ori_kps=batch['ori_keypoint'], aud_feat=batch['aud_feat'], mode='val', **kwargs) elif mode == 'sample': num_points = 68 return self.forward_sample( num_points=num_points, ref_kps=batch['ref_keypoint'], ori_kps=batch['ori_keypoint'], aud_feat=batch['aud_feat'], **kwargs) else: raise NotImplementedError()