multimodalart HF Staff commited on
Commit
f4cf641
·
verified ·
1 Parent(s): 6163c8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -59
app.py CHANGED
@@ -1,17 +1,18 @@
1
  import torch
2
- from diffusers import AutoencoderKLWan, WanPipeline, UniPCMultistepScheduler
3
- from diffusers.utils import export_to_video
4
- from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_wan_lora_to_diffusers # Keep this if it's the base for standard LoRA parts
5
  import gradio as gr
6
  import tempfile
7
  import os
8
- import spaces
9
  from huggingface_hub import hf_hub_download
10
- import logging # For better logging
11
- import re # For key manipulation
 
12
 
13
  # --- Global Model Loading & LoRA Handling ---
14
- MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
15
  LORA_REPO_ID = "Kijai/WanVideo_comfy"
16
  LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
17
 
@@ -20,17 +21,25 @@ logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
22
  # --- Model Loading ---
 
 
 
 
 
 
 
23
  logger.info(f"Loading VAE for {MODEL_ID}...")
24
  vae = AutoencoderKLWan.from_pretrained(
25
  MODEL_ID,
26
  subfolder="vae",
27
- torch_dtype=torch.float32 # float32 for VAE stability
28
  )
29
  logger.info(f"Loading Pipeline {MODEL_ID}...")
30
- pipe = WanPipeline.from_pretrained(
31
  MODEL_ID,
32
  vae=vae,
33
- torch_dtype=torch.bfloat16 # bfloat16 for pipeline
 
34
  )
35
  flow_shift = 8.0
