jbilcke-hf HF staff commited on
Commit
d741ed4
·
verified ·
1 Parent(s): e238883

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +104 -140
handler.py CHANGED
@@ -5,12 +5,11 @@ import logging
5
  import random
6
  import traceback
7
  import torch
8
- from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
 
 
9
  from varnish import Varnish
10
 
11
- from enhance_a_video import enable_enhance, inject_enhance_for_hunyuanvideo, set_enhance_weight
12
- from teacache import enable_teacache, disable_teacache
13
-
14
  # Configure logging
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
@@ -20,14 +19,14 @@ class GenerationConfig:
20
  """Configuration for video generation"""
21
  # Content settings
22
  prompt: str
23
- negative_prompt: str = ""
24
 
25
  # Model settings
26
- num_frames: int = 49 # Should be 4k + 1 format
27
- height: int = 320
28
- width: int = 576
29
- num_inference_steps: int = 50
30
- guidance_scale: float = 7.0
31
 
32
  # Reproducibility
33
  seed: int = -1
@@ -44,21 +43,18 @@ class GenerationConfig:
44
  audio_prompt: str = ""
45
  audio_negative_prompt: str = "voices, voice, talking, speaking, speech"
46
 
47
- # TeaCache settings
48
- enable_teacache: bool = True
49
- teacache_threshold: float = 0.15 # values: 0 (original), 0.1 (1.6x speedup), 0.15 (2.1x speedup)
50
-
51
-
52
- # Enhance-A-Video settings
53
- enable_enhance_a_video: bool = True
54
- enhance_a_video_weight: float = 4.0
 
55
 
56
  def validate_and_adjust(self) -> 'GenerationConfig':
57
  """Validate and adjust parameters"""
58
- # Ensure num_frames follows 4k + 1 format
59
- k = (self.num_frames - 1) // 4
60
- self.num_frames = (k * 4) + 1
61
-
62
  # Set random seed if not specified
63
  if self.seed == -1:
64
  self.seed = random.randint(0, 2**32 - 1)
@@ -66,7 +62,7 @@ class GenerationConfig:
66
  return self
67
 
68
  class EndpointHandler:
69
- """Handles video generation requests using HunyuanVideo and Varnish"""
70
 
71
  def __init__(self, path: str = ""):
72
  """Initialize handler with models
@@ -76,32 +72,20 @@ class EndpointHandler:
76
  """
77
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
78
 
79
-
80
- # Initialize transformer with Enhance-A-Video injection first
81
- transformer = HunyuanVideoTransformer3DModel.from_pretrained(
82
- path,
83
- subfolder="transformer",
84
- torch_dtype=torch.bfloat16
 
 
 
 
 
 
 
85
  )
86
- inject_enhance_for_hunyuanvideo(transformer)
87
-
88
- # Initialize HunyuanVideo pipeline with the enhanced transformer
89
- self.pipeline = HunyuanVideoPipeline.from_pretrained(
90
- path,
91
- transformer=transformer,
92
- torch_dtype=torch.float16,
93
- ).to(self.device)
94
-
95
-
96
- # Initialize text encoders in float16
97
- self.pipeline.text_encoder = self.pipeline.text_encoder.half()
98
- self.pipeline.text_encoder_2 = self.pipeline.text_encoder_2.half()
99
-
100
- # Initialize transformer in bfloat16
101
- self.pipeline.transformer = self.pipeline.transformer.to(torch.bfloat16)
102
-
103
- # Initialize VAE in float16
104
- self.pipeline.vae = self.pipeline.vae.half()
105
 
106
  # Initialize Varnish for post-processing
