ginipick commited on
Commit
30b4e47
·
verified ·
1 Parent(s): 52f499a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -37
app.py CHANGED
@@ -56,6 +56,13 @@ pipeline = wan.WanTI2V(
56
  print("Pipeline initialized and ready.")
57
 
58
  # --- Helper Functions ---
 
 
 
 
 
 
 
59
  def select_best_size_for_image(image, available_sizes):
60
  """Select the size option with aspect ratio closest to the input image."""
61
  if image is None:
@@ -90,6 +97,23 @@ def handle_image_upload(image):
90
 
91
  return gr.update(value=best_size)
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  def get_duration(image,
94
  prompt,
95
  size,
@@ -107,6 +131,14 @@ def get_duration(image,
107
  else:
108
  return 90
109
 
 
 
 
 
 
 
 
 
110
  # --- 2. Gradio Inference Function ---
111
  @spaces.GPU(duration=get_duration)
112
  def generate_video(
@@ -121,9 +153,18 @@ def generate_video(
121
  progress=gr.Progress(track_tqdm=True)
122
  ):
123
  """The main function to generate video, called by the Gradio interface."""
 
 
 
 
 
 
 
124
  if seed == -1:
125
  seed = random.randint(0, sys.maxsize)
126
 
 
 
127
  input_image = None
128
  if image is not None:
129
  input_image = Image.fromarray(image).convert("RGB")
@@ -134,44 +175,110 @@ def generate_video(
134
  # Calculate number of frames based on duration
135
  num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
136
 
137
- video_tensor = pipeline.generate(
138
- input_prompt=prompt,
139
- img=input_image, # Pass None for T2V, Image for I2V
140
- size=SIZE_CONFIGS[size],
141
- max_area=MAX_AREA_CONFIGS[size],
142
- frame_num=num_frames, # Use calculated frames instead of cfg.frame_num
143
- shift=shift,
144
- sample_solver='unipc',
145
- sampling_steps=int(sampling_steps),
146
- guide_scale=guide_scale,
147
- seed=seed,
148
- offload_model=True
149
- )
 
 
 
150
 
151
- # Save the video to a temporary file
152
- video_path = cache_video(
153
- tensor=video_tensor[None], # Add a batch dimension
154
- save_file=None, # cache_video will create a temp file
155
- fps=cfg.sample_fps,
156
- normalize=True,
157
- value_range=(-1, 1)
158
- )
159
- del video_tensor
160
- gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  return video_path
162
 
163
 
164
  # --- 3. Gradio Interface ---
165
- css = ".gradio-container {max-width: 1100px !important; margin: 0 auto} #output_video {height: 500px;} #input_image {height: 500px;}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  with gr.Blocks(css=css, theme=gr.themes.Soft(), delete_cache=(60, 900)) as demo:
168
- gr.Markdown("# Wan 2.2 TI2V 5B")
169
- gr.Markdown("generate high quality videos using **Wan 2.2 5B Text-Image-to-Video model**,[[model]](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B),[[paper]](https://arxiv.org/abs/2503.20314)")
 
 
 
 
 
 
 
 
 
 
170
 
171
  with gr.Row():
172
  with gr.Column(scale=2):
173
  image_input = gr.Image(type="numpy", label="Input Image (Optional)", elem_id="input_image")
174
- prompt_input = gr.Textbox(label="Prompt", value="A beautiful waterfall in a lush jungle, cinematic.", lines=3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  duration_input = gr.Slider(
176
  minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1),
177
  maximum=round(MAX_FRAMES_MODEL/FIXED_FPS, 1),
@@ -180,18 +287,57 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(), delete_cache=(60, 900)) as demo:
180
  label="Duration (seconds)",
181
  info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps."
182
  )
183
- size_input = gr.Dropdown(label="Output Resolution", choices=list(SUPPORTED_SIZES[TASK_NAME]), value="704*1280")
 
 
 
 
 
184
  with gr.Column(scale=2):
185
  video_output = gr.Video(label="Generated Video", elem_id="output_video")
186
 
187
-
 
 
 
 
 
 
 
 
188
  with gr.Accordion("Advanced Settings", open=False):
189
- steps_input = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=38, step=1)
190
- scale_input = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, value=cfg.sample_guide_scale, step=0.1)
191
- shift_input = gr.Slider(label="Sample Shift", minimum=1.0, maximum=20.0, value=cfg.sample_shift, step=0.1)
192
- seed_input = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
- run_button = gr.Button("Generate Video", variant="primary")
195
 
196
  # Add image upload handler
197
  image_input.upload(
@@ -206,12 +352,25 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(), delete_cache=(60, 900)) as demo:
206
  outputs=[size_input]
207
  )
208
 
 
 
 
 
 
 
 
 
 
 
 
209
  example_image_path = os.path.join(os.path.dirname(__file__), "examples/i2v_input.JPG")
210
  gr.Examples(
211
  examples=[
212
- [example_image_path, "The cat removes the glasses from its eyes.", "1280*704", 1.5],
213
- [None, "A cinematic shot of a boat sailing on a calm sea at sunset.", "1280*704", 2.0],
214
- [None, "Drone footage flying over a futuristic city with flying cars.", "1280*704", 2.0],
 
 
215
  ],
216
  inputs=[image_input, prompt_input, size_input, duration_input],
217
  outputs=video_output,
 
56
  print("Pipeline initialized and ready.")
57
 
58
  # --- Helper Functions ---
59
+ def clear_gpu_memory():
60
+ """Clear GPU memory more thoroughly"""
61
+ if torch.cuda.is_available():
62
+ torch.cuda.empty_cache()
63
+ torch.cuda.ipc_collect()
64
+ gc.collect()
65
+
66
  def select_best_size_for_image(image, available_sizes):
67
  """Select the size option with aspect ratio closest to the input image."""
68
  if image is None:
 
97
 
98
  return gr.update(value=best_size)
99
 
100
+ def validate_inputs(image, prompt, duration_seconds):
101
+ """Validate user inputs"""
102
+ errors = []
103
+
104
+ if not prompt or len(prompt.strip()) < 5:
105
+ errors.append("Prompt must be at least 5 characters long.")
106
+
107
+ if image is not None:
108
+ img = Image.fromarray(image)
109
+ if img.size[0] * img.size[1] > 4096 * 4096:
110
+ errors.append("Image size is too large (maximum 4096x4096).")
111
+
112
+ if duration_seconds > 5.0 and image is None:
113
+ errors.append("Videos longer than 5 seconds require an input image.")
114
+
115
+ return errors
116
+
117
  def get_duration(image,
118
  prompt,
119
  size,
 
131
  else:
132
  return 90
133
 
134
+ def apply_template(template, current_prompt):
135
+ """Apply prompt template"""
136
+ if "{subject}" in template:
137
+ # Extract the main subject from current prompt (simple heuristic)
138
+ subject = current_prompt.split(",")[0] if "," in current_prompt else current_prompt
139
+ return template.replace("{subject}", subject)
140
+ return template + " " + current_prompt
141
+
142
  # --- 2. Gradio Inference Function ---
143
  @spaces.GPU(duration=get_duration)
144
  def generate_video(
 
153
  progress=gr.Progress(track_tqdm=True)
154
  ):
155
  """The main function to generate video, called by the Gradio interface."""
156
+ # Validate inputs
157
+ errors = validate_inputs(image, prompt, duration_seconds)
158
+ if errors:
159
+ raise gr.Error("\n".join(errors))
160
+
161
+ progress(0, desc="Setting up...")
162
+
163
  if seed == -1:
164
  seed = random.randint(0, sys.maxsize)
165
 
166
+ progress(0.1, desc="Processing image...")
167
+
168
  input_image = None
169
  if image is not None:
170
  input_image = Image.fromarray(image).convert("RGB")
 
175
  # Calculate number of frames based on duration
176
  num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
177
 
178
+ progress(0.2, desc="Generating video...")
179
+
180
+ try:
181
+ video_tensor = pipeline.generate(
182
+ input_prompt=prompt,
183
+ img=input_image, # Pass None for T2V, Image for I2V
184
+ size=SIZE_CONFIGS[size],
185
+ max_area=MAX_AREA_CONFIGS[size],
186
+ frame_num=num_frames, # Use calculated frames instead of cfg.frame_num
187
+ shift=shift,
188
+ sample_solver='unipc',
189
+ sampling_steps=int(sampling_steps),
190
+ guide_scale=guide_scale,
191
+ seed=seed,
192
+ offload_model=True
193
+ )
194
 
195
+ progress(0.9, desc="Saving video...")
196
+
197
+ # Save the video to a temporary file
198
+ video_path = cache_video(
199
+ tensor=video_tensor[None], # Add a batch dimension
200
+ save_file=None, # cache_video will create a temp file
201
+ fps=cfg.sample_fps,
202
+ normalize=True,
203
+ value_range=(-1, 1)
204
+ )
205
+
206
+ progress(1.0, desc="Complete!")
207
+
208
+ except torch.cuda.OutOfMemoryError:
209
+ clear_gpu_memory()
210
+ raise gr.Error("GPU out of memory. Please try with lower settings.")
211
+ except Exception as e:
212
+ raise gr.Error(f"Video generation failed: {str(e)}")
213
+ finally:
214
+ if 'video_tensor' in locals():
215
+ del video_tensor
216
+ clear_gpu_memory()
217
+
218
  return video_path
219
 
220
 
221
  # --- 3. Gradio Interface ---
222
+ css = """
223
+ .gradio-container {max-width: 1100px !important; margin: 0 auto}
224
+ #output_video {height: 500px;}
225
+ #input_image {height: 500px;}
226
+ .template-btn {margin: 2px !important;}
227
+ """
228
+
229
+ # Default prompt with motion emphasis
230
+ DEFAULT_PROMPT = "Generate a video with smooth and natural movement. Objects should have visible motion while maintaining fluid transitions."
231
+
232
+ # Prompt templates
233
+ templates = {
234
+ "Cinematic": "cinematic shot of {subject}, professional lighting, smooth camera movement, 4k quality",
235
+ "Animation": "animated style {subject}, vibrant colors, fluid motion, dynamic movement",
236
+ "Nature": "nature documentary footage of {subject}, wildlife photography, natural movement",
237
+ "Slow Motion": "slow motion capture of {subject}, high speed camera, detailed motion",
238
+ "Action": "dynamic action shot of {subject}, fast paced movement, energetic motion"
239
+ }
240
 
241
  with gr.Blocks(css=css, theme=gr.themes.Soft(), delete_cache=(60, 900)) as demo:
242
+ gr.Markdown("""
243
+ # Wan 2.2 TI2V Enhanced
244
+
245
+ Generate high quality videos using **Wan 2.2 5B Text-Image-to-Video model**
246
+ [[model]](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B), [[paper]](https://arxiv.org/abs/2503.20314)
247
+
248
+ ### 💡 Tips for best results:
249
+ - 🖼️ Upload an image for better control over the video content
250
+ - ⏱️ Longer videos require more processing time
251
+ - 🎯 Be specific and descriptive in your prompts
252
+ - 🎬 Include motion-related keywords for dynamic videos
253
+ """)
254
 
255
  with gr.Row():
256
  with gr.Column(scale=2):
257
  image_input = gr.Image(type="numpy", label="Input Image (Optional)", elem_id="input_image")
258
+ prompt_input = gr.Textbox(
259
+ label="Prompt",
260
+ value=DEFAULT_PROMPT,
261
+ lines=3,
262
+ placeholder="Describe the video you want to generate..."
263
+ )
264
+
265
+ # Prompt templates section
266
+ with gr.Accordion("Prompt Templates", open=False):
267
+ gr.Markdown("Click a template to apply it to your prompt:")
268
+ with gr.Row():
269
+ template_buttons = {}
270
+ for name, template in templates.items():
271
+ btn = gr.Button(name, size="sm", elem_classes=["template-btn"])
272
+ template_buttons[name] = (btn, template)
273
+
274
+ # Connect template buttons
275
+ for name, (btn, template) in template_buttons.items():
276
+ btn.click(
277
+ fn=lambda t=template, p=prompt_input: apply_template(t, p),
278
+ inputs=[prompt_input],
279
+ outputs=prompt_input
280
+ )
281
+
282
  duration_input = gr.Slider(
283
  minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1),
284
  maximum=round(MAX_FRAMES_MODEL/FIXED_FPS, 1),
 
287
  label="Duration (seconds)",
288
  info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps."
289
  )
290
+ size_input = gr.Dropdown(
291
+ label="Output Resolution",
292
+ choices=list(SUPPORTED_SIZES[TASK_NAME]),
293
+ value="704*1280"
294
+ )
295
+
296
  with gr.Column(scale=2):
297
  video_output = gr.Video(label="Generated Video", elem_id="output_video")
298
 
299
+ # Status indicators
300
+ with gr.Row():
301
+ status_text = gr.Textbox(
302
+ label="Status",
303
+ value="Ready",
304
+ interactive=False,
305
+ max_lines=1
306
+ )
307
+
308
  with gr.Accordion("Advanced Settings", open=False):
309
+ steps_input = gr.Slider(
310
+ label="Sampling Steps",
311
+ minimum=10,
312
+ maximum=50,
313
+ value=38,
314
+ step=1,
315
+ info="Higher values = better quality but slower"
316
+ )
317
+ scale_input = gr.Slider(
318
+ label="Guidance Scale",
319
+ minimum=1.0,
320
+ maximum=10.0,
321
+ value=cfg.sample_guide_scale,
322
+ step=0.1,
323
+ info="Higher values = closer to prompt but less creative"
324
+ )
325
+ shift_input = gr.Slider(
326
+ label="Sample Shift",
327
+ minimum=1.0,
328
+ maximum=20.0,
329
+ value=cfg.sample_shift,
330
+ step=0.1,
331
+ info="Affects the sampling process dynamics"
332
+ )
333
+ seed_input = gr.Number(
334
+ label="Seed (-1 for random)",
335
+ value=-1,
336
+ precision=0,
337
+ info="Use same seed for reproducible results"
338
+ )
339
 
340
+ run_button = gr.Button("Generate Video", variant="primary", size="lg")
341
 
342
  # Add image upload handler
343
  image_input.upload(
 
352
  outputs=[size_input]
353
  )
354
 
355
+ # Update status when generating
356
+ def update_status_and_generate(*args):
357
+ status_text.value = "Generating..."
358
+ try:
359
+ result = generate_video(*args)
360
+ status_text.value = "Complete!"
361
+ return result
362
+ except Exception as e:
363
+ status_text.value = "Error occurred"
364
+ raise e
365
+
366
  example_image_path = os.path.join(os.path.dirname(__file__), "examples/i2v_input.JPG")
367
  gr.Examples(
368
  examples=[
369
+ [example_image_path, "The cat removes the glasses from its eyes with smooth motion.", "1280*704", 1.5],
370
+ [None, "A cinematic shot of a boat sailing on calm waves with gentle rocking motion at sunset.", "1280*704", 2.0],
371
+ [None, "Drone footage flying smoothly over a futuristic city with flying cars in continuous motion.", "1280*704", 2.0],
372
+ [None, DEFAULT_PROMPT + " A waterfall cascading down rocks.", "704*1280", 2.5],
373
+ [None, DEFAULT_PROMPT + " Birds flying across a cloudy sky.", "1280*704", 3.0],
374
  ],
375
  inputs=[image_input, prompt_input, size_input, duration_input],
376
  outputs=video_output,