36
  pipe.scheduler = UniPCMultistepScheduler.from_config(
@@ -43,99 +52,202 @@ pipe.to("cuda")
43
  logger.info(f"Downloading LoRA {LORA_FILENAME} from {LORA_REPO_ID}...")
44
  causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
45
 
46
- logger.info("Loading LoRA weights with custom converter...")
47
- pipe.load_lora_weights(causvid_path,adapter_name="causvid_lora")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
 
50
  # --- Gradio Interface Function ---
51
- @spaces.GPU
52
- def generate_video(prompt, negative_prompt, height, width, num_frames, guidance_scale, steps, fps, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
53
  logger.info("Starting video generation...")
 
54
  logger.info(f" Prompt: {prompt}")
55
  logger.info(f" Negative Prompt: {negative_prompt if negative_prompt else 'None'}")
56
- logger.info(f" Height: {height}, Width: {width}")
57
- logger.info(f" Num Frames: {num_frames}, FPS: {fps}")
58
- logger.info(f" Guidance Scale: {guidance_scale}")
 
59
 
60
- height = (int(height) // 8) * 8
61
- width = (int(width) // 8) * 8
62
  num_frames = int(num_frames)
63
- fps = int(fps)
 
 
 
 
 
 
 
64
 
65
  with torch.inference_mode():
66
  output_frames_list = pipe(
 
67
  prompt=prompt,
68
  negative_prompt=negative_prompt,
69
- height=height,
70
- width=width,
71
  num_frames=num_frames,
72
- guidance_scale=float(guidance_scale),
73
- num_inference_steps=steps
 
 
74
  ).frames[0]
75
 
76
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
77
  video_path = tmpfile.name
78
 
79
- export_to_video(output_frames_list, video_path, fps=fps)
80
  logger.info(f"Video successfully generated and saved to {video_path}")
81
  return video_path
82
 
83
  # --- Gradio UI Definition ---
84
- default_prompt = "A cat walks on the grass, realistic"
85
- default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
 
 
86
 
87
  with gr.Blocks() as demo:
88
  gr.Markdown(f"""
89
- # Text-to-Video with Wan 2.1 (14B) + CausVid LoRA
90
- Powered by `diffusers` and `Wan-AI/{MODEL_ID}`.
91
  Model is loaded into memory when the app starts. This might take a few minutes.
92
  Ensure you have a GPU with sufficient VRAM (e.g., ~24GB+ for these default settings).
 
93
  """)
94
  with gr.Row():
95
  with gr.Column(scale=2):
96
- prompt_input = gr.Textbox(label="Prompt", value=default_prompt, lines=3)
97
- negative_prompt_input = gr.Textbox(
98
- label="Negative Prompt (Optional)",
99
- value=default_negative_prompt,
100
- lines=3
101
- )
102
- with gr.Row():
103
- height_input = gr.Slider(minimum=256, maximum=768, step=64, value=480, label="Height (multiple of 8)")
104
- width_input = gr.Slider(minimum=256, maximum=1024, step=64, value=832, label="Width (multiple of 8)")
105
- with gr.Row():
106
- num_frames_input = gr.Slider(minimum=16, maximum=100, step=1, value=25, label="Number of Frames")
107
- fps_input = gr.Slider(minimum=5, maximum=30, step=1, value=15, label="Output FPS")
108
- steps = gr.Slider(minimum=1.0, maximum=30.0, value=4.0, label="Steps")
109
- guidance_scale_input = gr.Slider(minimum=1.0, maximum=20.0, step=0.5, value=1.0, label="Guidance Scale")
 
 
 
110
 
111
  generate_button = gr.Button("Generate Video", variant="primary")
112
 
113
  with gr.Column(scale=3):
114
- video_output = gr.Video(label="Generated Video")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  generate_button.click(
117
  fn=generate_video,
118
- inputs=[
119
- prompt_input,
120
- negative_prompt_input,
121
- height_input,
122
- width_input,
123
- num_frames_input,
124
- guidance_scale_input,
125
- steps,
126
- fps_input
127
- ],
128
  outputs=video_output
129
  )
130
 
131
  gr.Examples(
132
  examples=[
133
- ["A panda eating bamboo in a lush forest, cinematic lighting", default_negative_prompt, 480, 832, 25, 5.0, 4, 15],
134
- ["A majestic eagle soaring over snowy mountains", default_negative_prompt, 512, 768, 30, 7.0, 4, 12],
135
- ["Timelapse of a flower blooming, vibrant colors", "static, ugly", 384, 640, 40, 6.0, 4, 20],
136
- ["Astronaut walking on the moon, Earth in the background, highly detailed", default_negative_prompt, 480, 832, 20, 5.5, 4, 10],
137
  ],
138
- inputs=[prompt_input, negative_prompt_input, height_input, width_input, num_frames_input, guidance_scale_input, steps, fps_input],
139
  outputs=video_output,
140
  fn=generate_video,
141
  cache_examples=False
 
1
  import torch
2
+ from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UniPCMultistepScheduler
3
+ from diffusers.utils import export_to_video, load_image
4
+ from transformers import CLIPVisionModel
5
  import gradio as gr
6
  import tempfile
7
  import os
8
+ import spaces # Assuming this is for Hugging Face Spaces GPU decorator
9
  from huggingface_hub import hf_hub_download
10
+ import logging
11
+ import numpy as np
12
+ from PIL import Image # Added for type hinting
13
 
14
  # --- Global Model Loading & LoRA Handling ---
15
+ MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
16
  LORA_REPO_ID = "Kijai/WanVideo_comfy"
17
  LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
18
 
 
21
  logger = logging.getLogger(__name__)
22
 
23
  # --- Model Loading ---
24
+ logger.info(f"Loading Image Encoder for {MODEL_ID}...")
25
+ image_encoder = CLIPVisionModel.from_pretrained(
26
+ MODEL_ID,
27
+ subfolder="image_encoder",
28
+ torch_dtype=torch.float32
29
+ )
30
+
31
  logger.info(f"Loading VAE for {MODEL_ID}...")
32
  vae = AutoencoderKLWan.from_pretrained(
33
  MODEL_ID,
34
  subfolder="vae",
35
+ torch_dtype=torch.float32
36
  )
37
  logger.info(f"Loading Pipeline {MODEL_ID}...")
38
+ pipe = WanImageToVideoPipeline.from_pretrained(
39
  MODEL_ID,
40
  vae=vae,
41
+ image_encoder=image_encoder,
42
+ torch_dtype=torch.bfloat16
43
  )
44
  flow_shift = 8.0
45
  pipe.scheduler = UniPCMultistepScheduler.from_config(
 
52
  logger.info(f"Downloading LoRA {LORA_FILENAME} from {LORA_REPO_ID}...")
53
  causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
54
 
55
+ logger.info("Loading LoRA weights...")
56
+ pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
57
+ logger.info("Setting LoRA adapter...")
58
+ pipe.set_adapters(["causvid_lora"], adapter_weights=[1.0])
59
+
60
+ # MOD_VALUE for height/width constraints
61
+ # From WanImageToVideoPipeline docs: height/width must be multiple of vae_scale_factor * transformer.config.patch_size[1 or 2]
62
+ MOD_VALUE = pipe.vae_scale_factor * pipe.transformer.config.patch_size[1] # e.g., 8 * 16 = 128
63
+ logger.info(f"Derived MOD_VALUE for dimensions: {MOD_VALUE}")
64
+
65
+
66
+ # --- Helper functions and constants for automatic dimension adjustment ---
67
+ # These constants must match the Gradio slider definitions below
68
+ DEFAULT_H_SLIDER_VALUE = 384
69
+ DEFAULT_W_SLIDER_VALUE = 640
70
+ DEFAULT_TARGET_AREA = float(DEFAULT_H_SLIDER_VALUE * DEFAULT_W_SLIDER_VALUE)
71
+
72
+ SLIDER_MIN_H = 128
73
+ SLIDER_MAX_H = 512
74
+ SLIDER_MIN_W = 128
75
+ SLIDER_MAX_W = 1024
76
+
77
+ def _calculate_new_dimensions_wan(pil_image: Image.Image, mod_val: int, target_area: float,
78
+ min_h: int, max_h: int, min_w: int, max_w: int,
79
+ default_h: int, default_w: int) -> tuple[int, int]:
80
+ orig_w, orig_h = pil_image.size
81
+
82
+ if orig_w == 0 or orig_h == 0:
83
+ logger.warning("Uploaded image has zero width or height. Using default slider dimensions.")
84
+ return default_h, default_w
85
+
86
+ aspect_ratio = orig_h / orig_w
87
+
88
+ # Calculate ideal dimensions for the target area, maintaining aspect ratio
89
+ ideal_h = np.sqrt(target_area * aspect_ratio)
90
+ ideal_w = np.sqrt(target_area / aspect_ratio)
91
+
92
+ # Round to nearest multiple of mod_val
93
+ calc_h = round(ideal_h / mod_val) * mod_val
94
+ calc_w = round(ideal_w / mod_val) * mod_val
95
+
96
+ # Ensure dimensions are at least mod_val (smallest valid multiple)
97
+ calc_h = mod_val if calc_h == 0 else calc_h
98
+ calc_w = mod_val if calc_w == 0 else calc_w
99
+
100
+ # Clamp to slider limits
101
+ new_h = int(np.clip(calc_h, min_h, max_h))
102
+ new_w = int(np.clip(calc_w, min_w, max_w))
103
+
104
+ logger.info(f"Auto-dim: Original {orig_w}x{orig_h} (AR: {aspect_ratio:.2f}). Target Area: {target_area}.")
105
+ logger.info(f"Auto-dim: Ideal HxW: {ideal_h:.0f}x{ideal_w:.0f}. Rounded (step {mod_val}): {calc_h}x{calc_w}.")
106
+ logger.info(f"Auto-dim: Clamped HxW: {new_h}x{new_w} (H_range:[{min_h}-{max_h}], W_range:[{min_w}-{max_w}]).")
107
+
108
+ return new_h, new_w
109
+
110
+ def handle_image_upload_for_dims_wan(uploaded_pil_image: Image.Image | None, current_h_val: int, current_w_val: int):
111
+ if uploaded_pil_image is None: # Image cleared by user
112
+ logger.info("Image cleared. Resetting dimensions to default slider values.")
113
+ return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
114
+
115
+ try:
116
+ new_h, new_w = _calculate_new_dimensions_wan(
117
+ uploaded_pil_image,
118
+ MOD_VALUE,
119
+ DEFAULT_TARGET_AREA,
120
+ SLIDER_MIN_H, SLIDER_MAX_H,
121
+ SLIDER_MIN_W, SLIDER_MAX_W,
122
+ DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE
123
+ )
124
+ return gr.update(value=new_h), gr.update(value=new_w)
125
+ except Exception as e:
126
+ logger.error(f"Error auto-adjusting H/W from image: {e}", exc_info=True)
127
+ # On error, revert to defaults or keep current. Defaults are safer.
128
+ return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
129
 
130
 
131
  # --- Gradio Interface Function ---
132
+ @spaces.GPU # type: ignore
133
+ def generate_video(input_image: Image.Image, prompt: str, negative_prompt: str,
134
+ height: int, width: int, num_frames: int,
135
+ guidance_scale: float, steps: int, fps_for_conditioning_and_export: int,
136
+ progress=gr.Progress(track_tqdm=True)):
137
+ if input_image is None:
138
+ raise gr.Error("Please upload an input image.")
139
+
140
  logger.info("Starting video generation...")
141
+ logger.info(f" Input Image: Uploaded (Original size: {input_image.size if input_image else 'N/A'})")
142
  logger.info(f" Prompt: {prompt}")
143
  logger.info(f" Negative Prompt: {negative_prompt if negative_prompt else 'None'}")
144
+ logger.info(f" Target Output Height: {height}, Target Output Width: {width}")
145
+ logger.info(f" Num Frames: {num_frames}, FPS for conditioning & export: {fps_for_conditioning_and_export}")
146
+ logger.info(f" Guidance Scale: {guidance_scale}, Steps: {steps}")
147
+
148
 
149
+ target_height = int(height)
150
+ target_width = int(width)
151
  num_frames = int(num_frames)
152
+ fps_val = int(fps_for_conditioning_and_export)
153
+ guidance_scale_val = float(guidance_scale)
154
+ steps_val = int(steps)
155
+
156
+ # Resize the input PIL image to the target dimensions for the pipeline
157
+ resized_image = input_image.resize((target_width, target_height))
158
+ logger.info(f" Input image resized to: {resized_image.size} for pipeline input.")
159
+
160
 
161
  with torch.inference_mode():
162
  output_frames_list = pipe(
163
+ image=resized_image,
164
  prompt=prompt,
165
  negative_prompt=negative_prompt,
166
+ height=target_height,
167
+ width=target_width,
168
  num_frames=num_frames,
169
+ guidance_scale=guidance_scale_val,
170
+ num_inference_steps=steps_val,
171
+ fps=fps_val, # For conditioning
172
+ generator=torch.Generator(device="cuda").manual_seed(0) # For reproducibility
173
  ).frames[0]
174
 
175
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
176
  video_path = tmpfile.name
177
 
178
+ export_to_video(output_frames_list, video_path, fps=fps_val) # For export
179
  logger.info(f"Video successfully generated and saved to {video_path}")
180
  return video_path
181
 
182
  # --- Gradio UI Definition ---
183
+ default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
184
+ default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards, watermark, text, signature"
185
+ penguin_image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png"
186
+
187
 
188
  with gr.Blocks() as demo:
189
  gr.Markdown(f"""
190
+ # Image-to-Video with Wan 2.1 I2V (14B) + CausVid LoRA
191
+ Powered by `diffusers` and `{MODEL_ID}`.
192
  Model is loaded into memory when the app starts. This might take a few minutes.
193
  Ensure you have a GPU with sufficient VRAM (e.g., ~24GB+ for these default settings).
194
+ Output Height and Width must be multiples of **{MOD_VALUE}**. Uploading an image will suggest dimensions based on its aspect ratio and a target area.
195
  """)
196
  with gr.Row():
197
  with gr.Column(scale=2):
198
+ input_image_component = gr.Image(type="pil", label="Input Image (will be resized to target H/W)")
199
+ prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v, lines=3)
200
+
201
+ with gr.Accordion("Advanced Settings", open=False):
202
+ negative_prompt_input = gr.Textbox(
203
+ label="Negative Prompt (Optional)",
204
+ value=default_negative_prompt,
205
+ lines=3
206
+ )
207
+ with gr.Row():
208
+ height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label=f"Output Height (multiple of {MOD_VALUE})")
209
+ width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label=f"Output Width (multiple of {MOD_VALUE})")
210
+ with gr.Row():
211
+ num_frames_input = gr.Slider(minimum=8, maximum=81, step=1, value=25, label="Number of Frames")
212
+ fps_input = gr.Slider(minimum=5, maximum=30, step=1, value=16, label="FPS (for conditioning & export)")
213
+ steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=4, label="Inference Steps")
214
+ guidance_scale_input = gr.Slider(minimum=0.0, maximum=20.0, step=0.5, value=1.0, label="Guidance Scale")
215
 
216
  generate_button = gr.Button("Generate Video", variant="primary")
217
 
218
  with gr.Column(scale=3):
219
+ video_output = gr.Video(label="Generated Video", interactive=False)
220
+
221
+ # Event handler for image upload/clear to adjust H/W sliders
222
+ input_image_component.change(
223
+ fn=handle_image_upload_for_dims_wan,
224
+ inputs=[input_image_component, height_input, width_input], # Pass current slider values
225
+ outputs=[height_input, width_input]
226
+ )
227
+
228
+ inputs_for_click_and_examples = [
229
+ input_image_component,
230
+ prompt_input,
231
+ negative_prompt_input,
232
+ height_input,
233
+ width_input,
234
+ num_frames_input,
235
+ guidance_scale_input,
236
+ steps_slider,
237
+ fps_input
238
+ ]
239
 
240
  generate_button.click(
241
  fn=generate_video,
242
+ inputs=inputs_for_click_and_examples,
 
 
 
 
 
 
 
 
 
243
  outputs=video_output
244
  )
245
 
246
  gr.Examples(
247
  examples=[
248
+ [penguin_image_url, "a penguin playfully dancing in the snow, Antarctica", default_negative_prompt, DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE, 25, 1.0, 4, 16]
 
 
 
249
  ],
250
+ inputs=inputs_for_click_and_examples,
251
  outputs=video_output,
252
  fn=generate_video,
253
  cache_examples=False