import torch from typing import Optional, Any from diffusers import ( DDIMScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, AutoencoderKL, StableDiffusionXLPipeline, ) import logging def load_pipeline(model_name: str, device: torch.device, hf_token: Optional[str] = None, vae: Optional[AutoencoderKL] = None) -> Any: """Load the Stable Diffusion pipeline.""" try: pipeline = ( StableDiffusionXLPipeline.from_single_file if model_name.endswith(".safetensors") else StableDiffusionXLPipeline.from_pretrained ) pipe = pipeline( model_name, vae=vae, torch_dtype=torch.float16, custom_pipeline="lpw_stable_diffusion_xl", use_safetensors=True, add_watermarker=False ) pipe.to(device) return pipe except Exception as e: logging.error(f"Failed to load pipeline: {str(e)}", exc_info=True) raise