import gradio as gr import torch from diffusers import StableDiffusion3Pipeline import os import spaces # Use the token saved in secrets hf_token = os.getenv("HF_TOKEN") # Specify the pre-trained model ID model_id = "stabilityai/stable-diffusion-3.5-large" # Global variable for the pipeline (only initialized once) pipeline = None # Function for initializing and caching the pipeline def initialize_pipeline(): global pipeline if pipeline is None: try: # Load the pipeline with mixed precision (FP16) pipeline = StableDiffusion3Pipeline.from_pretrained( model_id, use_auth_token=hf_token, torch_dtype=torch.float16, # Use FP16 for mixed precision ) # Enable model offloading and attention slicing for memory efficiency pipeline.enable_model_cpu_offload() pipeline.enable_attention_slicing() print("Pipeline initialized and cached.") except Exception as e: # Error handling for model loading issues print(f"Error loading the model: {e}") raise RuntimeError("Failed to initialize the model pipeline.") return pipeline # Function for image generation, decorated to use GPU @spaces.GPU(duration=65) def generate_image(prompt): pipe = initialize_pipeline() # Initialize the pipeline (only once) # Generate the image using the pipeline try: image = pipe(prompt).images[0] except Exception as e: # Catch errors during image generation (e.g., GPU/Memory errors) print(f"Error during image generation: {e}") raise RuntimeError("Image generation failed.") return image # Set up Gradio interface with a simple input for text and output for image interface = gr.Interface(fn=generate_image, inputs="text", outputs="image") # Launch the interface interface.launch() # Optimize device and dtype handling for CUDA or CPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # Additional model validation (this is optional, more for debugging) pipe = initialize_pipeline() # Ensure the model is initialized and cached if not pipe or not hasattr(pipe, 'transformer'): raise ValueError("Failed to load the model or the transformer component is missing.") # Move the pipeline to the correct device (CUDA or CPU) pipe = pipe.to(device)