multimodalart HF Staff commited on
Commit
8268b44
·
verified ·
1 Parent(s): 1e531a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -196
app.py CHANGED
@@ -6,296 +6,157 @@ import gradio as gr
6
  import tempfile
7
  import spaces
8
  from huggingface_hub import hf_hub_download
9
- import logging
10
  import numpy as np
11
  from PIL import Image
12
- import random # Added for random seed generation
13
 
14
- # --- Global Model Loading & LoRA Handling ---
15
  MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
16
  LORA_REPO_ID = "Kijai/WanVideo_comfy"
17
  LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
18
 
19
- # Configure logging
20
- logging.basicConfig(level=logging.INFO)
21
- logger = logging.getLogger(__name__)
22
-
23
- # --- Model Loading ---
24
- logger.info(f"Loading Image Encoder for {MODEL_ID}...")
25
- image_encoder = CLIPVisionModel.from_pretrained(
26
- MODEL_ID,
27
- subfolder="image_encoder",
28
- torch_dtype=torch.float32 # Using float32 for image encoder as sometimes bfloat16/float16 can be problematic
29
- )
30
-
31
- logger.info(f"Loading VAE for {MODEL_ID}...")
32
- vae = AutoencoderKLWan.from_pretrained(
33
- MODEL_ID,
34
- subfolder="vae",
35
- torch_dtype=torch.float32 # Using float32 for VAE for precision
36
- )
37
- logger.info(f"Loading Pipeline {MODEL_ID}...")
38
  pipe = WanImageToVideoPipeline.from_pretrained(
39
- MODEL_ID,
40
- vae=vae,
41
- image_encoder=image_encoder,
42
- torch_dtype=torch.bfloat16 # Main pipeline can use bfloat16 for speed/memory
43
- )
44
- flow_shift = 8.0
45
- pipe.scheduler = UniPCMultistepScheduler.from_config(
46
- pipe.scheduler.config, flow_shift=flow_shift
47
  )
48
- logger.info("Moving pipeline to CUDA...")
49
  pipe.to("cuda")
50
 
51
- # --- LoRA Loading ---
52
- logger.info(f"Downloading LoRA {LORA_FILENAME} from {LORA_REPO_ID}...")
53
  causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
54
-
55
- logger.info("Loading LoRA weights...")
56
  pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
57
- logger.info("Setting LoRA adapter...")
58
- pipe.set_adapters(["causvid_lora"], adapter_weights=[1.0])
59
 
60
- # --- Constants for Dimension Calculation ---
61
  MOD_VALUE = 32
62
- MOD_VALUE_H = MOD_VALUE_W = MOD_VALUE
63
-
64
  DEFAULT_H_SLIDER_VALUE = 512
65
  DEFAULT_W_SLIDER_VALUE = 896
 
66
 
67
- # New fixed max_area for the calculation formula
68
- NEW_FORMULA_MAX_AREA = float(480 * 832)
69
-
70
- SLIDER_MIN_H = 128
71
- SLIDER_MAX_H = 896
72
- SLIDER_MIN_W = 128
73
- SLIDER_MAX_W = 896
74
-
75
- # --- Constant for Seed ---
76
  MAX_SEED = np.iinfo(np.int32).max
77
 
 
 
 
 
78
  def _calculate_new_dimensions_wan(pil_image: Image.Image, mod_val: int, calculation_max_area: float,
79
  min_slider_h: int, max_slider_h: int,
80
  min_slider_w: int, max_slider_w: int,
81
  default_h: int, default_w: int) -> tuple[int, int]:
82
  orig_w, orig_h = pil_image.size
83
-
84
- if orig_w <= 0 or orig_h <= 0: # Changed to <= 0 for robustness
85
- logger.warning(f"Uploaded image has non-positive width or height ({orig_w}x{orig_h}). Using default slider dimensions.")
86
  return default_h, default_w
87
 
88
  aspect_ratio = orig_h / orig_w
89
 
90
- sqrt_h_term = np.sqrt(calculation_max_area * aspect_ratio)
91
- sqrt_w_term = np.sqrt(calculation_max_area / aspect_ratio)
92
-
93
- calc_h = round(sqrt_h_term) // mod_val * mod_val
94
- calc_w = round(sqrt_w_term) // mod_val * mod_val
95
-
96
- calc_h = mod_val if calc_h < mod_val else calc_h
97
- calc_w = mod_val if calc_w < mod_val else calc_w
98
-
99
- effective_min_h = min_slider_h
100
- effective_min_w = min_slider_w
101
 
