from typing import Dict, List, Any import base64 from PIL import Image from io import BytesIO from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers import StableDiffusionPipeline from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline import torch # # set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device.type != 'cuda': raise ValueError("need to run on GPU") # set mixed precision dtype dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16 class EndpointHandler(): def __init__(self, path=""): self.stable_diffusion_id = "Lykon/dreamshaper-8" self.prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype)#.to(device) self.decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype)#.to(device) self.generator = torch.Generator(device=device.type).manual_seed(3) def __call__(self, data: Any) -> List[List[Dict[str, float]]]: # """ # :param data: A dictionary contains `inputs` and optional `image` field. # :return: A dictionary with `image` field contains image in base64. # """ prompt = data.pop("inputs", None) num_inference_steps = data.pop("num_inference_steps", 30) guidance_scale = data.pop("guidance_scale", 7.4) negative_prompt = data.pop("negative_prompt", None) height = data.pop("height", None) width = data.pop("width", None) self.prior_pipeline.to(device) self.decoder_pipeline.to(device) prior_output = self.prior_pipeline( prompt=prompt, height=height, width=width, num_inference_steps=num_inference_steps, # timesteps=DEFAULT_STAGE_C_TIMESTEPS, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_images_per_prompt=1, generator=self.generator, # callback=callback_prior, # callback_steps=callback_steps ) decoder_output = self.decoder_pipeline( image_embeddings=prior_output.image_embeddings, prompt=prompt, num_inference_steps=num_inference_steps, # timesteps=decoder_timesteps, guidance_scale=guidance_scale, negative_prompt=negative_prompt, generator=self.generator, output_type="pil", ).images return decoder_output[0]