DonImages commited on
Commit
6583c62
·
verified ·
1 Parent(s): 39efbe4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -6,7 +6,7 @@ import random
6
  from diffusers import StableDiffusion3Pipeline
7
  from diffusers.loaders import SD3LoraLoaderMixin
8
 
9
- # Device selection NOTE
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
12
 
@@ -24,10 +24,10 @@ pipeline = StableDiffusion3Pipeline.from_pretrained(
24
  ).to(device)
25
 
26
  # Load the LoRA trained weights once at the start
27
- lora_path = "lora_trained_model.pt" # Ensure this file is uploaded in the Space
28
  if os.path.exists(lora_path):
29
  try:
30
- pipeline.load_lora_weights(lora_path) # Use the correct method for loading LoRA weights
31
  print("✅ LoRA weights loaded successfully!")
32
  except Exception as e:
33
  print(f"❌ Error loading LoRA: {e}")
@@ -67,4 +67,4 @@ with gr.Blocks() as demo:
67
  generate_btn.click(generate_image, inputs=[prompt_input, seed_input], outputs=output_image)
68
 
69
  # Launch the Gradio app
70
- demo.launch()
 
6
  from diffusers import StableDiffusion3Pipeline
7
  from diffusers.loaders import SD3LoraLoaderMixin
8
 
9
+ # Device selection
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
12
 
 
24
  ).to(device)
25
 
26
  # Load the LoRA trained weights once at the start
27
+ lora_path = "lora_trained_model.safetensors" # Use the correct file name
28
  if os.path.exists(lora_path):
29
  try:
30
+ SD3LoraLoaderMixin.load_lora_into_model(pipeline, lora_path) # Correct method
31
  print("✅ LoRA weights loaded successfully!")
32
  except Exception as e:
33
  print(f"❌ Error loading LoRA: {e}")
 
67
  generate_btn.click(generate_image, inputs=[prompt_input, seed_input], outputs=output_image)
68
 
69
  # Launch the Gradio app
70
+ demo.launch()