102
- effective_max_h_from_slider = (max_slider_h // mod_val) * mod_val
103
- effective_max_w_from_slider = (max_slider_w // mod_val) * mod_val
 
 
 
104
 
105
- new_h = int(np.clip(calc_h, effective_min_h, effective_max_h_from_slider))
106
- new_w = int(np.clip(calc_w, effective_min_w, effective_max_w_from_slider))
107
-
108
- logger.info(f"Auto-dim: Original {orig_w}x{orig_h} (AR: {aspect_ratio:.2f}). Max Area for calc: {calculation_max_area}.")
109
- logger.info(f"Auto-dim: Sqrt terms HxW: {sqrt_h_term:.0f}x{sqrt_w_term:.0f}. Calculated (round(sqrt_term)//{mod_val}*{mod_val}): {calc_h}x{calc_w}.")
110
- logger.info(f"Auto-dim: Clamped HxW: {new_h}x{new_w} (Effective H_range:[{effective_min_h}-{effective_max_h_from_slider}], Effective W_range:[{effective_min_w}-{effective_max_w_from_slider}]).")
111
-
112
  return new_h, new_w
113
 
114
  def handle_image_upload_for_dims_wan(uploaded_pil_image: Image.Image | None, current_h_val: int, current_w_val: int):
115
  if uploaded_pil_image is None:
116
- logger.info("Image cleared. Resetting dimensions to default slider values.")
117
  return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
118
  try:
119
  new_h, new_w = _calculate_new_dimensions_wan(
120
- uploaded_pil_image,
121
- MOD_VALUE,
122
- NEW_FORMULA_MAX_AREA, # Use the globally defined max_area for the new formula
123
- SLIDER_MIN_H, SLIDER_MAX_H,
124
- SLIDER_MIN_W, SLIDER_MAX_W,
125
  DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE
126
  )
127
  return gr.update(value=new_h), gr.update(value=new_w)
128
  except Exception as e:
129
- logger.error(f"Error auto-adjusting H/W from image: {e}", exc_info=True)
130
- # Fallback to default slider values on error, as in the original code
131
  return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
132
 
133
-
134
- # --- Gradio Interface Function ---
135
  @spaces.GPU
136
  def generate_video(input_image: Image.Image, prompt: str, negative_prompt: str,
137
  height: int, width: int, duration_seconds: float,
138
  guidance_scale: float, steps: int,
139
  seed: int, randomize_seed: bool,
140
  progress=gr.Progress(track_tqdm=True)):
 
141
  if input_image is None:
142
  raise gr.Error("Please upload an input image.")
143
 
144
- # Constants for frame calculation
145
- FIXED_FPS = 24
146
- MIN_FRAMES_MODEL = 8
147
- MAX_FRAMES_MODEL = 81
148
-
149
- logger.info("Starting video generation...")
150
- logger.info(f" Input Image: Uploaded (Original size: {input_image.size if input_image else 'N/A'})")
151
- logger.info(f" Prompt: {prompt}")
152
- logger.info(f" Negative Prompt: {negative_prompt if negative_prompt else 'None'}")
153
- logger.info(f" Target Output Height: {height}, Target Output Width: {width}")
154
-
155
- target_height = int(height)
156
- target_width = int(width)
157
- guidance_scale_val = float(guidance_scale)
158
- steps_val = int(steps)
159
-
160
- num_frames_for_pipeline = int(round(duration_seconds * FIXED_FPS))
161
- num_frames_for_pipeline = max(MIN_FRAMES_MODEL, min(MAX_FRAMES_MODEL, num_frames_for_pipeline))
162
- if num_frames_for_pipeline < MIN_FRAMES_MODEL:
163
- num_frames_for_pipeline = MIN_FRAMES_MODEL
164
-
165
- logger.info(f" Duration: {duration_seconds:.1f}s, Fixed FPS (conditioning & export): {FIXED_FPS}")
166
- logger.info(f" Calculated Num Frames: {num_frames_for_pipeline} (clamped to [{MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL}])")
167
- logger.info(f" Guidance Scale: {guidance_scale_val}, Steps: {steps_val}")
168
 
169
- # Seed logic
170
- current_seed = int(seed)
171
- if randomize_seed:
172
- current_seed = random.randint(0, MAX_SEED)
173
- logger.info(f" Initial Seed: {seed}, Randomize: {randomize_seed}, Using Seed: {current_seed}")
174
-
175
-
176
- if target_height % MOD_VALUE_H != 0:
177
- logger.warning(f"Height {target_height} is not a multiple of {MOD_VALUE_H}. Adjusting...")
178
- target_height = (target_height // MOD_VALUE_H) * MOD_VALUE_H
179
- if target_width % MOD_VALUE_W != 0:
180
- logger.warning(f"Width {target_width} is not a multiple of {MOD_VALUE_W}. Adjusting...")
181
- target_width = (target_width // MOD_VALUE_W) * MOD_VALUE_W
182
-
183
- target_height = max(MOD_VALUE_H, target_height if target_height > 0 else MOD_VALUE_H)
184
- target_width = max(MOD_VALUE_W, target_width if target_width > 0 else MOD_VALUE_W)
185
-
186
 
187
- resized_image = input_image.resize((target_width, target_height))
188
- logger.info(f" Input image resized to: {resized_image.size} for pipeline input.")
189
 
190
  with torch.inference_mode():
191
  output_frames_list = pipe(
192
- image=resized_image,
193
- prompt=prompt,
194
- negative_prompt=negative_prompt,
195
- height=target_height,
196
- width=target_width,
197
- num_frames=num_frames_for_pipeline,
198
- guidance_scale=guidance_scale_val,
199
- num_inference_steps=steps_val,
200
- generator=torch.Generator(device="cuda").manual_seed(current_seed) # Use current_seed
201
  ).frames[0]
202
 
203
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
204
  video_path = tmpfile.name
205
-
206
  export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
207
- logger.info(f"Video successfully generated and saved to {video_path}")
208
  return video_path
209
 
210
- # --- Gradio UI Definition ---
211
  default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
212
  default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards, watermark, text, signature"
213
- penguin_image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png"
214
 
215
  with gr.Blocks() as demo:
216
- gr.Markdown(f"""
217
- # Fast 4 steps Wan 2.1 I2V (14B) with CausVid LoRA
218
- """)
219
  with gr.Row():
220
  with gr.Column():
221
- input_image_component = gr.Image(type="pil", label="Input Image (will be resized to target H/W)")
222
- prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v, lines=3)
223
- duration_seconds_input = gr.Slider(minimum=0.4, maximum=3.3, step=0.1, value=1.7, label="Duration (seconds)", info="The CausVid LoRA was trained on 24fps, Wan has 81 maximum frames limit, limiting the maximum to 3.3s")
224
 
225
  with gr.Accordion("Advanced Settings", open=False):
226
- negative_prompt_input = gr.Textbox(
227
- label="Negative Prompt (Optional)",
228
- value=default_negative_prompt,
229
- lines=3
230
- )
231
- # --- Added Seed Controls ---
232
- seed_input = gr.Slider(
233
- label="Seed",
234
- minimum=0,
235
- maximum=MAX_SEED,
236
- step=1,
237
- value=42, # Default seed value
238
- interactive=True
239
- )
240
- randomize_seed_checkbox = gr.Checkbox(
241
- label="Randomize seed",
242
- value=True, # Default to randomize
243
- interactive=True
244
- )
245
- # --- End of Added Seed Controls ---
246
  with gr.Row():
247
  height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label=f"Output Height (multiple of {MOD_VALUE})")
248
  width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label=f"Output Width (multiple of {MOD_VALUE})")
249
-
250
  steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=4, label="Inference Steps")
