File size: 5,703 Bytes
05773fe
ac20447
 
 
 
 
f36a3e1
 
ac20447
 
 
 
 
05773fe
 
ac20447
 
 
f36a3e1
ac20447
f36a3e1
 
 
 
 
 
 
 
ac20447
 
 
f36a3e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac20447
 
 
 
05773fe
 
 
ac20447
 
b754394
05773fe
f36a3e1
 
ac20447
f36a3e1
 
ac20447
 
 
 
 
05773fe
f36a3e1
05773fe
 
 
 
f36a3e1
 
05773fe
f36a3e1
05773fe
 
 
 
 
f36a3e1
 
05773fe
f36a3e1
05773fe
f36a3e1
 
 
05773fe
f36a3e1
 
 
 
 
 
 
 
05773fe
 
 
 
 
 
 
 
f36a3e1
 
ac20447
f36a3e1
ac20447
f36a3e1
 
ac20447
 
 
 
f36a3e1
 
ac20447
 
05773fe
 
 
 
ac20447
 
 
 
 
05773fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from typing import Dict, List, Any
import base64
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
import torch
import numpy as np
import cv2
import controlnet_hinter

# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
    raise ValueError("Need to run on GPU")

# set mixed precision dtype
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16

# controlnet mapping for controlnet id and control hinter
CONTROLNET_MAPPING = {
    "canny_edge": {
        "model_id": "lllyasviel/sd-controlnet-canny",
        "hinter": controlnet_hinter.hint_canny
    },
    "pose": {
        "model_id": "lllyasviel/sd-controlnet-openpose",
        "hinter": controlnet_hinter.hint_openpose
    },
    "depth": {
        "model_id": "lllyasviel/sd-controlnet-depth",
        "hinter": controlnet_hinter.hint_depth
    },
    "scribble": {
        "model_id": "lllyasviel/sd-controlnet-scribble",
        "hinter": controlnet_hinter.hint_scribble,
    },
    "segmentation": {
        "model_id": "lllyasviel/sd-controlnet-seg",
        "hinter": controlnet_hinter.hint_segmentation,
    },
    "normal": {
        "model_id": "lllyasviel/sd-controlnet-normal",
        "hinter": controlnet_hinter.hint_normal,
    },
    "hed": {
        "model_id": "lllyasviel/sd-controlnet-hed",
        "hinter": controlnet_hinter.hint_hed,
    },
    "hough": {
        "model_id": "lllyasviel/sd-controlnet-mlsd",
        "hinter": controlnet_hinter.hint_hough,
    }
}

class EndpointHandler():
    """
    A class to handle endpoint logic.
    """
    def __init__(self, path=""):
        # define default controlnet id and load controlnet
        self.control_type = "depth"
        self.controlnet = ControlNetModel.from_pretrained(controlnet_mapping[self.control_type]["model_id"], torch_dtype=dtype).to(device)
        
        # Load StableDiffusionControlNetPipeline 
        self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
        self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id, 
                                                                      controlnet=self.controlnet, 
                                                                      torch_dtype=dtype,
                                                                      safety_checker=None).to(device)
        # Define Generator with seed
        self.generator = torch.Generator(device="cpu").manual_seed(3)

    def __call__(self, data: Any) -> None:
        """
        Process input data and perform inference.
        
        :param data: A dictionary containing `inputs` and optional `image_path` field.
        :return: None
        """
        prompt = data.pop("inputs", None)
        image_path = data.pop("image_path", None)
        controlnet_type = data.pop("controlnet_type", None)

        # Check if neither prompt nor image path is provided
        if prompt is None and image_path is None:
            raise ValueError("Please provide a prompt and either an image path or a base64-encoded image.")

        # Check if a new controlnet is provided
        if controlnet_type is not None and controlnet_type != self.control_type:
            print(f"Changing controlnet from {self.control_type} to {controlnet_type} using {controlnet_mapping[controlnet_type]['model_id']} model")
            self.control_type = controlnet_type
            self.controlnet = ControlNetModel.from_pretrained(controlnet_mapping[self.control_type]["model_id"],
                                                              torch_dtype=dtype).to(device)
            self.pipe.controlnet = self.controlnet
        
        # hyperparameters
        num_inference_steps = data.pop("num_inference_steps", 30)
        guidance_scale = data.pop("guidance_scale", 7.5)
        negative_prompt = data.pop("negative_prompt", None)
        height = data.pop("height", None)
        width = data.pop("width", None)
        controlnet_conditioning_scale = data.pop("controlnet_conditioning_scale", 1.0)
        
        # process image
        if image_path is not None:
            # Load the image from the specified path
            image = Image.open(image_path)
        else:
            # Decode base64-encoded image
            image = self.decode_base64_image(data.pop("image", ""))
        
        control_image = controlnet_mapping[self.control_type]["hinter"](image)
        
        # run inference pipeline
        out = self.pipe(
            prompt=prompt, 
            negative_prompt=negative_prompt,
            image=control_image,
            num_inference_steps=num_inference_steps, 
            guidance_scale=guidance_scale,
            num_images_per_prompt=1,
            height=height,
            width=width,
            controlnet_conditioning_scale=controlnet_conditioning_scale,
            generator=self.generator
        )

        # save the generated image as a JPEG file
        output_image = out.images[0]
        output_image.save("output.jpg", format="JPEG")

    def decode_base64_image(self, image_string):
        base64_image = base64.b64decode(image_string)
        buffer = BytesIO(base64_image)
        image = Image.open(buffer)
        return image

# Example usage
payload = {
    "inputs": "Your prompt here",
    "image_path": "path/to/your/image.jpg",
    "controlnet_type": "depth",
    "num_inference_steps": 30,
    "guidance_scale": 7.5,
    "negative_prompt": None,
    "height": None,
    "width": None,
    "controlnet_conditioning_scale": 1.0,
}

handler = EndpointHandler()
handler(payload)