Update handler.py
Browse files- handler.py +19 -0
handler.py
CHANGED
@@ -8,6 +8,7 @@ from diffusers import HunyuanVideoPipeline
|
|
8 |
from varnish import Varnish
|
9 |
|
10 |
from enhance_a_video import enable_enhance, inject_enhance_for_hunyuanvideo, set_enhance_weight
|
|
|
11 |
|
12 |
# Configure logging
|
13 |
logging.basicConfig(level=logging.INFO)
|
@@ -42,6 +43,10 @@ class GenerationConfig:
|
|
42 |
audio_prompt: str = ""
|
43 |
audio_negative_prompt: str = "voices, voice, talking, speaking, speech"
|
44 |
|
|
|
|
|
|
|
|
|
45 |
def validate_and_adjust(self) -> 'GenerationConfig':
|
46 |
"""Validate and adjust parameters"""
|
47 |
# Ensure num_frames follows 4k + 1 format
|
@@ -128,6 +133,8 @@ class EndpointHandler:
|
|
128 |
enable_audio=params.get("enable_audio", False),
|
129 |
audio_prompt=params.get("audio_prompt", ""),
|
130 |
audio_negative_prompt=params.get("audio_negative_prompt", "voices, voice, talking, speaking, speech"),
|
|
|
|
|
131 |
).validate_and_adjust()
|
132 |
|
133 |
try:
|
@@ -139,6 +146,16 @@ class EndpointHandler:
|
|
139 |
else:
|
140 |
generator = None
|
141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
# Generate video frames
|
143 |
with torch.inference_mode():
|
144 |
output = self.pipeline(
|
@@ -192,6 +209,8 @@ class EndpointHandler:
|
|
192 |
"fps": result.metadata.fps,
|
193 |
"duration": result.metadata.duration,
|
194 |
"seed": config.seed,
|
|
|
|
|
195 |
}
|
196 |
}
|
197 |
|
|
|
8 |
from varnish import Varnish
|
9 |
|
10 |
from enhance_a_video import enable_enhance, inject_enhance_for_hunyuanvideo, set_enhance_weight
|
11 |
+
from teacache import enable_teacache, disable_teacache
|
12 |
|
13 |
# Configure logging
|
14 |
logging.basicConfig(level=logging.INFO)
|
|
|
43 |
audio_prompt: str = ""
|
44 |
audio_negative_prompt: str = "voices, voice, talking, speaking, speech"
|
45 |
|
46 |
+
# TeaCache settings
|
47 |
+
enable_teacache: bool = True
|
48 |
+
teacache_threshold: float = 0.15
|
49 |
+
|
50 |
def validate_and_adjust(self) -> 'GenerationConfig':
|
51 |
"""Validate and adjust parameters"""
|
52 |
# Ensure num_frames follows 4k + 1 format
|
|
|
133 |
enable_audio=params.get("enable_audio", False),
|
134 |
audio_prompt=params.get("audio_prompt", ""),
|
135 |
audio_negative_prompt=params.get("audio_negative_prompt", "voices, voice, talking, speaking, speech"),
|
136 |
+
enable_teacache=params.get("enable_teacache", True),
|
137 |
+
teacache_threshold=params.get("teacache_threshold", 0.15)
|
138 |
).validate_and_adjust()
|
139 |
|
140 |
try:
|
|
|
146 |
else:
|
147 |
generator = None
|
148 |
|
149 |
+
# Configure TeaCache
|
150 |
+
if config.enable_teacache:
|
151 |
+
enable_teacache(
|
152 |
+
self.pipeline.transformer,
|
153 |
+
num_inference_steps=config.num_inference_steps,
|
154 |
+
rel_l1_thresh=config.teacache_threshold
|
155 |
+
)
|
156 |
+
else:
|
157 |
+
disable_teacache(self.pipeline.transformer)
|
158 |
+
|
159 |
# Generate video frames
|
160 |
with torch.inference_mode():
|
161 |
output = self.pipeline(
|
|
|
209 |
"fps": result.metadata.fps,
|
210 |
"duration": result.metadata.duration,
|
211 |
"seed": config.seed,
|
212 |
+
"teacache_enabled": config.enable_teacache,
|
213 |
+
"teacache_threshold": config.teacache_threshold if config.enable_teacache else 0,
|
214 |
}
|
215 |
}
|
216 |
|