jbilcke-hf HF Staff commited on
Commit
3b5290e
·
verified ·
1 Parent(s): fa8e04c

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +199 -17
handler.py CHANGED
@@ -1,14 +1,31 @@
1
  from typing import Dict, Any
2
  import os
3
  import shutil
4
- from pathlib import Path
5
  import time
6
- from datetime import datetime
7
  import argparse
 
8
  from loguru import logger
 
 
 
9
  from hyvideo.utils.file_utils import save_videos_grid
10
  from hyvideo.inference import HunyuanVideoSampler
11
- from hyvideo.constants import NEGATIVE_PROMPT
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Configure logger
14
  logger.add("handler_debug.log", rotation="500 MB")
@@ -16,10 +33,13 @@ logger.add("handler_debug.log", rotation="500 MB")
16
  DEFAULT_RESOLUTION = "720p"
17
  DEFAULT_WIDTH = 1280
18
  DEFAULT_HEIGHT = 720
19
- DEFAULT_NB_FRAMES = (4 * 30) + 1 # or 129 (note: hunyan requires an extra +1 frame)
20
- DEFAULT_NB_STEPS = 22 # or 50
21
  DEFAULT_FPS = 24
22
 
 
 
 
23
  def setup_vae_path(vae_path: Path) -> Path:
24
  """Create a temporary directory with correctly named VAE config file"""
25
  tmp_vae_dir = Path("/tmp/vae")
@@ -124,14 +144,72 @@ def get_default_args():
124
  parser.add_argument("--ulysses-degree", type=int, default=1)
125
  parser.add_argument("--ring-degree", type=int, default=1)
126
 
 
 
 
 
 
 
 
 
 
 
127
  # Parse with empty args list to avoid reading sys.argv
128
  args = parser.parse_args([])
129
 
130
  return args
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  class EndpointHandler:
133
  def __init__(self, path: str = ""):
134
- """Initialize the handler with model path and default config."""
135
  logger.info(f"Initializing EndpointHandler with path: {path}")
136
 
137
  # Use default args instead of parsing from command line
@@ -144,14 +222,22 @@ class EndpointHandler:
144
  # Set up model paths
145
  self.args.model_base = path
146
 
147
- # Set paths for model components
148
- dit_weight_path = Path(path) / "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt"
149
- original_vae_path = Path(path) / "hunyuan-video-t2v-720p/vae"
150
-
151
- # to save on memory, we activate fp8 weights and we override the previous dit_weight_path setting
 
 
 
 
 
152
  self.args.use_fp8 = True
 
 
153
  dit_weight_path = Path(path) / "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt"
154
-
 
155
  # Log all critical paths
156
  logger.info(f"Model base path: {self.args.model_base}")
157
  logger.info(f"DiT weight path: {dit_weight_path}")
@@ -170,7 +256,6 @@ class EndpointHandler:
170
  tmp_vae_path = setup_vae_path(original_vae_path)
171
 
172
  # Override the VAE path in constants to use our temporary directory
173
- from hyvideo.constants import VAE_PATH, TEXT_ENCODER_PATH, TOKENIZER_PATH
174
  VAE_PATH["884-16c-hy"] = str(tmp_vae_path)
175
  logger.info(f"Updated VAE_PATH to: {VAE_PATH['884-16c-hy']}")
176
 
@@ -196,16 +281,83 @@ class EndpointHandler:
196
  logger.info(f"TOKENIZER_PATH['clipL']: {TOKENIZER_PATH['clipL']}")
197
 