251
- guidance_scale_input = gr.Slider(minimum=0.0, maximum=20.0, step=0.5, value=1.0, label="Guidance Scale", visible=False)
252
 
253
  generate_button = gr.Button("Generate Video", variant="primary")
254
-
255
  with gr.Column():
256
- video_output = gr.Video(label="Generated Video", interactive=False)
257
 
258
  input_image_component.upload(
259
  fn=handle_image_upload_for_dims_wan,
260
  inputs=[input_image_component, height_input, width_input],
261
  outputs=[height_input, width_input]
262
  )
263
- input_image_component.clear(
 
264
  fn=handle_image_upload_for_dims_wan,
265
  inputs=[input_image_component, height_input, width_input],
266
  outputs=[height_input, width_input]
267
  )
268
-
269
- inputs_for_click_and_examples = [
270
- input_image_component,
271
- prompt_input,
272
- negative_prompt_input,
273
- height_input,
274
- width_input,
275
- duration_seconds_input,
276
- guidance_scale_input,
277
- steps_slider,
278
- seed_input, # Added seed_input
279
- randomize_seed_checkbox # Added randomize_seed_checkbox
280
  ]
281
-
282
- generate_button.click(
283
- fn=generate_video,
284
- inputs=inputs_for_click_and_examples,
285
- outputs=video_output
286
- )
287
 
