|
from typing import Any, Dict, List, Tuple |
|
from pathlib import Path |
|
import os |
|
import hashlib |
|
import json |
|
import random |
|
import wandb |
|
import math |
|
import numpy as np |
|
from einops import rearrange, repeat |
|
from safetensors.torch import load_file, save_file |
|
from accelerate.logging import get_logger |
|
|
|
import torch |
|
|
|
from accelerate.utils import gather_object |
|
|
|
from diffusers import ( |
|
AutoencoderKLCogVideoX, |
|
CogVideoXDPMScheduler, |
|
CogVideoXImageToVideoPipeline, |
|
CogVideoXTransformer3DModel, |
|
) |
|
from diffusers.utils.export_utils import export_to_video |
|
|
|
from finetune.pipeline.flovd_FVSM_cogvideox_controlnet_pipeline import FloVDCogVideoXControlnetImageToVideoPipeline |
|
from finetune.constants import LOG_LEVEL, LOG_NAME |
|
|
|
from diffusers.models.embeddings import get_3d_rotary_pos_embed |
|
from PIL import Image |
|
from numpy import dtype |
|
from transformers import AutoTokenizer, T5EncoderModel |
|
from typing_extensions import override |
|
|
|
from finetune.schemas import Args, Components, State |
|
from finetune.trainer import Trainer |
|
from finetune.utils import ( |
|
cast_training_params, |
|
free_memory, |
|
get_memory_statistics, |
|
string_to_filename, |
|
unwrap_model, |
|
) |
|
from finetune.datasets.utils import ( |
|
preprocess_image_with_resize, |
|
load_binary_mask_compressed, |
|
) |
|
|
|
from finetune.modules.cogvideox_controlnet import CogVideoXControlnet |
|
from finetune.modules.cogvideox_custom_model import CustomCogVideoXTransformer3DModel |
|
from finetune.modules.camera_sampler import SampleManualCam |
|
from finetune.modules.camera_flow_generator import CameraFlowGenerator |
|
from finetune.modules.utils import get_camera_flow_generator_input, forward_bilinear_splatting |
|
|
|
from ..utils import register |
|
|
|
import sys |
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
import pdb |
|
|
|
logger = get_logger(LOG_NAME, LOG_LEVEL) |
|
|
|
class FloVDCogVideoXI2VControlnetTrainer(Trainer): |
|
UNLOAD_LIST = ["text_encoder"] |
|
|
|
@override |
|
def __init__(self, args: Args) -> None: |
|
super().__init__(args) |
|
|
|
|
|
self.CameraSampler = SampleManualCam() |
|
|
|
|
|
|
|
@override |
|
def load_components(self) -> Dict[str, Any]: |
|
|
|
components = Components() |
|
model_path = str(self.args.model_path) |
|
|
|
components.pipeline_cls = FloVDCogVideoXControlnetImageToVideoPipeline |
|
|
|
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer") |
|
|
|
components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder") |
|
|
|
|
|
|
|
components.transformer = CustomCogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer") |
|
|
|
additional_kwargs = { |
|
'num_layers': self.args.controlnet_transformer_num_layers, |
|
'out_proj_dim_factor': self.args.controlnet_out_proj_dim_factor, |
|
'out_proj_dim_zero_init': self.args.controlnet_out_proj_zero_init, |
|
'notextinflow': self.args.notextinflow, |
|
} |
|
components.controlnet = CogVideoXControlnet.from_pretrained(model_path, subfolder="transformer", **additional_kwargs) |
|
|
|
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae") |
|
|
|
components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler") |
|
|
|
return components |
|
|
|
|
|
@override |
|
def initialize_pipeline(self) -> FloVDCogVideoXControlnetImageToVideoPipeline: |
|
|
|
pipe = FloVDCogVideoXControlnetImageToVideoPipeline( |
|
tokenizer=self.components.tokenizer, |
|
text_encoder=unwrap_model(self.accelerator, self.components.text_encoder), |
|
vae=unwrap_model(self.accelerator, self.components.vae), |
|
transformer=unwrap_model(self.accelerator, self.components.transformer), |
|
controlnet=unwrap_model(self.accelerator, self.components.controlnet), |
|
scheduler=self.components.scheduler, |
|
) |
|
return pipe |
|
|
|
def initialize_flow_generator(self, ckpt_path): |
|
depth_estimator_kwargs = { |
|
"target": 'modules.depth_warping.depth_warping.DepthWarping_wrapper', |
|
"kwargs": { |
|
"ckpt_path": ckpt_path, |
|
"model_config": { |
|
"max_depth": 20, |
|
"encoder": 'vitb', |
|
"features": 128, |
|
"out_channels": [96, 192, 384, 768], |
|
} |
|
|
|
} |
|
} |
|
|
|
return CameraFlowGenerator(depth_estimator_kwargs) |
|
|
|
@override |
|
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]: |
|
ret = {"encoded_videos": [], "prompt_embedding": [], "images": [], "encoded_flow": []} |
|
|
|
for sample in samples: |
|
encoded_video = sample["encoded_video"] |
|
prompt_embedding = sample["prompt_embedding"] |
|
image = sample["image"] |
|
encoded_flow = sample["encoded_flow"] |
|
|
|
ret["encoded_videos"].append(encoded_video) |
|
ret["prompt_embedding"].append(prompt_embedding) |
|
ret["images"].append(image) |
|
ret["encoded_flow"].append(encoded_flow) |
|
|
|
|
|
ret["encoded_videos"] = torch.stack(ret["encoded_videos"]) |
|
ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"]) |
|
ret["images"] = torch.stack(ret["images"]) |
|
ret["encoded_flow"] = torch.stack(ret["encoded_flow"]) |
|
|
|
return ret |
|
|
|
|
|
@override |
|
def compute_loss(self, batch) -> torch.Tensor: |
|
prompt_embedding = batch["prompt_embedding"] |
|
latent = batch["encoded_videos"] |
|
images = batch["images"] |
|
latent_flow = batch["encoded_flow"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
patch_size_t = self.state.transformer_config.patch_size_t |
|
if patch_size_t is not None: |
|
ncopy = latent.shape[2] % patch_size_t |
|
|
|
first_frame = latent[:, :, :1, :, :] |
|
latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2) |
|
assert latent.shape[2] % patch_size_t == 0 |
|
|
|
batch_size, num_channels, num_frames, height, width = latent.shape |
|
|
|
|
|
_, seq_len, _ = prompt_embedding.shape |
|
prompt_embedding = prompt_embedding.view(batch_size, seq_len, -1).to(dtype=latent.dtype) |
|
|
|
|
|
images = images.unsqueeze(2) |
|
|
|
image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device) |
|
image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=images.dtype) |
|
noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None] |
|
image_latent_dist = self.components.vae.encode(noisy_images.to(dtype=self.components.vae.dtype)).latent_dist |
|
image_latents = image_latent_dist.sample() * self.components.vae.config.scaling_factor |
|
|
|
""" |
|
Modify below |
|
""" |
|
|
|
|
|
|
|
|
|
if self.args.enable_time_sampling: |
|
if self.args.time_sampling_type == "truncated_normal": |
|
time_sampling_dict = { |
|
'mean': self.args.time_sampling_mean, |
|
'std': self.args.time_sampling_std, |
|
'a': 1 - self.args.controlnet_guidance_end, |
|
'b': 1 - self.args.controlnet_guidance_start, |
|
} |
|
timesteps = torch.nn.init.trunc_normal_( |
|
torch.empty(batch_size, device=latent.device), **time_sampling_dict |
|
) * self.components.scheduler.config.num_train_timesteps |
|
elif self.args.time_sampling_type == "truncated_uniform": |
|
timesteps = torch.randint( |
|
int((1- self.args.controlnet_guidance_end) * self.components.scheduler.config.num_train_timesteps), |
|
int((1 - self.args.controlnet_guidance_start) * self.components.scheduler.config.num_train_timesteps), |
|
(batch_size,), device=latent.device |
|
) |
|
else: |
|
timesteps = torch.randint( |
|
0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device |
|
) |
|
timesteps = timesteps.long() |
|
|
|
|
|
latent = latent.permute(0, 2, 1, 3, 4) |
|
latent_flow = latent_flow.permute(0, 2, 1, 3, 4) |
|
image_latents = image_latents.permute(0, 2, 1, 3, 4) |
|
assert (latent.shape[0], *latent.shape[2:]) == (image_latents.shape[0], *image_latents.shape[2:]) == (latent_flow.shape[0], *latent_flow.shape[2:]) |
|
|
|
|
|
padding_shape = (latent.shape[0], latent.shape[1] - 1, *latent.shape[2:]) |
|
latent_padding = image_latents.new_zeros(padding_shape) |
|
image_latents = torch.cat([image_latents, latent_padding], dim=1) |
|
|
|
|
|
noise = torch.randn_like(latent) |
|
latent_noisy = self.components.scheduler.add_noise(latent, noise, timesteps) |
|
|
|
|
|
|
|
|
|
latent_img_noisy = torch.cat([latent_noisy, image_latents], dim=2) |
|
|
|
|
|
vae_scale_factor_spatial = 2 ** (len(self.components.vae.config.block_out_channels) - 1) |
|
transformer_config = self.state.transformer_config |
|
rotary_emb = ( |
|
self.prepare_rotary_positional_embeddings( |
|
height=height * vae_scale_factor_spatial, |
|
width=width * vae_scale_factor_spatial, |
|
num_frames=num_frames, |
|
transformer_config=transformer_config, |
|
vae_scale_factor_spatial=vae_scale_factor_spatial, |
|
device=self.accelerator.device, |
|
) |
|
if transformer_config.use_rotary_positional_embeddings |
|
else None |
|
) |
|
|
|
|
|
ofs_emb = ( |
|
None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0) |
|
) |
|
|
|
|
|
controlnet_states = self.components.controlnet( |
|
hidden_states=latent_noisy, |
|
encoder_hidden_states=prompt_embedding, |
|
image_rotary_emb=rotary_emb, |
|
controlnet_hidden_states=latent_flow, |
|
timestep=timesteps, |
|
return_dict=False, |
|
)[0] |
|
if isinstance(controlnet_states, (tuple, list)): |
|
controlnet_states = [x.to(dtype=self.state.weight_dtype) for x in controlnet_states] |
|
else: |
|
controlnet_states = controlnet_states.to(dtype=self.state.weight_dtype) |
|
|
|
|
|
|
|
predicted_noise = self.components.transformer( |
|
hidden_states=latent_img_noisy, |
|
encoder_hidden_states=prompt_embedding, |
|
controlnet_states=controlnet_states, |
|
controlnet_weights=self.args.controlnet_weights, |
|
timestep=timesteps, |
|
|
|
image_rotary_emb=rotary_emb, |
|
return_dict=False, |
|
)[0] |
|
|
|
|
|
|
|
latent_pred = self.components.scheduler.get_velocity(predicted_noise, latent_noisy, timesteps) |
|
|
|
alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps] |
|
weights = 1 / (1 - alphas_cumprod) |
|
while len(weights.shape) < len(latent_pred.shape): |
|
weights = weights.unsqueeze(-1) |
|
|
|
loss = torch.mean((weights * (latent_pred - latent) ** 2).reshape(batch_size, -1), dim=1) |
|
loss = loss.mean() |
|
|
|
return loss |
|
|
|
def prepare_rotary_positional_embeddings( |
|
self, |
|
height: int, |
|
width: int, |
|
num_frames: int, |
|
transformer_config: Dict, |
|
vae_scale_factor_spatial: int, |
|
device: torch.device, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size) |
|
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size) |
|
|
|
if transformer_config.patch_size_t is None: |
|
base_num_frames = num_frames |
|
else: |
|
base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t |
|
|
|
freqs_cos, freqs_sin = get_3d_rotary_pos_embed( |
|
embed_dim=transformer_config.attention_head_dim, |
|
crops_coords=None, |
|
grid_size=(grid_height, grid_width), |
|
temporal_size=base_num_frames, |
|
grid_type="slice", |
|
max_size=(grid_height, grid_width), |
|
device=device, |
|
) |
|
|
|
return freqs_cos, freqs_sin |
|
|
|
|
|
|
|
@override |
|
def prepare_for_validation(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_root = self.args.data_root |
|
metadata_path = data_root / "metadata_revised.jsonl" |
|
assert metadata_path.is_file(), "For this dataset type, you need metadata.jsonl or metadata_revised.jsonl in the root path" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metadata = [] |
|
with open(metadata_path, "r") as f: |
|
for line in f: |
|
metadata.append( json.loads(line) ) |
|
|
|
metadata = random.sample(metadata, self.args.max_scene) |
|
|
|
prompts = [x["prompt"] for x in metadata] |
|
prompt_embeddings = [data_root / "prompt_embeddings_revised" / (x["hash_code"] + '.safetensors') for x in metadata] |
|
videos = [data_root / "video_latent" / "x".join(str(x) for x in self.args.train_resolution) / (x["hash_code"] + '.safetensors') for x in metadata] |
|
images = [data_root / "first_frames" / (x["hash_code"] + '.png') for x in metadata] |
|
flows = [data_root / "flow_direct_f_latent" / (x["hash_code"] + '.safetensors') for x in metadata] |
|
|
|
|
|
validation_prompts = [] |
|
validation_prompt_embeddings = [] |
|
validation_video_latents = [] |
|
validation_images = [] |
|
validation_flow_latents = [] |
|
for prompt, prompt_embedding, video_latent, image, flow_latent in zip(prompts, prompt_embeddings, videos, images, flows): |
|
validation_prompts.append(prompt) |
|
validation_prompt_embeddings.append(load_file(prompt_embedding)["prompt_embedding"].unsqueeze(0)) |
|
validation_video_latents.append(load_file(video_latent)["encoded_video"].unsqueeze(0)) |
|
validation_flow_latents.append(load_file(flow_latent)["encoded_flow_f"].unsqueeze(0)) |
|
|
|
validation_images.append(image) |
|
|
|
|
|
validation_videos = [None] * len(validation_prompts) |
|
|
|
|
|
self.state.validation_prompts = validation_prompts |
|
self.state.validation_prompt_embeddings = validation_prompt_embeddings |
|
self.state.validation_images = validation_images |
|
self.state.validation_videos = validation_videos |
|
self.state.validation_video_latents = validation_video_latents |
|
self.state.validation_flow_latents = validation_flow_latents |
|
|
|
|
|
|
|
|
|
|
|
@override |
|
def validation_step( |
|
self, eval_data: Dict[str, Any], pipe: FloVDCogVideoXControlnetImageToVideoPipeline |
|
) -> List[Tuple[str, Image.Image | List[Image.Image]]]: |
|
""" |
|
Return the data that needs to be saved. For videos, the data format is List[PIL], |
|
and for images, the data format is PIL |
|
""" |
|
|
|
prompt_embedding, image, flow_latent = eval_data["prompt_embedding"], eval_data["image"], eval_data["flow_latent"] |
|
|
|
video_generate = pipe( |
|
num_frames=self.state.train_frames, |
|
height=self.state.train_height, |
|
width=self.state.train_width, |
|
prompt=None, |
|
prompt_embeds=prompt_embedding, |
|
image=image, |
|
flow_latent=flow_latent, |
|
generator=self.state.generator, |
|
num_inference_steps=50, |
|
controlnet_guidance_start = self.args.controlnet_guidance_start, |
|
controlnet_guidance_end = self.args.controlnet_guidance_end, |
|
).frames[0] |
|
return [("synthesized_video", video_generate)] |
|
|
|
|
|
@override |
|
def validate(self, step: int) -> None: |
|
|
|
logger.info("Starting validation") |
|
|
|
accelerator = self.accelerator |
|
num_validation_samples = len(self.state.validation_prompts) |
|
|
|
if num_validation_samples == 0: |
|
logger.warning("No validation samples found. Skipping validation.") |
|
return |
|
|
|
self.components.controlnet.eval() |
|
torch.set_grad_enabled(False) |
|
|
|
memory_statistics = get_memory_statistics() |
|
logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}") |
|
|
|
|
|
pipe = self.initialize_pipeline() |
|
camera_flow_generator = self.initialize_flow_generator(ckpt_path=self.args.depth_ckpt_path).to(device=self.accelerator.device, dtype=self.state.weight_dtype) |
|
|
|
if self.state.using_deepspeed: |
|
|
|
|
|
|
|
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=["controlnet"]) |
|
|
|
else: |
|
|
|
|
|
pipe.enable_model_cpu_offload(device=self.accelerator.device) |
|
|
|
|
|
|
|
pipe = pipe.to(dtype=self.state.weight_dtype) |
|
|
|
|
|
|
|
inference_type = ['training', 'inference'] |
|
|
|
for infer_type in inference_type: |
|
|
|
|
|
all_processes_artifacts = [] |
|
for i in range(num_validation_samples): |
|
if self.state.using_deepspeed and self.accelerator.deepspeed_plugin.zero_stage != 3: |
|
|
|
if i % accelerator.num_processes != accelerator.process_index: |
|
continue |
|
|
|
prompt = self.state.validation_prompts[i] |
|
image = self.state.validation_images[i] |
|
video = self.state.validation_videos[i] |
|
video_latent = self.state.validation_video_latents[i].permute(0,2,1,3,4) |
|
prompt_embedding = self.state.validation_prompt_embeddings[i] |
|
flow_latent = self.state.validation_flow_latents[i].permute(0,2,1,3,4) |
|
|
|
|
|
if image is not None: |
|
image = preprocess_image_with_resize(image, self.state.train_height, self.state.train_width) |
|
image_torch = image.detach().clone() |
|
|
|
image = image.to(torch.uint8) |
|
image = image.permute(1, 2, 0).cpu().numpy() |
|
image = Image.fromarray(image) |
|
|
|
if video is not None: |
|
video = preprocess_video_with_resize( |
|
video, self.state.train_frames, self.state.train_height, self.state.train_width |
|
) |
|
|
|
video = video.round().clamp(0, 255).to(torch.uint8) |
|
video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video] |
|
else: |
|
if infer_type == 'training': |
|
with torch.cuda.amp.autocast(enabled=True, dtype=self.state.weight_dtype): |
|
try: |
|
video_decoded = decode_latents(video_latent.to(self.accelerator.device), self.components.vae) |
|
except: |
|
pass |
|
video_decoded = decode_latents(video_latent.to(self.accelerator.device), self.components.vae) |
|
video = ((video_decoded + 1.) / 2. * 255.)[0].permute(1,0,2,3).float().clip(0., 255.).to(torch.uint8) |
|
video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video] |
|
|
|
with torch.cuda.amp.autocast(enabled=True, dtype=self.state.weight_dtype): |
|
try: |
|
flow_decoded = decode_flow(flow_latent.to(self.accelerator.device), self.components.vae, flow_scale_factor=[60, 36]) |
|
except: |
|
pass |
|
flow_decoded = decode_flow(flow_latent.to(self.accelerator.device), self.components.vae, flow_scale_factor=[60, 36]) |
|
|
|
|
|
|
|
if infer_type == 'inference': |
|
with torch.cuda.amp.autocast(enabled=True, dtype=self.state.weight_dtype): |
|
camparam, cam_name = self.CameraSampler.sample() |
|
camera_flow_generator_input = get_camera_flow_generator_input(image_torch, camparam, device=self.accelerator.device, speed=0.5) |
|
image_torch = ((image_torch.unsqueeze(0) / 255.) * 2. - 1.).to(self.accelerator.device) |
|
camera_flow, log_dict = camera_flow_generator(image_torch, camera_flow_generator_input) |
|
camera_flow = camera_flow.to(self.accelerator.device) |
|
|
|
try: |
|
flow_latent = rearrange(encode_flow(camera_flow, self.components.vae, flow_scale_factor=[60, 36]), 'b c f h w -> b f c h w').to(self.accelerator.device, self.state.weight_dtype) |
|
except: |
|
pass |
|
flow_latent = rearrange(encode_flow(camera_flow, self.components.vae, flow_scale_factor=[60, 36]), 'b c f h w -> b f c h w').to(self.accelerator.device, self.state.weight_dtype) |
|
|
|
|
|
logger.debug( |
|
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}", |
|
main_process_only=False, |
|
) |
|
|
|
validation_artifacts = self.validation_step({"prompt_embedding": prompt_embedding, "image": image, "flow_latent": flow_latent}, pipe) |
|
|
|
if ( |
|
self.state.using_deepspeed |
|
and self.accelerator.deepspeed_plugin.zero_stage == 3 |
|
and not accelerator.is_main_process |
|
): |
|
continue |
|
|
|
prompt_filename = string_to_filename(prompt)[:25] |
|
|
|
reversed_prompt = prompt[::-1] |
|
hash_suffix = hashlib.md5(reversed_prompt.encode()).hexdigest()[:5] |
|
|
|
artifacts = { |
|
"image": {"type": "image", "value": image}, |
|
"video": {"type": "video", "value": video}, |
|
} |
|
for i, (artifact_type, artifact_value) in enumerate(validation_artifacts): |
|
artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}}) |
|
if infer_type == 'training': |
|
|
|
image_tensor = repeat(rearrange(torch.tensor(np.array(image)).to(flow_decoded.device, torch.float), 'h w c -> 1 c h w'), 'b c h w -> (b f) c h w', f=flow_decoded.size(0)) |
|
warped_video = forward_bilinear_splatting(image_tensor, flow_decoded.to(torch.float)) |
|
frame_list = [] |
|
for frame in warped_video: |
|
frame = (frame.permute(1,2,0).float().detach().cpu().numpy()).astype(np.uint8).clip(0,255) |
|
frame_list.append(Image.fromarray(frame)) |
|
|
|
artifacts.update({f"artifact_warped_video_{i}": {"type": 'warped_video', "value": frame_list}}) |
|
|
|
if infer_type == 'inference': |
|
warped_video = log_dict['depth_warped_frames'] |
|
frame_list = [] |
|
for frame in warped_video: |
|
frame = (frame + 1.)/2. * 255. |
|
frame = (frame.permute(1,2,0).float().detach().cpu().numpy()).astype(np.uint8).clip(0,255) |
|
frame_list.append(Image.fromarray(frame)) |
|
|
|
artifacts.update({f"artifact_warped_video_{i}": {"type": 'warped_video', "value": frame_list}}) |
|
logger.debug( |
|
f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}", |
|
main_process_only=False, |
|
) |
|
|
|
for key, value in list(artifacts.items()): |
|
artifact_type = value["type"] |
|
artifact_value = value["value"] |
|
if artifact_type not in ["image", "video", "warped_video", "synthesized_video"] or artifact_value is None: |
|
continue |
|
|
|
extension = "png" if artifact_type == "image" else "mp4" |
|
if artifact_type == "warped_video": |
|
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}-{infer_type}_warped_video.{extension}" |
|
elif artifact_type == "synthesized_video": |
|
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}-{infer_type}_synthesized_video.{extension}" |
|
else: |
|
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}-{infer_type}.{extension}" |
|
validation_path = self.args.output_dir / "validation_res" |
|
validation_path.mkdir(parents=True, exist_ok=True) |
|
filename = str(validation_path / filename) |
|
|
|
if artifact_type == "image": |
|
logger.debug(f"Saving image to {filename}") |
|
artifact_value.save(filename) |
|
artifact_value = wandb.Image(filename) |
|
elif artifact_type == "video" or artifact_type == "warped_video" or artifact_type == "synthesized_video": |
|
logger.debug(f"Saving video to {filename}") |
|
export_to_video(artifact_value, filename, fps=self.args.gen_fps) |
|
artifact_value = wandb.Video(filename, caption=prompt) |
|
|
|
all_processes_artifacts.append(artifact_value) |
|
|
|
all_artifacts = gather_object(all_processes_artifacts) |
|
|
|
if accelerator.is_main_process: |
|
tracker_key = "validation" |
|
for tracker in accelerator.trackers: |
|
if tracker.name == "wandb": |
|
image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)] |
|
video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)] |
|
tracker.log( |
|
{ |
|
tracker_key: {f"images_{infer_type}": image_artifacts, f"videos_{infer_type}": video_artifacts}, |
|
}, |
|
step=step, |
|
) |
|
|
|
|
|
if self.state.using_deepspeed: |
|
del pipe |
|
|
|
self.__move_components_to_cpu(unload_list=self.UNLOAD_LIST) |
|
else: |
|
pipe.remove_all_hooks() |
|
del pipe |
|
|
|
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST) |
|
self.components.controlnet.to(self.accelerator.device, dtype=self.state.weight_dtype) |
|
|
|
|
|
cast_training_params([self.components.controlnet], dtype=torch.float32) |
|
|
|
del camera_flow_generator |
|
|
|
free_memory() |
|
accelerator.wait_for_everyone() |
|
|
|
|
|
memory_statistics = get_memory_statistics() |
|
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}") |
|
torch.cuda.reset_peak_memory_stats(accelerator.device) |
|
|
|
torch.set_grad_enabled(True) |
|
self.components.controlnet.train() |
|
|
|
|
|
|
|
def __move_components_to_device(self, dtype, ignore_list: List[str] = []): |
|
ignore_list = set(ignore_list) |
|
components = self.components.model_dump() |
|
for name, component in components.items(): |
|
if not isinstance(component, type) and hasattr(component, "to"): |
|
if name not in ignore_list: |
|
setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype)) |
|
|
|
|
|
def __move_components_to_cpu(self, unload_list: List[str] = []): |
|
unload_list = set(unload_list) |
|
components = self.components.model_dump() |
|
for name, component in components.items(): |
|
if not isinstance(component, type) and hasattr(component, "to"): |
|
if name in unload_list: |
|
setattr(self.components, name, component.to("cpu")) |
|
|
|
|
|
register("cogvideox-flovd", "controlnet", FloVDCogVideoXI2VControlnetTrainer) |
|
|
|
|
|
|
|
|
|
def encode_text(prompt: str, components, device) -> torch.Tensor: |
|
prompt_token_ids = components.tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=components.transformer.config.max_text_seq_length, |
|
truncation=True, |
|
add_special_tokens=True, |
|
return_tensors="pt", |
|
) |
|
prompt_token_ids = prompt_token_ids.input_ids |
|
prompt_embedding = components.text_encoder(prompt_token_ids.to(device))[0] |
|
return prompt_embedding |
|
|
|
def encode_video(video: torch.Tensor, vae) -> torch.Tensor: |
|
|
|
video = video.to(vae.device, dtype=vae.dtype) |
|
latent_dist = vae.encode(video).latent_dist |
|
latent = latent_dist.sample() * vae.config.scaling_factor |
|
return latent |
|
|
|
def decode_latents(latents: torch.Tensor, vae) -> torch.Tensor: |
|
latents = latents.permute(0, 2, 1, 3, 4) |
|
latents = 1 / vae.config.scaling_factor * latents |
|
|
|
frames = vae.decode(latents).sample |
|
return frames |
|
|
|
def compute_optical_flow(raft, ctxt, trgt, raft_iter=20, chunk=2, only_forward=True): |
|
num_frames = ctxt.shape[0] |
|
chunk_size = (num_frames // chunk) + 1 |
|
|
|
flow_f_list = [] |
|
if not only_forward: |
|
flow_b_list = [] |
|
for i in range(chunk): |
|
start = chunk_size * i |
|
end = chunk_size * (i+1) |
|
|
|
with torch.no_grad(): |
|
flow_f = raft(ctxt[start:end], trgt[start:end], num_flow_updates=raft_iter)[-1] |
|
if not only_forward: |
|
flow_b = raft(trgt[start:end], ctxt[start:end], num_flow_updates=raft_iter)[-1] |
|
|
|
flow_f_list.append(flow_f) |
|
if not only_forward: |
|
flow_b_list.append(flow_b) |
|
|
|
flow_f = torch.cat(flow_f_list) |
|
if not only_forward: |
|
flow_b = torch.cat(flow_b_list) |
|
|
|
if not only_forward: |
|
return flow_f, flow_b |
|
else: |
|
return flow_f, None |
|
|
|
def encode_flow(flow, vae, flow_scale_factor): |
|
|
|
|
|
assert flow.ndim == 4 |
|
num_frames, _, height, width = flow.shape |
|
|
|
|
|
|
|
flow = rearrange(flow, '(b f) c h w -> b f c h w', b=1) |
|
flow_norm = adaptive_normalize(flow, flow_scale_factor[0], flow_scale_factor[1]) |
|
|
|
|
|
flow_norm = rearrange(flow_norm, 'b f c h w -> (b f) c h w', b=1) |
|
|
|
|
|
num_frames, _, H, W = flow_norm.shape |
|
flow_norm_extended = torch.empty((num_frames, 3, height, width)).to(flow_norm) |
|
flow_norm_extended[:,:2] = flow_norm |
|
flow_norm_extended[:,-1:] = flow_norm.mean(dim=1, keepdim=True) |
|
flow_norm_extended = rearrange(flow_norm_extended, '(b f) c h w -> b c f h w', f=num_frames) |
|
|
|
return encode_video(flow_norm_extended, vae) |
|
|
|
def decode_flow(flow_latent, vae, flow_scale_factor): |
|
flow_latent = flow_latent.permute(0, 2, 1, 3, 4) |
|
flow_latent = 1 / vae.config.scaling_factor * flow_latent |
|
|
|
flow = vae.decode(flow_latent).sample |
|
|
|
|
|
flow = flow[:,:2].detach().clone() |
|
|
|
|
|
flow = rearrange(flow, 'b c f h w -> b f c h w') |
|
flow = adaptive_unnormalize(flow, flow_scale_factor[0], flow_scale_factor[1]) |
|
|
|
flow = rearrange(flow, 'b f c h w -> (b f) c h w') |
|
return flow |
|
|
|
def adaptive_normalize(flow, sf_x, sf_y): |
|
|
|
assert flow.ndim == 5, 'Set the shape of the flow input as (B, F, C, H, W)' |
|
assert sf_x is not None and sf_y is not None |
|
b, f, c, h, w = flow.shape |
|
|
|
max_clip_x = math.sqrt(w/sf_x) * 1.0 |
|
max_clip_y = math.sqrt(h/sf_y) * 1.0 |
|
|
|
flow_norm = flow.detach().clone() |
|
flow_x = flow[:, :, 0].detach().clone() |
|
flow_y = flow[:, :, 1].detach().clone() |
|
|
|
flow_x_norm = torch.sign(flow_x) * torch.sqrt(torch.abs(flow_x)/sf_x + 1e-7) |
|
flow_y_norm = torch.sign(flow_y) * torch.sqrt(torch.abs(flow_y)/sf_y + 1e-7) |
|
|
|
flow_norm[:, :, 0] = torch.clamp(flow_x_norm, min=-max_clip_x, max=max_clip_x) |
|
flow_norm[:, :, 1] = torch.clamp(flow_y_norm, min=-max_clip_y, max=max_clip_y) |
|
|
|
return flow_norm |
|
|
|
|
|
def adaptive_unnormalize(flow, sf_x, sf_y): |
|
|
|
assert flow.ndim == 5, 'Set the shape of the flow input as (B, F, C, H, W)' |
|
assert sf_x is not None and sf_y is not None |
|
|
|
flow_orig = flow.detach().clone() |
|
flow_x = flow[:, :, 0].detach().clone() |
|
flow_y = flow[:, :, 1].detach().clone() |
|
|
|
flow_orig[:, :, 0] = torch.sign(flow_x) * sf_x * (flow_x**2 - 1e-7) |
|
flow_orig[:, :, 1] = torch.sign(flow_y) * sf_y * (flow_y**2 - 1e-7) |
|
|
|
return flow_orig |
|
|
|
|
|
|