Spaces:
Build error
Build error
import spaces | |
import gradio as gr | |
import numpy as np | |
import os | |
import torch | |
from PIL import Image | |
from pathlib import Path | |
from diffusers import HunyuanVideoPipeline | |
from huggingface_hub import snapshot_download | |
# Configuration | |
LORA_CHOICES = [ | |
"Top_Off.safetensors", | |
"huanyan_helper.safetensors", | |
"huanyan_helper_alpha.safetensors", | |
"hunyuan-t-solo-v1.0.safetensors", | |
"stripe_v2.safetensors" | |
] | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 1024 | |
# Initialize pipeline with ZeroGPU optimizations | |
model_id = "Tencent-Hunyuan/Hunyuan-Video-Lite" | |
pipe = HunyuanVideoPipeline.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16 | |
).to("cuda") | |
# Load all available LoRAs | |
for lora_file in LORA_CHOICES: | |
try: | |
pipe.load_lora_weights( | |
"Sergidev/TTV4ME", | |
weight_name=lora_file, | |
adapter_name=lora_file.split('.')[0], | |
token=os.environ.get("HF_TOKEN") | |
) | |
except Exception as e: | |
print(f"Error loading {lora_file}: {str(e)}") | |
def generate( | |
prompt, | |
image_input, | |
height, | |
width, | |
num_frames, | |
num_inference_steps, | |
seed_value, | |
fps, | |
selected_loras, | |
lora_weights, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
# Image validation | |
if image_input is not None: | |
img = Image.open(image_input) | |
if img.size != (width, height): | |
raise gr.Error(f"Image resolution {img.size} must match video resolution {width}x{height}") | |
prompt = f"Image prompt: {prompt}" if prompt else "Based on uploaded image" | |
# Set active LoRAs | |
active_adapters = [] | |
adapter_weights = [] | |
for idx, selected in enumerate(selected_loras): | |
if selected: | |
active_adapters.append(LORA_CHOICES[idx].split('.')[0]) | |
adapter_weights.append(lora_weights[idx]) | |
if active_adapters: | |
pipe.set_adapters(active_adapters, adapter_weights) | |
# Generation logic | |
torch.cuda.empty_cache() | |
if seed_value == -1: | |
seed_value = torch.randint(0, MAX_SEED, (1,)).item() | |
generator = torch.Generator('cuda').manual_seed(seed_value) | |
try: | |
if image_input: | |
output = pipe.image_to_video( | |
Image.open(image_input).convert("RGB"), | |
prompt=prompt, | |
height=height, | |
width=width, | |
num_frames=num_frames, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
) | |
else: | |
output = pipe.text_to_video( | |
prompt=prompt, | |
height=height, | |
width=width, | |
num_frames=num_frames, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
) | |
return output.video | |
finally: | |
torch.cuda.empty_cache() | |
def apply_preset(preset_name): | |
if preset_name == "Higher Resolution": | |
return [608, 448, 24, 29, 12] | |
elif preset_name == "More Frames": | |
return [512, 320, 42, 27, 14] | |
return [512, 512, 24, 25, 12] | |
css = """ | |
/* Existing CSS remains unchanged */ | |
""" | |
with gr.Blocks(css=css, theme="dark") as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown("# 🎬 Hunyuan Studio", elem_classes=["title"]) | |
gr.Markdown( | |
"""Text-to-Video & Image-to-Video generation with multiple LoRA adapters.<br> | |
Ensure image resolution matches selected video dimensions.""", | |
elem_classes=["description"] | |
) | |
with gr.Column(elem_classes=["prompt-container"]): | |
prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="Enter text prompt or describe the image...", | |
elem_classes=["prompt-textbox"], | |
lines=3 | |
) | |
image_input = gr.Image( | |
label="Upload Reference Image (Optional)", | |
type="filepath", | |
visible=True | |
) | |
with gr.Row(): | |
run_button = gr.Button("🎬 Generate Video", variant="primary", size="lg") | |
with gr.Row(elem_classes=["preset-buttons"]): | |
preset_high_res = gr.Button("📺 Resolution Preset") | |
preset_more_frames = gr.Button("🎞️ Frames Preset") | |
with gr.Row(): | |
result = gr.Video(label="Generated Video") | |
with gr.Accordion("⚙️ Advanced Settings", open=False): | |
with gr.Row(): | |
seed = gr.Slider( | |
label="Seed (-1 for random)", | |
minimum=-1, | |
maximum=MAX_SEED, | |
step=1, | |
value=-1, | |
) | |
with gr.Row(): | |
height = gr.Slider( | |
label="Height", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=16, | |
value=512, | |
) | |
width = gr.Slider( | |
label="Width", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=16, | |
value=512, | |
) | |
with gr.Row(): | |
num_frames = gr.Slider( | |
label="Frame Count", | |
minimum=1, | |
maximum=257, | |
step=1, | |
value=24, | |
) | |
num_inference_steps = gr.Slider( | |
label="Inference Steps", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=25, | |
) | |
fps = gr.Slider( | |
label="FPS", | |
minimum=1, | |
maximum=60, | |
step=1, | |
value=12, | |
) | |
with gr.Accordion("🧩 LoRA Configuration", open=False): | |
lora_checkboxes = [] | |
lora_sliders = [] | |
for lora in LORA_CHOICES: | |
with gr.Row(): | |
cb = gr.Checkbox(label=f"Enable {lora}", value=False) | |
sl = gr.Slider(0.0, 1.0, value=0.8, label=f"{lora} Weight") | |
lora_checkboxes.append(cb) | |
lora_sliders.append(sl) | |
# Event handling | |
run_button.click( | |
fn=generate, | |
inputs=[prompt, image_input, height, width, num_frames, | |
num_inference_steps, seed, fps, lora_checkboxes, lora_sliders], | |
outputs=result | |
) | |
preset_high_res.click( | |
fn=lambda: apply_preset("Higher Resolution"), | |
outputs=[height, width, num_frames, num_inference_steps, fps] | |
) | |
preset_more_frames.click( | |
fn=lambda: apply_preset("More Frames"), | |
outputs=[height, width, num_frames, num_inference_steps, fps] | |
) | |