DonImages commited on
Commit
af98809
·
verified ·
1 Parent(s): 7ee03e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -22
app.py CHANGED
@@ -1,19 +1,9 @@
1
- import torch
2
- import zerogpu # Import ZeroGPU
3
  from diffusers import StableDiffusion3Pipeline
4
  from huggingface_hub import login
5
  import os
6
  import gradio as gr
7
 
8
- # Automatically choose GPU if available, otherwise CPU
9
- device = zerogpu.select_device() # ZeroGPU will automatically choose 'cuda' or 'cpu'
10
-
11
- # Check and print if the selected device is GPU or CPU
12
- if device == "cuda":
13
- print(f"Using GPU: {torch.cuda.get_device_name()}")
14
- else:
15
- print("Using CPU")
16
-
17
  # Retrieve the token from the environment variable
18
  token = os.getenv("HF_TOKEN") # Hugging Face token from the secret
19
  if token:
@@ -23,12 +13,10 @@ else:
23
 
24
  # Load the Stable Diffusion 3.5 model with lower precision (float16) if GPU is available
25
  model_id = "stabilityai/stable-diffusion-3.5-large"
26
- if device == "cuda":
27
- pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch.float16) # Use float16 precision
28
- else:
29
- pipe = StableDiffusion3Pipeline.from_pretrained(model_id) # Default precision for CPU
30
 
31
- pipe.to(device) # Ensuring the model is on the correct device (GPU or CPU)
 
32
 
33
  # Define the path to the LoRA model
34
  lora_model_path = "./lora_model.pth" # Assuming the file is saved locally
@@ -36,7 +24,7 @@ lora_model_path = "./lora_model.pth" # Assuming the file is saved locally
36
  # Custom method to load and apply LoRA weights to the Stable Diffusion pipeline
37
  def load_lora_model(pipe, lora_model_path):
38
  # Load the LoRA weights
39
- lora_weights = torch.load(lora_model_path, map_location=device) # Load LoRA model to the correct device
40
 
41
  # Print available attributes of the model to check access to `unet` (optional)
42
  print(dir(pipe)) # This will list all attributes and methods of the `pipe` object
@@ -55,16 +43,17 @@ def load_lora_model(pipe, lora_model_path):
55
  # Load and apply the LoRA model weights
56
  pipe = load_lora_model(pipe, lora_model_path)
57
 
58
- # Function to generate an image from a text prompt
59
- def generate_image(prompt, seed=None):
 
60
  generator = torch.manual_seed(seed) if seed is not None else None
61
- # Reduce image size for less memory usage
62
- image = pipe(prompt, height=512, width=512, generator=generator).images[0] # Changed image size
63
  return image
64
 
65
  # Gradio interface
66
  iface = gr.Interface(
67
- fn=generate_image,
68
  inputs=[
69
  gr.Textbox(label="Enter your prompt"), # For the prompt
70
  gr.Number(label="Enter a seed (optional)", value=None), # For the seed
 
1
+ import spaces
 
2
  from diffusers import StableDiffusion3Pipeline
3
  from huggingface_hub import login
4
  import os
5
  import gradio as gr
6
 
 
 
 
 
 
 
 
 
 
7
  # Retrieve the token from the environment variable
8
  token = os.getenv("HF_TOKEN") # Hugging Face token from the secret
9
  if token:
 
13
 
14
  # Load the Stable Diffusion 3.5 model with lower precision (float16) if GPU is available
15
  model_id = "stabilityai/stable-diffusion-3.5-large"
16
+ pipe = StableDiffusion3Pipeline.from_pretrained(model_id)
 
 
 
17
 
18
+ # Check if GPU is available, then move the model to the appropriate device
19
+ pipe.to('cuda' if torch.cuda.is_available() else 'cpu')
20
 
21
  # Define the path to the LoRA model
22
  lora_model_path = "./lora_model.pth" # Assuming the file is saved locally
 
24
  # Custom method to load and apply LoRA weights to the Stable Diffusion pipeline
25
  def load_lora_model(pipe, lora_model_path):
26
  # Load the LoRA weights
27
+ lora_weights = torch.load(lora_model_path, map_location=pipe.device) # Load LoRA model to the correct device
28
 
29
  # Print available attributes of the model to check access to `unet` (optional)
30
  print(dir(pipe)) # This will list all attributes and methods of the `pipe` object
 
43
  # Load and apply the LoRA model weights
44
  pipe = load_lora_model(pipe, lora_model_path)
45
 
46
+ # Use the @space.gpu decorator to ensure compatibility with GPU or CPU as needed
47
+ @spaces.gpu
48
+ def generate(prompt, seed=None):
49
  generator = torch.manual_seed(seed) if seed is not None else None
50
+ # Generate the image using the prompt
51
+ image = pipe(prompt, height=512, width=512, generator=generator).images[0]
52
  return image
53
 
54
  # Gradio interface
55
  iface = gr.Interface(
56
+ fn=generate,
57
  inputs=[
58
  gr.Textbox(label="Enter your prompt"), # For the prompt
59
  gr.Number(label="Enter a seed (optional)", value=None), # For the seed