Update app.py
Browse files
app.py
CHANGED
|
@@ -105,6 +105,18 @@ def download_file(url, directory=None):
|
|
| 105 |
|
| 106 |
return filepath
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height):
|
| 109 |
selected_index = evt.index
|
| 110 |
selected_indices = selected_indices or []
|
|
@@ -462,7 +474,9 @@ def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scal
|
|
| 462 |
for idx, lora in enumerate(selected_loras):
|
| 463 |
print(f"Inspecting LoRA {idx + 1}: {lora['title']}")
|
| 464 |
try:
|
| 465 |
-
|
|
|
|
|
|
|
| 466 |
print(f"LoRA Parameter Keys for {lora['title']}: {list(lora_weights.keys())}")
|
| 467 |
except Exception as e:
|
| 468 |
print(f"Error loading LoRA weights for {lora['title']} from {lora['repo']}: {e}")
|
|
@@ -499,16 +513,14 @@ def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scal
|
|
| 499 |
with calculateDuration("Loading LoRA weights"):
|
| 500 |
for idx, lora in enumerate(selected_loras):
|
| 501 |
lora_name = f"lora_{idx}"
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
pipe.load_lora_weights(
|
| 506 |
-
|
| 507 |
-
weight_name=lora_weights_path,
|
| 508 |
low_cpu_mem_usage=True,
|
| 509 |
adapter_name=lora_name,
|
| 510 |
-
merge_and_unload=True,
|
| 511 |
-
|
| 512 |
|
| 513 |
print("Adapter weights:", lora_weights)
|
| 514 |
try:
|
|
|
|
| 105 |
|
| 106 |
return filepath
|
| 107 |
|
| 108 |
+
def get_lora_weights(lora_repo, weight_name=None):
|
| 109 |
+
try:
|
| 110 |
+
# Download the weights from Hugging Face Hub
|
| 111 |
+
file_path = hf_hub_download(
|
| 112 |
+
repo_id=lora_repo,
|
| 113 |
+
filename=weight_name if weight_name else "pytorch_model.bin"
|
| 114 |
+
)
|
| 115 |
+
return file_path
|
| 116 |
+
except Exception as e:
|
| 117 |
+
print(f"Failed to fetch weights for {lora_repo}: {e}")
|
| 118 |
+
raise
|
| 119 |
+
|
| 120 |
def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height):
|
| 121 |
selected_index = evt.index
|
| 122 |
selected_indices = selected_indices or []
|
|
|
|
| 474 |
for idx, lora in enumerate(selected_loras):
|
| 475 |
print(f"Inspecting LoRA {idx + 1}: {lora['title']}")
|
| 476 |
try:
|
| 477 |
+
lora_weights_path = get_lora_weights(lora['repo'], lora.get("weights"))
|
| 478 |
+
print(f"LoRA weights fetched from: {lora_weights_path}")
|
| 479 |
+
lora_weights = torch.load(lora_weights_path, weights_only=True) #lora_weights = torch.load(lora_weights_path)
|
| 480 |
print(f"LoRA Parameter Keys for {lora['title']}: {list(lora_weights.keys())}")
|
| 481 |
except Exception as e:
|
| 482 |
print(f"Error loading LoRA weights for {lora['title']} from {lora['repo']}: {e}")
|
|
|
|
| 513 |
with calculateDuration("Loading LoRA weights"):
|
| 514 |
for idx, lora in enumerate(selected_loras):
|
| 515 |
lora_name = f"lora_{idx}"
|
| 516 |
+
print(f"Loading LoRA: {lora['title']} with adapter name: {lora_name}")
|
| 517 |
+
lora_weights_path = get_lora_weights(lora['repo'], lora.get("weights"))
|
|
|
|
| 518 |
pipe.load_lora_weights(
|
| 519 |
+
lora_weights_path,
|
|
|
|
| 520 |
low_cpu_mem_usage=True,
|
| 521 |
adapter_name=lora_name,
|
| 522 |
+
merge_and_unload=True,
|
| 523 |
+
)
|
| 524 |
|
| 525 |
print("Adapter weights:", lora_weights)
|
| 526 |
try:
|