Spaces:
Running
on
Zero
Running
on
Zero
import gc | |
import os | |
import random | |
import numpy as np | |
import json | |
import torch | |
import uuid | |
from PIL import Image, PngImagePlugin | |
from datetime import datetime | |
from dataclasses import dataclass | |
from typing import Callable, Dict, Optional, Tuple, Any, List | |
from diffusers import ( | |
DDIMScheduler, | |
DPMSolverMultistepScheduler, | |
DPMSolverSinglestepScheduler, | |
EulerAncestralDiscreteScheduler, | |
EulerDiscreteScheduler, | |
AutoencoderKL, | |
StableDiffusionXLPipeline, | |
) | |
import logging | |
def load_pipeline(model_name: str, device: torch.device, hf_token: Optional[str] = None, vae: Optional[AutoencoderKL] = None) -> Any: | |
"""Load the Stable Diffusion pipeline.""" | |
try: | |
pipeline = ( | |
StableDiffusionXLPipeline.from_single_file | |
if model_name.endswith(".safetensors") | |
else StableDiffusionXLPipeline.from_pretrained | |
) | |
pipe = pipeline( | |
model_name, | |
vae=vae, | |
torch_dtype=torch.float16, | |
custom_pipeline="lpw_stable_diffusion_xl", | |
use_safetensors=True, | |
add_watermarker=False | |
) | |
pipe.to(device) | |
return pipe | |
except Exception as e: | |
logging.error(f"Failed to load pipeline: {str(e)}", exc_info=True) | |
raise | |