import gradio as gr import torch import os from diffusers import StableDiffusion3Pipeline from safetensors.torch import load_file from spaces import GPU # Import GPU if in HF Space, otherwise remove this line # Access HF_TOKEN from environment variables hf_token = os.getenv("HF_TOKEN") # Specify the pre-trained model ID model_id = "stabilityai/stable-diffusion-3.5-large" # Initialize pipeline *outside* the function (but set to None initially) pipeline = None # Function to load the Stable Diffusion pipeline (called only ONCE) def load_pipeline(): global pipeline # Use the global keyword to modify the global variable try: pipeline = StableDiffusion3Pipeline.from_pretrained( model_id, use_auth_token=hf_token, torch_dtype=torch.float16, cache_dir="./model_cache" ) except Exception as e: print(f"Error loading model: {e}") return f"Error loading model: {e}" # Return error message pipeline.enable_model_cpu_offload() pipeline.enable_attention_slicing() return "Model loaded successfully" # Return success message # Function for image generation (now decorated) @GPU(duration=65) # Use GPU decorator (ONLY if in HF Space) def generate_image(prompt): global pipeline if pipeline is None: # Check if pipeline is loaded return "Model not loaded. Please wait." # Return message if not loaded # Load and apply LoRA (file is already in the Space) lora_filename = "lora_trained_model.safetensors" # Name of your LoRA file lora_path = os.path.join("./", lora_filename) # Construct the path print(f"Loading LoRA from: {lora_path}") try: if os.path.exists(lora_path): # check if the file exists lora_weights = load_file(lora_path) text_encoder = pipeline.text_encoder text_encoder.load_state_dict(lora_weights, strict=False) except Exception as e: return f"Error loading LoRA: {e}" try: image = pipeline(prompt).images[0] return image except Exception as e: return f"Error generating image: {e}" # Create the Gradio interface with gr.Blocks() as demo: prompt_input = gr.Textbox(label="Prompt") image_output = gr.Image(label="Generated Image") generate_button = gr.Button("Generate") load_model_button = gr.Button("Load Model") # Button to load model load_model_button.click(fn=load_pipeline, outputs=load_model_button) # Call load_pipeline generate_button.click( fn=generate_image, inputs=prompt_input, outputs=image_output, ) demo.launch()