jbilcke-hf HF staff commited on
Commit
609f5cd
·
verified ·
1 Parent(s): 3715d88

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +15 -10
handler.py CHANGED
@@ -9,10 +9,10 @@ import torch
9
 
10
  # note: there is no HunyuanImageToVideoPipeline yet in Diffusers
11
  from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
 
12
  from varnish import Varnish
13
  from varnish.utils import is_truthy, process_input_image
14
 
15
- from enhance_a_video import enable_enhance, inject_enhance_for_hunyuanvideo, set_enhance_weight
16
  from teacache import enable_teacache, disable_teacache
17
 
18
  # Configure logging
@@ -258,13 +258,6 @@ class EndpointHandler:
258
  # disable_teacache(self.pipeline.transformer)
259
 
260
  with torch.inference_mode():
261
- # Configure Enhance-A-Video weight if enabled
262
- if config.enable_enhance_a_video:
263
- set_enhance_weight(config.enhance_a_video_weight)
264
- enable_enhance()
265
- else:
266
- # Reset enhance weight to 0 to effectively disable it
267
- set_enhance_weight(0)
268
 
269
  # Prepare generation parameters
270
  generation_kwargs = {
@@ -314,8 +307,6 @@ class EndpointHandler:
314
  if config.lora_model_trigger:
315
  generation_kwargs["prompt"] = f"{config.lora_model_trigger} {generation_kwargs['prompt']}"
316
 
317
-
318
-
319
  # Check if image-to-video generation is requested
320
  if support_image_prompt and input_image:
321
  self._configure_teacache(self.image_to_video, config)
@@ -326,9 +317,23 @@ class EndpointHandler:
326
  config.input_image_quality,
327
  )
328
  generation_kwargs["image"] = processed_image
 
 
 
 
 
 
 
329
  frames = self.image_to_video(**generation_kwargs).frames
330
  else:
331
  self._configure_teacache(self.text_to_video, config)
 
 
 
 
 
 
 
332
  frames = self.text_to_video(**generation_kwargs).frames
333
 
334
 
 
9
 
10
  # note: there is no HunyuanImageToVideoPipeline yet in Diffusers
11
  from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
12
+ from diffusers.hooks import apply_enhance_a_video, EnhanceAVideoConfig
13
  from varnish import Varnish
14
  from varnish.utils import is_truthy, process_input_image
15
 
 
16
  from teacache import enable_teacache, disable_teacache
17
 
18
  # Configure logging
 
258
  # disable_teacache(self.pipeline.transformer)
259
 
260
  with torch.inference_mode():
 
 
 
 
 
 
 
261
 
262
  # Prepare generation parameters
263
  generation_kwargs = {
 
307
  if config.lora_model_trigger:
308
  generation_kwargs["prompt"] = f"{config.lora_model_trigger} {generation_kwargs['prompt']}"
309
 
 
 
310
  # Check if image-to-video generation is requested
311
  if support_image_prompt and input_image:
312
  self._configure_teacache(self.image_to_video, config)
 
317
  config.input_image_quality,
318
  )
319
  generation_kwargs["image"] = processed_image
320
+
321
+ apply_enhance_a_video(pipe.image_to_video.transformer, EnhanceAVideoConfig(
322
+ weight=config.enhance_a_video_weight if config.enable_enhance_a_video else 0.0,
323
+ num_frames_callback=lambda: (config.num_frames - 1),
324
+ _attention_type=1
325
+ ))
326
+
327
  frames = self.image_to_video(**generation_kwargs).frames
328
  else:
329
  self._configure_teacache(self.text_to_video, config)
330
+
331
+ apply_enhance_a_video(pipe.text_to_video.transformer, EnhanceAVideoConfig(
332
+ weight=config.enhance_a_video_weight if config.enable_enhance_a_video else 0.0,
333
+ num_frames_callback=lambda: (config.num_frames - 1),
334
+ _attention_type=1
335
+ ))
336
+
337
  frames = self.text_to_video(**generation_kwargs).frames
338
 
339