ovi054 commited on
Commit
542414a
·
verified ·
1 Parent(s): 1b24a66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -8
app.py CHANGED
@@ -36,23 +36,67 @@ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow
36
  # flow_shift=flow_shift # Retain flow_shift for WanPipeline compatibility
37
  # )
38
 
 
 
 
 
 
 
 
 
 
39
  CAUSVID_LORA_REPO = "WanVideo_comfy"
40
  CAUSVID_LORA_FILENAME = "Wan21_CausVid_bidirect2_T2V_1_3B_lora_rank32.safetensors"
41
-
42
  try:
43
  causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
44
- pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
45
- print("✅ CausVid LoRA loaded (strength: 1.0)")
46
  except Exception as e:
47
- print(f"⚠️ CausVid LoRA not loaded: {e}")
48
- causvid_path = None
 
 
 
49
 
50
 
51
  @spaces.GPU()
52
  def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
53
- if lora_id and lora_id.strip() != "":
54
- # pipe.unload_lora_weights()
55
- pipe.load_lora_weights(lora_id.strip())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  pipe.to("cuda")
57
  # apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.2))
58
  apply_cache_on_pipe(
 
36
  # flow_shift=flow_shift # Retain flow_shift for WanPipeline compatibility
37
  # )
38
 
39
+ # --- LoRA State Management ---
40
+ # Define unique names for our adapters
41
+ DEFAULT_LORA_NAME = "causvid_lora"
42
+ CUSTOM_LORA_NAME = "custom_lora"
43
+ # Track which custom LoRA is currently loaded to avoid reloading
44
+ CURRENTLY_LOADED_CUSTOM_LORA = None
45
+
46
+ # Load the default base LoRA ONCE at startup
47
+ print("Loading base LoRA...")
48
  CAUSVID_LORA_REPO = "WanVideo_comfy"
49
  CAUSVID_LORA_FILENAME = "Wan21_CausVid_bidirect2_T2V_1_3B_lora_rank32.safetensors"
 
50
  try:
51
  causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
52
+ pipe.load_lora_weights(causvid_path, adapter_name=DEFAULT_LORA_NAME)
53
+ print(f"✅ Default LoRA '{DEFAULT_LORA_NAME}' loaded successfully.")
54
  except Exception as e:
55
+ print(f"⚠️ Default LoRA could not be loaded: {e}")
56
+ DEFAULT_LORA_NAME = None
57
+
58
+ print("Initialization complete. Gradio is starting...")
59
+
60
 
61
 
62
  @spaces.GPU()
63
  def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
64
+ # if lora_id and lora_id.strip() != "":
65
+ # pipe.unload_lora_weights()
66
+ # pipe.load_lora_weights(lora_id.strip())
67
+
68
+ global CURRENTLY_LOADED_CUSTOM_LORA
69
+
70
+ active_adapters = []
71
+ adapter_weights = []
72
+
73
+ # Always add the default LoRA if it was loaded successfully
74
+ if DEFAULT_LORA_NAME:
75
+ active_adapters.append(DEFAULT_LORA_NAME)
76
+ adapter_weights.append(1.0) # Strength for base LoRA
77
+
78
+ # Handle the user-provided custom LoRA
79
+ clean_lora_id = lora_id.strip() if lora_id else ""
80
+ if clean_lora_id:
81
+ try:
82
+ # If the requested LoRA is different from the one in memory, swap it
83
+ if clean_lora_id != CURRENTLY_LOADED_CUSTOM_LORA:
84
+ print(f"Switching custom LoRA to: {clean_lora_id}")
85
+ # Unload the old custom LoRA to save memory
86
+ if CURRENTLY_LOADED_CUSTOM_LORA is not None:
87
+ pipe.unload_lora_weights(CUSTOM_LORA_NAME)
88
+
89
+ # Load the new one with its unique name
90
+ pipe.load_lora_weights(clean_lora_id, adapter_name=CUSTOM_LORA_NAME)
91
+ CURRENTLY_LOADED_CUSTOM_LORA = clean_lora_id
92
+
93
+ # Add the custom LoRA to the active list for this generation
94
+ active_adapters.append(CUSTOM_LORA_NAME)
95
+ adapter_weights.append(1.0) # Strength for custom LoRA
96
+
97
+ except Exception as e:
98
+ print(f"⚠️ Failed to load custom LoRA '{clean_lora_id}'. Error: {e}")
99
+
100
  pipe.to("cuda")
101
  # apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.2))
102
  apply_cache_on_pipe(