File size: 4,210 Bytes
1f7551f
 
 
 
061d7a5
1f7551f
 
 
 
 
061d7a5
 
 
1f7551f
061d7a5
1f7551f
 
 
 
 
061d7a5
 
 
 
 
 
 
a9ef0f5
061d7a5
 
 
 
3ece45d
4cdf1fd
f559cf4
82cac35
 
 
3ece45d
f559cf4
 
 
 
 
061d7a5
471adc0
 
061d7a5
 
3af03cf
061d7a5
 
 
 
 
f559cf4
80111fe
061d7a5
 
3af03cf
 
061d7a5
 
66c9a4e
471adc0
4cdf1fd
 
 
 
061d7a5
 
 
 
7e47e3f
 
4cdf1fd
 
 
061d7a5
64d3ad8
1f7551f
 
f1bbd17
1f7551f
 
061d7a5
3af03cf
 
1f7551f
 
 
061d7a5
06e49f5
1f7551f
 
061d7a5
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from typing import  Dict, List, Any
import base64
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker

import torch


import numpy as np
import cv2
import controlnet_hinter

# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
    raise ValueError("need to run on GPU")
# set mixed precision dtype
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
    
CONTROLNET_MAPPING = {
    "depth": {
        "model_id": "lllyasviel/sd-controlnet-depth",
        "hinter": controlnet_hinter.hint_depth
    },
}

class EndpointHandler():
    def __init__(self, path=""):
        self.control_type = "depth"
        self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],torch_dtype=dtype).to(device)

        self.stable_diffusion_id_0 = "Lykon/dreamshaper-8"
        self.dreamshaper = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id_0,
                                                                       controlnet=self.controlnet, 
                                                                       torch_dtype=dtype,
                                                                       safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to("cuda")

        self.stable_diffusion_id_1 = "lykon/dreamshaper-xl-v2-turbo"
        self.dreamshaper_2 = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id_1,
                                                                       controlnet=self.controlnet, 
                                                                       torch_dtype=dtype,
                                                                       safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to("cuda")
        # Define Generator with seed
        self.generator = torch.Generator(device=device.type).manual_seed(3)

    def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
        """
        :param data: A dictionary contains `prompt` and optional `image_depth_map` field.
        :return: A dictionary with `image` field contains image in base64.
        """
        

        # hyperparamters
        sd_model = data.pop("sd_model", "dreamshaper")
        prompt = data.pop("inputs", None)
        negative_prompt = data.pop("negative_prompt", None)
        image_depth_map = data.pop("image_depth_map", None)
        steps = data.pop("steps", 25)
        scale = data.pop("scale", 7)
        height = data.pop("height", None)
        width = data.pop("width", None)
        controlnet_conditioning_scale = data.pop("controlnet_conditioning_scale", 1.0)


        if sd_model is None or not hasattr(self, sd_model):
            return {"error": "Modelo SD no especificado o no válido"}
    
        if prompt is None:
            return {"error": "Please provide a prompt"}
        
        if(image_depth_map is None):
            return {"error": "Please provide a image_depth_map"}
                
                
        pipe = getattr(self, sd_model)

        # process image
        image = self.decode_base64_image(image_depth_map)

        # run inference pipeline
        out = pipe(
            prompt=prompt, 
            negative_prompt=negative_prompt,
            image=image,
            num_inference_steps=steps, 
            guidance_scale=scale,
            num_images_per_prompt=1,
            height=height,
            width=width,
            controlnet_conditioning_scale=controlnet_conditioning_scale,
            generator=self.generator
        )
        
        # return first generate PIL image
        return out.images[0]
    
    # helper to decode input image
    def decode_base64_image(self, image_string):
        base64_image = base64.b64decode(image_string)
        buffer = BytesIO(base64_image)
        image = Image.open(buffer)
        return image