Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -36,23 +36,67 @@ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow
|
|
36 |
# flow_shift=flow_shift # Retain flow_shift for WanPipeline compatibility
|
37 |
# )
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
CAUSVID_LORA_REPO = "WanVideo_comfy"
|
40 |
CAUSVID_LORA_FILENAME = "Wan21_CausVid_bidirect2_T2V_1_3B_lora_rank32.safetensors"
|
41 |
-
|
42 |
try:
|
43 |
causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
|
44 |
-
pipe.load_lora_weights(causvid_path, adapter_name=
|
45 |
-
print("✅
|
46 |
except Exception as e:
|
47 |
-
print(f"⚠️
|
48 |
-
|
|
|
|
|
|
|
49 |
|
50 |
|
51 |
@spaces.GPU()
|
52 |
def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
|
53 |
-
if lora_id and lora_id.strip() != "":
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
pipe.to("cuda")
|
57 |
# apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.2))
|
58 |
apply_cache_on_pipe(
|
|
|
36 |
# flow_shift=flow_shift # Retain flow_shift for WanPipeline compatibility
|
37 |
# )
|
38 |
|
39 |
+
# --- LoRA State Management ---
|
40 |
+
# Define unique names for our adapters
|
41 |
+
DEFAULT_LORA_NAME = "causvid_lora"
|
42 |
+
CUSTOM_LORA_NAME = "custom_lora"
|
43 |
+
# Track which custom LoRA is currently loaded to avoid reloading
|
44 |
+
CURRENTLY_LOADED_CUSTOM_LORA = None
|
45 |
+
|
46 |
+
# Load the default base LoRA ONCE at startup
|
47 |
+
print("Loading base LoRA...")
|
48 |
CAUSVID_LORA_REPO = "WanVideo_comfy"
|
49 |
CAUSVID_LORA_FILENAME = "Wan21_CausVid_bidirect2_T2V_1_3B_lora_rank32.safetensors"
|
|
|
50 |
try:
|
51 |
causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
|
52 |
+
pipe.load_lora_weights(causvid_path, adapter_name=DEFAULT_LORA_NAME)
|
53 |
+
print(f"✅ Default LoRA '{DEFAULT_LORA_NAME}' loaded successfully.")
|
54 |
except Exception as e:
|
55 |
+
print(f"⚠️ Default LoRA could not be loaded: {e}")
|
56 |
+
DEFAULT_LORA_NAME = None
|
57 |
+
|
58 |
+
print("Initialization complete. Gradio is starting...")
|
59 |
+
|
60 |
|
61 |
|
62 |
@spaces.GPU()
|
63 |
def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
|
64 |
+
# if lora_id and lora_id.strip() != "":
|
65 |
+
# pipe.unload_lora_weights()
|
66 |
+
# pipe.load_lora_weights(lora_id.strip())
|
67 |
+
|
68 |
+
global CURRENTLY_LOADED_CUSTOM_LORA
|
69 |
+
|
70 |
+
active_adapters = []
|
71 |
+
adapter_weights = []
|
72 |
+
|
73 |
+
# Always add the default LoRA if it was loaded successfully
|
74 |
+
if DEFAULT_LORA_NAME:
|
75 |
+
active_adapters.append(DEFAULT_LORA_NAME)
|
76 |
+
adapter_weights.append(1.0) # Strength for base LoRA
|
77 |
+
|
78 |
+
# Handle the user-provided custom LoRA
|
79 |
+
clean_lora_id = lora_id.strip() if lora_id else ""
|
80 |
+
if clean_lora_id:
|
81 |
+
try:
|
82 |
+
# If the requested LoRA is different from the one in memory, swap it
|
83 |
+
if clean_lora_id != CURRENTLY_LOADED_CUSTOM_LORA:
|
84 |
+
print(f"Switching custom LoRA to: {clean_lora_id}")
|
85 |
+
# Unload the old custom LoRA to save memory
|
86 |
+
if CURRENTLY_LOADED_CUSTOM_LORA is not None:
|
87 |
+
pipe.unload_lora_weights(CUSTOM_LORA_NAME)
|
88 |
+
|
89 |
+
# Load the new one with its unique name
|
90 |
+
pipe.load_lora_weights(clean_lora_id, adapter_name=CUSTOM_LORA_NAME)
|
91 |
+
CURRENTLY_LOADED_CUSTOM_LORA = clean_lora_id
|
92 |
+
|
93 |
+
# Add the custom LoRA to the active list for this generation
|
94 |
+
active_adapters.append(CUSTOM_LORA_NAME)
|
95 |
+
adapter_weights.append(1.0) # Strength for custom LoRA
|
96 |
+
|
97 |
+
except Exception as e:
|
98 |
+
print(f"⚠️ Failed to load custom LoRA '{clean_lora_id}'. Error: {e}")
|
99 |
+
|
100 |
pipe.to("cuda")
|
101 |
# apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.2))
|
102 |
apply_cache_on_pipe(
|