DonImages commited on
Commit
7195e76
·
verified ·
1 Parent(s): 24d5f49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -55
app.py CHANGED
@@ -1,65 +1,59 @@
1
- import torch
2
- import spaces
3
- from diffusers import StableDiffusion3Pipeline
4
- from huggingface_hub import login
5
  import os
6
  import gradio as gr
 
 
 
 
7
 
8
- # Retrieve the token from the environment variable
9
- token = os.getenv("HF_TOKEN") # Hugging Face token from the secret
10
- if token:
11
- login(token=token) # Log in with the retrieved token
12
- else:
13
- raise ValueError("Hugging Face token not found. Please set it as a repository secret in the Space settings.")
 
14
 
15
- # Load the Stable Diffusion 3.5 model with lower precision (float16) if GPU is available
16
- model_id = "stabilityai/stable-diffusion-3.5-large"
17
- pipe = StableDiffusion3Pipeline.from_pretrained(model_id)
18
 
19
- # Check if GPU is available, then move the model to the appropriate device
20
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
21
- pipe.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # Define the path to the LoRA model
24
- lora_model_path = "./lora_model.pth" # Assuming the file is saved locally
 
25
 
26
- # Custom method to load and apply LoRA weights to the Stable Diffusion pipeline
27
- def load_lora_model(pipe, lora_model_path):
28
- # When loading the LoRA weights
29
- lora_weights = torch.load(lora_model_path, map_location=device, weights_only=True)
30
 
31
- # Check if the transformer folder has the necessary attributes
32
- print(dir(pipe.transformer)) # List available attributes of the transformer (formerly 'unet')
33
-
34
- # Apply weights to the transformer submodule
35
- try:
36
- for name, param in pipe.transformer.named_parameters(): # Accessing transformer parameters
37
- if name in lora_weights:
38
- param.data += lora_weights[name]
39
- except AttributeError:
40
- print("The model doesn't have 'transformer' attributes. Please check the model structure.")
41
- # Add alternative handling or exit
42
 
43
- return pipe
44
-
45
- # Load and apply the LoRA model weights
46
- pipe = load_lora_model(pipe, lora_model_path)
47
-
48
- # Use the @spaces.gpu decorator to ensure compatibility with GPU or CPU as needed
49
- @spaces.GPU(duration=65) # This ensures GPU is allocated for 65 seconds
50
- def generate(prompt, seed=None):
51
- generator = torch.manual_seed(seed) if seed is not None else None
52
- # Generate the image using the prompt
53
- image = pipe(prompt, height=512, width=512, generator=generator).images[0]
54
- return image
55
 
56
- # Gradio interface
57
- iface = gr.Interface(
58
- fn=generate,
59
- inputs=[
60
- gr.Textbox(label="Enter your prompt"), # For the prompt
61
- gr.Number(label="Enter a seed (optional)", value=None), # For the seed
62
- ],
63
- outputs="image"
64
- )
65
- iface.launch()
 
 
 
 
 
1
  import os
2
  import gradio as gr
3
+ import torch
4
+ from diffusers import StableDiffusion3Pipeline
5
+ import spaces
6
+ import random
7
 
8
+ # Ensure GPU allocation in Hugging Face Spaces
9
+ @spaces.GPU(duration=65)
10
+ def generate_image(prompt: str, seed: int = None):
11
+ """Generates an image using Stable Diffusion 3.5 with LoRA fine-tuning."""
12
+ if seed is None:
13
+ seed = random.randint(0, 100000)
14
+ generator = torch.manual_seed(seed)
15
 
16
+ image = pipeline(prompt, generator=generator).images[0]
17
+ return image
 
18
 
19
+ # Device selection
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
22
+
23
+ # Load the Hugging Face token securely
24
+ token = os.getenv("HF_TOKEN")
25
+
26
+ # Model ID for SD 3.5 Large
27
+ model_repo_id = "stabilityai/stable-diffusion-3.5-large"
28
+
29
+ # Load Stable Diffusion pipeline
30
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
31
+ model_repo_id,
32
+ torch_dtype=torch_dtype,
33
+ use_safetensors=True, # Use safetensors format if supported
34
+ ).to(device)
35
+
36
+ # Load the LoRA trained weights
37
+ lora_path = "lora_trained_model.pt" # Ensure this file is uploaded in the Space
38
+ if os.path.exists(lora_path):
39
+ lora_state_dict = torch.load(lora_path, map_location=device)
40
+ pipeline.load_lora_weights(lora_state_dict)
41
+ print("✅ LoRA weights loaded successfully!")
42
+ else:
43
+ print("⚠️ LoRA file not found! Running base model.")
44
 
45
+ # Gradio Interface
46
+ with gr.Blocks() as demo:
47
+ gr.Markdown("# 🖼️ LoRA Fine-Tuned SD 3.5 Image Generator")
48
 
49
+ with gr.Row():
50
+ prompt_input = gr.Textbox(label="Enter Prompt", value="A woman in her 20s with expressive black eyes, graceful face, elegant body, standing on the beach at sunset. Photorealistic, highly detailed.")
51
+ seed_input = gr.Number(label="Seed (optional)", value=None)
 
52
 
53
+ generate_btn = gr.Button("Generate Image")
54
+ output_image = gr.Image(label="Generated Image")
 
 
 
 
 
 
 
 
 
55
 
56
+ generate_btn.click(generate_image, inputs=[prompt_input, seed_input], outputs=output_image)
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ # Launch Gradio App
59
+ demo.launch()