multimodalart HF Staff commited on
Commit
a314de9
·
verified ·
1 Parent(s): afdfe21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -13
app.py CHANGED
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
21
  # This dictionary will store the manual patches extracted by the converter
22
  MANUAL_PATCHES_STORE = {}
23
 
24
- def _custom_convert_non_diffusers_wan_lora_to_diffusers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
25
  """
26
  Custom converter for Wan 2.1 T2V LoRA.
27
  Separates LoRA A/B weights for PEFT and diff_b/diff for manual patching.
@@ -192,7 +192,7 @@ def _custom_convert_non_diffusers_wan_lora_to_diffusers(state_dict: Dict[str, to
192
  return final_peft_state_dict
193
 
194
 
195
- def apply_manual_diff_patches(pipe_model: torch.nn.Module, patches: Dict[str, torch.Tensor]):
196
  """
197
  Manually applies diff_b/diff patches to the model.
198
  Assumes PEFT LoRA layers have already been loaded.
@@ -307,20 +307,9 @@ logger.info(f"Downloading LoRA {LORA_FILENAME} from {LORA_REPO_ID}...")
307
  causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
308
 
309
  logger.info("Loading LoRA weights with custom converter...")
310
- # The load_lora_weights will use the lora_converters mechanism if available.
311
- # We need to ensure our custom converter is registered or passed correctly.
312
- # Since WanPipeline inherits from a LoraLoaderMixin that might have its own
313
- # lora_state_dict, we need to be careful.
314
- # A robust way is to load the state_dict, convert it, then load the converted dict.
315
 
316
  lora_state_dict_raw = WanPipeline.lora_state_dict(causvid_path) # This might already do some conversion
317
 
318
- # If WanPipeline.lora_state_dict doesn't directly call our specific wan converter,
319
- # we might need to load the raw safetensors file and then call our converter.
320
- # Let's assume for now that lora_state_dict loads it and we then pass it to our converter.
321
- # If WanPipeline's lora_state_dict already calls a wan-specific converter,
322
- # then we need to inject our custom one there, which is not possible without modifying the library.
323
-
324
  # Alternative: Load raw state_dict and then convert
325
  from safetensors.torch import load_file as load_safetensors
326
  raw_lora_state_dict = load_safetensors(causvid_path)
 
21
  # This dictionary will store the manual patches extracted by the converter
22
  MANUAL_PATCHES_STORE = {}
23
 
24
+ def _custom_convert_non_diffusers_wan_lora_to_diffusers(state_dict):
25
  """
26
  Custom converter for Wan 2.1 T2V LoRA.
27
  Separates LoRA A/B weights for PEFT and diff_b/diff for manual patching.
 
192
  return final_peft_state_dict
193
 
194
 
195
+ def apply_manual_diff_patches(pipe_model, patches):
196
  """
197
  Manually applies diff_b/diff patches to the model.
198
  Assumes PEFT LoRA layers have already been loaded.
 
307
  causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
308
 
309
  logger.info("Loading LoRA weights with custom converter...")
 
 
 
 
 
310
 
311
  lora_state_dict_raw = WanPipeline.lora_state_dict(causvid_path) # This might already do some conversion
312
 
 
 
 
 
 
 
313
  # Alternative: Load raw state_dict and then convert
314
  from safetensors.torch import load_file as load_safetensors
315
  raw_lora_state_dict = load_safetensors(causvid_path)