import gradio as gr import torch import os from diffusers import StableDiffusion3Pipeline from safetensors.torch import load_file from spaces import GPU # Remove if not in HF Space # 1. Define model ID and HF_TOKEN (at the VERY beginning) model_id = "stabilityai/stable-diffusion-3.5-large" # Or your preferred model ID hf_token = os.getenv("HF_TOKEN") # For private models (set in HF Space settings) # 2. Initialize pipeline (to None initially) pipeline = None # 3. Load Stable Diffusion and LoRA (before Gradio) try: if hf_token: # check if the token exists, if not, then do not pass the token pipeline = StableDiffusion3Pipeline.from_pretrained( model_id, torch_dtype=torch.float16, cache_dir="./model_cache" # For caching ) else: pipeline = StableDiffusion3Pipeline.from_pretrained( model_id, torch_dtype=torch.float16, cache_dir="./model_cache" # For caching ) lora_filename = "lora_trained_model.safetensors" # EXACT filename of your LoRA lora_path = os.path.join("./", lora_filename) if os.path.exists(lora_path): lora_weights = load_file(lora_path) text_encoder = pipeline.text_encoder text_encoder.load_state_dict(lora_weights, strict=False) print(f"LoRA loaded successfully from: {lora_path}") else: print(f"Error: LoRA file not found at: {lora_path}") exit() # Stop if LoRA is not found print("Stable Diffusion model loaded successfully!") except Exception as e: print(f"Error loading model or LoRA: {e}") exit() # Stop if model loading fails # 4. Image generation function (now decorated) @GPU(duration=65) # Only if in HF Space def generate_image(prompt): global pipeline if pipeline is None: return "Error: Model not loaded!" try: image = pipeline(prompt).images[0] # Try to generate the image print("Image generated successfully!") # Print success message (for debugging) return image # Return the image if successful except Exception as e: error_message = f"Error during image generation: {e}" # Capture error print(error_message) # Print error message to console return error_message # Return the error message to Gradio (so it shows up) # 5. Gradio interface with gr.Blocks() as demo: prompt_input = gr.Textbox(label="Prompt") image_output = gr.Image(label="Generated Image") generate_button = gr.Button("Generate") generate_button.click( fn=generate_image, inputs=prompt_input, outputs=image_output, ) demo.launch()