import os import gradio as gr import torch import spaces import random from diffusers import StableDiffusion3Pipeline from diffusers.loaders import SD3LoraLoaderMixin from safetensors.torch import load_file, save_file # Ensure GPU allocation for image generation (moved here) def main(): # Device selection device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if device == "cuda" else torch.float32 # Load Hugging Face token securely token = os.getenv("HF_TOKEN") # Model ID for SD 3.5 Large model_repo_id = "stabilityai/stable-diffusion-3.5-large" # Convert .pt to .safetensors if needed lora_pt_path = "lora_trained_model.pt" lora_safetensors_path = "lora_trained_model.safetensors" if os.path.exists(lora_pt_path) and not os.path.exists(lora_safetensors_path): print("🔄 Converting LoRA .pt to .safetensors...") lora_weights = torch.load(lora_pt_path, map_location="cpu") save_file(lora_weights, lora_safetensors_path) print(f"✅ LoRA saved as {lora_safetensors_path}") # Load Stable Diffusion 3.5 pipeline with optimized settings pipeline = StableDiffusion3Pipeline.from_pretrained( model_repo_id, torch_dtype=torch_dtype, use_safetensors=True, ).to(device) # Load and fuse LoRA weights (optimized method) if os.path.exists(lora_safetensors_path): try: SD3LoraLoaderMixin.load_lora_weights(pipeline, lora_safetensors_path) pipeline.fuse_lora() print("✅ LoRA weights loaded and fused successfully!") except Exception as e: print(f"❌ Error loading LoRA: {e}") else: print("⚠️ LoRA file not found! Running base model.") # Ensure LoRA is applied correctly applied_lora = any("lora" in name.lower() for name, _ in pipeline.text_encoder.named_parameters()) print(f"✅ LoRA Applied: {applied_lora}") # Image generation function with GPU decorator @spaces.GPU(duration=65) def generate_image(prompt: str, seed: int = None): """Generates an image using Stable Diffusion 3.5 with LoRA fine-tuning.""" seed = seed or random.randint(0, 100000) generator = torch.Generator(device).manual_seed(seed) return pipeline(prompt, generator=generator).images[0] # Gradio Interface with gr.Blocks() as demo: gr.Markdown("# 🖼️ LoRA Fine-Tuned SD 3.5 Image Generator") with gr.Row(): prompt_input = gr.Textbox( label="Enter Prompt", value="A woman in her 20s with expressive black eyes, graceful face, elegant body, standing on the beach at sunset. Photorealistic, highly detailed." ) seed_input = gr.Number(label="Seed (optional)", value=None) generate_btn = gr.Button("Generate Image") output_image = gr.Image(label="Generated Image") generate_btn.click(generate_image, inputs=[prompt_input, seed_input], outputs=output_image) # Launch Gradio app demo.launch() if __name__ == "__main__": main()