Testing2 / app.py
DonImages's picture
Update app.py
03b91b0 verified
raw
history blame
2.12 kB
import os
import gradio as gr
import torch
from diffusers import StableDiffusion3Pipeline
import spaces
import random
from peft import PeftModel, get_peft_model
# Ensure GPU allocation in Hugging Face Spaces
@spaces.GPU(duration=65)
def generate_image(prompt: str, seed: int = None):
"""Generates an image using Stable Diffusion 3.5 with LoRA fine-tuning."""
if seed is None:
seed = random.randint(0, 100000)
generator = torch.manual_seed(seed)
image = pipeline(prompt, generator=generator).images[0]
return image
# Device selection
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load the Hugging Face token securely
token = os.getenv("HF_TOKEN")
# Model ID for SD 3.5 Large
model_repo_id = "stabilityai/stable-diffusion-3.5-large"
# Load Stable Diffusion pipeline
pipeline = StableDiffusion3Pipeline.from_pretrained(
model_repo_id,
torch_dtype=torch_dtype,
use_safetensors=True, # Use safetensors format if supported
).to(device)
# Load the LoRA trained weights
lora_path = "lora_trained_model.pt" # Ensure this file is uploaded in the Space
if os.path.exists(lora_path):
lora_state_dict = torch.load(lora_path, map_location=device, weights_only=True)
pipeline = PeftModel.from_pretrained(pipeline, lora_path)
print("✅ LoRA weights loaded successfully!")
else:
print("⚠️ LoRA file not found! Running base model.")
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# 🖼️ LoRA Fine-Tuned SD 3.5 Image Generator")
with gr.Row():
prompt_input = gr.Textbox(label="Enter Prompt", value="A woman in her 20s with expressive black eyes, graceful face, elegant body, standing on the beach at sunset. Photorealistic, highly detailed.")
seed_input = gr.Number(label="Seed (optional)", value=None)
generate_btn = gr.Button("Generate Image")
output_image = gr.Image(label="Generated Image")
generate_btn.click(generate_image, inputs=[prompt_input, seed_input], outputs=output_image)
# Launch Gradio App
demo.launch()