|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
def add_first_frame_conditioning( |
|
latent_model_input, |
|
first_frame, |
|
vae |
|
): |
|
""" |
|
Adds first frame conditioning to a video diffusion model input. |
|
|
|
Args: |
|
latent_model_input: Original latent input (bs, channels, num_frames, height, width) |
|
first_frame: Tensor of first frame to condition on (bs, channels, height, width) |
|
vae: VAE model for encoding the conditioning |
|
|
|
Returns: |
|
conditioned_latent: The complete conditioned latent input (bs, 36, num_frames, height, width) |
|
""" |
|
device = latent_model_input.device |
|
dtype = latent_model_input.dtype |
|
vae_scale_factor_temporal = 2 ** sum(vae.temperal_downsample) |
|
|
|
|
|
_, _, num_latent_frames, _, _ = latent_model_input.shape |
|
|
|
|
|
|
|
|
|
num_frames = (num_latent_frames - 1) * 4 + 1 |
|
|
|
if len(first_frame.shape) == 3: |
|
|
|
first_frame = first_frame.unsqueeze(0) |
|
|
|
|
|
if first_frame.shape[0] != latent_model_input.shape[0]: |
|
first_frame = first_frame.expand(latent_model_input.shape[0], -1, -1, -1) |
|
|
|
|
|
vae_scale_factor = 8 |
|
first_frame = F.interpolate( |
|
first_frame, |
|
size=(latent_model_input.shape[3] * vae_scale_factor, latent_model_input.shape[4] * vae_scale_factor), |
|
mode='bilinear', |
|
align_corners=False |
|
) |
|
|
|
|
|
first_frame = first_frame.unsqueeze(2) |
|
|
|
|
|
zero_frame = torch.zeros_like(first_frame) |
|
video_condition = torch.cat([ |
|
first_frame, |
|
*[zero_frame for _ in range(num_frames - 1)] |
|
], dim=2) |
|
|
|
|
|
|
|
|
|
|
|
latent_condition = vae.encode( |
|
video_condition.to(device, dtype) |
|
).latent_dist.sample() |
|
latent_condition = latent_condition.to(device, dtype) |
|
|
|
|
|
batch_size = first_frame.shape[0] |
|
latent_height = latent_condition.shape[3] |
|
latent_width = latent_condition.shape[4] |
|
|
|
|
|
mask_lat_size = torch.ones( |
|
batch_size, 1, num_frames, latent_height, latent_width) |
|
|
|
|
|
mask_lat_size[:, :, list(range(1, num_frames))] = 0 |
|
|
|
|
|
first_frame_mask = mask_lat_size[:, :, 0:1] |
|
first_frame_mask = torch.repeat_interleave( |
|
first_frame_mask, dim=2, repeats=vae_scale_factor_temporal) |
|
|
|
|
|
mask_lat_size = torch.concat( |
|
[first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) |
|
|
|
|
|
mask_lat_size = mask_lat_size.view( |
|
batch_size, -1, vae_scale_factor_temporal, latent_height, latent_width) |
|
mask_lat_size = mask_lat_size.transpose(1, 2) |
|
mask_lat_size = mask_lat_size.to(device, dtype) |
|
|
|
|
|
first_frame_condition = torch.concat( |
|
[mask_lat_size, latent_condition], dim=1) |
|
conditioned_latent = torch.cat( |
|
[latent_model_input, first_frame_condition], dim=1) |
|
|
|
return conditioned_latent |
|
|