File size: 3,888 Bytes
1f7551f
 
 
 
061d7a5
1f7551f
 
 
 
 
061d7a5
 
 
1f7551f
061d7a5
1f7551f
 
 
 
 
061d7a5
 
 
 
 
 
 
a9ef0f5
 
061d7a5
 
 
 
 
 
a9ef0f5
061d7a5
 
 
 
 
471adc0
 
061d7a5
 
3af03cf
061d7a5
 
 
 
 
 
3af03cf
061d7a5
 
3af03cf
 
061d7a5
 
3af03cf
471adc0
061d7a5
a9ef0f5
061d7a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f7551f
 
061d7a5
1f7551f
 
061d7a5
3af03cf
 
1f7551f
 
 
061d7a5
06e49f5
1f7551f
 
061d7a5
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from typing import  Dict, List, Any
import base64
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker

import torch


import numpy as np
import cv2
import controlnet_hinter

# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
    raise ValueError("need to run on GPU")
# set mixed precision dtype
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
    
CONTROLNET_MAPPING = {
    "depth": {
        "model_id": "lllyasviel/sd-controlnet-depth",
        "hinter": controlnet_hinter.hint_depth
    },
}


SD_ID_MAPPING = {
    "dreamshaper":  "stablediffusionapi/dreamshaper-xl",
    "juggernaut":  "stablediffusionapi/juggernaut-xl-v8",
    "realistic-vision":"SG161222/Realistic_Vision_V1.4",
    "rev-animated":"s6yx/ReV_Animated"
}

class EndpointHandler():
    def __init__(self, path=""):
        self.control_type = "depth"
        self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],torch_dtype=dtype).to(device)
        # Define Generator with seed
        self.generator = torch.Generator(device=device.type).manual_seed(3)

    def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
        """
        :param data: A dictionary contains `prompt` and optional `image_depth_map` field.
        :return: A dictionary with `image` field contains image in base64.
        """
        

        # hyperparamters
        sd_model = data.pop("sd_model", None)
        prompt = data.pop("prompt", None)
        negative_prompt = data.pop("negative_prompt", None)
        image_depth_map = data.pop("image_depth_map", None)
        steps = data.pop("steps", 25)
        scale = data.pop("scale", 7)
        height = data.pop("height", None)
        width = data.pop("width", None)
        controlnet_conditioning_scale = data.pop("controlnet_conditioning_scale", 1)

        self.stable_diffusion_id = SD_ID_MAPPING.get(sd_model, "Lykon/dreamshaper-8")

        print(f"Using stable diffusion model: {self.stable_diffusion_id}")
        
        self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
                                                                      controlnet=self.controlnet, 
                                                                      torch_dtype=dtype,
                                                                      safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to("cuda")
        # Check if neither prompt nor image is provided
        if prompt is None:
            return {"error": "Please provide a prompt"}
        
        if(image_depth_map is None):
            with open("./default.jpg", "rb") as image_file:
                image = base64.b64encode(image_file.read()).decode('utf-8')
        
        
        # process image
        image = self.decode_base64_image(image)

        # run inference pipeline
        out = self.pipe(
            prompt=prompt, 
            negative_prompt=negative_prompt,
            image=image,
            num_inference_steps=steps, 
            guidance_scale=scale,
            num_images_per_prompt=1,
            height=height,
            width=width,
            controlnet_conditioning_scale=controlnet_conditioning_scale,
            generator=self.generator
        )
        
        # return first generate PIL image
        return out.images[0]
    
    # helper to decode input image
    def decode_base64_image(self, image_string):
        base64_image = base64.b64decode(image_string)
        buffer = BytesIO(base64_image)
        image = Image.open(buffer)
        return image