288
  gr.Examples(
289
- examples=[
290
- # Added seed (e.g., 42) and randomize_seed (e.g., True) to examples
291
- ["peng.png", "a penguin playfully dancing in the snow, Antarctica", default_negative_prompt, 896, 512, 2, 1.0, 4, 42, False],
292
- ["forg.jpg", "the frog jumps around", default_negative_prompt, 448, 832, 2, 1.0, 4, 123, False],
293
  ],
294
- inputs=inputs_for_click_and_examples,
295
- outputs=video_output,
296
- fn=generate_video,
297
- cache_examples="lazy"
298
  )
299
 
300
  if __name__ == "__main__":
301
- demo.queue().launch(share=True, debug=True)
 
6
  import tempfile
7
  import spaces
8
  from huggingface_hub import hf_hub_download
 
9
  import numpy as np
10
  from PIL import Image
11
+ import random
12
 
 
13
  MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
14
  LORA_REPO_ID = "Kijai/WanVideo_comfy"
15
  LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
16
 
17
+ image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32)
18
+ vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  pipe = WanImageToVideoPipeline.from_pretrained(
20
+ MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
 
 
 
 
 
 
 
21
  )
22
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
23
  pipe.to("cuda")
24
 
 
 
25
  causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
 
 
26
  pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
27
+ pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95])
 
28
 
 
29
  MOD_VALUE = 32
 
 
30
  DEFAULT_H_SLIDER_VALUE = 512
31
  DEFAULT_W_SLIDER_VALUE = 896
32
+ NEW_FORMULA_MAX_AREA = 480.0 * 832.0
33
 
34
+ SLIDER_MIN_H, SLIDER_MAX_H = 128, 896
35
+ SLIDER_MIN_W, SLIDER_MAX_W = 128, 896
 
 
 
 
 
 
 
36
  MAX_SEED = np.iinfo(np.int32).max
37
 
38
+ FIXED_FPS = 24
39
+ MIN_FRAMES_MODEL = 8
40
+ MAX_FRAMES_MODEL = 81
41
+
42
  def _calculate_new_dimensions_wan(pil_image: Image.Image, mod_val: int, calculation_max_area: float,
43
  min_slider_h: int, max_slider_h: int,
44
  min_slider_w: int, max_slider_w: int,
45
  default_h: int, default_w: int) -> tuple[int, int]:
46
  orig_w, orig_h = pil_image.size
47
+ if orig_w <= 0 or orig_h <= 0:
 
 
48
  return default_h, default_w
49
 
50
  aspect_ratio = orig_h / orig_w
51
 
52
+ calc_h = round(np.sqrt(calculation_max_area * aspect_ratio))
53
+ calc_w = round(np.sqrt(calculation_max_area / aspect_ratio))
 
 
 
 
 
 
 
 
 
54
 