107
  self.varnish = Varnish(
@@ -136,11 +120,11 @@ class EndpointHandler:
136
  config = GenerationConfig(
137
  prompt=prompt,
138
  negative_prompt=params.get("negative_prompt", ""),
139
- num_frames=params.get("num_frames", 49),
140
- height=params.get("height", 320),
141
- width=params.get("width", 576),
142
- num_inference_steps=params.get("num_inference_steps", 50),
143
- guidance_scale=params.get("guidance_scale", 7.0),
144
  seed=params.get("seed", -1),
145
  fps=params.get("fps", 30),
146
  double_num_frames=params.get("double_num_frames", False),
@@ -150,13 +134,14 @@ class EndpointHandler:
150
  enable_audio=params.get("enable_audio", False),
151
  audio_prompt=params.get("audio_prompt", ""),
152
  audio_negative_prompt=params.get("audio_negative_prompt", "voices, voice, talking, speaking, speech"),
153
- enable_teacache=params.get("enable_teacache", True),
154
-
155
- # values: 0 (original), 0.1 (1.6x speedup), 0.15 (2.1x speedup).
156
- teacache_threshold=params.get("teacache_threshold", 0.15),
157
-
158
- enable_enhance_a_video=params.get("enable_enhance_a_video", True),
159
- enhance_a_video_weight=params.get("enhance_a_video_weight", 4.0)
 
160
  ).validate_and_adjust()
161
 
162
  try:
@@ -164,90 +149,69 @@ class EndpointHandler:
164
  if config.seed != -1:
165
  torch.manual_seed(config.seed)
166
  random.seed(config.seed)
167
- generator = torch.Generator(device=self.device).manual_seed(config.seed)
168
- else:
169
- generator = None
170
-
171
- # Configure TeaCache
172
- #if config.enable_teacache:
173
- # enable_teacache(
174
- # self.pipeline.transformer,
175
- # num_inference_steps=config.num_inference_steps,
176
- # rel_l1_thresh=config.teacache_threshold
177
- # )
178
- #else:
179
- # disable_teacache(self.pipeline.transformer)
180
-
181
- # Configure Enhance-A-Video weight if enabled
182
- if config.enable_enhance_a_video:
183
- set_enhance_weight(config.enhance_a_video_weight)
184
- enable_enhance()
185
- else:
186
- # Reset enhance weight to 0 to effectively disable it
187
- set_enhance_weight(0)
188
-
189
- # Generate video frames
190
- with torch.inference_mode():
191
- output = self.pipeline(
192
- prompt=config.prompt,
193
-
194
- # Failed to generate video: HunyuanVideoPipeline.__call__() got an unexpected keyword argument 'negative_prompt'
195
- #negative_prompt=config.negative_prompt,
196
-
197
- num_frames=config.num_frames,
198
- height=config.height,
199
- width=config.width,
200
- num_inference_steps=config.num_inference_steps,
201
- guidance_scale=config.guidance_scale,
202
- generator=generator,
203
- output_type="pt",
204
- ).frames
205
-
206
- # Process with Varnish
207
- import asyncio
208
- try:
209
- loop = asyncio.get_event_loop()
210
- except RuntimeError:
211
- loop = asyncio.new_event_loop()
212
- asyncio.set_event_loop(loop)
213
-
214
- result = loop.run_until_complete(
215
- self.varnish(
216
- input_data=output,
217
- fps=config.fps,
218
- double_num_frames=config.double_num_frames,
219
- super_resolution=config.super_resolution,
220
- grain_amount=config.grain_amount,
221
- enable_audio=config.enable_audio,
222
- audio_prompt=config.audio_prompt,
223
- audio_negative_prompt=config.audio_negative_prompt,
224
- )
225
- )
226
 
227
- # Get video data URI
228
- video_uri = loop.run_until_complete(
229
- result.write(
230
- type="data-uri",
231
- quality=config.quality
232
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  )
 
234
 
235
- return {
236
- "video": video_uri,
237
- "content-type": "video/mp4",
238
- "metadata": {
239
- "width": result.metadata.width,
240
- "height": result.metadata.height,
241
- "num_frames": result.metadata.frame_count,
242
- "fps": result.metadata.fps,
243
- "duration": result.metadata.duration,
244
- "seed": config.seed,
245
- "enable_teacache": config.enable_teacache,
246
- "teacache_threshold": config.teacache_threshold if config.enable_teacache else 0,
247
- "enable_enhance_a_video": config.enable_enhance_a_video,
248
- "enhance_a_video_weight": config.enhance_a_video_weight if config.enable_enhance_a_video else 0,
249
- }
 
 
 
 
 
 
 
250
  }
 
251
 
252
  except Exception as e:
253
  message = f"Error generating video ({str(e)})\n{traceback.format_exc()}"
 
5
  import random
6
  import traceback
7
  import torch
8
+ from skyreelsinfer import TaskType
9
+ from skyreelsinfer.offload import OffloadConfig
10
+ from skyreelsinfer.skyreels_video_infer import SkyReelsVideoInfer
11
  from varnish import Varnish
12
 
 
 
 
13
  # Configure logging
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
 
19
  """Configuration for video generation"""
20
  # Content settings
21
  prompt: str
22
+ negative_prompt: str = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
23
 
24
  # Model settings
25
+ num_frames: int = 97 # SkyReels default
26
+ height: int = 544 # SkyReels default
27
+ width: int = 960 # SkyReels default
28
+ num_inference_steps: int = 30
29
+ guidance_scale: float = 6.0
30
 
31
  # Reproducibility
32
  seed: int = -1
 
43
  audio_prompt: str = ""
44
  audio_negative_prompt: str = "voices, voice, talking, speaking, speech"
45
 
46
+ # Model-specific settings
47
+ embedded_guidance_scale: float = 1.0
48
+ quant_model: bool = True
49
+ gpu_num: int = 1
50
+ offload: bool = True
51
+ high_cpu_memory: bool = True
52
+ parameters_level: bool = False
53
+ compiler_transformer: bool = False
54
+ sequence_batch: bool = False
55
 
56
  def validate_and_adjust(self) -> 'GenerationConfig':
57
  """Validate and adjust parameters"""
 
 
 
 
58
  # Set random seed if not specified
59
  if self.seed == -1:
60
  self.seed = random.randint(0, 2**32 - 1)
 
62
  return self
63
 
64
  class EndpointHandler:
65
+ """Handles video generation requests using SkyReels and Varnish"""
66
 
67
  def __init__(self, path: str = ""):
68
  """Initialize handler with models
 
72
  """
73
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
74
 
75
+ # Initialize SkyReelsVideoInfer
76
+ self.predictor = SkyReelsVideoInfer(
77
+ task_type=TaskType.T2V,
78
+ model_id=path or "Skywork/SkyReels-V1",
79
+ quant_model=True, # Enable quantization by default
80
+ world_size=1, # Single GPU by default
81
+ is_offload=True, # Enable offloading by default
82
+ offload_config=OffloadConfig(
83
+ high_cpu_memory=True,
84
+ parameters_level=False,
85
+ compiler_transformer=False,
86
+ ),
87
+ enable_cfg_parallel=True
88
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  # Initialize Varnish for post-processing
91
  self.varnish = Varnish(
 
120
  config = GenerationConfig(
121
  prompt=prompt,
122
  negative_prompt=params.get("negative_prompt", ""),
123
+ num_frames=params.get("num_frames", 97),
124
+ height=params.get("height", 544),
125
+ width=params.get("width", 960),
126
+ num_inference_steps=params.get("num_inference_steps", 30),
127
+ guidance_scale=params.get("guidance_scale", 6.0),
128
  seed=params.get("seed", -1),
129
  fps=params.get("fps", 30),
130
  double_num_frames=params.get("double_num_frames", False),
 
134
  enable_audio=params.get("enable_audio", False),
135
  audio_prompt=params.get("audio_prompt", ""),
136
  audio_negative_prompt=params.get("audio_negative_prompt", "voices, voice, talking, speaking, speech"),
137
+ embedded_guidance_scale=params.get("embedded_guidance_scale", 1.0),
138
+ quant_model=params.get("quant_model", True),
139
+ gpu_num=params.get("gpu_num", 1),
140
+ offload=params.get("offload", True),
141
+ high_cpu_memory=params.get("high_cpu_memory", True),
142
+ parameters_level=params.get("parameters_level", False),
143
+ compiler_transformer=params.get("compiler_transformer", False),
144
+ sequence_batch=params.get("sequence_batch", False)
145
  ).validate_and_adjust()
146
 
147
  try:
 
149
  if config.seed != -1:
150
  torch.manual_seed(config.seed)
151
  random.seed(config.seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
+ # Prepare generation parameters
154
+ generation_kwargs = {
155
+ "prompt": f"FPS-{config.fps}, {config.prompt}", # SkyReels expects FPS in prompt
156
+ "negative_prompt": config.negative_prompt,
157
+ "height": config.height,
158
+ "width": config.width,
159
+ "num_frames": config.num_frames,
160
+ "num_inference_steps": config.num_inference_steps,
161
+ "guidance_scale": config.guidance_scale,
162
+ "embedded_guidance_scale": config.embedded_guidance_scale,
163
+ "seed": config.seed,
164
+ "cfg_for": config.sequence_batch
165
+ }
166
+
167
+ # Generate video frames using SkyReels
168
+ output = self.predictor.inference(generation_kwargs)
169
+
170
+ # Process with Varnish
171
+ import asyncio
172
+ try:
173
+ loop = asyncio.get_event_loop()
174
+ except RuntimeError:
175
+ loop = asyncio.new_event_loop()
176
+ asyncio.set_event_loop(loop)
177
+
178
+ result = loop.run_until_complete(
179
+ self.varnish(
180
+ input_data=output,
181
+ fps=config.fps,
182
+ double_num_frames=config.double_num_frames,
183
+ super_resolution=config.super_resolution,
184
+ grain_amount=config.grain_amount,
185
+ enable_audio=config.enable_audio,
186
+ audio_prompt=config.audio_prompt,
187
+ audio_negative_prompt=config.audio_negative_prompt,
188
  )
189
+ )
190
 
191
+ # Get video data URI
192
+ video_uri = loop.run_until_complete(
193
+ result.write(
194
+ type="data-uri",
195
+ quality=config.quality
196
+ )
197
+ )
198
+
199
+ return {
200
+ "video": video_uri,
201
+ "content-type": "video/mp4",
202
+ "metadata": {
203
+ "width": result.metadata.width,
204
+ "height": result.metadata.height,
205
+ "num_frames": result.metadata.frame_count,
206
+ "fps": result.metadata.fps,
207
+ "duration": result.metadata.duration,
208
+ "seed": config.seed,
209
+ "gpu_num": config.gpu_num,
210
+ "quant_model": config.quant_model,
211
+ "guidance_scale": config.guidance_scale,
212
+ "embedded_guidance_scale": config.embedded_guidance_scale
213
  }
214
+ }
215
 
216
  except Exception as e:
217
  message = f"Error generating video ({str(e)})\n{traceback.format_exc()}"