DonImages commited on
Commit
b822313
·
verified ·
1 Parent(s): c3f8164

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -10
app.py CHANGED
@@ -25,19 +25,17 @@ pipeline = StableDiffusion3Pipeline.from_pretrained(
25
  # Load the LoRA trained weights once at the start
26
  lora_path = "lora_trained_model.pt" # Ensure this file is uploaded in the Space
27
  if os.path.exists(lora_path):
28
- lora_state_dict = torch.load(lora_path, map_location=device)
29
 
30
  try:
31
- # Assuming `pipeline` has a method `load_lora_weights` to load LoRA weights directly
32
- # If the pipeline does not support this method, we might need to apply the LoRA weights manually
33
- pipeline.load_lora_weights(lora_state_dict) # Load LoRA weights into the pipeline
34
  print("✅ LoRA weights loaded successfully!")
35
- except AttributeError:
36
- print("❌ pipeline does not support load_lora_weights method. Attempting manual application.")
37
- # Manual application of weights if load_lora_weights method does not exist
38
- # This is just a placeholder; you'll need to update this part based on how your LoRA weights should be applied
39
- # Example:
40
- # pipeline.model.load_state_dict(lora_state_dict, strict=False)
41
  else:
42
  print("⚠️ LoRA file not found! Running base model.")
43
 
 
25
  # Load the LoRA trained weights once at the start
26
  lora_path = "lora_trained_model.pt" # Ensure this file is uploaded in the Space
27
  if os.path.exists(lora_path):
28
+ lora_checkpoint = torch.load(lora_path, map_location=device)
29
 
30
  try:
31
+ # Assuming the checkpoint contains 'model' key with the model state dict
32
+ pipeline.unet.load_state_dict(lora_checkpoint['model'], strict=False) # Apply weights to the unet
33
+
34
  print("✅ LoRA weights loaded successfully!")
35
+ except KeyError as e:
36
+ print(f"❌ Error: Missing key in checkpoint: {e}")
37
+ except Exception as e:
38
+ print(f"❌ Error loading LoRA: {e}")
 
 
39
  else:
40
  print("⚠️ LoRA file not found! Running base model.")
41