jbilcke-hf HF staff commited on
Commit
0b910bc
·
verified ·
1 Parent(s): 306dec1

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +44 -1
handler.py CHANGED
@@ -207,6 +207,11 @@ class GenerationConfig:
207
  # Enhance-A-Video settings
208
  enable_enhance_a_video: bool = True
209
  enhance_a_video_weight: float = 4.0
 
 
 
 
 
210
 
211
  def validate_and_adjust(self) -> 'GenerationConfig':
212
  """Validate and adjust parameters to meet constraints"""
@@ -260,6 +265,9 @@ class EndpointHandler:
260
  torch_dtype=torch.bfloat16
261
  ).to("cuda")
262
 
 
 
 
263
  # Enable CPU offload for memory efficiency
264
  #self.text_to_video.enable_model_cpu_offload()
265
  #self.image_to_video.enable_model_cpu_offload()
@@ -393,7 +401,9 @@ class EndpointHandler:
393
  - teacache_threshold (optional, float, default to 0.05): Amount of cache, 0 (original), 0.03 (1.6x speedup), 0.05 (Default, 2.1x speedup).
394
  - enable_enhance_a_video (optional, bool, default to True): enable the enhance_a_video optimization
395
  - enhance_a_video_weight(optional, float, default to 4.0): amount of video enhancement to apply
396
-
 
 
397
  Returns:
398
  Dictionary containing:
399
  - video: Base64 encoded MP4 data URI
@@ -450,6 +460,11 @@ class EndpointHandler:
450
  # Add enhance-a-video settings
451
  enable_enhance_a_video=params.get("enable_enhance_a_video", True),
452
  enhance_a_video_weight=params.get("enhance_a_video_weight", 4.0),
 
 
 
 
 
453
  ).validate_and_adjust()
454
 
455
  #logger.debug(f"Global request settings:")
@@ -494,6 +509,34 @@ class EndpointHandler:
494
  }
495
  #logger.info(f"Video model generation settings:")
496
  #pprint.pprint(generation_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
 
498
  # Check if image-to-video generation is requested
499
  if input_image:
 
207
  # Enhance-A-Video settings
208
  enable_enhance_a_video: bool = True
209
  enhance_a_video_weight: float = 4.0
210
+
211
+ # LoRA settings
212
+ lora_model_name: str = "" # HuggingFace repo ID or path to LoRA model
213
+ lora_model_weight_file: str = "" # Specific weight file to load from the LoRA model
214
+ lora_model_trigger: str = "" # Optional trigger word to prepend to the prompt
215
 
216
  def validate_and_adjust(self) -> 'GenerationConfig':
217
  """Validate and adjust parameters to meet constraints"""
 
265
  torch_dtype=torch.bfloat16
266
  ).to("cuda")
267
 
268
+ # Initialize LoRA tracking
269
+ self._current_lora_model = None
270
+
271
  # Enable CPU offload for memory efficiency
272
  #self.text_to_video.enable_model_cpu_offload()
273
  #self.image_to_video.enable_model_cpu_offload()
 
401
  - teacache_threshold (optional, float, default to 0.05): Amount of cache, 0 (original), 0.03 (1.6x speedup), 0.05 (Default, 2.1x speedup).
402
  - enable_enhance_a_video (optional, bool, default to True): enable the enhance_a_video optimization
403
  - enhance_a_video_weight(optional, float, default to 4.0): amount of video enhancement to apply
404
+ - lora_model_name(optional, str, default to ""): HuggingFace repo ID or path to LoRA model
405
+ - lora_model_weight_file(optional, str, default to ""): Specific weight file to load from the LoRA model
406
+ - lora_model_trigger(optional, str, default to ""): Optional trigger word to prepend to the prompt
407
  Returns:
408
  Dictionary containing:
409
  - video: Base64 encoded MP4 data URI
 
460
  # Add enhance-a-video settings
461
  enable_enhance_a_video=params.get("enable_enhance_a_video", True),
462
  enhance_a_video_weight=params.get("enhance_a_video_weight", 4.0),
463
+
464
+ # LoRA settings
465
+ lora_model_name=params.get("lora_model_name", ""),
466
+ lora_model_weight_file=params.get("lora_model_weight_file", ""),
467
+ lora_model_trigger=params.get("lora_model_trigger", ""),
468
  ).validate_and_adjust()
469
 
470
  #logger.debug(f"Global request settings:")
 
509
  }
510
  #logger.info(f"Video model generation settings:")
511
  #pprint.pprint(generation_kwargs)
512
+
513
+ # Handle LoRA loading/unloading
514
+ if hasattr(self, '_current_lora_model'):
515
+ if self._current_lora_model != (config.lora_model_name, config.lora_model_weight_file):
516
+ # Unload previous LoRA if it exists and is different
517
+ if hasattr(self.text_to_video, 'unload_lora_weights'):
518
+ self.text_to_video.unload_lora_weights()
519
+ if hasattr(self.image_to_video, 'unload_lora_weights'):
520
+ self.image_to_video.unload_lora_weights()
521
+
522
+ if config.lora_model_name:
523
+ # Load new LoRA
524
+ if hasattr(self.text_to_video, 'load_lora_weights'):
525
+ self.text_to_video.load_lora_weights(
526
+ config.lora_model_name,
527
+ weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None
528
+ )
529
+ if hasattr(self.image_to_video, 'load_lora_weights'):
530
+ self.image_to_video.load_lora_weights(
531
+ config.lora_model_name,
532
+ weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None
533
+ )
534
+ self._current_lora_model = (config.lora_model_name, config.lora_model_weight_file)
535
+
536
+ # Modify prompt if trigger word is provided
537
+ if config.lora_model_trigger:
538
+ generation_kwargs["prompt"] = f"{config.lora_model_trigger} {generation_kwargs['prompt']}"
539
+
540
 
541
  # Check if image-to-video generation is requested
542
  if input_image: