jbilcke-hf HF staff commited on
Commit
8b4a69c
·
verified ·
1 Parent(s): fb0b962

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +19 -0
handler.py CHANGED
@@ -8,6 +8,7 @@ from diffusers import HunyuanVideoPipeline
8
  from varnish import Varnish
9
 
10
  from enhance_a_video import enable_enhance, inject_enhance_for_hunyuanvideo, set_enhance_weight
 
11
 
12
  # Configure logging
13
  logging.basicConfig(level=logging.INFO)
@@ -42,6 +43,10 @@ class GenerationConfig:
42
  audio_prompt: str = ""
43
  audio_negative_prompt: str = "voices, voice, talking, speaking, speech"
44
 
 
 
 
 
45
  def validate_and_adjust(self) -> 'GenerationConfig':
46
  """Validate and adjust parameters"""
47
  # Ensure num_frames follows 4k + 1 format
@@ -128,6 +133,8 @@ class EndpointHandler:
128
  enable_audio=params.get("enable_audio", False),
129
  audio_prompt=params.get("audio_prompt", ""),
130
  audio_negative_prompt=params.get("audio_negative_prompt", "voices, voice, talking, speaking, speech"),
 
 
131
  ).validate_and_adjust()
132
 
133
  try:
@@ -139,6 +146,16 @@ class EndpointHandler:
139
  else:
140
  generator = None
141
 
 
 
 
 
 
 
 
 
 
 
142
  # Generate video frames
143
  with torch.inference_mode():
144
  output = self.pipeline(
@@ -192,6 +209,8 @@ class EndpointHandler:
192
  "fps": result.metadata.fps,
193
  "duration": result.metadata.duration,
194
  "seed": config.seed,
 
 
195
  }
196
  }
197
 
 
8
  from varnish import Varnish
9
 
10
  from enhance_a_video import enable_enhance, inject_enhance_for_hunyuanvideo, set_enhance_weight
11
+ from teacache import enable_teacache, disable_teacache
12
 
13
  # Configure logging
14
  logging.basicConfig(level=logging.INFO)
 
43
  audio_prompt: str = ""
44
  audio_negative_prompt: str = "voices, voice, talking, speaking, speech"
45
 
46
+ # TeaCache settings
47
+ enable_teacache: bool = True
48
+ teacache_threshold: float = 0.15
49
+
50
  def validate_and_adjust(self) -> 'GenerationConfig':
51
  """Validate and adjust parameters"""
52
  # Ensure num_frames follows 4k + 1 format
 
133
  enable_audio=params.get("enable_audio", False),
134
  audio_prompt=params.get("audio_prompt", ""),
135
  audio_negative_prompt=params.get("audio_negative_prompt", "voices, voice, talking, speaking, speech"),
136
+ enable_teacache=params.get("enable_teacache", True),
137
+ teacache_threshold=params.get("teacache_threshold", 0.15)
138
  ).validate_and_adjust()
139
 
140
  try:
 
146
  else:
147
  generator = None
148
 
149
+ # Configure TeaCache
150
+ if config.enable_teacache:
151
+ enable_teacache(
152
+ self.pipeline.transformer,
153
+ num_inference_steps=config.num_inference_steps,
154
+ rel_l1_thresh=config.teacache_threshold
155
+ )
156
+ else:
157
+ disable_teacache(self.pipeline.transformer)
158
+
159
  # Generate video frames
160
  with torch.inference_mode():
161
  output = self.pipeline(
 
209
  "fps": result.metadata.fps,
210
  "duration": result.metadata.duration,
211
  "seed": config.seed,
212
+ "teacache_enabled": config.enable_teacache,
213
+ "teacache_threshold": config.teacache_threshold if config.enable_teacache else 0,
214
  }
215
  }
216