Vijish commited on
Commit
e4fdcb5
·
1 Parent(s): 2af39a9

Delete handler.py

Browse files
Files changed (1) hide show
  1. handler.py +0 -129
handler.py DELETED
@@ -1,129 +0,0 @@
1
- from typing import Dict, List, Any
2
- import base64
3
- from PIL import Image
4
- from io import BytesIO
5
- from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
6
- import torch
7
-
8
-
9
- import numpy as np
10
- import cv2
11
- import controlnet_hinter
12
-
13
- # set device
14
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
- if device.type != 'cuda':
16
- raise ValueError("need to run on GPU")
17
- # set mixed precision dtype
18
- dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
19
-
20
- # controlnet mapping for controlnet id and control hinter
21
- CONTROLNET_MAPPING = {
22
- "canny_edge": {
23
- "model_id": "lllyasviel/sd-controlnet-canny",
24
- "hinter": controlnet_hinter.hint_canny
25
- },
26
- "pose": {
27
- "model_id": "lllyasviel/sd-controlnet-openpose",
28
- "hinter": controlnet_hinter.hint_openpose
29
- },
30
- "depth": {
31
- "model_id": "lllyasviel/sd-controlnet-depth",
32
- "hinter": controlnet_hinter.hint_depth
33
- },
34
- "scribble": {
35
- "model_id": "lllyasviel/sd-controlnet-scribble",
36
- "hinter": controlnet_hinter.hint_scribble,
37
- },
38
- "segmentation": {
39
- "model_id": "lllyasviel/sd-controlnet-seg",
40
- "hinter": controlnet_hinter.hint_segmentation,
41
- },
42
- "normal": {
43
- "model_id": "lllyasviel/sd-controlnet-normal",
44
- "hinter": controlnet_hinter.hint_normal,
45
- },
46
- "hed": {
47
- "model_id": "lllyasviel/sd-controlnet-hed",
48
- "hinter": controlnet_hinter.hint_hed,
49
- },
50
- "hough": {
51
- "model_id": "lllyasviel/sd-controlnet-mlsd",
52
- "hinter": controlnet_hinter.hint_hough,
53
- }
54
- }
55
-
56
-
57
- class EndpointHandler():
58
- def __init__(self, path=""):
59
- # define default controlnet id and load controlnet
60
- self.control_type = "normal"
61
- self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],torch_dtype=dtype).to(device)
62
-
63
- # Load StableDiffusionControlNetPipeline
64
- self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
65
- self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
66
- controlnet=self.controlnet,
67
- torch_dtype=dtype,
68
- safety_checker=None).to(device)
69
- # Define Generator with seed
70
- self.generator = torch.Generator(device="cpu").manual_seed(3)
71
-
72
- def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
73
- """
74
- :param data: A dictionary contains `inputs` and optional `image` field.
75
- :return: A dictionary with `image` field contains image in base64.
76
- """
77
- prompt = data.pop("inputs", None)
78
- image = data.pop("image", None)
79
- controlnet_type = data.pop("controlnet_type", None)
80
-
81
- # Check if neither prompt nor image is provided
82
- if prompt is None and image is None:
83
- return {"error": "Please provide a prompt and base64 encoded image."}
84
-
85
- # Check if a new controlnet is provided
86
- if controlnet_type is not None and controlnet_type != self.control_type:
87
- print(f"changing controlnet from {self.control_type} to {controlnet_type} using {CONTROLNET_MAPPING[controlnet_type]['model_id']} model")
88
- self.control_type = controlnet_type
89
- self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],
90
- torch_dtype=dtype).to(device)
91
- self.pipe.controlnet = self.controlnet
92
-
93
-
94
- # hyperparamters
95
- num_inference_steps = data.pop("num_inference_steps", 30)
96
- guidance_scale = data.pop("guidance_scale", 7.5)
97
- negative_prompt = data.pop("negative_prompt", None)
98
- height = data.pop("height", None)
99
- width = data.pop("width", None)
100
- controlnet_conditioning_scale = data.pop("controlnet_conditioning_scale", 1.0)
101
-
102
- # process image
103
- image = self.decode_base64_image(image)
104
- control_image = CONTROLNET_MAPPING[self.control_type]["hinter"](image)
105
-
106
- # run inference pipeline
107
- out = self.pipe(
108
- prompt=prompt,
109
- negative_prompt=negative_prompt,
110
- image=control_image,
111
- num_inference_steps=num_inference_steps,
112
- guidance_scale=guidance_scale,
113
- num_images_per_prompt=1,
114
- height=height,
115
- width=width,
116
- controlnet_conditioning_scale=controlnet_conditioning_scale,
117
- generator=self.generator
118
- )
119
-
120
-
121
- # return first generate PIL image
122
- return out.images[0]
123
-
124
- # helper to decode input image
125
- def decode_base64_image(self, image_string):
126
- base64_image = base64.b64decode(image_string)
127
- buffer = BytesIO(base64_image)
128
- image = Image.open(buffer)
129
- return image