|
from dataclasses import dataclass |
|
from pathlib import Path |
|
import pathlib |
|
from typing import Dict, Any, Optional, Tuple |
|
import asyncio |
|
import base64 |
|
import io |
|
import pprint |
|
import logging |
|
import random |
|
import traceback |
|
import os |
|
import numpy as np |
|
import torch |
|
from diffusers import LTXPipeline, LTXImageToVideoPipeline |
|
from diffusers.hooks import apply_enhance_a_video, EnhanceAVideoConfig |
|
from PIL import Image |
|
|
|
from teacache import TeaCacheConfig, enable_teacache, disable_teacache |
|
from varnish import Varnish |
|
from varnish.utils import is_truthy, process_input_image |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
hf_token = os.getenv("HF_API_TOKEN") |
|
|
|
|
|
MAX_LARGE_SIDE = 1280 |
|
MAX_SMALL_SIDE = 768 |
|
MAX_FRAMES = (8 * 21) + 1 |
|
|
|
|
|
support_image_prompt = is_truthy(os.getenv("SUPPORT_INPUT_IMAGE_PROMPT")) |
|
|
|
@dataclass |
|
class GenerationConfig: |
|
"""Configuration for video generation""" |
|
|
|
|
|
prompt: str = "" |
|
negative_prompt: str = "saturated, highlight, overexposed, highlighted, overlit, shaking, too bright, worst quality, inconsistent motion, blurry, jittery, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles, lowres" |
|
|
|
|
|
|
|
width: int = 768 |
|
height: int = 416 |
|
|
|
|
|
|
|
|
|
input_image_quality: int = 70 |
|
|
|
|
|
|
|
|
|
num_frames: int = (8 * 14) + 1 |
|
|
|
|
|
guidance_scale: float = 3.5 |
|
|
|
num_inference_steps: int = 50 |
|
|
|
|
|
seed: int = -1 |
|
|
|
|
|
fps: int = 30 |
|
double_num_frames: bool = False |
|
super_resolution: bool = False |
|
|
|
grain_amount: float = 0.0 |
|
|
|
|
|
enable_audio: bool = False |
|
audio_prompt: str = "" |
|
audio_negative_prompt: str = "voices, voice, talking, speaking, speech" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
quality: int = 18 |
|
|
|
|
|
enable_teacache: bool = True |
|
teacache_threshold: float = 0.05 |
|
|
|
|
|
enable_enhance_a_video: bool = True |
|
enhance_a_video_weight: float = 5.0 |
|
|
|
|
|
lora_model_name: str = "" |
|
lora_model_weight_file: str = "" |
|
lora_model_trigger: str = "" |
|
|
|
def validate_and_adjust(self) -> 'GenerationConfig': |
|
"""Validate and adjust parameters to meet constraints""" |
|
|
|
if not ((self.width == MAX_LARGE_SIDE and self.height == MAX_SMALL_SIDE) or |
|
(self.width == MAX_SMALL_SIDE and self.height == MAX_LARGE_SIDE)): |
|
|
|
MAX_TOTAL_PIXELS = MAX_SMALL_SIDE * MAX_LARGE_SIDE |
|
|
|
|
|
total_pixels = self.width * self.height |
|
if total_pixels > MAX_TOTAL_PIXELS: |
|
scale = (MAX_TOTAL_PIXELS / total_pixels) ** 0.5 |
|
self.width = max(128, min(MAX_LARGE_SIDE, round(self.width * scale / 32) * 32)) |
|
self.height = max(128, min(MAX_LARGE_SIDE, round(self.height * scale / 32) * 32)) |
|
else: |
|
|
|
self.width = max(128, min(MAX_LARGE_SIDE, round(self.width / 32) * 32)) |
|
self.height = max(128, min(MAX_LARGE_SIDE, round(self.height / 32) * 32)) |
|
|
|
|
|
k = (self.num_frames - 1) // 8 |
|
self.num_frames = min((k * 8) + 1, MAX_FRAMES) |
|
|
|
|
|
if self.seed == -1: |
|
self.seed = random.randint(0, 2**32 - 1) |
|
|
|
return self |
|
|
|
class EndpointHandler: |
|
"""Handles video generation requests using LTX models and Varnish post-processing""" |
|
|
|
def __init__(self, model_path: str = ""): |
|
"""Initialize the handler with LTX models and Varnish |
|
|
|
Args: |
|
model_path: Path to LTX model weights |
|
""" |
|
|
|
|
|
|
|
if support_image_prompt: |
|
self.image_to_video = LTXImageToVideoPipeline.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.bfloat16 |
|
).to("cuda") |
|
|
|
else: |
|
|
|
self.text_to_video = LTXPipeline.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.bfloat16 |
|
).to("cuda") |
|
|
|
|
|
self._current_lora_model = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.varnish = Varnish( |
|
device="cuda", |
|
model_base_dir="/repository/varnish", |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enable_mmaudio=True, |
|
) |
|
|
|
|
|
self.text_to_video_teacache = None |
|
self.image_to_video_teacache = None |
|
|
|
def _configure_teacache(self, model, config: GenerationConfig): |
|
"""Configure TeaCache for a model based on generation config |
|
|
|
Args: |
|
model: The model to configure TeaCache for |
|
config: Generation configuration |
|
""" |
|
if config.enable_teacache: |
|
|
|
teacache_config = TeaCacheConfig( |
|
enabled=True, |
|
rel_l1_thresh=config.teacache_threshold, |
|
num_inference_steps=config.num_inference_steps |
|
) |
|
enable_teacache(model.transformer.__class__, teacache_config) |
|
logger.info(f"TeaCache enabled with threshold {config.teacache_threshold}") |
|
else: |
|
|
|
if hasattr(model.transformer.__class__, 'teacache_config'): |
|
disable_teacache(model.transformer.__class__) |
|
logger.info("TeaCache disabled") |
|
|
|
async def process_frames( |
|
self, |
|
frames: torch.Tensor, |
|
config: GenerationConfig |
|
) -> tuple[str, dict]: |
|
"""Post-process generated frames using Varnish |
|
|
|
Args: |
|
frames: Generated video frames tensor |
|
config: Generation configuration |
|
|
|
Returns: |
|
Tuple of (video data URI, metadata dictionary) |
|
""" |
|
try: |
|
|
|
result = await self.varnish( |
|
input_data=frames, |
|
fps=config.fps, |
|
double_num_frames=config.double_num_frames, |
|
super_resolution=config.super_resolution, |
|
grain_amount=config.grain_amount, |
|
enable_audio=config.enable_audio, |
|
audio_prompt=config.audio_prompt, |
|
audio_negative_prompt=config.audio_negative_prompt, |
|
) |
|
|
|
|
|
video_uri = await result.write(type="data-uri", quality=config.quality) |
|
|
|
|
|
metadata = { |
|
"width": result.metadata.width, |
|
"height": result.metadata.height, |
|
"num_frames": result.metadata.frame_count, |
|
"fps": result.metadata.fps, |
|
"duration": result.metadata.duration, |
|
"seed": config.seed, |
|
} |
|
|
|
return video_uri, metadata |
|
|
|
except Exception as e: |
|
logger.error(f"Error in process_frames: {str(e)}") |
|
raise RuntimeError(f"Failed to process frames: {str(e)}") |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
"""Process incoming requests for video generation |
|
|
|
Args: |
|
data: Request data containing: |
|
- inputs (dict): Dictionary containing input, which can be either "prompt" (text field) or "image" (input image) |
|
- parameters (dict): |
|
- prompt (required, string): list of concepts to keep in the video. |
|
- negative_prompt (optional, string): list of concepts to ignore in the video. |
|
- width (optional, int, default to 768): width, or horizontal size in pixels. |
|
- height (optional, int, default to 512): height, or vertical size in pixels. |
|
- input_image_quality (optional, int, default to 100): this is a trick we use to convert a "pristine" image into a "dirty" video frame. This helps fooling LTX-Video into turning the image into an animated one. |
|
- num_frames (optional, int, default to 129): the numer of frames must be a multiple of 8, plus 1 frame. |
|
- guidance_scale (optional, float, default to 3.5): Guidance scale (values between 3.0 and 4.0 are nice) |
|
- num_inference_steps (optional, int, default to 50): number of inference steps |
|
- seed (optional, int, default to -1): set a random number generator seed, -1 means random seed. |
|
- fps (optional, int, default to 24): FPS of the final video (eg. 24, 25, 30, 60) |
|
- double_num_frames (optional, bool): if enabled, the number of frames will be multiplied by 2 using RIFE |
|
- super_resolution (optional, bool): if enabled, the resolution will be multiplied by 2 using Real_ESRGAN |
|
- grain_amount (optional, float): amount of film grain to add to the output video |
|
- enable_audio (optional, bool): automatically generate an audio track |
|
- audio_prompt (optional, str): prompt to use for the audio generation (concepts to add) |
|
- audio_negative_prompt (optional, str): nehative prompt to use for the audio generation (concepts to ignore) |
|
- quality (optional, str, default to 18): The range of the CRF scale is 0–51, where 0 is lossless (for 8 bit only, for 10 bit use -qp 0), 23 is the default, and 51 is worst quality possible. |
|
- enable_teacache (optional, bool, default to True): Generate faster at the cost of a slight quality loss |
|
- teacache_threshold (optional, float, default to 0.05): Amount of cache, 0 (original), 0.03 (1.6x speedup), 0.05 (Default, 2.1x speedup). |
|
- enable_enhance_a_video (optional, bool, default to True): enable the enhance_a_video optimization |
|
- enhance_a_video_weight(optional, float, default to 5.0): amount of video enhancement to apply |
|
- lora_model_name(optional, str, default to ""): HuggingFace repo ID or path to LoRA model |
|
- lora_model_weight_file(optional, str, default to ""): Specific weight file to load from the LoRA model |
|
- lora_model_trigger(optional, str, default to ""): Optional trigger word to prepend to the prompt |
|
Returns: |
|
Dictionary containing: |
|
- video: Base64 encoded MP4 data URI |
|
- content-type: MIME type |
|
- metadata: Generation metadata |
|
""" |
|
inputs = data.get("inputs", dict()) |
|
|
|
input_prompt = inputs.get("prompt", "") |
|
input_image = inputs.get("image") |
|
|
|
params = data.get("parameters", dict()) |
|
|
|
if not input_image and not input_prompt: |
|
raise ValueError("Either prompt or image must be provided") |
|
|
|
|
|
|
|
|
|
|
|
config = GenerationConfig( |
|
|
|
prompt=input_prompt, |
|
negative_prompt=params.get("negative_prompt", GenerationConfig.negative_prompt), |
|
|
|
|
|
width=params.get("width", GenerationConfig.width), |
|
height=params.get("height", GenerationConfig.height), |
|
input_image_quality=params.get("input_image_quality", GenerationConfig.input_image_quality), |
|
num_frames=params.get("num_frames", GenerationConfig.num_frames), |
|
guidance_scale=params.get("guidance_scale", GenerationConfig.guidance_scale), |
|
num_inference_steps=params.get("num_inference_steps", GenerationConfig.num_inference_steps), |
|
|
|
|
|
seed=params.get("seed", GenerationConfig.seed), |
|
|
|
|
|
fps=params.get("fps", GenerationConfig.fps), |
|
double_num_frames=params.get("double_num_frames", GenerationConfig.double_num_frames), |
|
super_resolution=params.get("super_resolution", GenerationConfig.super_resolution), |
|
grain_amount=params.get("grain_amount", GenerationConfig.grain_amount), |
|
enable_audio=params.get("enable_audio", GenerationConfig.enable_audio), |
|
audio_prompt=params.get("audio_prompt", GenerationConfig.audio_prompt), |
|
audio_negative_prompt=params.get("audio_negative_prompt", GenerationConfig.audio_negative_prompt), |
|
quality=params.get("quality", GenerationConfig.quality), |
|
|
|
|
|
enable_teacache=params.get("enable_teacache", True), |
|
|
|
|
|
teacache_threshold=params.get("teacache_threshold", 0.05), |
|
|
|
|
|
|
|
enable_enhance_a_video=params.get("enable_enhance_a_video", True), |
|
enhance_a_video_weight=params.get("enhance_a_video_weight", 5.0), |
|
|
|
|
|
lora_model_name=params.get("lora_model_name", ""), |
|
lora_model_weight_file=params.get("lora_model_weight_file", ""), |
|
lora_model_trigger=params.get("lora_model_trigger", ""), |
|
).validate_and_adjust() |
|
|
|
|
|
|
|
|
|
try: |
|
with torch.inference_mode(): |
|
|
|
random.seed(config.seed) |
|
np.random.seed(config.seed) |
|
torch.manual_seed(config.seed) |
|
generator = torch.Generator(device='cuda') |
|
generator = generator.manual_seed(config.seed) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generation_kwargs = { |
|
|
|
"prompt": config.prompt, |
|
"negative_prompt": config.negative_prompt, |
|
|
|
|
|
"width": config.width, |
|
"height": config.height, |
|
"num_frames": config.num_frames, |
|
"guidance_scale": config.guidance_scale, |
|
"num_inference_steps": config.num_inference_steps, |
|
|
|
|
|
"output_type": "pt", |
|
"generator": generator, |
|
|
|
|
|
"decode_timestep": 0.05, |
|
|
|
|
|
"decode_noise_scale": 0.025, |
|
} |
|
|
|
|
|
|
|
|
|
if hasattr(self, '_current_lora_model'): |
|
if self._current_lora_model != (config.lora_model_name, config.lora_model_weight_file): |
|
|
|
if hasattr(self.text_to_video, 'unload_lora_weights'): |
|
self.text_to_video.unload_lora_weights() |
|
|
|
if support_image_prompt and hasattr(self.image_to_video, 'unload_lora_weights'): |
|
self.image_to_video.unload_lora_weights() |
|
|
|
if config.lora_model_name: |
|
|
|
if hasattr(self.text_to_video, 'load_lora_weights'): |
|
self.text_to_video.load_lora_weights( |
|
config.lora_model_name, |
|
weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None, |
|
token=hf_token, |
|
) |
|
if support_image_prompt and hasattr(self.image_to_video, 'load_lora_weights'): |
|
self.image_to_video.load_lora_weights( |
|
config.lora_model_name, |
|
weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None, |
|
token=hf_token, |
|
) |
|
self._current_lora_model = (config.lora_model_name, config.lora_model_weight_file) |
|
|
|
|
|
if config.lora_model_trigger: |
|
generation_kwargs["prompt"] = f"{config.lora_model_trigger} {generation_kwargs['prompt']}" |
|
|
|
enhance_a_video_config = EnhanceAVideoConfig( |
|
weight=config.enhance_a_video_weight if config.enable_enhance_a_video else 0.0, |
|
|
|
num_frames_callback=lambda: (8 + 1), |
|
|
|
|
|
|
|
_attention_type=1 |
|
) |
|
|
|
|
|
if support_image_prompt and input_image: |
|
self._configure_teacache(self.image_to_video, config) |
|
processed_image = process_input_image( |
|
input_image, |
|
config.width, |
|
config.height, |
|
config.input_image_quality, |
|
) |
|
generation_kwargs["image"] = processed_image |
|
|
|
|
|
frames = self.image_to_video(**generation_kwargs).frames |
|
else: |
|
self._configure_teacache(self.text_to_video, config) |
|
|
|
|
|
frames = self.text_to_video(**generation_kwargs).frames |
|
|
|
try: |
|
loop = asyncio.get_event_loop() |
|
except RuntimeError: |
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
|
|
video_uri, metadata = loop.run_until_complete(self.process_frames(frames, config)) |
|
|
|
return { |
|
"video": video_uri, |
|
"content-type": "video/mp4", |
|
"metadata": metadata |
|
} |
|
|
|
except Exception as e: |
|
message = f"Error generating video ({str(e)})\n{traceback.format_exc()}" |
|
print(message) |
|
raise RuntimeError(message) |