yamildiego commited on
Commit
a9ef0f5
·
1 Parent(s): 847ce27

test tiling

Browse files
Files changed (1) hide show
  1. handler.py +37 -123
handler.py CHANGED
@@ -2,154 +2,68 @@ from typing import Dict, List, Any
2
  import base64
3
  from PIL import Image
4
  from io import BytesIO
5
- from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
6
- #from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionSafetyChecker
7
- # import Safety Checker
8
  from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
9
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
10
 
11
  import torch
12
 
13
 
14
- import numpy as np
15
- import cv2
16
- import controlnet_hinter
17
 
18
- # set device
19
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
  if device.type != 'cuda':
21
  raise ValueError("need to run on GPU")
22
  # set mixed precision dtype
23
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
24
-
25
- # controlnet mapping for controlnet id and control hinter
26
- CONTROLNET_MAPPING = {
27
- "canny_edge": {
28
- "model_id": "lllyasviel/sd-controlnet-canny",
29
- "hinter": controlnet_hinter.hint_canny
30
- },
31
- "pose": {
32
- "model_id": "lllyasviel/sd-controlnet-openpose",
33
- "hinter": controlnet_hinter.hint_openpose
34
- },
35
- "depth": {
36
- "model_id": "lllyasviel/sd-controlnet-depth",
37
- "hinter": controlnet_hinter.hint_depth
38
- },
39
- "scribble": {
40
- "model_id": "lllyasviel/sd-controlnet-scribble",
41
- "hinter": controlnet_hinter.hint_scribble,
42
- },
43
- "segmentation": {
44
- "model_id": "lllyasviel/sd-controlnet-seg",
45
- "hinter": controlnet_hinter.hint_segmentation,
46
- },
47
- "normal": {
48
- "model_id": "lllyasviel/sd-controlnet-normal",
49
- "hinter": controlnet_hinter.hint_normal,
50
- },
51
- "hed": {
52
- "model_id": "lllyasviel/sd-controlnet-hed",
53
- "hinter": controlnet_hinter.hint_hed,
54
- },
55
- "hough": {
56
- "model_id": "lllyasviel/sd-controlnet-mlsd",
57
- "hinter": controlnet_hinter.hint_hough,
58
- }
59
- }
60
-
61
 
62
  class EndpointHandler():
63
- def __init__(self, path=""):
64
- # define default controlnet id and load controlnet
65
- self.control_type = "depth"
66
- self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],torch_dtype=dtype).to(device)
67
 
68
- #processor = AutoProcessor.from_pretrained("CompVis/stable-diffusion-safety-checker")
69
 
70
-
71
- # Load StableDiffusionControlNetPipeline
72
- #self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
73
- self.stable_diffusion_id = "Lykon/dreamshaper-8"
74
-
75
- self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
76
- controlnet=self.controlnet,
77
- torch_dtype=dtype,
78
- safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to("cuda")
79
-
80
- self.pipe_without_controlnet = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=dtype).to(device.type)
81
-
82
- # Define Generator with seed
83
- self.generator = torch.Generator(device=device.type).manual_seed(3)
84
-
85
- def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
86
- """
87
- :param data: A dictionary contains `inputs` and optional `image` field.
88
- :return: A dictionary with `image` field contains image in base64.
89
- """
90
- prompt = data.pop("inputs", None)
91
- image = data.pop("image", None)
92
- controlnet_type = data.pop("controlnet_type", None)
93
-
94
- # Check if neither prompt nor image is provided
95
- if prompt is None and image is None:
96
- return {"error": "Please provide a prompt and base64 encoded image."}
97
-
98
- # Check if a new controlnet is provided
99
- if controlnet_type is not None and controlnet_type != self.control_type:
100
- print(f"changing controlnet from {self.control_type} to {controlnet_type} using {CONTROLNET_MAPPING[controlnet_type]['model_id']} model")
101
- self.control_type = controlnet_type
102
- self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],
103
- torch_dtype=dtype).to(device)
104
- self.pipe.controlnet = self.controlnet
105
-
106
-
107
- targets = [self.pipe_without_controlnet.vae, self.pipe_without_controlnet.unet]
108
- for target in targets:
109
- for module in target.modules():
110
- if isinstance(module, torch.nn.Conv2d):
111
- module.padding_mode = "circular"
112
-
113
- # hyperparamters
114
- num_inference_steps = data.pop("num_inference_steps", 30)
115
- guidance_scale = data.pop("guidance_scale", 7.4)
116
- negative_prompt = data.pop("negative_prompt", None)
117
- height = data.pop("height", None)
118
- width = data.pop("width", None)
119
- controlnet_conditioning_scale = data.pop("controlnet_conditioning_scale", 1.0)
120
-
121
- test_var = data.pop("test_var", "DEFAULT")
122
- tiling = data.pop("tiling", True)
123
-
124
- print(f"prompt: {prompt}")
125
- print(f"prompt: {test_var}")
126
-
127
- # process image
128
- image = self.decode_base64_image(image)
129
- #control_image = CONTROLNET_MAPPING[self.control_type]["hinter"](image)
130
 
