import gradio as gr import torch import os from diffusers import StableDiffusion3Pipeline from safetensors.torch import load_file from spaces import GPU # Remove if not in HF Space # 1. Define model ID and HF_TOKEN (at the VERY beginning) model_id = "stabilityai/stable-diffusion-3.5-large" # Or your preferred model ID hf_token = os.getenv("HF_TOKEN") # For private models (set in HF Space settings) # 2. Initialize pipeline (to None initially) pipeline = None # 3. Load Stable Diffusion and LoRA (before Gradio) try: if hf_token: # check if the token exists, if not, then do not pass the token pipeline = StableDiffusion3Pipeline.from_pretrained( model_id, torch_dtype=torch.float16, cache_dir="./model_cache" # For caching ) else: pipeline = StableDiffusion3Pipeline.from_pretrained( model_id, torch_dtype=torch.float16, cache_dir="./model_cache" # For caching ) lora_filename = "lora_trained_model.safetensors" # EXACT filename of your LoRA lora_path = os.path.join("./", lora_filename) if os.path.exists(lora_path): lora_weights = load_file(lora_path) text_encoder = pipeline.text_encoder text_encoder.load_state_dict(lora_weights, strict=False) print(f"LoRA loaded successfully from: {lora_path}") else: print(f"Error: LoRA file not found at: {lora_path}") exit() # Stop if LoRA is not found print("Stable Diffusion model loaded successfully!") except Exception as e: print(f"Error loading model or LoRA: {e}") exit() # Stop if model loading fails # 4. Image generation function (now decorated) @GPU(duration=65) # ONLY if in a HF Space, remove if not def generate_image(prompt): global pipeline if pipeline is None: # Should not happen, but good to check return "Error: Model not loaded!" try: image = pipeline(prompt).images[0] # Access the first image from the list return image except Exception as e: return f"Error generating image: {e}" # 5. Gradio interface with gr.Blocks() as demo: prompt_input = gr.Textbox(label="Prompt") image_output = gr.Image(label="Generated Image") generate_button = gr.Button("Generate") generate_button.click( fn=generate_image, inputs=prompt_input, outputs=image_output, ) demo.launch()