from typing import Dict, List, Any import base64 from PIL import Image from io import BytesIO from diffusers import StableDiffusionControlNetPipeline, ControlNetModel from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker import torch import numpy as np import cv2 import controlnet_hinter # set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device.type != 'cuda': raise ValueError("need to run on GPU") # set mixed precision dtype dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16 CONTROLNET_MAPPING = { "depth": { "model_id": "lllyasviel/sd-controlnet-depth", "hinter": controlnet_hinter.hint_depth }, } SD_ID_MAPPING = { "dreamshaper": "stablediffusionapi/dreamshaper-xl", "juggernaut": "stablediffusionapi/juggernaut-xl-v8", "realistic-vision":"SG161222/Realistic_Vision_V1.4", "rev-animated":"s6yx/ReV_Animated" } class EndpointHandler(): def __init__(self, path=""): self.control_type = "depth" self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],torch_dtype=dtype).to(device) # Define Generator with seed self.generator = torch.Generator(device=device.type).manual_seed(3) def __call__(self, data: Any) -> List[List[Dict[str, float]]]: """ :param data: A dictionary contains `inputs` and optional `image` field. :return: A dictionary with `image` field contains image in base64. """ # hyperparamters sd_model = data.pop("sd_model", None) prompt = data.pop("inputs", None) negative_prompt = data.pop("negative_prompt", None) image_depth_map = data.pop("image_depth_map", None) num_inference_steps = data.pop("num_inference_steps", 25) guidance_scale = data.pop("guidance_scale", 7) height = data.pop("height", None) width = data.pop("width", None) controlnet_conditioning_scale = data.pop("controlnet_conditioning_scale", 1.0) self.stable_diffusion_id = SD_ID_MAPPING.get(sd_model, "Lykon/dreamshaper-8") print(f"Using stable diffusion model: {self.stable_diffusion_id}") self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id, controlnet=self.controlnet, torch_dtype=dtype, safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to("cuda") # Check if neither prompt nor image is provided if prompt is None: return {"error": "Please provide a prompt"} if(image_depth_map is None): with open("./default.jpg", "rb") as image_file: image = base64.b64encode(image_file.read()).decode('utf-8') # process image image = self.decode_base64_image(image) # run inference pipeline out = self.pipe( prompt=prompt, negative_prompt=negative_prompt, image=image, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_images_per_prompt=1, height=height, width=width, controlnet_conditioning_scale=controlnet_conditioning_scale, generator=self.generator ) # return first generate PIL image return out.images[0] # helper to decode input image def decode_base64_image(self, image_string): base64_image = base64.b64decode(image_string) buffer = BytesIO(base64_image) image = Image.open(buffer) return image