jbilcke-hf HF staff commited on
Commit
4585b1e
·
verified ·
1 Parent(s): a2c4d6f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +16 -3
handler.py CHANGED
@@ -13,11 +13,9 @@ import os
13
  import numpy as np
14
  import torch
15
  from diffusers import LTXPipeline, LTXImageToVideoPipeline
 
16
  from PIL import Image
17
 
18
- from enhance_a_video import enable_enhance, set_enhance_weight
19
- from enhance import inject_enhance_for_ltx
20
-
21
  from teacache import TeaCacheConfig, enable_teacache, disable_teacache
22
  from varnish import Varnish
23
  from varnish.utils import is_truthy, process_input_image
@@ -150,6 +148,7 @@ class EndpointHandler:
150
  model_path,
151
  torch_dtype=torch.bfloat16
152
  ).to("cuda")
 
153
  else:
154
  # Initialize models with bfloat16 precision
155
  self.text_to_video = LTXPipeline.from_pretrained(
@@ -447,9 +446,23 @@ class EndpointHandler:
447
  config.input_image_quality,
448
  )
449
  generation_kwargs["image"] = processed_image
 
 
 
 
 
 
 
450
  frames = self.image_to_video(**generation_kwargs).frames
451
  else:
452
  self._configure_teacache(self.text_to_video, config)
 
 
 
 
 
 
 
453
  frames = self.text_to_video(**generation_kwargs).frames
454
 
455
  try:
 
13
  import numpy as np
14
  import torch
15
  from diffusers import LTXPipeline, LTXImageToVideoPipeline
16
+ from diffusers.hooks import apply_enhance_a_video, EnhanceAVideoConfig
17
  from PIL import Image
18
 
 
 
 
19
  from teacache import TeaCacheConfig, enable_teacache, disable_teacache
20
  from varnish import Varnish
21
  from varnish.utils import is_truthy, process_input_image
 
148
  model_path,
149
  torch_dtype=torch.bfloat16
150
  ).to("cuda")
151
+
152
  else:
153
  # Initialize models with bfloat16 precision
154
  self.text_to_video = LTXPipeline.from_pretrained(
 
446
  config.input_image_quality,
447
  )
448
  generation_kwargs["image"] = processed_image
449
+
450
+ apply_enhance_a_video(pipe.image_to_video.transformer, EnhanceAVideoConfig(
451
+ weight=config.enhance_a_video_weight if config.enable_enhance_a_video else 0.0,
452
+ num_frames_callback=lambda: (config.num_frames - 1),
453
+ _attention_type=1
454
+ ))
455
+
456
  frames = self.image_to_video(**generation_kwargs).frames
457
  else:
458
  self._configure_teacache(self.text_to_video, config)
459
+
460
+ apply_enhance_a_video(pipe.text_to_video.transformer, EnhanceAVideoConfig(
461
+ weight=config.enhance_a_video_weight if config.enable_enhance_a_video else 0.0,
462
+ num_frames_callback=lambda: (config.num_frames - 1),
463
+ _attention_type=1
464
+ ))
465
+
466
  frames = self.text_to_video(**generation_kwargs).frames
467
 
468
  try: