Spaces:
Paused
Paused
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. | |
import argparse | |
import os.path as osp | |
import os | |
import sys | |
import warnings | |
import gradio as gr | |
warnings.filterwarnings('ignore') | |
# Model | |
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) | |
import wan | |
from wan.configs import WAN_CONFIGS | |
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander | |
from wan.utils.utils import cache_video | |
# Global Var | |
prompt_expander = None | |
wan_t2v = None | |
# Button Func | |
def prompt_enc(prompt, tar_lang): | |
global prompt_expander | |
prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower()) | |
if prompt_output.status == False: | |
return prompt | |
else: | |
return prompt_output.prompt | |
def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale, | |
shift_scale, seed, n_prompt): | |
global wan_t2v | |
# print(f"{txt2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}") | |
W = int(resolution.split("*")[0]) | |
H = int(resolution.split("*")[1]) | |
video = wan_t2v.generate( | |
txt2vid_prompt, | |
size=(W, H), | |
shift=shift_scale, | |
sampling_steps=sd_steps, | |
guide_scale=guide_scale, | |
n_prompt=n_prompt, | |
seed=seed, | |
offload_model=True) | |
cache_video( | |
tensor=video[None], | |
save_file="example.mp4", | |
fps=16, | |
nrow=1, | |
normalize=True, | |
value_range=(-1, 1)) | |
return "example.mp4" | |
# Interface | |
def gradio_interface(): | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;"> | |
Wan2.1 (T2V-14B) | |
</div> | |
<div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;"> | |
Wan: Open and Advanced Large-Scale Video Generative Models. | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
txt2vid_prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="Describe the video you want to generate", | |
) | |
tar_lang = gr.Radio( | |
choices=["CH", "EN"], | |
label="Target language of prompt enhance", | |
value="CH") | |
run_p_button = gr.Button(value="Prompt Enhance") | |
with gr.Accordion("Advanced Options", open=True): | |
resolution = gr.Dropdown( | |
label='Resolution(Width*Height)', | |
choices=[ | |
'720*1280', '1280*720', '960*960', '1088*832', | |
'832*1088', '480*832', '832*480', '624*624', | |
'704*544', '544*704' | |
], | |
value='720*1280') | |
with gr.Row(): | |
sd_steps = gr.Slider( | |
label="Diffusion steps", | |
minimum=1, | |
maximum=1000, | |
value=50, | |
step=1) | |
guide_scale = gr.Slider( | |
label="Guide scale", | |
minimum=0, | |
maximum=20, | |
value=5.0, | |
step=1) | |
with gr.Row(): | |
shift_scale = gr.Slider( | |
label="Shift scale", | |
minimum=0, | |
maximum=10, | |
value=5.0, | |
step=1) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=-1, | |
maximum=2147483647, | |
step=1, | |
value=-1) | |
n_prompt = gr.Textbox( | |
label="Negative Prompt", | |
placeholder="Describe the negative prompt you want to add" | |
) | |
run_t2v_button = gr.Button("Generate Video") | |
with gr.Column(): | |
result_gallery = gr.Video( | |
label='Generated Video', interactive=False, height=600) | |
run_p_button.click( | |
fn=prompt_enc, | |
inputs=[txt2vid_prompt, tar_lang], | |
outputs=[txt2vid_prompt]) | |
run_t2v_button.click( | |
fn=t2v_generation, | |
inputs=[ | |
txt2vid_prompt, resolution, sd_steps, guide_scale, shift_scale, | |
seed, n_prompt | |
], | |
outputs=[result_gallery], | |
) | |
return demo | |
# Main | |
def _parse_args(): | |
parser = argparse.ArgumentParser( | |
description="Generate a video from a text prompt or image using Gradio") | |
parser.add_argument( | |
"--ckpt_dir", | |
type=str, | |
default="cache", | |
help="The path to the checkpoint directory.") | |
parser.add_argument( | |
"--prompt_extend_method", | |
type=str, | |
default="local_qwen", | |
choices=["dashscope", "local_qwen"], | |
help="The prompt extend method to use.") | |
parser.add_argument( | |
"--prompt_extend_model", | |
type=str, | |
default=None, | |
help="The prompt extend model to use.") | |
args = parser.parse_args() | |
return args | |
if __name__ == '__main__': | |
args = _parse_args() | |
print("Step1: Init prompt_expander...", end='', flush=True) | |
if args.prompt_extend_method == "dashscope": | |
prompt_expander = DashScopePromptExpander( | |
model_name=args.prompt_extend_model, is_vl=False) | |
elif args.prompt_extend_method == "local_qwen": | |
prompt_expander = QwenPromptExpander( | |
model_name=args.prompt_extend_model, is_vl=False, device=0) | |
else: | |
raise NotImplementedError( | |
f"Unsupport prompt_extend_method: {args.prompt_extend_method}") | |
print("done", flush=True) | |
print("Step2: Init 14B t2v model...", end='', flush=True) | |
cfg = WAN_CONFIGS['t2v-14B'] | |
wan_t2v = wan.WanT2V( | |
config=cfg, | |
checkpoint_dir=args.ckpt_dir, | |
device_id=0, | |
rank=0, | |
t5_fsdp=False, | |
dit_fsdp=False, | |
use_usp=False, | |
) | |
print("done", flush=True) | |
demo = gradio_interface() | |
demo.launch(server_name="0.0.0.0", share=False, server_port=7860) | |