File size: 6,878 Bytes
4e424ea
ca753f0
4e424ea
3893fde
 
7bedcdd
4e424ea
 
 
 
 
 
 
 
ca753f0
73566e5
665534e
 
 
 
ae36822
935512c
665534e
4d169d8
 
 
 
665534e
935512c
665534e
935512c
665534e
4c0cc6b
933922d
3893fde
665534e
 
 
 
3893fde
933922d
665534e
2694c8d
933922d
 
 
2694c8d
 
 
 
ae36822
 
2694c8d
3893fde
2694c8d
ae36822
2694c8d
665534e
ae36822
 
 
3893fde
 
ae36822
 
 
933922d
665534e
4e424ea
665534e
4e424ea
 
 
 
 
f0f4c78
4e424ea
 
 
f0f4c78
 
 
 
 
2694c8d
f0f4c78
4e424ea
ca753f0
6c641ac
 
 
935512c
665534e
935512c
 
665534e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
935512c
 
665534e
 
 
 
 
 
 
 
 
1cfe5df
935512c
665534e
0561c55
665534e
a0044b5
 
2694c8d
4c0cc6b
665534e
935512c
 
 
665534e
 
 
 
 
3893fde
665534e
 
 
 
933922d
3893fde
665534e
 
 
 
 
 
3893fde
 
ca753f0
4c0cc6b
 
665534e
f0f4c78
665534e
 
 
 
 
 
 
935512c
2562fab
f0f4c78
4e424ea
d701afa
4e424ea
 
f0f4c78
4e424ea
 
 
 
c81f025
4e424ea
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import gradio as gr
import re 
import subprocess
import time
import threading
from tqdm import tqdm
from huggingface_hub import snapshot_download

#Download model
snapshot_download(
    repo_id = "Wan-AI/Wan2.1-T2V-1.3B",
    local_dir = "./Wan2.1-T2V-1.3B"
)

def infer(prompt, progress=gr.Progress(track_tqdm=True)):
    
    # Configuration:
    total_process_steps = 11          # Total steps (including irrelevant ones)
    irrelevant_steps = 4              # First 4 INFO messages are skipped
    # Relevant steps = 11 - 4 = 7 overall steps that will be shown
    relevant_steps = total_process_steps - irrelevant_steps

    # Create overall process progress bar (level 1)
    overall_bar = tqdm(total=relevant_steps, desc="Overall Process", position=1,
                       ncols=120, dynamic_ncols=False, leave=True)
    processed_steps = 0

    # Regex to capture video generation progress lines (for level 3)
    progress_pattern = re.compile(r"(\d+)%\|.*\| (\d+)/(\d+)")
    video_progress_bar = None

    # Variables for managing sub-step progress bar (level 2)
    current_sub_bar = None
    current_cancel_event = None
    sub_lock = threading.Lock()
    current_sub_thread = None

    # A flag to indicate if we are in video generation phase.
    video_phase = False

    def update_sub_bar(sub_bar, cancel_event):
        # Tick sub_bar once per second for up to 20 seconds
        for _ in range(20):
            if cancel_event.is_set():
                break
            time.sleep(1)
            with sub_lock:
                if sub_bar.n < sub_bar.total:
                    sub_bar.update(1)
                    sub_bar.refresh()

    def cancel_sub_bar():
        nonlocal current_sub_bar, current_cancel_event
        with sub_lock:
            if current_cancel_event:
                current_cancel_event.set()
            if current_sub_bar:
                # Finish any remaining ticks
                remaining = current_sub_bar.total - current_sub_bar.n
                if remaining > 0:
                    current_sub_bar.update(remaining)
                current_sub_bar.close()
                current_sub_bar = None
            overall_bar.update(1)
            overall_bar.refresh()
            current_cancel_event = None

    # Build the command.
    command = [
        "python", "-u", "-m", "generate",  # -u forces unbuffered output
        "--task", "t2v-1.3B",
        "--size", "832*480",
        "--ckpt_dir", "./Wan2.1-T2V-1.3B",
        "--sample_shift", "8",
        "--sample_guide_scale", "6",
        "--prompt", prompt,
        "--save_file", "generated_video.mp4"
    ]

    process = subprocess.Popen(
        command,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1
    )

    for line in iter(process.stdout.readline, ''):
        stripped_line = line.strip()
        if not stripped_line:
            continue

        # Check for video generation progress (level 3)
        progress_match = progress_pattern.search(stripped_line)
        if progress_match:
            # On the first video progress line, if not already in video phase:
            if not video_phase:
                # Cancel any active sub-step bar before entering video phase.
                with sub_lock:
                    if current_sub_bar:
                        cancel_sub_bar()
                video_phase = True
                # Initialize video progress bar.
                # Here we assume the total will come from the log; if not, adjust as needed.
                current = int(progress_match.group(2))
                total = int(progress_match.group(3))
                if video_progress_bar is None:
                    video_progress_bar = tqdm(total=total, desc="Video Generation", position=0,
                                              ncols=120, dynamic_ncols=True, leave=True)
            # Update video generation progress.
            current = int(progress_match.group(2))
            total = int(progress_match.group(3))
            video_progress_bar.update(current - video_progress_bar.n)
            video_progress_bar.refresh()
            # If video progress is complete, finish the video phase.
            if video_progress_bar.n >= video_progress_bar.total:
                video_phase = False
                overall_bar.update(1)
                overall_bar.refresh()
                video_progress_bar.close()
                video_progress_bar = None
            continue

        # Process INFO messages (level 2 sub-step)
        if "INFO:" in stripped_line:
            # Extract the text after "INFO:"
            parts = stripped_line.split("INFO:", 1)
            msg = parts[1].strip() if len(parts) > 1 else ""
            tqdm.write(stripped_line)

            # Skip the first 4 irrelevant INFO messages.
            if processed_steps < irrelevant_steps:
                processed_steps += 1
            else:
                # If we are in video phase, ignore new INFO messages (or optionally queue them).
                if video_phase:
                    continue

                # If a sub-step bar is already active, cancel it.
                with sub_lock:
                    if current_sub_bar is not None:
                        cancel_sub_bar()
                    # Create a new sub-step bar for this INFO message.
                    current_cancel_event = threading.Event()
                    current_sub_bar = tqdm(total=20, desc=msg, position=2,
                                           ncols=120, dynamic_ncols=False, leave=True)
                    current_sub_thread = threading.Thread(
                        target=update_sub_bar,
                        args=(current_sub_bar, current_cancel_event),
                        daemon=True
                    )
                    current_sub_thread.start()
            continue

        else:
            tqdm.write(stripped_line)

    # Process finished; clean up any active sub-step.
    process.wait()
    with sub_lock:
        if current_cancel_event:
            current_cancel_event.set()
        if current_sub_bar:
            cancel_sub_bar()
    if video_progress_bar:
        video_progress_bar.close()
    overall_bar.close()
    
    if process.returncode == 0:
        print("Command executed successfully.")
        return "generated_video.mp4"
    else:
        print("Error executing command.")
        raise Exception("Error executing command")

with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown("# Wan 2.1")
        prompt = gr.Textbox(label="Prompt")
        submit_btn = gr.Button("Submit")
        video_res = gr.Video(label="Generated Video")

    submit_btn.click(
        fn = infer,
        inputs = [prompt],
        outputs = [video_res]
    )

demo.queue().launch(show_error=True, show_api=False, ssr_mode=False)