File size: 4,023 Bytes
4e424ea
ca753f0
4e424ea
7bedcdd
4e424ea
 
 
 
 
 
 
 
ca753f0
73566e5
 
2562fab
73566e5
935512c
72a0d14
935512c
 
 
2fef112
 
72a0d14
935512c
72a0d14
935512c
 
4e424ea
f0f4c78
4e424ea
 
 
 
 
f0f4c78
4e424ea
 
 
f0f4c78
 
 
 
 
 
 
 
4e424ea
ca753f0
6c641ac
 
 
935512c
2fef112
935512c
 
 
 
f20624c
935512c
72a0d14
f20624c
 
2fef112
935512c
2fef112
935512c
 
 
72a0d14
2fef112
 
935512c
 
 
 
3217fc0
72a0d14
 
 
73566e5
ca753f0
2fef112
f20624c
935512c
f0f4c78
2562fab
f20624c
935512c
2562fab
f0f4c78
4e424ea
d701afa
4e424ea
 
f0f4c78
4e424ea
 
 
 
c81f025
4e424ea
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
import re 
import subprocess
from tqdm import tqdm
from huggingface_hub import snapshot_download

#Download model
snapshot_download(
    repo_id = "Wan-AI/Wan2.1-T2V-1.3B",
    local_dir = "./Wan2.1-T2V-1.3B"
)

def infer(prompt, progress=gr.Progress(track_tqdm=True)):
    
    total_process_steps = 11
    irrelevant_steps = 4
    relevant_steps = total_process_steps - irrelevant_steps  # 7 steps

    # Create an overall progress bar for the 9 relevant steps.
    overall_bar = tqdm(total=relevant_steps, desc="Overall Process", position=1, dynamic_ncols=True, leave=True)
    processed_steps = 0

    # Regex to extract the INFO message (everything after "INFO:")
    info_pattern = re.compile(r"\[.*?\]\s+INFO:\s+(.*)")
    # Regex to capture video generation progress lines (e.g., " 10%|...| 5/50")
    progress_pattern = re.compile(r"(\d+)%\|.*\| (\d+)/(\d+)")
    
    gen_progress_bar = None

    command = [
        "python", "-u", "-m", "generate",  # using -u for unbuffered output and omitting .py extension
        "--task", "t2v-1.3B",
        "--size", "832*480",
        "--ckpt_dir", "./Wan2.1-T2V-1.3B",
        "--sample_shift", "8",
        "--sample_guide_scale", "6",
        "--prompt", prompt,
        "--save_file", "generated_video.mp4"
    ]

    # Start the process with unbuffered output and combine stdout and stderr.
    process = subprocess.Popen(
        command,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1  # line-buffered
    )

    for line in iter(process.stdout.readline, ''):
        stripped_line = line.strip()
        if not stripped_line:
            continue

        # Check if this is a video generation progress line.
        progress_match = progress_pattern.search(stripped_line)
        if progress_match:
            current = int(progress_match.group(2))
            total = int(progress_match.group(3))
            if gen_progress_bar is None:
                gen_progress_bar = tqdm(total=total, desc="Video Generation", position=0, dynamic_ncols=True, leave=True)
            # Update the video generation progress bar
            gen_progress_bar.update(current - gen_progress_bar.n)
            gen_progress_bar.refresh()
            continue  # Skip further processing for progress lines

        # Check if this is an INFO log line.
        info_match = info_pattern.search(stripped_line)
        if info_match:
            msg = info_match.group(1)
            # Always print the log line (using tqdm.write so it doesn't interfere with the bars)
            tqdm.write(stripped_line)
            # For relevant steps (i.e. after the first three), update the overall progress.
            if processed_steps < irrelevant_steps:
                processed_steps += 1
            else:
                overall_bar.update(1)
                percentage = (overall_bar.n / overall_bar.total) * 100
                # Update the description for the left part and set the postfix to the INFO message.
                overall_bar.set_description(f"Overall Process - {percentage:.1f}%")
                overall_bar.set_postfix_str(msg)
                overall_bar.refresh()
        else:
            # For any other line, print it.
            tqdm.write(stripped_line)

    process.wait()
    if gen_progress_bar:
        gen_progress_bar.close()
    overall_bar.close()
    
    if process.returncode == 0:
        print("Command executed successfully.")
        return "generated_video.mp4"
    else:
        print("Error executing command.")
        raise Exception("Error executing command")

with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown("# Wan 2.1")
        prompt = gr.Textbox(label="Prompt")
        submit_btn = gr.Button("Submit")
        video_res = gr.Video(label="Generated Video")

    submit_btn.click(
        fn = infer,
        inputs = [prompt],
        outputs = [video_res]
    )

demo.queue().launch(show_error=True, show_api=False, ssr_mode=False)