import gradio as gr import torch import os from diffusers import StableDiffusion3Pipeline from safetensors.torch import load_file from spaces import GPU # Access HF_TOKEN from environment variables hf_token = os.getenv("HF_TOKEN") # Specify the pre-trained model ID model_id = "stabilityai/stable-diffusion-3.5-large" # Lazy pipeline initialization pipeline = None # Function for image generation @gr.GPU(duration=65) def generate_image(prompt): # Remove lora_file input global pipeline if pipeline is None: try: pipeline = StableDiffusion3Pipeline.from_pretrained( model_id, use_auth_token=hf_token, torch_dtype=torch.float16, cache_dir="./model_cache" ) except Exception as e: print(f"Error loading from cache: {e}") pipeline = StableDiffusion3Pipeline.from_pretrained( model_id, use_auth_token=hf_token, torch_dtype=torch.float16, local_files_only=False ) pipeline.enable_model_cpu_offload() pipeline.enable_attention_slicing() # Load and apply LoRA (file is already in the Space) lora_filename = "lora_trained_model.safetensors" # Name of your LoRA file lora_path = os.path.join("./", lora_filename) # Construct the path print(f"Loading LoRA from: {lora_path}") try: if os.path.exists(lora_path): # check if the file exists lora_weights = load_file(lora_path) text_encoder = pipeline.text_encoder text_encoder.load_state_dict(lora_weights, strict=False) else: return f"Error: LoRA file not found at {lora_path}" except Exception as e: return f"Error loading LoRA: {e}" try: image = pipeline(prompt).images[0] return image except Exception as e: return f"Error generating image: {e}" # Create the Gradio interface (remove lora_upload) with gr.Blocks() as demo: prompt_input = gr.Textbox(label="Prompt") image_output = gr.Image(label="Generated Image") generate_button = gr.Button("Generate") generate_button.click( fn=generate_image, inputs=prompt_input, # Only prompt input now outputs=image_output, ) demo.launch()