yamildiego's picture
back to the version with no control net (before that CN works)
3a07267
raw
history blame
1.96 kB
from typing import Dict, List, Any
import base64
from PIL import Image
from io import BytesIO
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers import StableDiffusionPipeline
import torch
# # set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
raise ValueError("need to run on GPU")
# set mixed precision dtype
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
class EndpointHandler():
def __init__(self, path=""):
self.stable_diffusion_id = "Lykon/dreamshaper-8"
self.pipe = StableDiffusionPipeline.from_pretrained(self.stable_diffusion_id,torch_dtype=dtype,safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to(device.type)
self.generator = torch.Generator(device=device.type).manual_seed(3)
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
# """
# :param data: A dictionary contains `inputs` and optional `image` field.
# :return: A dictionary with `image` field contains image in base64.
# """
prompt = data.pop("inputs", None)
num_inference_steps = data.pop("num_inference_steps", 30)
guidance_scale = data.pop("guidance_scale", 7.4)
negative_prompt = data.pop("negative_prompt", None)
height = data.pop("height", None)
width = data.pop("width", None)
# run inference pipeline
out = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
height=height,
width=width,
generator=self.generator
)
# return first generate PIL image
return out.images[0]