|
|
from typing import Dict, List, Optional, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from accelerate.logging import get_logger |
|
|
from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel |
|
|
from PIL import Image |
|
|
from transformers import T5EncoderModel, T5Tokenizer |
|
|
|
|
|
|
|
|
logger = get_logger("finetrainers") |
|
|
|
|
|
|
|
|
def load_condition_models( |
|
|
model_id: str = "Lightricks/LTX-Video", |
|
|
text_encoder_dtype: torch.dtype = torch.bfloat16, |
|
|
revision: Optional[str] = None, |
|
|
cache_dir: Optional[str] = None, |
|
|
**kwargs, |
|
|
) -> Dict[str, nn.Module]: |
|
|
tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir) |
|
|
text_encoder = T5EncoderModel.from_pretrained( |
|
|
model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir |
|
|
) |
|
|
return {"tokenizer": tokenizer, "text_encoder": text_encoder} |
|
|
|
|
|
|
|
|
def load_latent_models( |
|
|
model_id: str = "Lightricks/LTX-Video", |
|
|
vae_dtype: torch.dtype = torch.bfloat16, |
|
|
revision: Optional[str] = None, |
|
|
cache_dir: Optional[str] = None, |
|
|
**kwargs, |
|
|
) -> Dict[str, nn.Module]: |
|
|
vae = AutoencoderKLLTXVideo.from_pretrained( |
|
|
model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir |
|
|
) |
|
|
return {"vae": vae} |
|
|
|
|
|
|
|
|
def load_diffusion_models( |
|
|
model_id: str = "Lightricks/LTX-Video", |
|
|
transformer_dtype: torch.dtype = torch.bfloat16, |
|
|
revision: Optional[str] = None, |
|
|
cache_dir: Optional[str] = None, |
|
|
**kwargs, |
|
|
) -> Dict[str, nn.Module]: |
|
|
transformer = LTXVideoTransformer3DModel.from_pretrained( |
|
|
model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir |
|
|
) |
|
|
scheduler = FlowMatchEulerDiscreteScheduler() |
|
|
return {"transformer": transformer, "scheduler": scheduler} |
|
|
|
|
|
|
|
|
def initialize_pipeline( |
|
|
model_id: str = "Lightricks/LTX-Video", |
|
|
text_encoder_dtype: torch.dtype = torch.bfloat16, |
|
|
transformer_dtype: torch.dtype = torch.bfloat16, |
|
|
vae_dtype: torch.dtype = torch.bfloat16, |
|
|
tokenizer: Optional[T5Tokenizer] = None, |
|
|
text_encoder: Optional[T5EncoderModel] = None, |
|
|
transformer: Optional[LTXVideoTransformer3DModel] = None, |
|
|
vae: Optional[AutoencoderKLLTXVideo] = None, |
|
|
scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, |
|
|
device: Optional[torch.device] = None, |
|
|
revision: Optional[str] = None, |
|
|
cache_dir: Optional[str] = None, |
|
|
enable_slicing: bool = False, |
|
|
enable_tiling: bool = False, |
|
|
enable_model_cpu_offload: bool = False, |
|
|
is_training: bool = False, |
|
|
**kwargs, |
|
|
) -> LTXPipeline: |
|
|
component_name_pairs = [ |
|
|
("tokenizer", tokenizer), |
|
|
("text_encoder", text_encoder), |
|
|
("transformer", transformer), |
|
|
("vae", vae), |
|
|
("scheduler", scheduler), |
|
|
] |
|
|
components = {} |
|
|
for name, component in component_name_pairs: |
|
|
if component is not None: |
|
|
components[name] = component |
|
|
|
|
|
pipe = LTXPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir) |
|
|
pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype) |
|
|
pipe.vae = pipe.vae.to(dtype=vae_dtype) |
|
|
|
|
|
|
|
|
|
|
|
if not is_training: |
|
|
pipe.transformer = pipe.transformer.to(dtype=transformer_dtype) |
|
|
|
|
|
if enable_slicing: |
|
|
pipe.vae.enable_slicing() |
|
|
if enable_tiling: |
|
|
pipe.vae.enable_tiling() |
|
|
|
|
|
if enable_model_cpu_offload: |
|
|
pipe.enable_model_cpu_offload(device=device) |
|
|
else: |
|
|
pipe.to(device=device) |
|
|
|
|
|
return pipe |
|
|
|
|
|
|
|
|
def prepare_conditions( |
|
|
tokenizer: T5Tokenizer, |
|
|
text_encoder: T5EncoderModel, |
|
|
prompt: Union[str, List[str]], |
|
|
device: Optional[torch.device] = None, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
max_sequence_length: int = 128, |
|
|
**kwargs, |
|
|
) -> torch.Tensor: |
|
|
device = device or text_encoder.device |
|
|
dtype = dtype or text_encoder.dtype |
|
|
|
|
|
if isinstance(prompt, str): |
|
|
prompt = [prompt] |
|
|
|
|
|
return _encode_prompt_t5(tokenizer, text_encoder, prompt, device, dtype, max_sequence_length) |
|
|
|
|
|
|
|
|
def prepare_latents( |
|
|
vae: AutoencoderKLLTXVideo, |
|
|
image_or_video: torch.Tensor, |
|
|
patch_size: int = 1, |
|
|
patch_size_t: int = 1, |
|
|
device: Optional[torch.device] = None, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
generator: Optional[torch.Generator] = None, |
|
|
precompute: bool = False, |
|
|
) -> torch.Tensor: |
|
|
device = device or vae.device |
|
|
|
|
|
if image_or_video.ndim == 4: |
|
|
image_or_video = image_or_video.unsqueeze(2) |
|
|
assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor" |
|
|
|
|
|
image_or_video = image_or_video.to(device=device, dtype=vae.dtype) |
|
|
image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() |
|
|
if not precompute: |
|
|
latents = vae.encode(image_or_video).latent_dist.sample(generator=generator) |
|
|
latents = latents.to(dtype=dtype) |
|
|
_, _, num_frames, height, width = latents.shape |
|
|
latents = _normalize_latents(latents, vae.latents_mean, vae.latents_std) |
|
|
latents = _pack_latents(latents, patch_size, patch_size_t) |
|
|
return {"latents": latents, "num_frames": num_frames, "height": height, "width": width} |
|
|
else: |
|
|
if vae.use_slicing and image_or_video.shape[0] > 1: |
|
|
encoded_slices = [vae._encode(x_slice) for x_slice in image_or_video.split(1)] |
|
|
h = torch.cat(encoded_slices) |
|
|
else: |
|
|
h = vae._encode(image_or_video) |
|
|
_, _, num_frames, height, width = h.shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return { |
|
|
"latents": h, |
|
|
"num_frames": num_frames, |
|
|
"height": height, |
|
|
"width": width, |
|
|
"latents_mean": vae.latents_mean, |
|
|
"latents_std": vae.latents_std, |
|
|
} |
|
|
|
|
|
|
|
|
def post_latent_preparation( |
|
|
latents: torch.Tensor, |
|
|
latents_mean: torch.Tensor, |
|
|
latents_std: torch.Tensor, |
|
|
num_frames: int, |
|
|
height: int, |
|
|
width: int, |
|
|
patch_size: int = 1, |
|
|
patch_size_t: int = 1, |
|
|
**kwargs, |
|
|
) -> torch.Tensor: |
|
|
latents = _normalize_latents(latents, latents_mean, latents_std) |
|
|
latents = _pack_latents(latents, patch_size, patch_size_t) |
|
|
return {"latents": latents, "num_frames": num_frames, "height": height, "width": width} |
|
|
|
|
|
|
|
|
def collate_fn_t2v(batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]: |
|
|
return { |
|
|
"prompts": [x["prompt"] for x in batch[0]], |
|
|
"videos": torch.stack([x["video"] for x in batch[0]]), |
|
|
} |
|
|
|
|
|
|
|
|
def forward_pass( |
|
|
transformer: LTXVideoTransformer3DModel, |
|
|
prompt_embeds: torch.Tensor, |
|
|
prompt_attention_mask: torch.Tensor, |
|
|
latents: torch.Tensor, |
|
|
noisy_latents: torch.Tensor, |
|
|
timesteps: torch.LongTensor, |
|
|
num_frames: int, |
|
|
height: int, |
|
|
width: int, |
|
|
**kwargs, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
frame_rate = 25 |
|
|
latent_frame_rate = frame_rate / 8 |
|
|
spatial_compression_ratio = 32 |
|
|
rope_interpolation_scale = [1 / latent_frame_rate, spatial_compression_ratio, spatial_compression_ratio] |
|
|
|
|
|
denoised_latents = transformer( |
|
|
hidden_states=noisy_latents, |
|
|
encoder_hidden_states=prompt_embeds, |
|
|
timestep=timesteps, |
|
|
encoder_attention_mask=prompt_attention_mask, |
|
|
num_frames=num_frames, |
|
|
height=height, |
|
|
width=width, |
|
|
rope_interpolation_scale=rope_interpolation_scale, |
|
|
return_dict=False, |
|
|
)[0] |
|
|
|
|
|
return {"latents": denoised_latents} |
|
|
|
|
|
|
|
|
def validation( |
|
|
pipeline: LTXPipeline, |
|
|
prompt: str, |
|
|
image: Optional[Image.Image] = None, |
|
|
video: Optional[List[Image.Image]] = None, |
|
|
height: Optional[int] = None, |
|
|
width: Optional[int] = None, |
|
|
num_frames: Optional[int] = None, |
|
|
frame_rate: int = 24, |
|
|
num_videos_per_prompt: int = 1, |
|
|
generator: Optional[torch.Generator] = None, |
|
|
**kwargs, |
|
|
): |
|
|
generation_kwargs = { |
|
|
"prompt": prompt, |
|
|
"height": height, |
|
|
"width": width, |
|
|
"num_frames": num_frames, |
|
|
"frame_rate": frame_rate, |
|
|
"num_videos_per_prompt": num_videos_per_prompt, |
|
|
"generator": generator, |
|
|
"return_dict": True, |
|
|
"output_type": "pil", |
|
|
} |
|
|
generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None} |
|
|
video = pipeline(**generation_kwargs).frames[0] |
|
|
return [("video", video)] |
|
|
|
|
|
|
|
|
def _encode_prompt_t5( |
|
|
tokenizer: T5Tokenizer, |
|
|
text_encoder: T5EncoderModel, |
|
|
prompt: List[str], |
|
|
device: torch.device, |
|
|
dtype: torch.dtype, |
|
|
max_sequence_length, |
|
|
) -> torch.Tensor: |
|
|
batch_size = len(prompt) |
|
|
|
|
|
text_inputs = tokenizer( |
|
|
prompt, |
|
|
padding="max_length", |
|
|
max_length=max_sequence_length, |
|
|
truncation=True, |
|
|
add_special_tokens=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
text_input_ids = text_inputs.input_ids |
|
|
prompt_attention_mask = text_inputs.attention_mask |
|
|
prompt_attention_mask = prompt_attention_mask.bool().to(device) |
|
|
|
|
|
prompt_embeds = text_encoder(text_input_ids.to(device))[0] |
|
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) |
|
|
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) |
|
|
|
|
|
return {"prompt_embeds": prompt_embeds, "prompt_attention_mask": prompt_attention_mask} |
|
|
|
|
|
|
|
|
def _normalize_latents( |
|
|
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 |
|
|
) -> torch.Tensor: |
|
|
|
|
|
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) |
|
|
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) |
|
|
latents = (latents - latents_mean) * scaling_factor / latents_std |
|
|
return latents |
|
|
|
|
|
|
|
|
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_size, num_channels, num_frames, height, width = latents.shape |
|
|
post_patch_num_frames = num_frames // patch_size_t |
|
|
post_patch_height = height // patch_size |
|
|
post_patch_width = width // patch_size |
|
|
latents = latents.reshape( |
|
|
batch_size, |
|
|
-1, |
|
|
post_patch_num_frames, |
|
|
patch_size_t, |
|
|
post_patch_height, |
|
|
patch_size, |
|
|
post_patch_width, |
|
|
patch_size, |
|
|
) |
|
|
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) |
|
|
return latents |
|
|
|
|
|
|
|
|
LTX_VIDEO_T2V_LORA_CONFIG = { |
|
|
"pipeline_cls": LTXPipeline, |
|
|
"load_condition_models": load_condition_models, |
|
|
"load_latent_models": load_latent_models, |
|
|
"load_diffusion_models": load_diffusion_models, |
|
|
"initialize_pipeline": initialize_pipeline, |
|
|
"prepare_conditions": prepare_conditions, |
|
|
"prepare_latents": prepare_latents, |
|
|
"post_latent_preparation": post_latent_preparation, |
|
|
"collate_fn": collate_fn_t2v, |
|
|
"forward_pass": forward_pass, |
|
|
"validation": validation, |
|
|
} |
|
|
|