|
import gradio as gr |
|
import torch |
|
from diffusers import StableDiffusion3Pipeline |
|
import os |
|
import spaces |
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
|
|
|
model_id = "stabilityai/stable-diffusion-3.5-large" |
|
|
|
|
|
pipeline = None |
|
|
|
|
|
def initialize_pipeline(): |
|
global pipeline |
|
if pipeline is None: |
|
try: |
|
|
|
pipeline = StableDiffusion3Pipeline.from_pretrained( |
|
model_id, |
|
use_auth_token=hf_token, |
|
torch_dtype=torch.float16, |
|
) |
|
|
|
pipeline.enable_model_cpu_offload() |
|
pipeline.enable_attention_slicing() |
|
print("Pipeline initialized and cached.") |
|
except Exception as e: |
|
|
|
print(f"Error loading the model: {e}") |
|
raise RuntimeError("Failed to initialize the model pipeline.") |
|
return pipeline |
|
|
|
|
|
@spaces.GPU(duration=65) |
|
def generate_image(prompt): |
|
pipe = initialize_pipeline() |
|
|
|
try: |
|
image = pipe(prompt).images[0] |
|
except Exception as e: |
|
|
|
print(f"Error during image generation: {e}") |
|
raise RuntimeError("Image generation failed.") |
|
return image |
|
|
|
|
|
interface = gr.Interface(fn=generate_image, inputs="text", outputs="image") |
|
|
|
|
|
interface.launch() |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
|
pipe = initialize_pipeline() |
|
if not pipe or not hasattr(pipe, 'transformer'): |
|
raise ValueError("Failed to load the model or the transformer component is missing.") |
|
|
|
|
|
pipe = pipe.to(device) |