File size: 4,753 Bytes
4e424ea
ca753f0
4e424ea
3893fde
 
7bedcdd
4e424ea
 
 
 
 
 
 
 
ca753f0
73566e5
 
2562fab
73566e5
935512c
3893fde
935512c
 
 
3893fde
4c0cc6b
3893fde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e424ea
f0f4c78
4e424ea
 
 
 
 
f0f4c78
4e424ea
 
 
f0f4c78
 
 
 
 
 
 
 
4e424ea
ca753f0
6c641ac
 
 
935512c
417bb8d
935512c
 
 
 
f20624c
4c0cc6b
3893fde
f20624c
 
1cfe5df
935512c
417bb8d
0561c55
a0044b5
 
3893fde
4c0cc6b
935512c
 
 
3893fde
 
 
 
 
 
 
 
 
 
 
 
 
 
ca753f0
4c0cc6b
 
f0f4c78
3893fde
 
 
 
 
2562fab
f20624c
935512c
2562fab
f0f4c78
4e424ea
d701afa
4e424ea
 
f0f4c78
4e424ea
 
 
 
c81f025
4e424ea
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import gradio as gr
import re 
import subprocess
import time
import threading
from tqdm import tqdm
from huggingface_hub import snapshot_download

#Download model
snapshot_download(
    repo_id = "Wan-AI/Wan2.1-T2V-1.3B",
    local_dir = "./Wan2.1-T2V-1.3B"
)

def infer(prompt, progress=gr.Progress(track_tqdm=True)):
    
    total_process_steps = 11
    irrelevant_steps = 4
    relevant_steps = total_process_steps - irrelevant_steps  # 7 steps

    # Regex for detecting video generation progress lines (e.g., "10%|...| 5/50")
    progress_pattern = re.compile(r"(\d+)%\|.*\| (\d+)/(\d+)")
    gen_progress_bar = None

    # Variables for managing the sub-progress bar for each step.
    current_sub_bar = None
    current_timer = None
    sub_lock = threading.Lock()

    def close_sub_bar():
        nonlocal current_sub_bar, current_timer, overall_bar
        with sub_lock:
            if current_sub_bar is not None:
                try:
                    # Ensure the sub-bar is complete.
                    current_sub_bar.update(1 - current_sub_bar.n)
                except Exception:
                    pass
                current_sub_bar.close()
                overall_bar.update(1)
                overall_bar.refresh()
                current_sub_bar = None
                current_timer = None
    
    command = [
        "python", "-u", "-m", "generate",  # using -u for unbuffered output and omitting .py extension
        "--task", "t2v-1.3B",
        "--size", "832*480",
        "--ckpt_dir", "./Wan2.1-T2V-1.3B",
        "--sample_shift", "8",
        "--sample_guide_scale", "6",
        "--prompt", prompt,
        "--save_file", "generated_video.mp4"
    ]

    # Start the process with unbuffered output and combine stdout and stderr.
    process = subprocess.Popen(
        command,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1  # line-buffered
    )

    for line in iter(process.stdout.readline, ''):
        stripped_line = line.strip()
        if not stripped_line:
            continue

        # Check if this is a video generation progress line.
        progress_match = progress_pattern.search(stripped_line)
        if progress_match:
            current = int(progress_match.group(2))
            total = int(progress_match.group(3))
            if gen_progress_bar is None:
                gen_progress_bar = tqdm(total=total, desc="Video Generation", position=0,
                                        ncols=120, dynamic_ncols=True, leave=True)
            gen_progress_bar.update(current - gen_progress_bar.n)
            gen_progress_bar.refresh()
            continue

        # Check for INFO lines.
        if "INFO:" in stripped_line:
            parts = stripped_line.split("INFO:", 1)
            msg = parts[1].strip() if len(parts) > 1 else ""
            tqdm.write(stripped_line)  # Print the log line

            if processed_steps < irrelevant_steps:
                processed_steps += 1
            else:
                with sub_lock:
                    # If a sub-bar is active, cancel its timer and close it immediately.
                    if current_sub_bar is not None:
                        if current_timer is not None:
                            current_timer.cancel()
                        close_sub_bar()
                    # Create a new sub-bar for the current step.
                    current_sub_bar = tqdm(total=1, desc=msg, position=2,
                                           ncols=120, dynamic_ncols=False, leave=True)
                    # Start a timer to automatically close this sub-bar after 20 seconds.
                    current_timer = threading.Timer(20, close_sub_bar)
                    current_timer.start()
            continue

        else:
            tqdm.write(stripped_line)

    process.wait()
    # Clean up: if a sub-bar is still active, close it.
    if current_timer is not None:
        current_timer.cancel()
    if current_sub_bar is not None:
        close_sub_bar()
    if gen_progress_bar:
        gen_progress_bar.close()
    overall_bar.close()
    
    if process.returncode == 0:
        print("Command executed successfully.")
        return "generated_video.mp4"
    else:
        print("Error executing command.")
        raise Exception("Error executing command")

with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown("# Wan 2.1")
        prompt = gr.Textbox(label="Prompt")
        submit_btn = gr.Button("Submit")
        video_res = gr.Video(label="Generated Video")

    submit_btn.click(
        fn = infer,
        inputs = [prompt],
        outputs = [video_res]
    )

demo.queue().launch(show_error=True, show_api=False, ssr_mode=False)