import gradio as gr import torch import os from diffusers import StableDiffusion3Pipeline from safetensors.torch import load_file from spaces import GPU # Remove if not in HF Space # ... (HF_TOKEN, model_id - same as before) pipeline = None # Global pipeline variable # Load Stable Diffusion and LoRA *immediately* (before Gradio) try: pipeline = StableDiffusion3Pipeline.from_pretrained( model_id, use_auth_token=hf_token, torch_dtype=torch.float16, cache_dir="./model_cache" ) pipeline.enable_model_cpu_offload() pipeline.enable_attention_slicing() lora_filename = "lora_trained_model.safetensors" lora_path = os.path.join("./", lora_filename) print(f"Loading LoRA from: {lora_path}") if os.path.exists(lora_path): lora_weights = load_file(lora_path) text_encoder = pipeline.text_encoder text_encoder.load_state_dict(lora_weights, strict=False) print("LoRA loaded successfully!") # Confirmation message else: print(f"Error: LoRA file not found at {lora_path}") exit() # Exit if LoRA is not found print("Stable Diffusion model loaded successfully!") except Exception as e: print(f"Error loading Stable Diffusion or LoRA: {e}") exit() # Exit if there's an error # Function for image generation (now much simpler) @GPU(duration=65) # Use GPU decorator (ONLY if in HF Space) def generate_image(prompt): global pipeline if pipeline is None: # This should never happen now return "Error: Stable Diffusion model not loaded!" try: image = pipeline(prompt).images[0] return image except Exception as e: return f"Error generating image: {e}" # Create the Gradio interface (no "Load Model" button needed) with gr.Blocks() as demo: prompt_input = gr.Textbox(label="Prompt") image_output = gr.Image(label="Generated Image") generate_button = gr.Button("Generate") generate_button.click( fn=generate_image, inputs=prompt_input, outputs=image_output, ) demo.launch()