File size: 5,726 Bytes
4e424ea ca753f0 4e424ea 3893fde 7bedcdd 4e424ea ca753f0 73566e5 189dec4 2562fab ae36822 935512c ae36822 4d169d8 ae36822 935512c 3893fde 4c0cc6b 933922d 3893fde 933922d ae36822 933922d ae36822 933922d 3893fde ae36822 3893fde ae36822 3893fde ae36822 933922d 4e424ea ae36822 4e424ea f0f4c78 4e424ea f0f4c78 4e424ea ca753f0 6c641ac 935512c ae36822 935512c f20624c 4c0cc6b 3893fde f20624c 1cfe5df 935512c 417bb8d 0561c55 ae36822 a0044b5 ae36822 4c0cc6b 935512c 3893fde ae36822 3893fde ae36822 933922d 3893fde 933922d ae36822 933922d 3893fde ca753f0 4c0cc6b ae36822 f0f4c78 933922d ae36822 933922d 2562fab f20624c 935512c 2562fab f0f4c78 4e424ea d701afa 4e424ea f0f4c78 4e424ea c81f025 4e424ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import gradio as gr
import re
import subprocess
import time
import threading
from tqdm import tqdm
from huggingface_hub import snapshot_download
#Download model
snapshot_download(
repo_id = "Wan-AI/Wan2.1-T2V-1.3B",
local_dir = "./Wan2.1-T2V-1.3B"
)
def infer(prompt, progress=gr.Progress(track_tqdm=True)):
total_process_steps = 11
irrelevant_steps = 4
# Only steps 5 through 11 (i.e. 7 steps) count.
relevant_steps = total_process_steps - irrelevant_steps
# Create the overall progress bar for the steps.
overall_bar = tqdm(total=relevant_steps, desc="Overall Process", position=1,
ncols=120, dynamic_ncols=False, leave=True)
processed_steps = 0
# Regex for detecting video generation progress lines (e.g. "10%|...| 5/50")
progress_pattern = re.compile(r"(\d+)%\|.*\| (\d+)/(\d+)")
gen_progress_bar = None
# Variables for managing the sub-progress bar for each step.
current_sub_bar = None
current_sub_thread = None
current_cancel_event = None
sub_lock = threading.Lock()
def update_sub_bar(sub_bar, cancel_event):
# Update sub-bar once per second for up to 20 seconds,
# unless cancel_event is set.
for i in range(20):
if cancel_event.is_set():
break
time.sleep(1)
sub_bar.update(1)
sub_bar.refresh()
# (Closing and overall-bar update are handled externally.)
def cancel_sub_bar():
nonlocal current_sub_bar, current_sub_thread, current_cancel_event
with sub_lock:
if current_cancel_event is not None:
current_cancel_event.set()
if current_sub_thread is not None:
current_sub_thread.join(timeout=1)
current_sub_thread = None
if current_sub_bar is not None:
# Complete any remaining ticks.
remaining = current_sub_bar.total - current_sub_bar.n
if remaining > 0:
current_sub_bar.update(remaining)
current_sub_bar.close()
current_sub_bar = None
# Update overall progress by one step.
overall_bar.update(1)
overall_bar.refresh()
current_cancel_event = None
command = [
"python", "-u", "-m", "generate", # using unbuffered mode
"--task", "t2v-1.3B",
"--size", "832*480",
"--ckpt_dir", "./Wan2.1-T2V-1.3B",
"--sample_shift", "8",
"--sample_guide_scale", "6",
"--prompt", prompt,
"--save_file", "generated_video.mp4"
]
process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1 # line-buffered
)
for line in iter(process.stdout.readline, ''):
stripped_line = line.strip()
if not stripped_line:
continue
# Check for video generation progress lines.
progress_match = progress_pattern.search(stripped_line)
if progress_match:
current = int(progress_match.group(2))
total = int(progress_match.group(3))
if gen_progress_bar is None:
gen_progress_bar = tqdm(total=total, desc="Video Generation", position=0,
ncols=120, dynamic_ncols=True, leave=True)
gen_progress_bar.update(current - gen_progress_bar.n)
gen_progress_bar.refresh()
continue
# Check for INFO lines.
if "INFO:" in stripped_line:
# Extract the INFO message (the text after "INFO:")
parts = stripped_line.split("INFO:", 1)
msg = parts[1].strip() if len(parts) > 1 else ""
tqdm.write(stripped_line) # Log the full line
if processed_steps < irrelevant_steps:
processed_steps += 1
else:
with sub_lock:
# If a sub-bar is active, cancel it immediately.
if current_sub_bar is not None:
cancel_sub_bar()
# Create a new sub-progress bar for this step (lasting up to 20 seconds).
current_sub_bar = tqdm(total=20, desc=msg, position=2,
ncols=120, dynamic_ncols=False, leave=True)
current_cancel_event = threading.Event()
current_sub_thread = threading.Thread(target=update_sub_bar, args=(current_sub_bar, current_cancel_event))
current_sub_thread.daemon = True
current_sub_thread.start()
continue
else:
tqdm.write(stripped_line)
# Process has ended; cancel any active sub-progress bar.
process.wait()
with sub_lock:
if current_cancel_event is not None:
current_cancel_event.set()
if current_sub_bar is not None:
cancel_sub_bar()
if gen_progress_bar:
gen_progress_bar.close()
overall_bar.close()
if process.returncode == 0:
print("Command executed successfully.")
return "generated_video.mp4"
else:
print("Error executing command.")
raise Exception("Error executing command")
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown("# Wan 2.1")
prompt = gr.Textbox(label="Prompt")
submit_btn = gr.Button("Submit")
video_res = gr.Video(label="Generated Video")
submit_btn.click(
fn = infer,
inputs = [prompt],
outputs = [video_res]
)
demo.queue().launch(show_error=True, show_api=False, ssr_mode=False) |