Spaces:
Runtime error
Runtime error
Optimize video generation by adding torch.no_grad() context to reduce memory usage
Browse files
app.py
CHANGED
|
@@ -198,23 +198,24 @@ def generate_video_from_text(
|
|
| 198 |
def gradio_progress_callback(self, step, timestep, kwargs):
|
| 199 |
progress((step + 1) / num_inference_steps)
|
| 200 |
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
|
|
|
| 218 |
|
| 219 |
output_path = tempfile.mktemp(suffix=".mp4")
|
| 220 |
print(images.shape)
|
|
@@ -268,23 +269,24 @@ def generate_video_from_image(
|
|
| 268 |
def gradio_progress_callback(self, step, timestep, kwargs):
|
| 269 |
progress((step + 1) / num_inference_steps)
|
| 270 |
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
|
|
|
| 288 |
|
| 289 |
output_path = tempfile.mktemp(suffix=".mp4")
|
| 290 |
video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
|
|
|
|
| 198 |
def gradio_progress_callback(self, step, timestep, kwargs):
|
| 199 |
progress((step + 1) / num_inference_steps)
|
| 200 |
|
| 201 |
+
with torch.no_grad():
|
| 202 |
+
images = pipeline(
|
| 203 |
+
num_inference_steps=num_inference_steps,
|
| 204 |
+
num_images_per_prompt=1,
|
| 205 |
+
guidance_scale=guidance_scale,
|
| 206 |
+
generator=generator,
|
| 207 |
+
output_type="pt",
|
| 208 |
+
height=height,
|
| 209 |
+
width=width,
|
| 210 |
+
num_frames=num_frames,
|
| 211 |
+
frame_rate=frame_rate,
|
| 212 |
+
**sample,
|
| 213 |
+
is_video=True,
|
| 214 |
+
vae_per_channel_normalize=True,
|
| 215 |
+
conditioning_method=ConditioningMethod.FIRST_FRAME,
|
| 216 |
+
mixed_precision=True,
|
| 217 |
+
callback_on_step_end=gradio_progress_callback,
|
| 218 |
+
).images
|
| 219 |
|
| 220 |
output_path = tempfile.mktemp(suffix=".mp4")
|
| 221 |
print(images.shape)
|
|
|
|
| 269 |
def gradio_progress_callback(self, step, timestep, kwargs):
|
| 270 |
progress((step + 1) / num_inference_steps)
|
| 271 |
|
| 272 |
+
with torch.no_grad():
|
| 273 |
+
images = pipeline(
|
| 274 |
+
num_inference_steps=num_inference_steps,
|
| 275 |
+
num_images_per_prompt=1,
|
| 276 |
+
guidance_scale=guidance_scale,
|
| 277 |
+
generator=generator,
|
| 278 |
+
output_type="pt",
|
| 279 |
+
height=height,
|
| 280 |
+
width=width,
|
| 281 |
+
num_frames=num_frames,
|
| 282 |
+
frame_rate=frame_rate,
|
| 283 |
+
**sample,
|
| 284 |
+
is_video=True,
|
| 285 |
+
vae_per_channel_normalize=True,
|
| 286 |
+
conditioning_method=ConditioningMethod.FIRST_FRAME,
|
| 287 |
+
mixed_precision=True,
|
| 288 |
+
callback_on_step_end=gradio_progress_callback,
|
| 289 |
+
).images
|
| 290 |
|
| 291 |
output_path = tempfile.mktemp(suffix=".mp4")
|
| 292 |
video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
|