131
  # run inference pipeline
132
- out = self.pipe_without_controlnet(
133
  prompt=prompt,
134
  negative_prompt=negative_prompt,
135
- #image=control_image,
136
- #image=image,
137
  num_inference_steps=num_inference_steps,
138
  guidance_scale=guidance_scale,
139
  num_images_per_prompt=1,
140
  height=height,
141
  width=width,
142
- #controlnet_conditioning_scale=controlnet_conditioning_scale,
143
  generator=self.generator
144
  )
145
 
146
 
147
- # return first generate PIL image
148
- return out.images[0]
149
-
150
- # helper to decode input image
151
- def decode_base64_image(self, image_string):
152
- base64_image = base64.b64decode(image_string)
153
- buffer = BytesIO(base64_image)
154
- image = Image.open(buffer)
155
- return image
 
2
  import base64
3
  from PIL import Image
4
  from io import BytesIO
 
 
 
5
  from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
6
+ from diffusers import StableDiffusionPipeline
7
 
8
  import torch
9
 
10
 
11
+ # import numpy as np
12
+ # import cv2
 
13
 
14
+ # # set device
15
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
  if device.type != 'cuda':
17
  raise ValueError("need to run on GPU")
18
  # set mixed precision dtype
19
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  class EndpointHandler():
22
+ def __init__(self, path=""):
23
+ self.stable_diffusion_id = "Lykon/dreamshaper-8"
 
 
24
 
 
25
 
26
+
27
+ def patch_conv(cls):
28
+ init = cls.__init__
29
+ def __init__(self, *args, **kwargs):
30
+ return init(self, *args, **kwargs, padding_mode='circular')
31
+ cls.__init__ = __init__
32
+
33
+ patch_conv(torch.nn.Conv2d)
34
+
35
+
36
+
37
+
38
+ self.pipe = StableDiffusionPipeline.from_pretrained(self.stable_diffusion_id,torch_dtype=dtype,safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to(device.type)
39
+
40
+ self.generator = torch.Generator(device=device.type).manual_seed(3)
41
+
42
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
43
+ # """
44
+ # :param data: A dictionary contains `inputs` and optional `image` field.
45
+ # :return: A dictionary with `image` field contains image in base64.
46
+ # """
47
+ prompt = data.pop("inputs", None)
48
+ num_inference_steps = data.pop("num_inference_steps", 30)
49
+ guidance_scale = data.pop("guidance_scale", 7.4)
50
+ negative_prompt = data.pop("negative_prompt", None)
51
+ height = data.pop("height", None)
52
+ width = data.pop("width", None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  # run inference pipeline
55
+ out = self.pipe(
56
  prompt=prompt,
57
  negative_prompt=negative_prompt,
 
 
58
  num_inference_steps=num_inference_steps,
59
  guidance_scale=guidance_scale,
60
  num_images_per_prompt=1,
61
  height=height,
62
  width=width,
 
63
  generator=self.generator
64
  )
65
 
66
 
67
+ # return first generate PIL image
68
+ return out.images[0]
69
+