File size: 4,722 Bytes
33bdb5b
e18fe42
 
 
33bdb5b
 
44ef737
33bdb5b
 
e18fe42
 
 
 
 
 
b262be3
e18fe42
33bdb5b
 
 
 
 
 
 
 
7f14f6f
 
 
 
37b2b3a
33bdb5b
 
 
37b2b3a
33bdb5b
44ef737
33bdb5b
1d2fe7b
33bdb5b
7f14f6f
 
33bdb5b
7f14f6f
 
 
 
33bdb5b
7f14f6f
44ef737
33bdb5b
 
 
4df0ad4
7f14f6f
33bdb5b
7f14f6f
 
 
 
 
 
 
33bdb5b
7f14f6f
33bdb5b
 
7f14f6f
e18fe42
 
44ef737
33bdb5b
e18fe42
 
33bdb5b
 
 
 
 
 
e18fe42
 
37b2b3a
e18fe42
33bdb5b
e18fe42
37b2b3a
e18fe42
 
33bdb5b
37b2b3a
44ef737
e18fe42
33bdb5b
 
 
 
 
 
 
 
 
 
 
 
 
 
37b2b3a
e18fe42
 
33bdb5b
 
 
 
 
 
 
 
 
37b2b3a
33bdb5b
37b2b3a
e18fe42
33bdb5b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# PyTorch nightly for CUDA compatibility
import os
os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')

# Imports
import spaces
import torch
from diffusers import WanPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video
import gradio as gr
import tempfile
import random
import numpy as np

# Constants
MODEL_ID = "Runware/Wan2.2-T2V-A14B"
FIXED_FPS = 16
MAX_SEED = np.iinfo(np.int32).max
DEFAULT_HEIGHT = 720
DEFAULT_WIDTH = 1280
MAX_FRAMES = 81

# Prompts
default_prompt_t2v = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
default_negative_prompt = (
    "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,"
    "最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,"
    "画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
)

# Load pipeline
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32

vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32).to(device)

pipe = WanPipeline.from_pretrained(MODEL_ID, vae=vae, torch_dtype=dtype).to(device)

# Optional: warm-up
_ = pipe(
    prompt="warmup",
    negative_prompt=default_negative_prompt,
    height=512,
    width=768,
    num_frames=8,
    num_inference_steps=2,
    generator=torch.Generator(device=device).manual_seed(0)
).frames[0]

# Space-aware duration helper
def get_duration(prompt, negative_prompt, height, width, num_frames, guidance_scale, guidance_scale_2, steps, seed, randomize_seed, progress):
    return int(steps * 15)

@spaces.GPU(duration=get_duration)
def generate_t2v(
    prompt,
    negative_prompt,
    height,
    width,
    num_frames,
    guidance_scale,
    guidance_scale_2,
    steps,
    seed,
    randomize_seed,
    progress=gr.Progress(track_tqdm=True),
):
    current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
    generator = torch.Generator(device=device).manual_seed(current_seed)

    output_frames = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        height=int(height),
        width=int(width),
        num_frames=int(num_frames),
        guidance_scale=float(guidance_scale),
        guidance_scale_2=float(guidance_scale_2),
        num_inference_steps=int(steps),
        generator=generator,
    ).frames[0]

    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
        export_to_video(output_frames, tmpfile.name, fps=FIXED_FPS)
        return tmpfile.name, current_seed

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## 🎬 Wan 2.2 T2V: Text-to-Video via Wan-AI")

    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(label="Prompt", value=default_prompt_t2v)
            negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
            height_slider = gr.Slider(360, 1024, step=16, value=DEFAULT_HEIGHT, label="Height")
            width_slider = gr.Slider(360, 1920, step=16, value=DEFAULT_WIDTH, label="Width")
            frames_slider = gr.Slider(8, MAX_FRAMES, value=MAX_FRAMES, step=1, label="Frames")

            with gr.Accordion("Advanced Settings", open=False):
                guidance_slider = gr.Slider(0.0, 20.0, step=0.5, value=4.0, label="Guidance Scale")
                guidance2_slider = gr.Slider(0.0, 20.0, step=0.5, value=3.0, label="Guidance Scale 2")
                steps_slider = gr.Slider(1, 60, step=1, value=40, label="Inference Steps")
                seed_slider = gr.Slider(0, MAX_SEED, step=1, value=42, label="Seed", interactive=True)
                randomize_seed_check = gr.Checkbox(label="Randomize Seed", value=True)

            generate_button = gr.Button("🎥 Generate Video", variant="primary")

        with gr.Column():
            video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
            used_seed = gr.Number(label="Used Seed", interactive=False)

    inputs = [
        prompt_input, negative_prompt_input,
        height_slider, width_slider,
        frames_slider,
        guidance_slider, guidance2_slider,
        steps_slider, seed_slider, randomize_seed_check
    ]

    generate_button.click(fn=generate_t2v, inputs=inputs, outputs=[video_output, used_seed])

if __name__ == "__main__":
    demo.queue().launch(mcp_server=True)