import gc import os import random import numpy as np import json import torch import uuid from PIL import Image, PngImagePlugin from datetime import datetime from dataclasses import dataclass from typing import Callable, Dict, Optional, Tuple, Any, List from diffusers import ( DDIMScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, AutoencoderKL, StableDiffusionXLPipeline, ) import logging def load_pipeline(model_name: str, device: torch.device, hf_token: Optional[str] = None, vae: Optional[AutoencoderKL] = None) -> Any: """Load the Stable Diffusion pipeline.""" try: pipeline = ( StableDiffusionXLPipeline.from_single_file if model_name.endswith(".safetensors") else StableDiffusionXLPipeline.from_pretrained ) pipe = pipeline( model_name, vae=vae, torch_dtype=torch.float16, custom_pipeline="lpw_stable_diffusion_xl", use_safetensors=True, add_watermarker=False ) pipe.to(device) return pipe except Exception as e: logging.error(f"Failed to load pipeline: {str(e)}", exc_info=True) raise