import os import gradio as gr import torch import spaces import random from diffusers import StableDiffusion3Pipeline from diffusers.loaders import SD3LoraLoaderMixin # Device selection device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # Load the Hugging Face token securely token = os.getenv("HF_TOKEN") # Model ID for SD 3.5 Large model_repo_id = "stabilityai/stable-diffusion-3.5-large" # Load Stable Diffusion pipeline once at the start pipeline = StableDiffusion3Pipeline.from_pretrained( model_repo_id, torch_dtype=torch_dtype, use_safetensors=True, # Use safetensors format if supported ).to(device) # Load the LoRA trained weights once at the start lora_path = "lora_trained_model.pt" # Ensure this file is uploaded in the Space if os.path.exists(lora_path): try: pipeline.load_lora_weights(lora_path) # This automatically applies to the right components print("✅ LoRA weights loaded successfully!") except ValueError as e: print(f"❌ Error loading LoRA: {e}") else: print("⚠️ LoRA file not found! Running base model.") # Ensure GPU allocation in Hugging Face Spaces @spaces.GPU(duration=65) def generate_image(prompt: str, seed: int = None): """Generates an image using Stable Diffusion 3.5 with LoRA fine-tuning.""" if seed is None: seed = random.randint(0, 100000) # Create a generator with the seed generator = torch.manual_seed(seed) # Generate the image using the pipeline image = pipeline(prompt, generator=generator).images[0] return image # Gradio Interface with gr.Blocks() as demo: gr.Markdown("# 🖼️ LoRA Fine-Tuned SD 3.5 Image Generator") with gr.Row(): prompt_input = gr.Textbox(label="Enter Prompt", value="A woman in her 20s with expressive black eyes, graceful face, elegant body, standing on the beach at sunset. Photorealistic, highly detailed.") seed_input = gr.Number(label="Seed (optional)", value=None) generate_btn = gr.Button("Generate Image") output_image = gr.Image(label="Generated Image") generate_btn.click(generate_image, inputs=[prompt_input, seed_input], outputs=output_image) # Launch Gradio App demo.launch()