DonImages commited on
Commit
3819d16
·
verified ·
1 Parent(s): 2e7987e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -14,7 +14,10 @@ else:
14
  # Load the Stable Diffusion 3.5 model with lower precision (float16)
15
  model_id = "stabilityai/stable-diffusion-3.5-large"
16
  pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch.float16) # Use float16 precision
17
- pipe.to("cuda") # Ensuring it runs on GPU
 
 
 
18
 
19
  # Define the path to the LoRA model
20
  lora_model_path = "./lora_model.pth" # Assuming the file is saved locally
@@ -22,7 +25,7 @@ lora_model_path = "./lora_model.pth" # Assuming the file is saved locally
22
  # Custom method to load and apply LoRA weights to the Stable Diffusion pipeline
23
  def load_lora_model(pipe, lora_model_path):
24
  # Load the LoRA weights
25
- lora_weights = torch.load(lora_model_path, map_location="cpu")
26
 
27
  # Apply weights to the UNet submodule
28
  for name, param in pipe.unet.named_parameters(): # Accessing unet parameters
@@ -50,4 +53,4 @@ iface = gr.Interface(
50
  ],
51
  outputs="image"
52
  )
53
- iface.launch()
 
14
  # Load the Stable Diffusion 3.5 model with lower precision (float16)
15
  model_id = "stabilityai/stable-diffusion-3.5-large"
16
  pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch.float16) # Use float16 precision
17
+
18
+ # Check for GPU availability and set device accordingly
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ pipe.to(device) # Use GPU if available, otherwise fallback to CPU
21
 
22
  # Define the path to the LoRA model
23
  lora_model_path = "./lora_model.pth" # Assuming the file is saved locally
 
25
  # Custom method to load and apply LoRA weights to the Stable Diffusion pipeline
26
  def load_lora_model(pipe, lora_model_path):
27
  # Load the LoRA weights
28
+ lora_weights = torch.load(lora_model_path, map_location=device) # Use correct device
29
 
30
  # Apply weights to the UNet submodule
31
  for name, param in pipe.unet.named_parameters(): # Accessing unet parameters
 
53
  ],
54
  outputs="image"
55
  )
56
+ iface.launch()