Testing / app.py
DonImages's picture
Update app.py
a9b26e4 verified
import gradio as gr
import torch
from diffusers import StableDiffusion3Pipeline
import os
import spaces
# Use the token saved in secrets
hf_token = os.getenv("HF_TOKEN")
# Specify the pre-trained model ID
model_id = "stabilityai/stable-diffusion-3.5-large"
# Global variable for the pipeline (only initialized once)
pipeline = None
# Function for initializing and caching the pipeline
def initialize_pipeline():
global pipeline
if pipeline is None:
try:
# Load the pipeline with mixed precision (FP16)
pipeline = StableDiffusion3Pipeline.from_pretrained(
model_id,
use_auth_token=hf_token,
torch_dtype=torch.float16, # Use FP16 for mixed precision
)
# Enable model offloading and attention slicing for memory efficiency
pipeline.enable_model_cpu_offload()
pipeline.enable_attention_slicing()
print("Pipeline initialized and cached.")
except Exception as e:
# Error handling for model loading issues
print(f"Error loading the model: {e}")
raise RuntimeError("Failed to initialize the model pipeline.")
return pipeline
# Function for image generation, decorated to use GPU
@spaces.GPU(duration=65)
def generate_image(prompt):
pipe = initialize_pipeline() # Initialize the pipeline (only once)
# Generate the image using the pipeline
try:
image = pipe(prompt).images[0]
except Exception as e:
# Catch errors during image generation (e.g., GPU/Memory errors)
print(f"Error during image generation: {e}")
raise RuntimeError("Image generation failed.")
return image
# Set up Gradio interface with a simple input for text and output for image
interface = gr.Interface(fn=generate_image, inputs="text", outputs="image")
# Launch the interface
interface.launch()
# Optimize device and dtype handling for CUDA or CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Additional model validation (this is optional, more for debugging)
pipe = initialize_pipeline() # Ensure the model is initialized and cached
if not pipe or not hasattr(pipe, 'transformer'):
raise ValueError("Failed to load the model or the transformer component is missing.")
# Move the pipeline to the correct device (CUDA or CPU)
pipe = pipe.to(device)