|
import inspect |
|
from typing import Optional |
|
from einops import rearrange |
|
import torch |
|
import torch.nn.functional as F |
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler |
|
from diffusers.schedulers.scheduling_ddim import DDIMScheduler |
|
from diffusers.schedulers.scheduling_pndm import PNDMScheduler |
|
|
|
from torch import Tensor |
|
from tqdm import tqdm |
|
from diffusers import ModelMixin |
|
from .model_utils import get_custom_betas |
|
from .point_model import PointModel |
|
import copy |
|
import torch.nn as nn |
|
|
|
class TemporalSmoothnessLoss(nn.Module): |
|
def __init__(self): |
|
super(TemporalSmoothnessLoss, self).__init__() |
|
|
|
def forward(self, input): |
|
|
|
diff = input[:, 1:, :] - input[:, :-1, :] |
|
|
|
|
|
smoothness_loss = torch.mean(torch.sum(diff ** 2, dim=2)) |
|
|
|
return smoothness_loss |
|
|
|
class ConditionalPointCloudDiffusionModel(ModelMixin): |
|
def __init__( |
|
self, |
|
beta_start: float = 1e-5, |
|
beta_end: float = 8e-3, |
|
beta_schedule: str = 'linear', |
|
point_cloud_model: str = 'simple', |
|
point_cloud_model_embed_dim: int = 64, |
|
): |
|
super().__init__() |
|
self.in_channels = 70 |
|
self.out_channels = 70 |
|
|
|
|
|
|
|
scheduler_kwargs = {} |
|
if beta_schedule == 'custom': |
|
scheduler_kwargs.update(dict(trained_betas=get_custom_betas(beta_start=beta_start, beta_end=beta_end))) |
|
else: |
|
scheduler_kwargs.update(dict(beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule)) |
|
self.schedulers_map = { |
|
'ddpm': DDPMScheduler(**scheduler_kwargs, clip_sample=False), |
|
'ddim': DDIMScheduler(**scheduler_kwargs, clip_sample=False), |
|
'pndm': PNDMScheduler(**scheduler_kwargs), |
|
} |
|
self.scheduler = self.schedulers_map['ddim'] |
|
|
|
|
|
self.point_model = PointModel( |
|
model_type=point_cloud_model, |
|
embed_dim=point_cloud_model_embed_dim, |
|
in_channels=self.in_channels, |
|
out_channels=self.out_channels, |
|
) |
|
|
|
def forward_train( |
|
self, |
|
pc: Optional[Tensor], |
|
ref_kps: Optional[Tensor], |
|
ori_kps: Optional[Tensor], |
|
aud_feat: Optional[Tensor], |
|
mode: str = 'train', |
|
return_intermediate_steps: bool = False |
|
): |
|
|
|
|
|
x_0 = pc |
|
B, Nf, Np, D = x_0.shape |
|
|
|
|
|
x_0=x_0[:,:,:,0] |
|
|
|
|
|
noise = torch.randn_like(x_0) |
|
|
|
|
|
timestep = torch.randint(0, self.scheduler.num_train_timesteps, (B,), |
|
device=self.device, dtype=torch.long) |
|
|
|
|
|
x_t = self.scheduler.add_noise(x_0, noise, timestep) |
|
|
|
|
|
ref_kps = ref_kps[:, :, 0] |
|
|
|
x_t_input = torch.cat([ori_kps.unsqueeze(1), ref_kps.unsqueeze(1), x_t], dim=1) |
|
|
|
aud_feat = torch.cat([torch.zeros(B, 2, 512).cuda(), aud_feat], 1) |
|
|
|
|
|
if mode in 'train': |
|
if torch.rand(1) > 0.3: |
|
mean = torch.mean(aud_feat) |
|
std = torch.std(aud_feat) |
|
sample = torch.normal(mean=torch.full(aud_feat.shape, mean), std=torch.full(aud_feat.shape, std)).cuda() |
|
aud_feat = sample + aud_feat |
|
else: |
|
pass |
|
else: |
|
pass |
|
|
|
|
|
noise_pred = self.point_model(x_t_input, timestep, context=aud_feat) |
|
noise_pred = noise_pred[:, 2:] |
|
|
|
|
|
if not noise_pred.shape == noise.shape: |
|
raise ValueError(f'{noise_pred.shape=} and {noise.shape=}') |
|
|
|
loss = F.mse_loss(noise_pred, noise) |
|
|
|
loss_pose = F.mse_loss(noise_pred[:, :, 1:7], noise[:, :, 1:7]) |
|
loss_exp = F.mse_loss(noise_pred[:, :, 7:], noise[:, :, 7:]) |
|
|
|
|
|
|
|
if return_intermediate_steps: |
|
return loss, (x_0, x_t, noise, noise_pred) |
|
|
|
return loss, loss_exp, loss_pose |
|
|
|
@torch.no_grad() |
|
def forward_sample( |
|
self, |
|
num_points: int, |
|
ref_kps: Optional[Tensor], |
|
ori_kps: Optional[Tensor], |
|
aud_feat: Optional[Tensor], |
|
|
|
scheduler: Optional[str] = 'ddpm', |
|
|
|
num_inference_steps: Optional[int] = 50, |
|
eta: Optional[float] = 0.0, |
|
|
|
return_sample_every_n_steps: int = -1, |
|
|
|
disable_tqdm: bool = False, |
|
): |
|
|
|
|
|
scheduler = self.scheduler if scheduler is None else self.schedulers_map[scheduler] |
|
|
|
|
|
Np = num_points |
|
Nf = aud_feat.size(1) |
|
B = 1 |
|
D = 3 |
|
device = self.device |
|
|
|
|
|
x_t = torch.randn(B, Nf, Np, D, device=device) |
|
|
|
x_t = x_t[:, :, :, 0] |
|
|
|
ref_kps = ref_kps[:,:,0] |
|
|
|
|
|
accepts_offset = "offset" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
|
extra_set_kwargs = {"offset": 1} if accepts_offset else {} |
|
scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) |
|
|
|
accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys()) |
|
extra_step_kwargs = {"eta": eta} if accepts_eta else {} |
|
|
|
|
|
all_outputs = [] |
|
return_all_outputs = (return_sample_every_n_steps > 0) |
|
progress_bar = tqdm(scheduler.timesteps.to(device), desc=f'Sampling ({x_t.shape})', disable=disable_tqdm) |
|
|
|
aud_feat = torch.cat([torch.zeros(B, 2, 512).cuda(), aud_feat], 1) |
|
|
|
for i, t in enumerate(progress_bar): |
|
x_t_input = torch.cat([ori_kps.unsqueeze(1).detach(),ref_kps.unsqueeze(1).detach(), x_t], dim=1) |
|
|
|
|
|
noise_pred = self.point_model(x_t_input, t.reshape(1).expand(B), context=aud_feat)[:, 2:] |
|
|
|
|
|
x_t = scheduler.step(noise_pred, t, x_t, **extra_step_kwargs).prev_sample |
|
|
|
|
|
if (return_all_outputs and (i % return_sample_every_n_steps == 0 or i == len(scheduler.timesteps) - 1)): |
|
all_outputs.append(x_t) |
|
|
|
|
|
output = x_t |
|
output = torch.stack([output,output,output],-1) |
|
if return_all_outputs: |
|
all_outputs = torch.stack(all_outputs, dim=1) |
|
return (output, all_outputs) if return_all_outputs else output |
|
|
|
def forward(self, batch: dict, mode: str = 'train', **kwargs): |
|
"""A wrapper around the forward method for training and inference""" |
|
|
|
if mode == 'train': |
|
return self.forward_train( |
|
pc=batch['sequence_keypoints'], |
|
ref_kps=batch['ref_keypoint'], |
|
ori_kps=batch['ori_keypoint'], |
|
aud_feat=batch['aud_feat'], |
|
mode='train', |
|
**kwargs) |
|
elif mode == 'val': |
|
return self.forward_train( |
|
pc=batch['sequence_keypoints'], |
|
ref_kps=batch['ref_keypoint'], |
|
ori_kps=batch['ori_keypoint'], |
|
aud_feat=batch['aud_feat'], |
|
mode='val', |
|
**kwargs) |
|
elif mode == 'sample': |
|
num_points = 70 |
|
return self.forward_sample( |
|
num_points=num_points, |
|
ref_kps=batch['ref_keypoint'], |
|
ori_kps=batch['ori_keypoint'], |
|
aud_feat=batch['aud_feat'], |
|
**kwargs) |
|
else: |
|
raise NotImplementedError() |