multimodalart HF Staff commited on
Commit
9762ac2
·
verified ·
1 Parent(s): 1bb48a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -251
app.py CHANGED
@@ -22,277 +22,241 @@ logger = logging.getLogger(__name__)
22
  MANUAL_PATCHES_STORE = {}
23
 
24
  def _custom_convert_non_diffusers_wan_lora_to_diffusers(state_dict):
25
- """
26
- Custom converter for Wan 2.1 T2V LoRA.
27
- Separates LoRA A/B weights for PEFT and diff_b/diff for manual patching.
28
- Stores diff_b/diff in the global MANUAL_PATCHES_STORE.
29
- """
30
  global MANUAL_PATCHES_STORE
31
- MANUAL_PATCHES_STORE.clear() # Clear previous patches if any
 
 
32
 
33
- converted_state_dict_for_peft = {}
34
- manual_diff_patches = {}
35
 
36
- # Strip "diffusion_model." prefix
37
- original_state_dict = {
38
- k[len("diffusion_model.") :]: v
39
- for k, v in state_dict.items()
40
- if k.startswith("diffusion_model.")
41
- }
 
 
42
 
43
- # --- Determine number of blocks ---
44
  block_indices = set()
45
- for k_orig in original_state_dict:
46
- if "blocks." in k_orig:
47
  try:
48
- block_idx_str = k_orig.split("blocks.")[1].split(".")[0]
49
  if block_idx_str.isdigit():
50
  block_indices.add(int(block_idx_str))
