File size: 3,883 Bytes
e18fe42
 
 
44ef737
e18fe42
 
 
 
 
44ef737
 
e18fe42
 
 
 
 
 
37b2b3a
44ef737
e18fe42
 
37b2b3a
e18fe42
 
 
44ef737
 
e18fe42
 
1d2fe7b
e18fe42
44ef737
4df0ad4
 
e18fe42
 
 
44ef737
e18fe42
 
 
 
 
 
 
 
 
 
 
37b2b3a
e18fe42
 
 
37b2b3a
e18fe42
 
 
37b2b3a
44ef737
e18fe42
 
 
 
 
 
 
 
 
 
 
37b2b3a
e18fe42
37b2b3a
e18fe42
 
 
37b2b3a
e18fe42
44ef737
e18fe42
 
37b2b3a
 
e18fe42
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')

import torch
import gradio as gr
import tempfile
import random
import numpy as np
import spaces
from diffusers import WanPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video

# Constants
MODEL_ID = "Wan-AI/Wan2.2-T2V-A14B-Diffusers"
MAX_SEED = np.iinfo(np.int32).max
FIXED_FPS = 16
DEFAULT_NEGATIVE_PROMPT = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"

# Setup
dtype = torch.float16  # switched to float16 for stability
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load model
vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
pipe = WanPipeline.from_pretrained(MODEL_ID, vae=vae, torch_dtype=dtype)
pipe.to(device)

# Prime the pipeline (warm-up to reduce first-run latency)
_ = pipe(prompt="warmup", negative_prompt=DEFAULT_NEGATIVE_PROMPT, height=512, width=768, num_frames=8, num_inference_steps=2).frames[0]

# GPU duration estimator


@spaces.GPU(duration=200)
def generate_video(prompt, negative_prompt, height, width, num_frames, guidance_scale, guidance_scale_2, num_steps, seed, randomize_seed):
    current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
    generator = torch.Generator(device=device).manual_seed(current_seed)

    output = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        height=height,
        width=width,
        num_frames=num_frames,
        guidance_scale=guidance_scale,
        guidance_scale_2=guidance_scale_2,
        num_inference_steps=num_steps,
        generator=generator,
    ).frames[0]

    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
        export_to_video(output, tmpfile.name, fps=FIXED_FPS)
        return tmpfile.name, current_seed

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## 🎬 Wan2.2 Text-to-Video Generator with HF Spaces GPU")

    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt", value="Two anthropomorphic cats in comfy boxing gear fight intensely.")
            negative_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE_PROMPT, lines=3)
            height = gr.Slider(360, 1024, value=720, step=16, label="Height")
            width = gr.Slider(360, 1920, value=1280, step=16, label="Width")
            num_frames = gr.Slider(8, 81, value=81, step=1, label="Number of Frames")
            num_steps = gr.Slider(10, 60, value=40, step=1, label="Inference Steps")
            guidance_scale = gr.Slider(1.0, 10.0, value=4.0, step=0.5, label="Guidance Scale")
            guidance_scale_2 = gr.Slider(1.0, 10.0, value=3.0, step=0.5, label="Guidance Scale 2")
            seed = gr.Slider(0, MAX_SEED, value=42, step=1, label="Seed")
            randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)

            generate_button = gr.Button("🎥 Generate Video")

        with gr.Column():
            video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
            final_seed_display = gr.Number(label="Used Seed", interactive=False)

    generate_button.click(
        fn=generate_video,
        inputs=[prompt, negative_prompt, height, width, num_frames, guidance_scale, guidance_scale_2, num_steps, seed, randomize_seed],
        outputs=[video_output, final_seed_display],
    )

if __name__ == "__main__":
    demo.queue().launch()