198
  self.args.dit_weight = str(dit_weight_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
- # Initialize model
201
- models_root_path = Path(path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  if not models_root_path.exists():
203
  raise ValueError(f"models_root_path does not exist: {models_root_path}")
204
 
205
  try:
206
  logger.info("Attempting to initialize HunyuanVideoSampler...")
 
 
 
 
207
  self.model = HunyuanVideoSampler.from_pretrained(models_root_path, args=self.args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  logger.info("Successfully initialized HunyuanVideoSampler")
 
209
  except Exception as e:
210
  logger.error(f"Error initializing model: {str(e)}")
211
  raise
@@ -232,12 +384,27 @@ class EndpointHandler:
232
  guidance_scale = float(data.pop("guidance_scale", 1.0))
233
  flow_shift = float(data.pop("flow_shift", 7.0))
234
  embedded_guidance_scale = float(data.pop("embedded_guidance_scale", 6.0))
 
235
 
236
  logger.info(f"Processing with parameters: width={width}, height={height}, "
237
  f"video_length={video_length}, seed={seed}, "
238
  f"num_inference_steps={num_inference_steps}")
239
 
240
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  # Run inference
242
  outputs = self.model.predict(
243
  prompt=prompt,
@@ -251,7 +418,8 @@ class EndpointHandler:
251
  num_videos_per_prompt=1,
252
  flow_shift=flow_shift,
253
  batch_size=1,
254
- embedded_guidance_scale=embedded_guidance_scale
 
255
  )
256
 
257
  # Get the video tensor
@@ -265,7 +433,6 @@ class EndpointHandler:
265
  # Read video file and convert to base64
266
  with open(temp_path, "rb") as f:
267
  video_bytes = f.read()
268
- import base64
269
  video_base64 = base64.b64encode(video_bytes).decode()
270
 
271
  # Add MP4 data URI prefix
@@ -274,10 +441,25 @@ class EndpointHandler:
274
  # Cleanup
275
  os.remove(temp_path)
276
 
 
 
 
 
 
 
 
277
  logger.info("Successfully generated and encoded video")
278
 
279
  return video_data_uri
280
 
281
  except Exception as e:
282
  logger.error(f"Error during video generation: {str(e)}")
 
 
 
 
 
 
 
 
283
  raise
 
1
  from typing import Dict, Any
2
  import os
3
  import shutil
4
+ import gc
5
  import time
6
+ from pathlib import Path
7
  import argparse
8
+ from datetime import datetime
9
  from loguru import logger
10
+ import torch
11
+ import base64
12
+
13
  from hyvideo.utils.file_utils import save_videos_grid
14
  from hyvideo.inference import HunyuanVideoSampler
15
+ from hyvideo.constants import NEGATIVE_PROMPT, VAE_PATH, TEXT_ENCODER_PATH, TOKENIZER_PATH
16
+ from hyvideo.modules.attenion import get_attention_modes
17
+
18
+ try:
19
+ import triton
20
+ has_triton = True
21
+ except ImportError:
22
+ has_triton = False
23
+
24
+ try:
25
+ from mmgp import offload, safetensors2, profile_type
26
+ has_mmgp = True
27
+ except ImportError:
28
+ has_mmgp = False
29
 
30
  # Configure logger
31
  logger.add("handler_debug.log", rotation="500 MB")
 
33
  DEFAULT_RESOLUTION = "720p"
34
  DEFAULT_WIDTH = 1280
35
  DEFAULT_HEIGHT = 720
36
+ DEFAULT_NB_FRAMES = (4 * 30) + 1 # or 129 (note: hunyan requires an extra +1 frame)
37
+ DEFAULT_NB_STEPS = 22 # Default for standard model
38
  DEFAULT_FPS = 24
39
 
40
+ # Get supported attention modes
41
+ attention_modes_supported = get_attention_modes()
42
+
43
  def setup_vae_path(vae_path: Path) -> Path:
44
  """Create a temporary directory with correctly named VAE config file"""
45
  tmp_vae_dir = Path("/tmp/vae")
 
144
  parser.add_argument("--ulysses-degree", type=int, default=1)
145
  parser.add_argument("--ring-degree", type=int, default=1)
146
 
147
+ # Added from gradio server
148
+ parser.add_argument("--attention", type=str, default="auto",
149
+ choices=["auto", "sdpa", "flash", "sage", "sage2", "xformers"])
150
+ parser.add_argument("--profile", type=int, default=1) # HighRAM_HighVRAM
151
+ parser.add_argument("--quantize-transformer", action="store_true", default=False)
152
+ parser.add_argument("--tea-cache", type=float, default=0.0)
153
+ parser.add_argument("--compile", action="store_true", default=False)
154
+ parser.add_argument("--enable-riflex", action="store_true", default=True)
155
+ parser.add_argument("--vae-config", type=int, default=0)
156
+
157
  # Parse with empty args list to avoid reading sys.argv
158
  args = parser.parse_args([])
159
 
160
  return args
161
 
162
+ def get_auto_attention():
163
+ """Select the best available attention mode"""
164
+ for attn in ["sage2", "sage", "sdpa"]:
165
+ if attn in attention_modes_supported:
166
+ return attn
167
+ return "sdpa"
168
+
169
+ def setup_vae_config(device_mem_capacity, vae, vae_config=0):
170
+ """Configure VAE tiling based on available VRAM"""
171
+ if vae_config == 0:
172
+ # Auto-select based on VRAM
173
+ if device_mem_capacity >= 24000:
174
+ use_vae_config = 1
175
+ elif device_mem_capacity >= 16000:
176
+ use_vae_config = 3
177
+ elif device_mem_capacity >= 12000:
178
+ use_vae_config = 4
179
+ else:
180
+ use_vae_config = 5
181
+ else:
182
+ use_vae_config = vae_config
183
+
184
+ # VAE tiling configuration options
185
+ if use_vae_config == 1:
186
+ sample_tsize = 32
187
+ sample_size = 256
188
+ elif use_vae_config == 2:
189
+ sample_tsize = 64
190
+ sample_size = 192
191
+ elif use_vae_config == 3:
192
+ sample_tsize = 32
193
+ sample_size = 192
194
+ elif use_vae_config == 4:
195
+ sample_tsize = 16
196
+ sample_size = 256
197
+ else:
198
+ sample_tsize = 16
199
+ sample_size = 192
200
+
201
+ # Apply settings
202
+ vae.tile_sample_min_tsize = sample_tsize
203
+ vae.tile_latent_min_tsize = sample_tsize // vae.time_compression_ratio
204
+ vae.tile_sample_min_size = sample_size
205
+ vae.tile_latent_min_size = int(sample_size / (2 ** (len(vae.config.block_out_channels) - 1)))
206
+ vae.tile_overlap_factor = 0.25
207
+
208
+ return use_vae_config
209
+
210
  class EndpointHandler:
211
  def __init__(self, path: str = ""):
212
+ """Initialize the handler with model path and config."""
213
  logger.info(f"Initializing EndpointHandler with path: {path}")
214
 
215
  # Use default args instead of parsing from command line
 
222
  # Set up model paths
223
  self.args.model_base = path
224
 
225
+ # Model configurations
226
+ self.init_model_paths(path)
227
+ self.configure_model()
228
+
229
+ # Initialize model
230
+ self.initialize_model()
231
+
232
+ def init_model_paths(self, path):
233
+ """Setup paths for model components"""
234
+ # We'll use the FP8 model for memory efficiency
235
  self.args.use_fp8 = True
236
+
237
+ # Model component paths
238
  dit_weight_path = Path(path) / "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt"
239
+ original_vae_path = Path(path) / "hunyuan-video-t2v-720p/vae"
240
+
241
  # Log all critical paths
242
  logger.info(f"Model base path: {self.args.model_base}")
243
  logger.info(f"DiT weight path: {dit_weight_path}")
 
256
  tmp_vae_path = setup_vae_path(original_vae_path)
257
 
258
  # Override the VAE path in constants to use our temporary directory
 
259
  VAE_PATH["884-16c-hy"] = str(tmp_vae_path)
260
  logger.info(f"Updated VAE_PATH to: {VAE_PATH['884-16c-hy']}")
261
 
 
281
  logger.info(f"TOKENIZER_PATH['clipL']: {TOKENIZER_PATH['clipL']}")
282
 
283
  self.args.dit_weight = str(dit_weight_path)
284
+
285
+ def configure_model(self):
286
+ """Configure model based on available hardware and settings"""
287
+ # Set attention mode (auto-select best available if set to 'auto')
288
+ if self.args.attention == "auto":
289
+ self.attention_mode = get_auto_attention()
290
+ elif self.args.attention in attention_modes_supported:
291
+ self.attention_mode = self.args.attention
292
+ else:
293
+ logger.warning(f"Attention mode {self.args.attention} not supported. Falling back to sdpa.")
294
+ self.attention_mode = "sdpa"
295
+
296
+ logger.info(f"Using attention mode: {self.attention_mode}")
297
 
298
+ # Set compilation flag based on Triton availability
299
+ if self.args.compile and not has_triton:
300
+ logger.warning("Compilation requested but Triton not available. Compilation disabled.")
301
+ self.args.compile = False
302
+
303
+ # Set profile based on memory configuration
304
+ # We default to HighRAM_HighVRAM (1) as specified
305
+ if has_mmgp:
306
+ self.profile = self.args.profile
307
+ logger.info(f"Using memory profile: {self.profile}")
308
+ else:
309
+ logger.warning("MMGP not available. Memory profiles not used.")
310
+
311
+ def initialize_model(self):
312
+ """Initialize the model with configured settings"""
313
+ models_root_path = Path(self.args.model_base)
314
  if not models_root_path.exists():
315
  raise ValueError(f"models_root_path does not exist: {models_root_path}")
316
 
317
  try:
318
  logger.info("Attempting to initialize HunyuanVideoSampler...")
319
+
320
+ # Apply attention mode setting
321
+ self.args.attention = self.attention_mode
322
+
323
  self.model = HunyuanVideoSampler.from_pretrained(models_root_path, args=self.args)
324
+
325
+ # Set attention mode for transformer blocks
326
+ if hasattr(self.model, 'pipeline') and hasattr(self.model.pipeline, 'transformer'):
327
+ transformer = self.model.pipeline.transformer
328
+ transformer.attention_mode = self.attention_mode
329
+ # Apply to all blocks
330
+ if hasattr(transformer, 'double_blocks'):
331
+ for module in transformer.double_blocks:
332
+ module.attention_mode = self.attention_mode
333
+ if hasattr(transformer, 'single_blocks'):
334
+ for module in transformer.single_blocks:
335
+ module.attention_mode = self.attention_mode
336
+
337
+ # Enable compilation if requested
338
+ if self.args.compile:
339
+ transformer.any_compilation = True
340
+ logger.info("PyTorch compilation enabled for transformer")
341
+
342
+ # Enable TeaCache if requested
343
+ if self.args.tea_cache > 0:
344
+ transformer.enable_teacache = True
345
+ transformer.rel_l1_thresh = self.args.tea_cache
346
+ logger.info(f"TeaCache enabled with threshold: {self.args.tea_cache}")
347
+ else:
348
+ transformer.enable_teacache = False
349
+
350
+ # Apply VAE tiling configuration if supported
351
+ if hasattr(self.model, 'vae'):
352
+ if torch.cuda.is_available():
353
+ device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
354
+ vae_config = setup_vae_config(device_mem_capacity, self.model.vae, self.args.vae_config)
355
+ logger.info(f"Configured VAE tiling with config: {vae_config}")
356
+ else:
357
+ logger.warning("CUDA not available, using default VAE configuration")
358
+
359
  logger.info("Successfully initialized HunyuanVideoSampler")
360
+
361
  except Exception as e:
362
  logger.error(f"Error initializing model: {str(e)}")
363
  raise
 
384
  guidance_scale = float(data.pop("guidance_scale", 1.0))
385
  flow_shift = float(data.pop("flow_shift", 7.0))
386
  embedded_guidance_scale = float(data.pop("embedded_guidance_scale", 6.0))
387
+ enable_riflex = data.pop("enable_riflex", self.args.enable_riflex)
388
 
389
  logger.info(f"Processing with parameters: width={width}, height={height}, "
390
  f"video_length={video_length}, seed={seed}, "
391
  f"num_inference_steps={num_inference_steps}")
392
 
393
  try:
394
+ # Set up TeaCache for this generation if enabled
395
+ if hasattr(self.model.pipeline, 'transformer') and self.model.pipeline.transformer.enable_teacache:
396
+ transformer = self.model.pipeline.transformer
397
+ transformer.num_steps = num_inference_steps
398
+ transformer.cnt = 0
399
+ transformer.accumulated_rel_l1_distance = 0
400
+ transformer.previous_modulated_input = None
401
+ transformer.previous_residual = None
402
+
403
+ # Clean up memory before generation
404
+ gc.collect()
405
+ if torch.cuda.is_available():
406
+ torch.cuda.empty_cache()
407
+
408
  # Run inference
409
  outputs = self.model.predict(
410
  prompt=prompt,
 
418
  num_videos_per_prompt=1,
419
  flow_shift=flow_shift,
420
  batch_size=1,
421
+ embedded_guidance_scale=embedded_guidance_scale,
422
+ enable_riflex=enable_riflex
423
  )
424
 
425
  # Get the video tensor
 
433
  # Read video file and convert to base64
434
  with open(temp_path, "rb") as f:
435
  video_bytes = f.read()
 
436
  video_base64 = base64.b64encode(video_bytes).decode()
437
 
438
  # Add MP4 data URI prefix
 
441
  # Cleanup
442
  os.remove(temp_path)
443
 
444
+ # Clean up memory after generation
445
+ if has_mmgp and hasattr(offload, 'last_offload_obj'):
446
+ offload.last_offload_obj.unload_all()
447
+ gc.collect()
448
+ if torch.cuda.is_available():
449
+ torch.cuda.empty_cache()
450
+
451
  logger.info("Successfully generated and encoded video")
452
 
453
  return video_data_uri
454
 
455
  except Exception as e:
456
  logger.error(f"Error during video generation: {str(e)}")
457
+
458
+ # Clean up memory after error
459
+ if has_mmgp and hasattr(offload, 'last_offload_obj'):
460
+ offload.last_offload_obj.unload_all()
461
+ gc.collect()
462
+ if torch.cuda.is_available():
463
+ torch.cuda.empty_cache()
464
+
465
  raise