51
- except (IndexError, ValueError) as e:
52
- logger.warning(f"Could not parse block index from key: {k_orig} due to {e}")
53
-
54
- num_transformer_blocks = max(block_indices) + 1 if block_indices else 0
55
- if not block_indices and any("blocks." in k for k in original_state_dict):
56
- logger.warning("Found 'blocks.' in keys but could not determine num_transformer_blocks reliably.")
57
-
58
-
59
- # --- Convert Transformer Blocks (blocks.0 to blocks.N-1) ---
60
- for i in range(num_transformer_blocks):
61
- # Self-attention (attn1 in Diffusers DiT)
62
- for lora_key_part, diffusers_layer_name in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
63
- orig_lora_down_key = f"blocks.{i}.self_attn.{lora_key_part}.lora_down.weight"
64
- orig_lora_up_key = f"blocks.{i}.self_attn.{lora_key_part}.lora_up.weight"
65
- target_base_key_peft = f"blocks.{i}.attn1.{diffusers_layer_name}"
66
- target_base_key_manual = f"transformer.blocks.{i}.attn1.{diffusers_layer_name}"
67
-
68
- if orig_lora_down_key in original_state_dict and orig_lora_up_key in original_state_dict:
69
- converted_state_dict_for_peft[f"{target_base_key_peft}.lora_A.weight"] = original_state_dict.pop(orig_lora_down_key)
70
- converted_state_dict_for_peft[f"{target_base_key_peft}.lora_B.weight"] = original_state_dict.pop(orig_lora_up_key)
71
-
72
- orig_diff_b_key = f"blocks.{i}.self_attn.{lora_key_part}.diff_b"
73
- if orig_diff_b_key in original_state_dict:
74
- manual_diff_patches[f"{target_base_key_manual}.bias"] = original_state_dict.pop(orig_diff_b_key)
75
-
76
- # Cross-attention (attn2 in Diffusers DiT)
77
- for lora_key_part, diffusers_layer_name in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
78
- orig_lora_down_key = f"blocks.{i}.cross_attn.{lora_key_part}.lora_down.weight"
79
- orig_lora_up_key = f"blocks.{i}.cross_attn.{lora_key_part}.lora_up.weight"
80
- target_base_key_peft = f"blocks.{i}.attn2.{diffusers_layer_name}"
81
- target_base_key_manual = f"transformer.blocks.{i}.attn2.{diffusers_layer_name}"
82
-
83
- if orig_lora_down_key in original_state_dict and orig_lora_up_key in original_state_dict:
84
- converted_state_dict_for_peft[f"{target_base_key_peft}.lora_A.weight"] = original_state_dict.pop(orig_lora_down_key)
85
- converted_state_dict_for_peft[f"{target_base_key_peft}.lora_B.weight"] = original_state_dict.pop(orig_lora_up_key)
86
-
87
- orig_diff_b_key = f"blocks.{i}.cross_attn.{lora_key_part}.diff_b"
88
- if orig_diff_b_key in original_state_dict:
89
- manual_diff_patches[f"{target_base_key_manual}.bias"] = original_state_dict.pop(orig_diff_b_key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  # FFN
92
- for original_ffn_idx, diffusers_ffn_path_part in zip(["0", "2"], ["net.0.proj", "net.2"]):
93
- orig_lora_down_key = f"blocks.{i}.ffn.{original_ffn_idx}.lora_down.weight"
94
- orig_lora_up_key = f"blocks.{i}.ffn.{original_ffn_idx}.lora_up.weight"
95
- target_base_key_peft = f"blocks.{i}.ffn.{diffusers_ffn_path_part}"
96
- target_base_key_manual = f"transformer.blocks.{i}.ffn.{diffusers_ffn_path_part}"
97
-
98
- if orig_lora_down_key in original_state_dict and orig_lora_up_key in original_state_dict:
99
- converted_state_dict_for_peft[f"{target_base_key_peft}.lora_A.weight"] = original_state_dict.pop(orig_lora_down_key)
100
- converted_state_dict_for_peft[f"{target_base_key_peft}.lora_B.weight"] = original_state_dict.pop(orig_lora_up_key)
101
-
102
- orig_diff_b_key = f"blocks.{i}.ffn.{original_ffn_idx}.diff_b"
103
- if orig_diff_b_key in original_state_dict:
104
- manual_diff_patches[f"{target_base_key_manual}.bias"] = original_state_dict.pop(orig_diff_b_key)
105
-
106
- # Norm layers within blocks
107
- # LoRA has `norm3.diff` and `norm3.diff_b`. Wan2.1 base DiTBlock has `norm2`.
108
- norm3_diff_key = f"blocks.{i}.norm3.diff"
109
- norm3_diff_b_key = f"blocks.{i}.norm3.diff_b"
110
- target_norm_key_base_manual = f"transformer.blocks.{i}.norm2" # Diffusers DiTBlock's second norm
111
- if norm3_diff_key in original_state_dict:
112
- manual_diff_patches[f"{target_norm_key_base_manual}.weight"] = original_state_dict.pop(norm3_diff_key)
113
- if norm3_diff_b_key in original_state_dict:
114
- manual_diff_patches[f"{target_norm_key_base_manual}.bias"] = original_state_dict.pop(norm3_diff_b_key)
115
-
116
- # Attention QK norms
117
- for attn_type, diffusers_attn_block in zip(["self_attn", "cross_attn"], ["attn1", "attn2"]):
118
- for norm_target_suffix in ["norm_q", "norm_k"]:
119
- orig_norm_diff_key = f"blocks.{i}.{attn_type}.{norm_target_suffix}.diff"
120
- target_norm_key_manual = f"transformer.blocks.{i}.{diffusers_attn_block}.{norm_target_suffix}.weight"
121
- if orig_norm_diff_key in original_state_dict:
122
- manual_diff_patches[target_norm_key_manual] = original_state_dict.pop(orig_norm_diff_key)
123
-
124
- # --- Convert Non-Block Components ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  # Patch Embedding
126
  patch_emb_diff_b_key = "patch_embedding.diff_b"
127
- if patch_emb_diff_b_key in original_state_dict:
128
- manual_diff_patches["transformer.patch_embedding.bias"] = original_state_dict.pop(patch_emb_diff_b_key)
129
-
130
- # Text Embedding
131
- for orig_idx, diffusers_linear_idx in zip(["0", "2"], ["linear_1", "linear_2"]):
132
- orig_lora_down_key = f"text_embedding.{orig_idx}.lora_down.weight"
133
- orig_lora_up_key = f"text_embedding.{orig_idx}.lora_up.weight"
134
- target_base_key_peft = f"condition_embedder.text_embedder.{diffusers_linear_idx}"
135
- target_base_key_manual = f"transformer.condition_embedder.text_embedder.{diffusers_linear_idx}"
136
- if orig_lora_down_key in original_state_dict and orig_lora_up_key in original_state_dict:
137
- converted_state_dict_for_peft[f"{target_base_key_peft}.lora_A.weight"] = original_state_dict.pop(orig_lora_down_key)
138
- converted_state_dict_for_peft[f"{target_base_key_peft}.lora_B.weight"] = original_state_dict.pop(orig_lora_up_key)
139
- orig_diff_b_key = f"text_embedding.{orig_idx}.diff_b"
140
- if orig_diff_b_key in original_state_dict:
141
- manual_diff_patches[f"{target_base_key_manual}.bias"] = original_state_dict.pop(orig_diff_b_key)
142
-
143
- # Time Embedding
144
- for orig_idx, diffusers_linear_idx in zip(["0", "2"], ["linear_1", "linear_2"]):
145
- orig_lora_down_key = f"time_embedding.{orig_idx}.lora_down.weight"
146
- orig_lora_up_key = f"time_embedding.{orig_idx}.lora_up.weight"
147
- target_base_key_peft = f"condition_embedder.time_embedder.{diffusers_linear_idx}"
148
- target_base_key_manual = f"transformer.condition_embedder.time_embedder.{diffusers_linear_idx}"
149
- if orig_lora_down_key in original_state_dict and orig_lora_up_key in original_state_dict:
150
- converted_state_dict_for_peft[f"{target_base_key_peft}.lora_A.weight"] = original_state_dict.pop(orig_lora_down_key)
151
- converted_state_dict_for_peft[f"{target_base_key_peft}.lora_B.weight"] = original_state_dict.pop(orig_lora_up_key)
152
- orig_diff_b_key = f"time_embedding.{orig_idx}.diff_b"
153
- if orig_diff_b_key in original_state_dict:
154
- manual_diff_patches[f"{target_base_key_manual}.bias"] = original_state_dict.pop(orig_diff_b_key)
155
-
156
- # Time Projection
157
- orig_lora_down_key = "time_projection.1.lora_down.weight"
158
- orig_lora_up_key = "time_projection.1.lora_up.weight"
159
- target_base_key_peft = "condition_embedder.time_proj"
160
- target_base_key_manual = "transformer.condition_embedder.time_proj"
161
- if orig_lora_down_key in original_state_dict and orig_lora_up_key in original_state_dict:
162
- converted_state_dict_for_peft[f"{target_base_key_peft}.lora_A.weight"] = original_state_dict.pop(orig_lora_down_key)
163
- converted_state_dict_for_peft[f"{target_base_key_peft}.lora_B.weight"] = original_state_dict.pop(orig_lora_up_key)
164
- orig_diff_b_key = "time_projection.1.diff_b"
165
- if orig_diff_b_key in original_state_dict:
166
- manual_diff_patches[f"{target_base_key_manual}.bias"] = original_state_dict.pop(orig_diff_b_key)
167
-
168
- # Head
169
- orig_lora_down_key = "head.head.lora_down.weight"
170
- orig_lora_up_key = "head.head.lora_up.weight"
171
- target_base_key_peft = "proj_out" # Directly under transformer in Diffusers DiT
172
- target_base_key_manual = "transformer.proj_out"
173
- if orig_lora_down_key in original_state_dict and orig_lora_up_key in original_state_dict:
174
- converted_state_dict_for_peft[f"{target_base_key_peft}.lora_A.weight"] = original_state_dict.pop(orig_lora_down_key)
175
- converted_state_dict_for_peft[f"{target_base_key_peft}.lora_B.weight"] = original_state_dict.pop(orig_lora_up_key)
176
- orig_diff_b_key = "head.head.diff_b"
177
- if orig_diff_b_key in original_state_dict:
178
- manual_diff_patches[f"{target_base_key_manual}.bias"] = original_state_dict.pop(orig_diff_b_key)
179
-
180
- # Log any remaining keys from the original LoRA after stripping "diffusion_model."
181
- if len(original_state_dict) > 0:
182
  logger.warning(
183
- f"Following keys from LoRA (after stripping 'diffusion_model.') were not converted or explicitly handled for PEFT/manual patching: {original_state_dict.keys()}"
 
184
  )
185
 
186
- # Add "transformer." prefix for Diffusers LoraLoaderMixin to the PEFT keys
187
- final_peft_state_dict = {}
188
- for k_peft, v_peft in converted_state_dict_for_peft.items():
189
- final_peft_state_dict[f"transformer.{k_peft}"] = v_peft
190
-
191
- MANUAL_PATCHES_STORE = manual_diff_patches # Store for later use
192
- return final_peft_state_dict
193
 
 
194
 
195
- def apply_manual_diff_patches(pipe_model, patches):
196
- """
197
- Manually applies diff_b/diff patches to the model.
198
- Assumes PEFT LoRA layers have already been loaded.
199
- """
200
- if not patches:
201
  logger.info("No manual diff patches to apply.")
202
  return
203
 
204
- logger.info(f"Applying {len(patches)} manual diff patches...")
205
- patched_keys_count = 0
206
- unpatched_keys_count = 0
207
- skipped_keys_details = []
208
-
209
- for key, diff_tensor in patches.items():
210
  try:
211
- # key is like "transformer.blocks.0.attn1.to_q.bias"
212
- current_module = pipe_model # Starts from pipe.transformer
213
- path_parts = key.split('.')[1:] # Remove "transformer." prefix for getattr navigation
214
- # e.g., ["blocks", "0", "attn1", "to_q", "bias"]
215
-
216
- # Navigate to the parent module of the parameter
217
- # Example: for "blocks.0.attn1.to_q.bias", parent_module_path is "blocks.0.attn1.to_q"
218
- parent_module_path = path_parts[:-1]
219
- param_name_to_patch = path_parts[-1] # "bias" or "weight"
220
-
221
- for part in parent_module_path:
222
- if hasattr(current_module, part):
223
- current_module = getattr(current_module, part)
224
- elif hasattr(current_module, 'base_layer') and hasattr(current_module.base_layer, part):
225
- # This case is unlikely here as we are navigating *to* the layer,
226
- # not trying to access a sub-component of a base_layer.
227
- # PEFT wrapping affects the layer itself, not its parent structure.
228
- current_module = getattr(current_module.base_layer, part)
229
- else:
230
- raise AttributeError(f"Submodule '{part}' not found in path '{'.'.join(parent_module_path)}' within {key}")
231
-
232
- # Now, current_module is the layer whose parameter we want to patch
233
- # e.g., if key was transformer.blocks.0.attn1.to_q.bias,
234
- # current_module is the to_q Linear layer (or LoraLayer wrapping it)
235
-
236
- layer_to_modify = current_module
237
- # If PEFT wrapped the Linear layer (common for attention q,k,v,o and ffn projections)
238
- if hasattr(layer_to_modify, "base_layer") and isinstance(layer_to_modify.base_layer, (torch.nn.Linear, torch.nn.LayerNorm)):
239
- actual_param_owner = layer_to_modify.base_layer
240
- else: # For non-wrapped layers like LayerNorm, or if it's already the base_layer
241
- actual_param_owner = layer_to_modify
242
-
243
- if not hasattr(actual_param_owner, param_name_to_patch):
244
- skipped_keys_details.append(f"Key: {key}, Reason: Parameter '{param_name_to_patch}' not found in layer '{actual_param_owner}'. Layer type: {type(actual_param_owner)}")
245
- unpatched_keys_count += 1
246
- continue
247
 
248
- original_param = getattr(actual_param_owner, param_name_to_patch)
249
-
250
- if original_param is None and param_name_to_patch == "bias":
251
- logger.info(f"Key '{key}': Original bias is None. Attempting to initialize.")
252
- if isinstance(actual_param_owner, torch.nn.Linear) or isinstance(actual_param_owner, torch.nn.LayerNorm):
253
- # For LayerNorm, bias exists if elementwise_affine=True (default).
254
- # If it was False, we are making it affine by adding a bias.
255
- # For Linear, if bias was False, we are adding one.
256
- actual_param_owner.bias = torch.nn.Parameter(torch.zeros_like(diff_tensor, device=diff_tensor.device, dtype=diff_tensor.dtype))
257
- original_param = actual_param_owner.bias
258
- logger.info(f"Key '{key}': Initialized bias for {type(actual_param_owner)}.")
259
- else:
260
- skipped_keys_details.append(f"Key: {key}, Reason: Original bias is None and layer '{actual_param_owner}' is not Linear or LayerNorm. Cannot initialize.")
261
- unpatched_keys_count +=1
262
- continue
263
-
264
- # Special handling for RMSNorm which typically has no bias
265
- if isinstance(actual_param_owner, torch.nn.RMSNorm) and param_name_to_patch == "bias":
266
- skipped_keys_details.append(f"Key: {key}, Reason: Layer '{actual_param_owner}' is RMSNorm which has no bias parameter. Skipping bias diff.")
267
- unpatched_keys_count +=1
268
  continue
269
 
270
-
271
- if original_param is not None:
272
- if original_param.shape != diff_tensor.shape:
273
- skipped_keys_details.append(f"Key: {key}, Reason: Shape mismatch. Model param: {original_param.shape}, LoRA diff: {diff_tensor.shape}. Layer: {actual_param_owner}")
274
- unpatched_keys_count += 1
275
- continue
276
- with torch.no_grad():
277
- original_param.add_(diff_tensor.to(original_param.device, original_param.dtype))
278
- # logger.info(f"Successfully applied diff to '{key}'") # Too verbose, will log summary
279
- patched_keys_count += 1
280
- else:
281
- skipped_keys_details.append(f"Key: {key}, Reason: Original parameter '{param_name_to_patch}' is None and was not initialized. Layer: {actual_param_owner}")
282
- unpatched_keys_count += 1
283
-
284
- except AttributeError as e:
285
- skipped_keys_details.append(f"Key: {key}, Reason: AttributeError - {e}")
286
- unpatched_keys_count += 1
287
  except Exception as e:
288
- skipped_keys_details.append(f"Key: {key}, Reason: General Exception - {e}")
289
- unpatched_keys_count += 1
290
-
291
- logger.info(f"Manual patching summary: {patched_keys_count} keys patched, {unpatched_keys_count} keys failed or skipped.")
292
- if unpatched_keys_count > 0:
293
- logger.warning("Details of unpatched/skipped keys:")
294
- for detail in skipped_keys_details:
295
- logger.warning(f" - {detail}")
296
 
297
  # --- Model Loading ---
298
  logger.info(f"Loading VAE for {MODEL_ID}...")
@@ -326,21 +290,19 @@ logger.info("Loading LoRA weights with custom converter...")
326
  from safetensors.torch import load_file as load_safetensors
327
  raw_lora_state_dict = load_safetensors(causvid_path)
328
 
329
- # Now call our custom converter which will populate MANUAL_PATCHES_STORE
330
  peft_state_dict = _custom_convert_non_diffusers_wan_lora_to_diffusers(raw_lora_state_dict)
331
 
332
- # Load the LoRA A/B matrices using PEFT
333
  if peft_state_dict:
334
  pipe.load_lora_weights(
335
- peft_state_dict, # Pass the dictionary directly
336
  adapter_name="causvid_lora"
337
  )
338
  logger.info("PEFT LoRA A/B weights loaded.")
339
  else:
340
  logger.warning("No PEFT-compatible LoRA weights found after conversion.")
341
 
342
- # Apply manual diff_b and diff patches
343
- apply_manual_diff_patches(pipe.transformer, MANUAL_PATCHES_STORE) # Apply to the transformer component
344
  logger.info("Manual diff_b/diff patches applied.")
345
 
346
 
 
22
  MANUAL_PATCHES_STORE = {}
23
 
24
  def _custom_convert_non_diffusers_wan_lora_to_diffusers(state_dict):
 
 
 
 
 
25
  global MANUAL_PATCHES_STORE
26
+ MANUAL_PATCHES_STORE = {} # Clear previous patches
27
+ peft_state_dict = {}
28
+ unhandled_keys = []
29
 
30
+ original_keys = list(state_dict.keys())
 
31
 
32
+ processed_state_dict = {}
33
+ for k, v in state_dict.items():
34
+ if k.startswith("diffusion_model."):
35
+ processed_state_dict[k[len("diffusion_model."):]] = v
36
+ elif k.startswith("difusion_model."): # Handle potential typo
37
+ processed_state_dict[k[len("difusion_model."):]] = v
38
+ else:
39
+ unhandled_keys.append(k) # Will be logged later if not handled by diff/diff_b
40
 
 
41
  block_indices = set()
42
+ for k_proc in processed_state_dict:
43
+ if k_proc.startswith("blocks."):
44
  try:
45
+ block_idx_str = k_proc.split("blocks.")[1].split(".")[0]
46
  if block_idx_str.isdigit():
47
  block_indices.add(int(block_idx_str))
48
+ except IndexError:
49
+ pass # Will be handled as a non-block key or logged
50
+ num_blocks = 0
51
+ if block_indices:
52
+ num_blocks = max(block_indices) + 1
53
+
54
+ is_i2v_lora = any("k_img" in k for k in processed_state_dict) and \
55
+ any("v_img" in k for k in processed_state_dict)
56
+
57
+ handled_original_keys = set()
58
+
59
+ # --- Handle Block-level LoRAs & Diffs ---
60
+ for i in range(num_blocks):
61
+ # Self-attention (maps to attn1 in WanTransformerBlock)
62
+ for o_lora, c_diffusers in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
63
+ lora_down_key_proc = f"blocks.{i}.self_attn.{o_lora}.lora_down.weight"
64
+ lora_up_key_proc = f"blocks.{i}.self_attn.{o_lora}.lora_up.weight"
65
+ diff_b_key_proc = f"blocks.{i}.self_attn.{o_lora}.diff_b"
66
+ diff_w_key_proc = f"blocks.{i}.self_attn.{o_lora}.diff" # Assuming .diff for weight
67
+
68
+ if lora_down_key_proc in processed_state_dict and lora_up_key_proc in processed_state_dict:
69
+ peft_state_dict[f"transformer.blocks.{i}.attn1.{c_diffusers}.lora_A.weight"] = processed_state_dict[lora_down_key_proc]
70
+ peft_state_dict[f"transformer.blocks.{i}.attn1.{c_diffusers}.lora_B.weight"] = processed_state_dict[lora_up_key_proc]
71
+ handled_original_keys.add(f"diffusion_model.{lora_down_key_proc}")
72
+ handled_original_keys.add(f"diffusion_model.{lora_up_key_proc}")
73
+ if diff_b_key_proc in processed_state_dict:
74
+ target_bias_key = f"transformer.blocks.{i}.attn1.{c_diffusers}.bias"
75
+ MANUAL_PATCHES_STORE[target_bias_key] = ("diff_b", processed_state_dict[diff_b_key_proc])
76
+ handled_original_keys.add(f"diffusion_model.{diff_b_key_proc}")
77
+ if diff_w_key_proc in processed_state_dict:
78
+ target_weight_key = f"transformer.blocks.{i}.attn1.{c_diffusers}.weight"
79
+ MANUAL_PATCHES_STORE[target_weight_key] = ("diff", processed_state_dict[diff_w_key_proc])
80
+ handled_original_keys.add(f"diffusion_model.{diff_w_key_proc}")
81
+
82
+ # Cross-attention (maps to attn2 in WanTransformerBlock)
83
+ for o_lora, c_diffusers in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
84
+ lora_down_key_proc = f"blocks.{i}.cross_attn.{o_lora}.lora_down.weight"
85
+ lora_up_key_proc = f"blocks.{i}.cross_attn.{o_lora}.lora_up.weight"
86
+ diff_b_key_proc = f"blocks.{i}.cross_attn.{o_lora}.diff_b"
87
+ diff_w_key_proc = f"blocks.{i}.cross_attn.{o_lora}.diff"
88
+ norm_q_diff_key_proc = f"blocks.{i}.cross_attn.norm_q.diff" # specific norm diff
89
+ norm_k_diff_key_proc = f"blocks.{i}.cross_attn.norm_k.diff" # specific norm diff
90
+
91
+ if lora_down_key_proc in processed_state_dict and lora_up_key_proc in processed_state_dict:
92
+ peft_state_dict[f"transformer.blocks.{i}.attn2.{c_diffusers}.lora_A.weight"] = processed_state_dict[lora_down_key_proc]
93
+ peft_state_dict[f"transformer.blocks.{i}.attn2.{c_diffusers}.lora_B.weight"] = processed_state_dict[lora_up_key_proc]
94
+ handled_original_keys.add(f"diffusion_model.{lora_down_key_proc}")
95
+ handled_original_keys.add(f"diffusion_model.{lora_up_key_proc}")
96
+ if diff_b_key_proc in processed_state_dict:
97
+ target_bias_key = f"transformer.blocks.{i}.attn2.{c_diffusers}.bias"
98
+ MANUAL_PATCHES_STORE[target_bias_key] = ("diff_b", processed_state_dict[diff_b_key_proc])
99
+ handled_original_keys.add(f"diffusion_model.{diff_b_key_proc}")
100
+ if diff_w_key_proc in processed_state_dict:
101
+ target_weight_key = f"transformer.blocks.{i}.attn2.{c_diffusers}.weight"
102
+ MANUAL_PATCHES_STORE[target_weight_key] = ("diff", processed_state_dict[diff_w_key_proc])
103
+ handled_original_keys.add(f"diffusion_model.{diff_w_key_proc}")
104
+
105
+ if norm_q_diff_key_proc in processed_state_dict: # Assuming norm_q on q_proj
106
+ MANUAL_PATCHES_STORE[f"transformer.blocks.{i}.attn2.norm_q.weight"] = ("diff", processed_state_dict[norm_q_diff_key_proc])
107
+ handled_original_keys.add(f"diffusion_model.{norm_q_diff_key_proc}")
108
+ if norm_k_diff_key_proc in processed_state_dict: # Assuming norm_k on k_proj
109
+ MANUAL_PATCHES_STORE[f"transformer.blocks.{i}.attn2.norm_k.weight"] = ("diff", processed_state_dict[norm_k_diff_key_proc])
110
+ handled_original_keys.add(f"diffusion_model.{norm_k_diff_key_proc}")
111
+
112
+
113
+ if is_i2v_lora:
114
+ for o_lora, c_diffusers in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
115
+ lora_down_key_proc = f"blocks.{i}.cross_attn.{o_lora}.lora_down.weight"
116
+ lora_up_key_proc = f"blocks.{i}.cross_attn.{o_lora}.lora_up.weight"
117
+ diff_b_key_proc = f"blocks.{i}.cross_attn.{o_lora}.diff_b"
118
+ diff_w_key_proc = f"blocks.{i}.cross_attn.{o_lora}.diff"
119
+
120
+ if lora_down_key_proc in processed_state_dict and lora_up_key_proc in processed_state_dict:
121
+ peft_state_dict[f"transformer.blocks.{i}.attn2.{c_diffusers}.lora_A.weight"] = processed_state_dict[lora_down_key_proc]
122
+ peft_state_dict[f"transformer.blocks.{i}.attn2.{c_diffusers}.lora_B.weight"] = processed_state_dict[lora_up_key_proc]
123
+ handled_original_keys.add(f"diffusion_model.{lora_down_key_proc}")
124
+ handled_original_keys.add(f"diffusion_model.{lora_up_key_proc}")
125
+ if diff_b_key_proc in processed_state_dict:
126
+ target_bias_key = f"transformer.blocks.{i}.attn2.{c_diffusers}.bias"
127
+ MANUAL_PATCHES_STORE[target_bias_key] = ("diff_b", processed_state_dict[diff_b_key_proc])
128
+ handled_original_keys.add(f"diffusion_model.{diff_b_key_proc}")
129
+ if diff_w_key_proc in processed_state_dict:
130
+ target_weight_key = f"transformer.blocks.{i}.attn2.{c_diffusers}.weight"
131
+ MANUAL_PATCHES_STORE[target_weight_key] = ("diff", processed_state_dict[diff_w_key_proc])
132
+ handled_original_keys.add(f"diffusion_model.{diff_w_key_proc}")
133
 
134
  # FFN
135
+ for o_lora_suffix, c_diffusers_path in zip([".0", ".2"], ["net.0.proj", "net.2"]):
136
+ lora_down_key_proc = f"blocks.{i}.ffn{o_lora_suffix}.lora_down.weight"
137
+ lora_up_key_proc = f"blocks.{i}.ffn{o_lora_suffix}.lora_up.weight"
138
+ diff_b_key_proc = f"blocks.{i}.ffn{o_lora_suffix}.diff_b"
139
+ diff_w_key_proc = f"blocks.{i}.ffn{o_lora_suffix}.diff" # Assuming .diff for weight
140
+
141
+ if lora_down_key_proc in processed_state_dict and lora_up_key_proc in processed_state_dict:
142
+ peft_state_dict[f"transformer.blocks.{i}.ffn.{c_diffusers_path}.lora_A.weight"] = processed_state_dict[lora_down_key_proc]
143
+ peft_state_dict[f"transformer.blocks.{i}.ffn.{c_diffusers_path}.lora_B.weight"] = processed_state_dict[lora_up_key_proc]
144
+ handled_original_keys.add(f"diffusion_model.{lora_down_key_proc}")
145
+ handled_original_keys.add(f"diffusion_model.{lora_up_key_proc}")
146
+ if diff_b_key_proc in processed_state_dict:
147
+ target_bias_key = f"transformer.blocks.{i}.ffn.{c_diffusers_path}.bias"
148
+ MANUAL_PATCHES_STORE[target_bias_key] = ("diff_b", processed_state_dict[diff_b_key_proc])
149
+ handled_original_keys.add(f"diffusion_model.{diff_b_key_proc}")
150
+ if diff_w_key_proc in processed_state_dict:
151
+ target_weight_key = f"transformer.blocks.{i}.ffn.{c_diffusers_path}.weight"
152
+ MANUAL_PATCHES_STORE[target_weight_key] = ("diff", processed_state_dict[diff_w_key_proc])
153
+ handled_original_keys.add(f"diffusion_model.{diff_w_key_proc}")
154
+
155
+ # Block norm3 diffs (assuming norm3 applies to the output of the FFN in the original Wan block structure)
156
+ norm3_diff_key_proc = f"blocks.{i}.norm3.diff"
157
+ norm3_diff_b_key_proc = f"blocks.{i}.norm3.diff_b"
158
+ if norm3_diff_key_proc in processed_state_dict:
159
+ MANUAL_PATCHES_STORE[f"transformer.blocks.{i}.norm3.weight"] = ("diff", processed_state_dict[norm3_diff_key_proc]) # Norms usually have .weight
160
+ handled_original_keys.add(f"diffusion_model.{norm3_diff_key_proc}")
161
+ if norm3_diff_b_key_proc in processed_state_dict:
162
+ MANUAL_PATCHES_STORE[f"transformer.blocks.{i}.norm3.bias"] = ("diff_b", processed_state_dict[norm3_diff_b_key_proc]) # And .bias
163
+ handled_original_keys.add(f"diffusion_model.{norm3_diff_b_key_proc}")
164
+
165
+
166
+ # --- Handle Top-level LoRAs & Diffs ---
167
+ top_level_mappings = [
168
+ # (lora_base_path_proc, diffusers_base_path, lora_suffixes, diffusers_suffixes)
169
+ ("text_embedding", "transformer.condition_embedder.text_embedder", ["0", "2"], ["linear_1", "linear_2"]),
170
+ ("time_embedding", "transformer.condition_embedder.time_embedder", ["0", "2"], ["linear_1", "linear_2"]),
171
+ ("time_projection", "transformer.condition_embedder.time_proj", ["1"], [""]), # Wan has .1, Diffusers has no suffix
172
+ ("head", "transformer.proj_out", ["head"], [""]), # Wan has .head, Diffusers has no suffix
173
+ ]
174
+
175
+ for lora_base_proc, diffusers_base, lora_suffixes, diffusers_suffixes in top_level_mappings:
176
+ for l_suffix, d_suffix in zip(lora_suffixes, diffusers_suffixes):
177
+ actual_lora_path_proc = f"{lora_base_proc}.{l_suffix}" if l_suffix else lora_base_proc
178
+ actual_diffusers_path = f"{diffusers_base}.{d_suffix}" if d_suffix else diffusers_base
179
+
180
+ lora_down_key_proc = f"{actual_lora_path_proc}.lora_down.weight"
181
+ lora_up_key_proc = f"{actual_lora_path_proc}.lora_up.weight"
182
+ diff_b_key_proc = f"{actual_lora_path_proc}.diff_b"
183
+ diff_w_key_proc = f"{actual_lora_path_proc}.diff"
184
+
185
+ if lora_down_key_proc in processed_state_dict and lora_up_key_proc in processed_state_dict:
186
+ peft_state_dict[f"{actual_diffusers_path}.lora_A.weight"] = processed_state_dict[lora_down_key_proc]
187
+ peft_state_dict[f"{actual_diffusers_path}.lora_B.weight"] = processed_state_dict[lora_up_key_proc]
188
+ handled_original_keys.add(f"diffusion_model.{lora_down_key_proc}")
189
+ handled_original_keys.add(f"diffusion_model.{lora_up_key_proc}")
190
+ if diff_b_key_proc in processed_state_dict:
191
+ MANUAL_PATCHES_STORE[f"{actual_diffusers_path}.bias"] = ("diff_b", processed_state_dict[diff_b_key_proc])
192
+ handled_original_keys.add(f"diffusion_model.{diff_b_key_proc}")
193
+ if diff_w_key_proc in processed_state_dict:
194
+ MANUAL_PATCHES_STORE[f"{actual_diffusers_path}.weight"] = ("diff", processed_state_dict[diff_w_key_proc])
195
+ handled_original_keys.add(f"diffusion_model.{diff_w_key_proc}")
196
+
197
  # Patch Embedding
198
  patch_emb_diff_b_key = "patch_embedding.diff_b"
199
+ if patch_emb_diff_b_key in processed_state_dict:
200
+ MANUAL_PATCHES_STORE["transformer.patch_embedding.bias"] = ("diff_b", processed_state_dict[patch_emb_diff_b_key])
201
+ handled_original_keys.add(f"diffusion_model.{patch_emb_diff_b_key}")
202
+ # Assuming .diff might exist for patch_embedding.weight, though not explicitly in your example list
203
+ patch_emb_diff_w_key = "patch_embedding.diff"
204
+ if patch_emb_diff_w_key in processed_state_dict:
205
+ MANUAL_PATCHES_STORE["transformer.patch_embedding.weight"] = ("diff", processed_state_dict[patch_emb_diff_w_key])
206
+ handled_original_keys.add(f"diffusion_model.{patch_emb_diff_w_key}")
207
+
208
+
209
+ # Log unhandled keys
210
+ final_unhandled_keys = []
211
+ for k_orig in original_keys:
212
+ # Reconstruct the processed key to check if it was actually handled by diff/diff_b or lora A/B logic
213
+ k_proc = None
214
+ if k_orig.startswith("diffusion_model."):
215
+ k_proc = k_orig[len("diffusion_model."):]
216
+ elif k_orig.startswith("difusion_model."):
217
+ k_proc = k_orig[len("difusion_model."):]
218
+
219
+ if k_orig not in handled_original_keys and (k_proc is None or not any(k_proc.endswith(s) for s in [".lora_down.weight", ".lora_up.weight", ".diff", ".diff_b", ".alpha"])):
220
+ final_unhandled_keys.append(k_orig)
221
+
222
+ if final_unhandled_keys:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  logger.warning(
224
+ f"The following keys from the Wan 2.1 LoRA checkpoint were not converted to PEFT LoRA A/B format "
225
+ f"nor identified as manual diff patches: {final_unhandled_keys}."
226
  )
227
 
228
+ if not peft_state_dict and not MANUAL_PATCHES_STORE:
229
+ logger.warning("No valid LoRA A/B weights or manual diff patches found after conversion.")
 
 
 
 
 
230
 
231
+ return peft_state_dict
232
 
233
+ def apply_manual_diff_patches(pipe_model_component, patches_store, strength_model=1.0):
234
+ if not patches_store:
 
 
 
 
235
  logger.info("No manual diff patches to apply.")
236
  return
237
 
238
+ logger.info(f"Applying {len(patches_store)} manual diff patches...")
239
+ for target_key, (patch_type, diff_tensor) in patches_store.items():
 
 
 
 
240
  try:
241
+ module_path, param_name = target_key.rsplit('.', 1)
242
+ module = pipe_model_component.get_submodule(module_path)
243
+ original_param = getattr(module, param_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
+ if original_param.shape != diff_tensor.shape:
246
+ logger.warning(f"Shape mismatch for {target_key}: model {original_param.shape}, LoRA {diff_tensor.shape}. Skipping patch.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  continue
248
 
249
+ with torch.no_grad():
250
+ # Ensure diff_tensor is on the same device and dtype as the original parameter
251
+ diff_tensor_casted = diff_tensor.to(device=original_param.device, dtype=original_param.dtype)
252
+ scaled_diff = diff_tensor_casted * strength_model
253
+ original_param.add_(scaled_diff)
254
+ # logger.info(f"Applied {patch_type} to {target_key} with strength {strength_model}")
255
+ except AttributeError:
256
+ logger.warning(f"Could not find parameter {target_key} in the model component. Skipping patch.")
 
 
 
 
 
 
 
 
 
257
  except Exception as e:
258
+ logger.error(f"Error applying patch to {target_key}: {e}")
259
+ logger.info("Finished applying manual diff patches.")
 
 
 
 
 
 
260
 
261
  # --- Model Loading ---
262
  logger.info(f"Loading VAE for {MODEL_ID}...")
 
290
  from safetensors.torch import load_file as load_safetensors
291
  raw_lora_state_dict = load_safetensors(causvid_path)
292
 
 
293
  peft_state_dict = _custom_convert_non_diffusers_wan_lora_to_diffusers(raw_lora_state_dict)
294
 
 
295
  if peft_state_dict:
296
  pipe.load_lora_weights(
297
+ peft_state_dict,
298
  adapter_name="causvid_lora"
299
  )
300
  logger.info("PEFT LoRA A/B weights loaded.")
301
  else:
302
  logger.warning("No PEFT-compatible LoRA weights found after conversion.")
303
 
304
+ lora_strength = 1.0
305
+ apply_manual_diff_patches(pipe.transformer, MANUAL_PATCHES_STORE, strength_model=lora_strength)
306
  logger.info("Manual diff_b/diff patches applied.")
307
 
308