File size: 4,719 Bytes
4e424ea
ca753f0
4e424ea
3893fde
 
7bedcdd
4e424ea
 
 
 
 
 
 
 
ca753f0
73566e5
189dec4
2562fab
ae36822
935512c
4d169d8
 
 
 
935512c
 
 
4c0cc6b
933922d
3893fde
 
933922d
2694c8d
 
933922d
 
 
2694c8d
 
 
 
ae36822
 
2694c8d
 
3893fde
2694c8d
ae36822
2694c8d
ae36822
 
 
3893fde
 
ae36822
 
 
933922d
4e424ea
2694c8d
4e424ea
 
 
 
 
f0f4c78
4e424ea
 
 
f0f4c78
 
 
 
 
2694c8d
f0f4c78
4e424ea
ca753f0
6c641ac
 
 
935512c
2694c8d
935512c
 
 
 
f20624c
4c0cc6b
3893fde
f20624c
 
1cfe5df
935512c
2694c8d
0561c55
a0044b5
 
2694c8d
4c0cc6b
935512c
 
 
3893fde
2694c8d
 
 
933922d
3893fde
933922d
2694c8d
 
3893fde
 
ca753f0
4c0cc6b
 
f0f4c78
2694c8d
933922d
2562fab
f20624c
935512c
2562fab
f0f4c78
4e424ea
d701afa
4e424ea
 
f0f4c78
4e424ea
 
 
 
c81f025
4e424ea
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
import re 
import subprocess
import time
import threading
from tqdm import tqdm
from huggingface_hub import snapshot_download

#Download model
snapshot_download(
    repo_id = "Wan-AI/Wan2.1-T2V-1.3B",
    local_dir = "./Wan2.1-T2V-1.3B"
)

def infer(prompt, progress=gr.Progress(track_tqdm=True)):
    
    total_process_steps = 11
    irrelevant_steps = 4
    relevant_steps = total_process_steps - irrelevant_steps

    overall_bar = tqdm(total=relevant_steps, desc="Overall Process", position=1,
                       ncols=120, dynamic_ncols=False, leave=True)
    processed_steps = 0

    progress_pattern = re.compile(r"(\d+)%\|.*\| (\d+)/(\d+)")
    gen_progress_bar = None

    current_sub_bar = None
    current_cancel_event = None
    sub_lock = threading.Lock()

    def update_sub_bar(sub_bar, cancel_event):
        """Updates the sub-bar every second up to 20 seconds unless canceled."""
        for _ in range(20):
            if cancel_event.is_set():
                break
            time.sleep(1)
            with sub_lock:
                if sub_bar.n < sub_bar.total:
                    sub_bar.update(1)
                    sub_bar.refresh()

    def cancel_sub_bar():
        """Cancels the current sub-bar and advances the overall process."""
        nonlocal current_sub_bar, current_cancel_event
        with sub_lock:
            if current_cancel_event:
                current_cancel_event.set()
            if current_sub_bar:
                remaining = current_sub_bar.total - current_sub_bar.n
                if remaining > 0:
                    current_sub_bar.update(remaining)
                current_sub_bar.close()
                current_sub_bar = None
            overall_bar.update(1)
            overall_bar.refresh()
            current_cancel_event = None

    command = [
        "python", "-u", "-m", "generate",
        "--task", "t2v-1.3B",
        "--size", "832*480",
        "--ckpt_dir", "./Wan2.1-T2V-1.3B",
        "--sample_shift", "8",
        "--sample_guide_scale", "6",
        "--prompt", prompt,
        "--save_file", "generated_video.mp4"
    ]

    process = subprocess.Popen(
        command,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1
    )

    for line in iter(process.stdout.readline, ''):
        stripped_line = line.strip()
        if not stripped_line:
            continue

        # Check for video generation progress
        progress_match = progress_pattern.search(stripped_line)
        if progress_match:
            current = int(progress_match.group(2))
            total = int(progress_match.group(3))
            if gen_progress_bar is None:
                gen_progress_bar = tqdm(total=total, desc="Video Generation", position=0,
                                        ncols=120, dynamic_ncols=True, leave=True)
            gen_progress_bar.update(current - gen_progress_bar.n)
            gen_progress_bar.refresh()
            continue

        # Check for INFO messages
        if "INFO:" in stripped_line:
            parts = stripped_line.split("INFO:", 1)
            msg = parts[1].strip() if len(parts) > 1 else ""
            tqdm.write(stripped_line)

            if processed_steps < irrelevant_steps:
                processed_steps += 1
            else:
                with sub_lock:
                    # Cancel the previous sub-bar if it exists
                    cancel_sub_bar()
                    # Start a new sub-bar
                    current_sub_bar = tqdm(total=20, desc=msg, position=2,
                                           ncols=120, dynamic_ncols=False, leave=True)
                    current_cancel_event = threading.Event()
                    threading.Thread(target=update_sub_bar, args=(current_sub_bar, current_cancel_event),
                                     daemon=True).start()
            continue

        else:
            tqdm.write(stripped_line)

    process.wait()
    cancel_sub_bar()

    if gen_progress_bar:
        gen_progress_bar.close()
    overall_bar.close()
    
    if process.returncode == 0:
        print("Command executed successfully.")
        return "generated_video.mp4"
    else:
        print("Error executing command.")
        raise Exception("Error executing command")

with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown("# Wan 2.1")
        prompt = gr.Textbox(label="Prompt")
        submit_btn = gr.Button("Submit")
        video_res = gr.Video(label="Generated Video")

    submit_btn.click(
        fn = infer,
        inputs = [prompt],
        outputs = [video_res]
    )

demo.queue().launch(show_error=True, show_api=False, ssr_mode=False)