File size: 1,285 Bytes
6aefd85
 
 
 
 
781a759
6aefd85
 
 
 
 
781a759
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import gc
import os
import random
import numpy as np
import json
import torch
import uuid
from PIL import Image, PngImagePlugin
from datetime import datetime
from dataclasses import dataclass
from typing import Callable, Dict, Optional, Tuple, Any, List
from diffusers import (
    DDIMScheduler,
    DPMSolverMultistepScheduler,
    DPMSolverSinglestepScheduler,
    EulerAncestralDiscreteScheduler,
    EulerDiscreteScheduler,
    AutoencoderKL,
    StableDiffusionXLPipeline,
)
import logging

def load_pipeline(model_name: str, device: torch.device, hf_token: Optional[str] = None, vae: Optional[AutoencoderKL] = None) -> Any:
    """Load the Stable Diffusion pipeline."""
    try:
        pipeline = (
            StableDiffusionXLPipeline.from_single_file
            if model_name.endswith(".safetensors")
            else StableDiffusionXLPipeline.from_pretrained
        )

        pipe = pipeline(
            model_name,
            vae=vae,
            torch_dtype=torch.float16,
            custom_pipeline="lpw_stable_diffusion_xl",
            use_safetensors=True,
            add_watermarker=False
        )
        pipe.to(device)
        return pipe
    except Exception as e:
        logging.error(f"Failed to load pipeline: {str(e)}", exc_info=True)
        raise