DonImages commited on
Commit
4ef7068
·
verified ·
1 Parent(s): b822313

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -9
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import gradio as gr
3
  import torch
4
- from diffusers import StableDiffusion3Pipeline
5
  import spaces
6
  import random
7
 
@@ -25,16 +25,10 @@ pipeline = StableDiffusion3Pipeline.from_pretrained(
25
  # Load the LoRA trained weights once at the start
26
  lora_path = "lora_trained_model.pt" # Ensure this file is uploaded in the Space
27
  if os.path.exists(lora_path):
28
- lora_checkpoint = torch.load(lora_path, map_location=device)
29
-
30
  try:
31
- # Assuming the checkpoint contains 'model' key with the model state dict
32
- pipeline.unet.load_state_dict(lora_checkpoint['model'], strict=False) # Apply weights to the unet
33
-
34
  print("✅ LoRA weights loaded successfully!")
35
- except KeyError as e:
36
- print(f"❌ Error: Missing key in checkpoint: {e}")
37
- except Exception as e:
38
  print(f"❌ Error loading LoRA: {e}")
39
  else:
40
  print("⚠️ LoRA file not found! Running base model.")
 
1
  import os
2
  import gradio as gr
3
  import torch
4
+ from diffusers import StableDiffusion3Pipeline, SD3LoraLoaderMixin
5
  import spaces
6
  import random
7
 
 
25
  # Load the LoRA trained weights once at the start
26
  lora_path = "lora_trained_model.pt" # Ensure this file is uploaded in the Space
27
  if os.path.exists(lora_path):
 
 
28
  try:
29
+ pipeline.load_lora_weights(lora_path) # This automatically applies to the right components
 
 
30
  print("✅ LoRA weights loaded successfully!")
31
+ except ValueError as e:
 
 
32
  print(f"❌ Error loading LoRA: {e}")
33
  else:
34
  print("⚠️ LoRA file not found! Running base model.")