Update handler.py
Browse files- handler.py +104 -140
handler.py
CHANGED
@@ -5,12 +5,11 @@ import logging
|
|
5 |
import random
|
6 |
import traceback
|
7 |
import torch
|
8 |
-
from
|
|
|
|
|
9 |
from varnish import Varnish
|
10 |
|
11 |
-
from enhance_a_video import enable_enhance, inject_enhance_for_hunyuanvideo, set_enhance_weight
|
12 |
-
from teacache import enable_teacache, disable_teacache
|
13 |
-
|
14 |
# Configure logging
|
15 |
logging.basicConfig(level=logging.INFO)
|
16 |
logger = logging.getLogger(__name__)
|
@@ -20,14 +19,14 @@ class GenerationConfig:
|
|
20 |
"""Configuration for video generation"""
|
21 |
# Content settings
|
22 |
prompt: str
|
23 |
-
negative_prompt: str = ""
|
24 |
|
25 |
# Model settings
|
26 |
-
num_frames: int =
|
27 |
-
height: int =
|
28 |
-
width: int =
|
29 |
-
num_inference_steps: int =
|
30 |
-
guidance_scale: float =
|
31 |
|
32 |
# Reproducibility
|
33 |
seed: int = -1
|
@@ -44,21 +43,18 @@ class GenerationConfig:
|
|
44 |
audio_prompt: str = ""
|
45 |
audio_negative_prompt: str = "voices, voice, talking, speaking, speech"
|
46 |
|
47 |
-
#
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
55 |
|
56 |
def validate_and_adjust(self) -> 'GenerationConfig':
|
57 |
"""Validate and adjust parameters"""
|
58 |
-
# Ensure num_frames follows 4k + 1 format
|
59 |
-
k = (self.num_frames - 1) // 4
|
60 |
-
self.num_frames = (k * 4) + 1
|
61 |
-
|
62 |
# Set random seed if not specified
|
63 |
if self.seed == -1:
|
64 |
self.seed = random.randint(0, 2**32 - 1)
|
@@ -66,7 +62,7 @@ class GenerationConfig:
|
|
66 |
return self
|
67 |
|
68 |
class EndpointHandler:
|
69 |
-
"""Handles video generation requests using
|
70 |
|
71 |
def __init__(self, path: str = ""):
|
72 |
"""Initialize handler with models
|
@@ -76,32 +72,20 @@ class EndpointHandler:
|
|
76 |
"""
|
77 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
78 |
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
path,
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
)
|
86 |
-
inject_enhance_for_hunyuanvideo(transformer)
|
87 |
-
|
88 |
-
# Initialize HunyuanVideo pipeline with the enhanced transformer
|
89 |
-
self.pipeline = HunyuanVideoPipeline.from_pretrained(
|
90 |
-
path,
|
91 |
-
transformer=transformer,
|
92 |
-
torch_dtype=torch.float16,
|
93 |
-
).to(self.device)
|
94 |
-
|
95 |
-
|
96 |
-
# Initialize text encoders in float16
|
97 |
-
self.pipeline.text_encoder = self.pipeline.text_encoder.half()
|
98 |
-
self.pipeline.text_encoder_2 = self.pipeline.text_encoder_2.half()
|
99 |
-
|
100 |
-
# Initialize transformer in bfloat16
|
101 |
-
self.pipeline.transformer = self.pipeline.transformer.to(torch.bfloat16)
|
102 |
-
|
103 |
-
# Initialize VAE in float16
|
104 |
-
self.pipeline.vae = self.pipeline.vae.half()
|
105 |
|
106 |
# Initialize Varnish for post-processing
|
107 |
self.varnish = Varnish(
|
@@ -136,11 +120,11 @@ class EndpointHandler:
|
|
136 |
config = GenerationConfig(
|
137 |
prompt=prompt,
|
138 |
negative_prompt=params.get("negative_prompt", ""),
|
139 |
-
num_frames=params.get("num_frames",
|
140 |
-
height=params.get("height",
|
141 |
-
width=params.get("width",
|
142 |
-
num_inference_steps=params.get("num_inference_steps",
|
143 |
-
guidance_scale=params.get("guidance_scale",
|
144 |
seed=params.get("seed", -1),
|
145 |
fps=params.get("fps", 30),
|
146 |
double_num_frames=params.get("double_num_frames", False),
|
@@ -150,13 +134,14 @@ class EndpointHandler:
|
|
150 |
enable_audio=params.get("enable_audio", False),
|
151 |
audio_prompt=params.get("audio_prompt", ""),
|
152 |
audio_negative_prompt=params.get("audio_negative_prompt", "voices, voice, talking, speaking, speech"),
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
|
|
160 |
).validate_and_adjust()
|
161 |
|
162 |
try:
|
@@ -164,90 +149,69 @@ class EndpointHandler:
|
|
164 |
if config.seed != -1:
|
165 |
torch.manual_seed(config.seed)
|
166 |
random.seed(config.seed)
|
167 |
-
generator = torch.Generator(device=self.device).manual_seed(config.seed)
|
168 |
-
else:
|
169 |
-
generator = None
|
170 |
-
|
171 |
-
# Configure TeaCache
|
172 |
-
#if config.enable_teacache:
|
173 |
-
# enable_teacache(
|
174 |
-
# self.pipeline.transformer,
|
175 |
-
# num_inference_steps=config.num_inference_steps,
|
176 |
-
# rel_l1_thresh=config.teacache_threshold
|
177 |
-
# )
|
178 |
-
#else:
|
179 |
-
# disable_teacache(self.pipeline.transformer)
|
180 |
-
|
181 |
-
# Configure Enhance-A-Video weight if enabled
|
182 |
-
if config.enable_enhance_a_video:
|
183 |
-
set_enhance_weight(config.enhance_a_video_weight)
|
184 |
-
enable_enhance()
|
185 |
-
else:
|
186 |
-
# Reset enhance weight to 0 to effectively disable it
|
187 |
-
set_enhance_weight(0)
|
188 |
-
|
189 |
-
# Generate video frames
|
190 |
-
with torch.inference_mode():
|
191 |
-
output = self.pipeline(
|
192 |
-
prompt=config.prompt,
|
193 |
-
|
194 |
-
# Failed to generate video: HunyuanVideoPipeline.__call__() got an unexpected keyword argument 'negative_prompt'
|
195 |
-
#negative_prompt=config.negative_prompt,
|
196 |
-
|
197 |
-
num_frames=config.num_frames,
|
198 |
-
height=config.height,
|
199 |
-
width=config.width,
|
200 |
-
num_inference_steps=config.num_inference_steps,
|
201 |
-
guidance_scale=config.guidance_scale,
|
202 |
-
generator=generator,
|
203 |
-
output_type="pt",
|
204 |
-
).frames
|
205 |
-
|
206 |
-
# Process with Varnish
|
207 |
-
import asyncio
|
208 |
-
try:
|
209 |
-
loop = asyncio.get_event_loop()
|
210 |
-
except RuntimeError:
|
211 |
-
loop = asyncio.new_event_loop()
|
212 |
-
asyncio.set_event_loop(loop)
|
213 |
-
|
214 |
-
result = loop.run_until_complete(
|
215 |
-
self.varnish(
|
216 |
-
input_data=output,
|
217 |
-
fps=config.fps,
|
218 |
-
double_num_frames=config.double_num_frames,
|
219 |
-
super_resolution=config.super_resolution,
|
220 |
-
grain_amount=config.grain_amount,
|
221 |
-
enable_audio=config.enable_audio,
|
222 |
-
audio_prompt=config.audio_prompt,
|
223 |
-
audio_negative_prompt=config.audio_negative_prompt,
|
224 |
-
)
|
225 |
-
)
|
226 |
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
)
|
|
|
234 |
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
"
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
}
|
|
|
251 |
|
252 |
except Exception as e:
|
253 |
message = f"Error generating video ({str(e)})\n{traceback.format_exc()}"
|
|
|
5 |
import random
|
6 |
import traceback
|
7 |
import torch
|
8 |
+
from skyreelsinfer import TaskType
|
9 |
+
from skyreelsinfer.offload import OffloadConfig
|
10 |
+
from skyreelsinfer.skyreels_video_infer import SkyReelsVideoInfer
|
11 |
from varnish import Varnish
|
12 |
|
|
|
|
|
|
|
13 |
# Configure logging
|
14 |
logging.basicConfig(level=logging.INFO)
|
15 |
logger = logging.getLogger(__name__)
|
|
|
19 |
"""Configuration for video generation"""
|
20 |
# Content settings
|
21 |
prompt: str
|
22 |
+
negative_prompt: str = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
|
23 |
|
24 |
# Model settings
|
25 |
+
num_frames: int = 97 # SkyReels default
|
26 |
+
height: int = 544 # SkyReels default
|
27 |
+
width: int = 960 # SkyReels default
|
28 |
+
num_inference_steps: int = 30
|
29 |
+
guidance_scale: float = 6.0
|
30 |
|
31 |
# Reproducibility
|
32 |
seed: int = -1
|
|
|
43 |
audio_prompt: str = ""
|
44 |
audio_negative_prompt: str = "voices, voice, talking, speaking, speech"
|
45 |
|
46 |
+
# Model-specific settings
|
47 |
+
embedded_guidance_scale: float = 1.0
|
48 |
+
quant_model: bool = True
|
49 |
+
gpu_num: int = 1
|
50 |
+
offload: bool = True
|
51 |
+
high_cpu_memory: bool = True
|
52 |
+
parameters_level: bool = False
|
53 |
+
compiler_transformer: bool = False
|
54 |
+
sequence_batch: bool = False
|
55 |
|
56 |
def validate_and_adjust(self) -> 'GenerationConfig':
|
57 |
"""Validate and adjust parameters"""
|
|
|
|
|
|
|
|
|
58 |
# Set random seed if not specified
|
59 |
if self.seed == -1:
|
60 |
self.seed = random.randint(0, 2**32 - 1)
|
|
|
62 |
return self
|
63 |
|
64 |
class EndpointHandler:
|
65 |
+
"""Handles video generation requests using SkyReels and Varnish"""
|
66 |
|
67 |
def __init__(self, path: str = ""):
|
68 |
"""Initialize handler with models
|
|
|
72 |
"""
|
73 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
74 |
|
75 |
+
# Initialize SkyReelsVideoInfer
|
76 |
+
self.predictor = SkyReelsVideoInfer(
|
77 |
+
task_type=TaskType.T2V,
|
78 |
+
model_id=path or "Skywork/SkyReels-V1",
|
79 |
+
quant_model=True, # Enable quantization by default
|
80 |
+
world_size=1, # Single GPU by default
|
81 |
+
is_offload=True, # Enable offloading by default
|
82 |
+
offload_config=OffloadConfig(
|
83 |
+
high_cpu_memory=True,
|
84 |
+
parameters_level=False,
|
85 |
+
compiler_transformer=False,
|
86 |
+
),
|
87 |
+
enable_cfg_parallel=True
|
88 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
# Initialize Varnish for post-processing
|
91 |
self.varnish = Varnish(
|
|
|
120 |
config = GenerationConfig(
|
121 |
prompt=prompt,
|
122 |
negative_prompt=params.get("negative_prompt", ""),
|
123 |
+
num_frames=params.get("num_frames", 97),
|
124 |
+
height=params.get("height", 544),
|
125 |
+
width=params.get("width", 960),
|
126 |
+
num_inference_steps=params.get("num_inference_steps", 30),
|
127 |
+
guidance_scale=params.get("guidance_scale", 6.0),
|
128 |
seed=params.get("seed", -1),
|
129 |
fps=params.get("fps", 30),
|
130 |
double_num_frames=params.get("double_num_frames", False),
|
|
|
134 |
enable_audio=params.get("enable_audio", False),
|
135 |
audio_prompt=params.get("audio_prompt", ""),
|
136 |
audio_negative_prompt=params.get("audio_negative_prompt", "voices, voice, talking, speaking, speech"),
|
137 |
+
embedded_guidance_scale=params.get("embedded_guidance_scale", 1.0),
|
138 |
+
quant_model=params.get("quant_model", True),
|
139 |
+
gpu_num=params.get("gpu_num", 1),
|
140 |
+
offload=params.get("offload", True),
|
141 |
+
high_cpu_memory=params.get("high_cpu_memory", True),
|
142 |
+
parameters_level=params.get("parameters_level", False),
|
143 |
+
compiler_transformer=params.get("compiler_transformer", False),
|
144 |
+
sequence_batch=params.get("sequence_batch", False)
|
145 |
).validate_and_adjust()
|
146 |
|
147 |
try:
|
|
|
149 |
if config.seed != -1:
|
150 |
torch.manual_seed(config.seed)
|
151 |
random.seed(config.seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
+
# Prepare generation parameters
|
154 |
+
generation_kwargs = {
|
155 |
+
"prompt": f"FPS-{config.fps}, {config.prompt}", # SkyReels expects FPS in prompt
|
156 |
+
"negative_prompt": config.negative_prompt,
|
157 |
+
"height": config.height,
|
158 |
+
"width": config.width,
|
159 |
+
"num_frames": config.num_frames,
|
160 |
+
"num_inference_steps": config.num_inference_steps,
|
161 |
+
"guidance_scale": config.guidance_scale,
|
162 |
+
"embedded_guidance_scale": config.embedded_guidance_scale,
|
163 |
+
"seed": config.seed,
|
164 |
+
"cfg_for": config.sequence_batch
|
165 |
+
}
|
166 |
+
|
167 |
+
# Generate video frames using SkyReels
|
168 |
+
output = self.predictor.inference(generation_kwargs)
|
169 |
+
|
170 |
+
# Process with Varnish
|
171 |
+
import asyncio
|
172 |
+
try:
|
173 |
+
loop = asyncio.get_event_loop()
|
174 |
+
except RuntimeError:
|
175 |
+
loop = asyncio.new_event_loop()
|
176 |
+
asyncio.set_event_loop(loop)
|
177 |
+
|
178 |
+
result = loop.run_until_complete(
|
179 |
+
self.varnish(
|
180 |
+
input_data=output,
|
181 |
+
fps=config.fps,
|
182 |
+
double_num_frames=config.double_num_frames,
|
183 |
+
super_resolution=config.super_resolution,
|
184 |
+
grain_amount=config.grain_amount,
|
185 |
+
enable_audio=config.enable_audio,
|
186 |
+
audio_prompt=config.audio_prompt,
|
187 |
+
audio_negative_prompt=config.audio_negative_prompt,
|
188 |
)
|
189 |
+
)
|
190 |
|
191 |
+
# Get video data URI
|
192 |
+
video_uri = loop.run_until_complete(
|
193 |
+
result.write(
|
194 |
+
type="data-uri",
|
195 |
+
quality=config.quality
|
196 |
+
)
|
197 |
+
)
|
198 |
+
|
199 |
+
return {
|
200 |
+
"video": video_uri,
|
201 |
+
"content-type": "video/mp4",
|
202 |
+
"metadata": {
|
203 |
+
"width": result.metadata.width,
|
204 |
+
"height": result.metadata.height,
|
205 |
+
"num_frames": result.metadata.frame_count,
|
206 |
+
"fps": result.metadata.fps,
|
207 |
+
"duration": result.metadata.duration,
|
208 |
+
"seed": config.seed,
|
209 |
+
"gpu_num": config.gpu_num,
|
210 |
+
"quant_model": config.quant_model,
|
211 |
+
"guidance_scale": config.guidance_scale,
|
212 |
+
"embedded_guidance_scale": config.embedded_guidance_scale
|
213 |
}
|
214 |
+
}
|
215 |
|
216 |
except Exception as e:
|
217 |
message = f"Error generating video ({str(e)})\n{traceback.format_exc()}"
|