import os import gradio as gr import torch import spaces import random from diffusers import StableDiffusion3Pipeline from diffusers.loaders import SD3LoraLoaderMixin from safetensors.torch import load_file, save_file # Device selection device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # Load Hugging Face token securely token = os.getenv("HF_TOKEN") # Model ID for SD 3.5 Large model_repo_id = "stabilityai/stable-diffusion-3.5-large" # Convert .pt to .safetensors if needed lora_pt_path = "lora_trained_model.pt" lora_safetensors_path = "lora_trained_model.safetensors" if os.path.exists(lora_pt_path) and not os.path.exists(lora_safetensors_path): print("🔄 Converting LoRA .pt to .safetensors...") lora_weights = torch.load(lora_pt_path, map_location="cpu") save_file(lora_weights, lora_safetensors_path) print(f"✅ LoRA saved as {lora_safetensors_path}") # Load Stable Diffusion pipeline pipeline = StableDiffusion3Pipeline.from_pretrained( model_repo_id, torch_dtype=torch_dtype, use_safetensors=True, # Use safetensors format if supported ).to(device) # Load and fuse LoRA trained weights if os.path.exists(lora_safetensors_path): try: pipeline.load_lora_weights(".", weight_name="lora_trained_model.safetensors") # Corrected loading method pipeline.fuse_lora() # Merges LoRA into the base model print("✅ LoRA weights loaded and fused successfully!") except Exception as e: print(f"❌ Error loading LoRA: {e}") else: print("⚠️ LoRA file not found! Running base model.") # Verify if LoRA is applied for name, param in pipeline.text_encoder.named_parameters(): if "lora" in name.lower(): print(f"✅ LoRA applied to: {name}, requires_grad={param.requires_grad}") # Ensure GPU allocation in Hugging Face Spaces @spaces.GPU(duration=65) def generate_image(prompt: str, seed: int = None): """Generates an image using Stable Diffusion 3.5 with LoRA fine-tuning.""" if seed is None: seed = random.randint(0, 100000) # Create a generator with the seed generator = torch.manual_seed(seed) # Generate the image using the pipeline image = pipeline(prompt, generator=generator).images[0] return image # Gradio Interface with gr.Blocks() as demo: gr.Markdown("# 🖼️ LoRA Fine-Tuned SD 3.5 Image Generator") with gr.Row(): prompt_input = gr.Textbox(label="Enter Prompt", value="A woman in her 20s with expressive black eyes, graceful face, elegant body, standing on the beach at sunset. Photorealistic, highly detailed.") seed_input = gr.Number(label="Seed (optional)", value=None) generate_btn = gr.Button("Generate Image") output_image = gr.Image(label="Generated Image") generate_btn.click(generate_image, inputs=[prompt_input, seed_input], outputs=output_image) # Launch the Gradio app demo.launch()