File size: 5,703 Bytes
05773fe ac20447 f36a3e1 ac20447 05773fe ac20447 f36a3e1 ac20447 f36a3e1 ac20447 f36a3e1 ac20447 05773fe ac20447 b754394 05773fe f36a3e1 ac20447 f36a3e1 ac20447 05773fe f36a3e1 05773fe f36a3e1 05773fe f36a3e1 05773fe f36a3e1 05773fe f36a3e1 05773fe f36a3e1 05773fe f36a3e1 05773fe f36a3e1 ac20447 f36a3e1 ac20447 f36a3e1 ac20447 f36a3e1 ac20447 05773fe ac20447 05773fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
from typing import Dict, List, Any
import base64
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
import torch
import numpy as np
import cv2
import controlnet_hinter
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
raise ValueError("Need to run on GPU")
# set mixed precision dtype
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
# controlnet mapping for controlnet id and control hinter
CONTROLNET_MAPPING = {
"canny_edge": {
"model_id": "lllyasviel/sd-controlnet-canny",
"hinter": controlnet_hinter.hint_canny
},
"pose": {
"model_id": "lllyasviel/sd-controlnet-openpose",
"hinter": controlnet_hinter.hint_openpose
},
"depth": {
"model_id": "lllyasviel/sd-controlnet-depth",
"hinter": controlnet_hinter.hint_depth
},
"scribble": {
"model_id": "lllyasviel/sd-controlnet-scribble",
"hinter": controlnet_hinter.hint_scribble,
},
"segmentation": {
"model_id": "lllyasviel/sd-controlnet-seg",
"hinter": controlnet_hinter.hint_segmentation,
},
"normal": {
"model_id": "lllyasviel/sd-controlnet-normal",
"hinter": controlnet_hinter.hint_normal,
},
"hed": {
"model_id": "lllyasviel/sd-controlnet-hed",
"hinter": controlnet_hinter.hint_hed,
},
"hough": {
"model_id": "lllyasviel/sd-controlnet-mlsd",
"hinter": controlnet_hinter.hint_hough,
}
}
class EndpointHandler():
"""
A class to handle endpoint logic.
"""
def __init__(self, path=""):
# define default controlnet id and load controlnet
self.control_type = "depth"
self.controlnet = ControlNetModel.from_pretrained(controlnet_mapping[self.control_type]["model_id"], torch_dtype=dtype).to(device)
# Load StableDiffusionControlNetPipeline
self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
controlnet=self.controlnet,
torch_dtype=dtype,
safety_checker=None).to(device)
# Define Generator with seed
self.generator = torch.Generator(device="cpu").manual_seed(3)
def __call__(self, data: Any) -> None:
"""
Process input data and perform inference.
:param data: A dictionary containing `inputs` and optional `image_path` field.
:return: None
"""
prompt = data.pop("inputs", None)
image_path = data.pop("image_path", None)
controlnet_type = data.pop("controlnet_type", None)
# Check if neither prompt nor image path is provided
if prompt is None and image_path is None:
raise ValueError("Please provide a prompt and either an image path or a base64-encoded image.")
# Check if a new controlnet is provided
if controlnet_type is not None and controlnet_type != self.control_type:
print(f"Changing controlnet from {self.control_type} to {controlnet_type} using {controlnet_mapping[controlnet_type]['model_id']} model")
self.control_type = controlnet_type
self.controlnet = ControlNetModel.from_pretrained(controlnet_mapping[self.control_type]["model_id"],
torch_dtype=dtype).to(device)
self.pipe.controlnet = self.controlnet
# hyperparameters
num_inference_steps = data.pop("num_inference_steps", 30)
guidance_scale = data.pop("guidance_scale", 7.5)
negative_prompt = data.pop("negative_prompt", None)
height = data.pop("height", None)
width = data.pop("width", None)
controlnet_conditioning_scale = data.pop("controlnet_conditioning_scale", 1.0)
# process image
if image_path is not None:
# Load the image from the specified path
image = Image.open(image_path)
else:
# Decode base64-encoded image
image = self.decode_base64_image(data.pop("image", ""))
control_image = controlnet_mapping[self.control_type]["hinter"](image)
# run inference pipeline
out = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=control_image,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
height=height,
width=width,
controlnet_conditioning_scale=controlnet_conditioning_scale,
generator=self.generator
)
# save the generated image as a JPEG file
output_image = out.images[0]
output_image.save("output.jpg", format="JPEG")
def decode_base64_image(self, image_string):
base64_image = base64.b64decode(image_string)
buffer = BytesIO(base64_image)
image = Image.open(buffer)
return image
# Example usage
payload = {
"inputs": "Your prompt here",
"image_path": "path/to/your/image.jpg",
"controlnet_type": "depth",
"num_inference_steps": 30,
"guidance_scale": 7.5,
"negative_prompt": None,
"height": None,
"width": None,
"controlnet_conditioning_scale": 1.0,
}
handler = EndpointHandler()
handler(payload)
|