55
+ calc_h = max(mod_val, (calc_h // mod_val) * mod_val)
56
+ calc_w = max(mod_val, (calc_w // mod_val) * mod_val)
57
+
58
+ new_h = int(np.clip(calc_h, min_slider_h, (max_slider_h // mod_val) * mod_val))
59
+ new_w = int(np.clip(calc_w, min_slider_w, (max_slider_w // mod_val) * mod_val))
60
 
 
 
 
 
 
 
 
61
  return new_h, new_w
62
 
63
  def handle_image_upload_for_dims_wan(uploaded_pil_image: Image.Image | None, current_h_val: int, current_w_val: int):
64
  if uploaded_pil_image is None:
 
65
  return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
66
  try:
67
  new_h, new_w = _calculate_new_dimensions_wan(
68
+ uploaded_pil_image, MOD_VALUE, NEW_FORMULA_MAX_AREA,
69
+ SLIDER_MIN_H, SLIDER_MAX_H, SLIDER_MIN_W, SLIDER_MAX_W,
 
 
 
70
  DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE
71
  )
72
  return gr.update(value=new_h), gr.update(value=new_w)
73
  except Exception as e:
74
+ gr.Warning("Error attempting to calculate new dimensions")
 
75
  return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
76
 
 
 
77
  @spaces.GPU
78
  def generate_video(input_image: Image.Image, prompt: str, negative_prompt: str,
79
  height: int, width: int, duration_seconds: float,
80
  guidance_scale: float, steps: int,
81
  seed: int, randomize_seed: bool,
82
  progress=gr.Progress(track_tqdm=True)):
83
+
84
  if input_image is None:
85
  raise gr.Error("Please upload an input image.")
86
 
87
+ target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
88
+ target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
91
+
92
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
+ resized_image = input_image.resize((target_w, target_h))
 
95
 
96
  with torch.inference_mode():
97
  output_frames_list = pipe(
98
+ image=resized_image, prompt=prompt, negative_prompt=negative_prompt,
99
+ height=target_h, width=target_w, num_frames=num_frames,
100
+ guidance_scale=float(guidance_scale), num_inference_steps=int(steps),
101
+ generator=torch.Generator(device="cuda").manual_seed(current_seed)
 
 
 
 
 
102
  ).frames[0]
103
 
104
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
105
  video_path = tmpfile.name
 
106
  export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
 
107
  return video_path
108
 
 
109
  default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
110
  default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards, watermark, text, signature"
 
111
 
112
  with gr.Blocks() as demo:
113
+ gr.Markdown("# Fast 4 steps Wan 2.1 I2V (14B) with CausVid LoRA")
 
 
114
  with gr.Row():
115
  with gr.Column():
116
+ input_image_component = gr.Image(type="pil", label="Input Image (auto-resized to target H/W)")
117
+ prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
118
+ duration_seconds_input = gr.Slider(minimum=round(MIN_FRAMES_MODEL/FIXED_FPS,1), maximum=round(MAX_FRAMES_MODEL/FIXED_FPS,1), step=0.1, value=2, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
119
 
120
  with gr.Accordion("Advanced Settings", open=False):
121
+ negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
122
+ seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
123
+ randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  with gr.Row():
125
  height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label=f"Output Height (multiple of {MOD_VALUE})")
126
  width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label=f"Output Width (multiple of {MOD_VALUE})")
 
127
  steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=4, label="Inference Steps")
128
+ guidance_scale_input = gr.Slider(minimum=0.0, maximum=20.0, step=0.5, value=1.0, label="Guidance Scale", visible=False)
129
 
130
  generate_button = gr.Button("Generate Video", variant="primary")
 
131
  with gr.Column():
132
+ video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
133
 
134
  input_image_component.upload(
135
  fn=handle_image_upload_for_dims_wan,
136
  inputs=[input_image_component, height_input, width_input],
137
  outputs=[height_input, width_input]
138
  )
139
+
140
+ input_image_component.clear(
141
  fn=handle_image_upload_for_dims_wan,
142
  inputs=[input_image_component, height_input, width_input],
143
  outputs=[height_input, width_input]
144
  )
145
+
146
+ ui_inputs = [
147
+ input_image_component, prompt_input, negative_prompt_input,
148
+ height_input, width_input, duration_seconds_input,
149
+ guidance_scale_input, steps_slider, seed_input, randomize_seed_checkbox
 
 
 
 
 
 
 
150
  ]
151
+ generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=video_output)
 
 
 
 
 
152
 
153
  gr.Examples(
154
+ examples=[
155
+ ["peng.png", "a penguin playfully dancing in the snow, Antarctica", default_negative_prompt, 896, 512, 2.0, 1.0, 4, 42, False],
156
+ ["forg.jpg", "the frog jumps around", default_negative_prompt, 448, 832, 2.0, 1.0, 4, 123, False],
 
157
  ],
158
+ inputs=ui_inputs, outputs=video_output, fn=generate_video, cache_examples="lazy"
 
 
 
159
  )
160
 
161
  if __name__ == "__main__":
162
+ demo.queue().launch()