multimodalart HF Staff commited on
Commit
ab273c0
·
verified ·
1 Parent(s): 93c7ebe

attempt resize

Browse files
Files changed (1) hide show
  1. app.py +140 -58
app.py CHANGED
@@ -16,20 +16,26 @@ import random
16
  import gc
17
  from optimization import optimize_pipeline_
18
 
19
-
20
  MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
21
 
22
- LANDSCAPE_WIDTH = 832
23
- LANDSCAPE_HEIGHT = 480
 
 
 
 
 
 
 
 
24
  MAX_SEED = np.iinfo(np.int32).max
25
 
26
  FIXED_FPS = 16
27
  MIN_FRAMES_MODEL = 8
28
  MAX_FRAMES_MODEL = 81
29
 
30
- MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS,1)
31
- MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS,1)
32
-
33
 
34
  pipe = WanImageToVideoPipeline.from_pretrained(MODEL_ID,
35
  transformer=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers',
@@ -50,40 +56,63 @@ for i in range(3):
50
  torch.cuda.synchronize()
51
  torch.cuda.empty_cache()
52
 
 
53
  optimize_pipeline_(pipe,
54
- image=Image.new('RGB', (LANDSCAPE_WIDTH, LANDSCAPE_HEIGHT)),
55
  prompt='prompt',
56
- height=LANDSCAPE_HEIGHT,
57
- width=LANDSCAPE_WIDTH,
58
  num_frames=MAX_FRAMES_MODEL,
59
  )
60
 
61
-
62
  default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
63
  default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
64
 
65
 
66
- def resize_image(image: Image.Image) -> Image.Image:
67
- if image.height > image.width:
68
- transposed = image.transpose(Image.Transpose.ROTATE_90)
69
- resized = resize_image_landscape(transposed)
70
- return resized.transpose(Image.Transpose.ROTATE_270)
71
- return resize_image_landscape(image)
72
-
73
-
74
- def resize_image_landscape(image: Image.Image) -> Image.Image:
75
- target_aspect = LANDSCAPE_WIDTH / LANDSCAPE_HEIGHT
76
- width, height = image.size
77
- in_aspect = width / height
78
- if in_aspect > target_aspect:
79
- new_width = round(height * target_aspect)
80
- left = (width - new_width) // 2
81
- image = image.crop((left, 0, left + new_width, height))
82
- else:
83
- new_height = round(width / target_aspect)
84
- top = (height - new_height) // 2
85
- image = image.crop((0, top, width, top + new_height))
86
- return image.resize((LANDSCAPE_WIDTH, LANDSCAPE_HEIGHT), Image.LANCZOS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  def get_duration(
89
  input_image,
@@ -91,6 +120,8 @@ def get_duration(
91
  steps,
92
  negative_prompt,
93
  duration_seconds,
 
 
94
  guidance_scale,
95
  guidance_scale_2,
96
  seed,
@@ -99,24 +130,27 @@ def get_duration(
99
  ):
100
  return int(steps) * 15
101
 
 
102
  @spaces.GPU(duration=get_duration)
103
  def generate_video(
104
  input_image,
105
  prompt,
106
- steps = 4,
107
  negative_prompt=default_negative_prompt,
108
- duration_seconds = MAX_DURATION,
109
- guidance_scale = 1,
110
- guidance_scale_2 = 1,
111
- seed = 42,
112
- randomize_seed = False,
 
 
113
  progress=gr.Progress(track_tqdm=True),
114
  ):
115
  """
116
  Generate a video from an input image using the Wan 2.2 14B I2V model with Lightning LoRA.
117
 
118
  This function takes an input image and generates a video animation based on the provided
119
- prompt and parameters. It uses an FP8 qunatized Wan 2.2 14B Image-to-Video model in with Lightning LoRA
120
  for fast generation in 4-8 steps.
121
 
122
  Args:
@@ -127,11 +161,13 @@ def generate_video(
127
  negative_prompt (str, optional): Negative prompt to avoid unwanted elements.
128
  Defaults to default_negative_prompt (contains unwanted visual artifacts).
129
  duration_seconds (float, optional): Duration of the generated video in seconds.
130
- Defaults to 2. Clamped between MIN_FRAMES_MODEL/FIXED_FPS and MAX_FRAMES_MODEL/FIXED_FPS.
 
 
131
  guidance_scale (float, optional): Controls adherence to the prompt. Higher values = more adherence.
132
- Defaults to 1.0. Range: 0.0-20.0.
133
- guidance_scale_2 (float, optional): Controls adherence to the prompt. Higher values = more adherence.
134
- Defaults to 1.0. Range: 0.0-20.0.
135
  seed (int, optional): Random seed for reproducible results. Defaults to 42.
136
  Range: 0 to MAX_SEED (2147483647).
137
  randomize_seed (bool, optional): Whether to use a random seed instead of the provided seed.
@@ -145,27 +181,26 @@ def generate_video(
145
 
146
  Raises:
147
  gr.Error: If input_image is None (no image uploaded).
148
-
149
- Note:
150
- - The function automatically resizes the input image to the target dimensions
151
- - Frame count is calculated as duration_seconds * FIXED_FPS (24)
152
- - Output dimensions are adjusted to be multiples of MOD_VALUE (32)
153
- - The function uses GPU acceleration via the @spaces.GPU decorator
154
- - Generation time varies based on steps and duration (see get_duration function)
155
  """
156
  if input_image is None:
157
  raise gr.Error("Please upload an input image.")
158
 
 
 
 
 
159
  num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
160
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
161
- resized_image = resize_image(input_image)
 
 
162
 
163
  output_frames_list = pipe(
164
  image=resized_image,
165
  prompt=prompt,
166
  negative_prompt=negative_prompt,
167
- height=resized_image.height,
168
- width=resized_image.width,
169
  num_frames=num_frames,
170
  guidance_scale=float(guidance_scale),
171
  guidance_scale_2=float(guidance_scale_2),
@@ -180,39 +215,83 @@ def generate_video(
180
 
181
  return video_path, current_seed
182
 
 
183
  with gr.Blocks() as demo:
184
  gr.Markdown("# Fast 4 steps Wan 2.2 I2V (14B) with Lightning LoRA")
185
- gr.Markdown("run Wan 2.2 in just 4-8 steps, with [Lightning LoRA](https://huggingface.co/Kijai/WanVideo_comfy/tree/main/Wan22-Lightning), fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU⚡️")
 
186
  with gr.Row():
187
  with gr.Column():
188
  input_image_component = gr.Image(type="pil", label="Input Image (auto-resized to target H/W)")
189
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
190
- duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=3.5, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
 
 
 
 
 
 
 
191
 
192
  with gr.Accordion("Advanced Settings", open=False):
193
  negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
194
  seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
195
  randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=6, label="Inference Steps")
197
  guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale - high noise stage")
198
  guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale 2 - low noise stage")
199
 
200
  generate_button = gr.Button("Generate Video", variant="primary")
 
201
  with gr.Column():
202
  video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  ui_inputs = [
205
  input_image_component, prompt_input, steps_slider,
206
  negative_prompt_input, duration_seconds_input,
207
- guidance_scale_input, guidance_scale_2_input, seed_input, randomize_seed_checkbox
 
 
208
  ]
 
209
  generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input])
210
 
211
  gr.Examples(
212
  examples=[
213
  [
214
  "wan_i2v_input.JPG",
215
- "POV selfie video, white cat with sunglasses standing on surfboard, relaxed smile, tropical beach behind (clear water, green hills, blue sky with clouds). Surfboard tips, cat falls into ocean, camera plunges underwater with bubbles and sunlight beams. Brief underwater view of cats face, then cat resurfaces, still filming selfie, playful summer vacation mood.",
216
  4,
217
  ],
218
  [
@@ -226,8 +305,11 @@ with gr.Blocks() as demo:
226
  6,
227
  ],
228
  ],
229
- inputs=[input_image_component, prompt_input, steps_slider], outputs=[video_output, seed_input], fn=generate_video, cache_examples="lazy"
 
 
 
230
  )
231
 
232
  if __name__ == "__main__":
233
- demo.queue().launch(mcp_server=True)
 
16
  import gc
17
  from optimization import optimize_pipeline_
18
 
 
19
  MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
20
 
21
+ # Dynamic sizing parameters
22
+ MOD_VALUE = 32
23
+ DEFAULT_H = 480
24
+ DEFAULT_W = 832
25
+ MAX_AREA = 480.0 * 832.0 # Maximum area for resolution calculation
26
+
27
+ # Slider ranges for manual adjustment
28
+ SLIDER_MIN_H, SLIDER_MAX_H = 128, 896
29
+ SLIDER_MIN_W, SLIDER_MAX_W = 128, 896
30
+
31
  MAX_SEED = np.iinfo(np.int32).max
32
 
33
  FIXED_FPS = 16
34
  MIN_FRAMES_MODEL = 8
35
  MAX_FRAMES_MODEL = 81
36
 
37
+ MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS, 1)
38
+ MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS, 1)
 
39
 
40
  pipe = WanImageToVideoPipeline.from_pretrained(MODEL_ID,
41
  transformer=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers',
 
56
  torch.cuda.synchronize()
57
  torch.cuda.empty_cache()
58
 
59
+ # Optimize with default dimensions for initial load
60
  optimize_pipeline_(pipe,
61
+ image=Image.new('RGB', (DEFAULT_W, DEFAULT_H)),
62
  prompt='prompt',
63
+ height=DEFAULT_H,
64
+ width=DEFAULT_W,
65
  num_frames=MAX_FRAMES_MODEL,
66
  )
67
 
 
68
  default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
69
  default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
70
 
71
 
72
+ def calculate_optimal_dimensions(pil_image):
73
+ """
74
+ Calculate optimal dimensions for the output video based on input image aspect ratio.
75
+ Maintains aspect ratio while fitting within the maximum area constraint.
76
+ """
77
+ if pil_image is None:
78
+ return DEFAULT_H, DEFAULT_W
79
+
80
+ orig_w, orig_h = pil_image.size
81
+ if orig_w <= 0 or orig_h <= 0:
82
+ return DEFAULT_H, DEFAULT_W
83
+
84
+ # Calculate aspect ratio
85
+ aspect_ratio = orig_h / orig_w
86
+
87
+ # Calculate dimensions that maintain aspect ratio within max area
88
+ calc_h = round(np.sqrt(MAX_AREA * aspect_ratio))
89
+ calc_w = round(np.sqrt(MAX_AREA / aspect_ratio))
90
+
91
+ # Ensure dimensions are multiples of MOD_VALUE
92
+ calc_h = max(MOD_VALUE, (calc_h // MOD_VALUE) * MOD_VALUE)
93
+ calc_w = max(MOD_VALUE, (calc_w // MOD_VALUE) * MOD_VALUE)
94
+
95
+ # Clamp to slider ranges
96
+ new_h = int(np.clip(calc_h, SLIDER_MIN_H, (SLIDER_MAX_H // MOD_VALUE) * MOD_VALUE))
97
+ new_w = int(np.clip(calc_w, SLIDER_MIN_W, (SLIDER_MAX_W // MOD_VALUE) * MOD_VALUE))
98
+
99
+ return new_h, new_w
100
+
101
+
102
+ def handle_image_upload(uploaded_image, current_h, current_w):
103
+ """
104
+ Update height and width sliders when an image is uploaded.
105
+ """
106
+ if uploaded_image is None:
107
+ return gr.update(value=DEFAULT_H), gr.update(value=DEFAULT_W)
108
+
109
+ try:
110
+ new_h, new_w = calculate_optimal_dimensions(uploaded_image)
111
+ return gr.update(value=new_h), gr.update(value=new_w)
112
+ except Exception as e:
113
+ gr.Warning("Error calculating dimensions, using defaults")
114
+ return gr.update(value=DEFAULT_H), gr.update(value=DEFAULT_W)
115
+
116
 
117
  def get_duration(
118
  input_image,
 
120
  steps,
121
  negative_prompt,
122
  duration_seconds,
123
+ height,
124
+ width,
125
  guidance_scale,
126
  guidance_scale_2,
127
  seed,
 
130
  ):
131
  return int(steps) * 15
132
 
133
+
134
  @spaces.GPU(duration=get_duration)
135
  def generate_video(
136
  input_image,
137
  prompt,
138
+ steps=4,
139
  negative_prompt=default_negative_prompt,
140
+ duration_seconds=MAX_DURATION,
141
+ height=DEFAULT_H,
142
+ width=DEFAULT_W,
143
+ guidance_scale=1,
144
+ guidance_scale_2=1,
145
+ seed=42,
146
+ randomize_seed=False,
147
  progress=gr.Progress(track_tqdm=True),
148
  ):
149
  """
150
  Generate a video from an input image using the Wan 2.2 14B I2V model with Lightning LoRA.
151
 
152
  This function takes an input image and generates a video animation based on the provided
153
+ prompt and parameters. It uses an FP8 quantized Wan 2.2 14B Image-to-Video model with Lightning LoRA
154
  for fast generation in 4-8 steps.
155
 
156
  Args:
 
161
  negative_prompt (str, optional): Negative prompt to avoid unwanted elements.
162
  Defaults to default_negative_prompt (contains unwanted visual artifacts).
163
  duration_seconds (float, optional): Duration of the generated video in seconds.
164
+ Defaults to MAX_DURATION. Clamped between MIN_DURATION and MAX_DURATION.
165
+ height (int): Target height for the output video. Will be adjusted to multiple of MOD_VALUE (32).
166
+ width (int): Target width for the output video. Will be adjusted to multiple of MOD_VALUE (32).
167
  guidance_scale (float, optional): Controls adherence to the prompt. Higher values = more adherence.
168
+ Defaults to 1.0. Range: 0.0-10.0.
169
+ guidance_scale_2 (float, optional): Controls adherence to the prompt in low noise stage.
170
+ Defaults to 1.0. Range: 0.0-10.0.
171
  seed (int, optional): Random seed for reproducible results. Defaults to 42.
172
  Range: 0 to MAX_SEED (2147483647).
173
  randomize_seed (bool, optional): Whether to use a random seed instead of the provided seed.
 
181
 
182
  Raises:
183
  gr.Error: If input_image is None (no image uploaded).
 
 
 
 
 
 
 
184
  """
185
  if input_image is None:
186
  raise gr.Error("Please upload an input image.")
187
 
188
+ # Ensure dimensions are multiples of MOD_VALUE
189
+ target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
190
+ target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
191
+
192
  num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
193
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
194
+
195
+ # Resize image to target dimensions
196
+ resized_image = input_image.resize((target_w, target_h), Image.LANCZOS)
197
 
198
  output_frames_list = pipe(
199
  image=resized_image,
200
  prompt=prompt,
201
  negative_prompt=negative_prompt,
202
+ height=target_h,
203
+ width=target_w,
204
  num_frames=num_frames,
205
  guidance_scale=float(guidance_scale),
206
  guidance_scale_2=float(guidance_scale_2),
 
215
 
216
  return video_path, current_seed
217
 
218
+
219
  with gr.Blocks() as demo:
220
  gr.Markdown("# Fast 4 steps Wan 2.2 I2V (14B) with Lightning LoRA")
221
+ gr.Markdown("Run Wan 2.2 in just 4-8 steps, with [Lightning LoRA](https://huggingface.co/Kijai/WanVideo_comfy/tree/main/Wan22-Lightning), fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU⚡️")
222
+
223
  with gr.Row():
224
  with gr.Column():
225
  input_image_component = gr.Image(type="pil", label="Input Image (auto-resized to target H/W)")
226
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
227
+ duration_seconds_input = gr.Slider(
228
+ minimum=MIN_DURATION,
229
+ maximum=MAX_DURATION,
230
+ step=0.1,
231
+ value=3.5,
232
+ label="Duration (seconds)",
233
+ info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps."
234
+ )
235
 
236
  with gr.Accordion("Advanced Settings", open=False):
237
  negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
238
  seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
239
  randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
240
+
241
+ with gr.Row():
242
+ height_input = gr.Slider(
243
+ minimum=SLIDER_MIN_H,
244
+ maximum=SLIDER_MAX_H,
245
+ step=MOD_VALUE,
246
+ value=DEFAULT_H,
247
+ label=f"Output Height (multiple of {MOD_VALUE})"
248
+ )
249
+ width_input = gr.Slider(
250
+ minimum=SLIDER_MIN_W,
251
+ maximum=SLIDER_MAX_W,
252
+ step=MOD_VALUE,
253
+ value=DEFAULT_W,
254
+ label=f"Output Width (multiple of {MOD_VALUE})"
255
+ )
256
+
257
  steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=6, label="Inference Steps")
258
  guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale - high noise stage")
259
  guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale 2 - low noise stage")
260
 
261
  generate_button = gr.Button("Generate Video", variant="primary")
262
+
263
  with gr.Column():
264
  video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
265
 
266
+ # Auto-update dimensions when image is uploaded
267
+ input_image_component.upload(
268
+ fn=handle_image_upload,
269
+ inputs=[input_image_component, height_input, width_input],
270
+ outputs=[height_input, width_input]
271
+ )
272
+
273
+ # Reset dimensions when image is cleared
274
+ input_image_component.clear(
275
+ fn=handle_image_upload,
276
+ inputs=[input_image_component, height_input, width_input],
277
+ outputs=[height_input, width_input]
278
+ )
279
+
280
  ui_inputs = [
281
  input_image_component, prompt_input, steps_slider,
282
  negative_prompt_input, duration_seconds_input,
283
+ height_input, width_input,
284
+ guidance_scale_input, guidance_scale_2_input,
285
+ seed_input, randomize_seed_checkbox
286
  ]
287
+
288
  generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input])
289
 
290
  gr.Examples(
291
  examples=[
292
  [
293
  "wan_i2v_input.JPG",
294
+ "POV selfie video, white cat with sunglasses standing on surfboard, relaxed smile, tropical beach behind (clear water, green hills, blue sky with clouds). Surfboard tips, cat falls into ocean, camera plunges underwater with bubbles and sunlight beams. Brief underwater view of cat's face, then cat resurfaces, still filming selfie, playful summer vacation mood.",
295
  4,
296
  ],
297
  [
 
305
  6,
306
  ],
307
  ],
308
+ inputs=[input_image_component, prompt_input, steps_slider],
309
+ outputs=[video_output, seed_input],
310
+ fn=generate_video,
311
+ cache_examples="lazy"
312
  )
313
 
314
  if __name__ == "__main__":
315
+ demo.queue().launch(mcp_server=True)