Update handler.py
Browse files- handler.py +44 -1
handler.py
CHANGED
@@ -207,6 +207,11 @@ class GenerationConfig:
|
|
207 |
# Enhance-A-Video settings
|
208 |
enable_enhance_a_video: bool = True
|
209 |
enhance_a_video_weight: float = 4.0
|
|
|
|
|
|
|
|
|
|
|
210 |
|
211 |
def validate_and_adjust(self) -> 'GenerationConfig':
|
212 |
"""Validate and adjust parameters to meet constraints"""
|
@@ -260,6 +265,9 @@ class EndpointHandler:
|
|
260 |
torch_dtype=torch.bfloat16
|
261 |
).to("cuda")
|
262 |
|
|
|
|
|
|
|
263 |
# Enable CPU offload for memory efficiency
|
264 |
#self.text_to_video.enable_model_cpu_offload()
|
265 |
#self.image_to_video.enable_model_cpu_offload()
|
@@ -393,7 +401,9 @@ class EndpointHandler:
|
|
393 |
- teacache_threshold (optional, float, default to 0.05): Amount of cache, 0 (original), 0.03 (1.6x speedup), 0.05 (Default, 2.1x speedup).
|
394 |
- enable_enhance_a_video (optional, bool, default to True): enable the enhance_a_video optimization
|
395 |
- enhance_a_video_weight(optional, float, default to 4.0): amount of video enhancement to apply
|
396 |
-
|
|
|
|
|
397 |
Returns:
|
398 |
Dictionary containing:
|
399 |
- video: Base64 encoded MP4 data URI
|
@@ -450,6 +460,11 @@ class EndpointHandler:
|
|
450 |
# Add enhance-a-video settings
|
451 |
enable_enhance_a_video=params.get("enable_enhance_a_video", True),
|
452 |
enhance_a_video_weight=params.get("enhance_a_video_weight", 4.0),
|
|
|
|
|
|
|
|
|
|
|
453 |
).validate_and_adjust()
|
454 |
|
455 |
#logger.debug(f"Global request settings:")
|
@@ -494,6 +509,34 @@ class EndpointHandler:
|
|
494 |
}
|
495 |
#logger.info(f"Video model generation settings:")
|
496 |
#pprint.pprint(generation_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
497 |
|
498 |
# Check if image-to-video generation is requested
|
499 |
if input_image:
|
|
|
207 |
# Enhance-A-Video settings
|
208 |
enable_enhance_a_video: bool = True
|
209 |
enhance_a_video_weight: float = 4.0
|
210 |
+
|
211 |
+
# LoRA settings
|
212 |
+
lora_model_name: str = "" # HuggingFace repo ID or path to LoRA model
|
213 |
+
lora_model_weight_file: str = "" # Specific weight file to load from the LoRA model
|
214 |
+
lora_model_trigger: str = "" # Optional trigger word to prepend to the prompt
|
215 |
|
216 |
def validate_and_adjust(self) -> 'GenerationConfig':
|
217 |
"""Validate and adjust parameters to meet constraints"""
|
|
|
265 |
torch_dtype=torch.bfloat16
|
266 |
).to("cuda")
|
267 |
|
268 |
+
# Initialize LoRA tracking
|
269 |
+
self._current_lora_model = None
|
270 |
+
|
271 |
# Enable CPU offload for memory efficiency
|
272 |
#self.text_to_video.enable_model_cpu_offload()
|
273 |
#self.image_to_video.enable_model_cpu_offload()
|
|
|
401 |
- teacache_threshold (optional, float, default to 0.05): Amount of cache, 0 (original), 0.03 (1.6x speedup), 0.05 (Default, 2.1x speedup).
|
402 |
- enable_enhance_a_video (optional, bool, default to True): enable the enhance_a_video optimization
|
403 |
- enhance_a_video_weight(optional, float, default to 4.0): amount of video enhancement to apply
|
404 |
+
- lora_model_name(optional, str, default to ""): HuggingFace repo ID or path to LoRA model
|
405 |
+
- lora_model_weight_file(optional, str, default to ""): Specific weight file to load from the LoRA model
|
406 |
+
- lora_model_trigger(optional, str, default to ""): Optional trigger word to prepend to the prompt
|
407 |
Returns:
|
408 |
Dictionary containing:
|
409 |
- video: Base64 encoded MP4 data URI
|
|
|
460 |
# Add enhance-a-video settings
|
461 |
enable_enhance_a_video=params.get("enable_enhance_a_video", True),
|
462 |
enhance_a_video_weight=params.get("enhance_a_video_weight", 4.0),
|
463 |
+
|
464 |
+
# LoRA settings
|
465 |
+
lora_model_name=params.get("lora_model_name", ""),
|
466 |
+
lora_model_weight_file=params.get("lora_model_weight_file", ""),
|
467 |
+
lora_model_trigger=params.get("lora_model_trigger", ""),
|
468 |
).validate_and_adjust()
|
469 |
|
470 |
#logger.debug(f"Global request settings:")
|
|
|
509 |
}
|
510 |
#logger.info(f"Video model generation settings:")
|
511 |
#pprint.pprint(generation_kwargs)
|
512 |
+
|
513 |
+
# Handle LoRA loading/unloading
|
514 |
+
if hasattr(self, '_current_lora_model'):
|
515 |
+
if self._current_lora_model != (config.lora_model_name, config.lora_model_weight_file):
|
516 |
+
# Unload previous LoRA if it exists and is different
|
517 |
+
if hasattr(self.text_to_video, 'unload_lora_weights'):
|
518 |
+
self.text_to_video.unload_lora_weights()
|
519 |
+
if hasattr(self.image_to_video, 'unload_lora_weights'):
|
520 |
+
self.image_to_video.unload_lora_weights()
|
521 |
+
|
522 |
+
if config.lora_model_name:
|
523 |
+
# Load new LoRA
|
524 |
+
if hasattr(self.text_to_video, 'load_lora_weights'):
|
525 |
+
self.text_to_video.load_lora_weights(
|
526 |
+
config.lora_model_name,
|
527 |
+
weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None
|
528 |
+
)
|
529 |
+
if hasattr(self.image_to_video, 'load_lora_weights'):
|
530 |
+
self.image_to_video.load_lora_weights(
|
531 |
+
config.lora_model_name,
|
532 |
+
weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None
|
533 |
+
)
|
534 |
+
self._current_lora_model = (config.lora_model_name, config.lora_model_weight_file)
|
535 |
+
|
536 |
+
# Modify prompt if trigger word is provided
|
537 |
+
if config.lora_model_trigger:
|
538 |
+
generation_kwargs["prompt"] = f"{config.lora_model_trigger} {generation_kwargs['prompt']}"
|
539 |
+
|
540 |
|
541 |
# Check if image-to-video generation is requested
|
542 |
if input_image:
|