File size: 6,141 Bytes
1f7551f 061d7a5 1f7551f 061d7a5 1f7551f 061d7a5 1f7551f 061d7a5 a9ef0f5 061d7a5 4cdf1fd 061d7a5 4cdf1fd 061d7a5 a9ef0f5 061d7a5 3ece45d 4cdf1fd 3ece45d 4cdf1fd 3ece45d 4cdf1fd 3ece45d 061d7a5 471adc0 061d7a5 3af03cf 061d7a5 4cdf1fd 80111fe 061d7a5 3af03cf 061d7a5 66c9a4e 471adc0 4cdf1fd 061d7a5 7e47e3f 4cdf1fd 061d7a5 64d3ad8 1f7551f 061d7a5 1f7551f 061d7a5 3af03cf 1f7551f 061d7a5 06e49f5 1f7551f 061d7a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
from typing import Dict, List, Any
import base64
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
import torch
import numpy as np
import cv2
import controlnet_hinter
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
raise ValueError("need to run on GPU")
# set mixed precision dtype
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
CONTROLNET_MAPPING = {
"depth": {
"model_id": "lllyasviel/sd-controlnet-depth",
"hinter": controlnet_hinter.hint_depth
},
}
SD_ID_MAPPING = {
"default": "Lykon/dreamshaper-8",
"dreamshaper": "stablediffusionapi/dreamshaper-xl",
"juggernaut": "stablediffusionapi/juggernaut-xl-v8",
"realistic":"SG161222/Realistic_Vision_V1.4",
"rev":"s6yx/ReV_Animated"
}
class EndpointHandler():
def __init__(self, path=""):
self.control_type = "depth"
self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],torch_dtype=dtype).to(device)
self.stable_diffusion_id_0 = "Lykon/dreamshaper-8"
self.dreamshaper = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id_0,
controlnet=self.controlnet,
torch_dtype=dtype,
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to("cuda")
self.stable_diffusion_id_1 = "stablediffusionapi/dreamshaper-xl"
self.juggernaut = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id_1,
controlnet=self.controlnet,
torch_dtype=dtype,
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to("cuda")
self.stable_diffusion_id_2 = "stablediffusionapi/juggernaut-xl-v8"
self.realistic = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id_2,
controlnet=self.controlnet,
torch_dtype=dtype,
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to("cuda")
self.stable_diffusion_id_3 = "SG161222/Realistic_Vision_V1.4"
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id_3,
controlnet=self.controlnet,
torch_dtype=dtype,
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to("cuda")
self.stable_diffusion_id_4 = "s6yx/ReV_Animated"
self.rev = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id_4,
controlnet=self.controlnet,
torch_dtype=dtype,
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to("cuda")
# Define Generator with seed
self.generator = torch.Generator(device=device.type).manual_seed(3)
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
"""
:param data: A dictionary contains `prompt` and optional `image_depth_map` field.
:return: A dictionary with `image` field contains image in base64.
"""
# hyperparamters
sd_model = data.pop("sd_model", "default")
prompt = data.pop("inputs", None)
negative_prompt = data.pop("negative_prompt", None)
image_depth_map = data.pop("image_depth_map", None)
steps = data.pop("steps", 25)
scale = data.pop("scale", 7)
height = data.pop("height", None)
width = data.pop("width", None)
controlnet_conditioning_scale = data.pop("controlnet_conditioning_scale", 1.0)
if sd_model is None or not hasattr(self, sd_model):
return {"error": "Modelo SD no especificado o no válido"}
if prompt is None:
return {"error": "Please provide a prompt"}
if(image_depth_map is None):
return {"error": "Please provide a image_depth_map"}
pipe = getattr(self, sd_model)
# process image
image = self.decode_base64_image(image_depth_map)
# run inference pipeline
out = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=image,
num_inference_steps=steps,
guidance_scale=scale,
num_images_per_prompt=1,
height=height,
width=width,
controlnet_conditioning_scale=controlnet_conditioning_scale,
generator=self.generator
)
# return first generate PIL image
return out.images[0]
# helper to decode input image
def decode_base64_image(self, image_string):
base64_image = base64.b64decode(image_string)
buffer = BytesIO(base64_image)
image = Image.open(buffer)
return image |