yonishafir commited on
Commit
e035438
·
verified ·
1 Parent(s): 55801fe

Upload image_processor.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. image_processor.py +991 -0
image_processor.py ADDED
@@ -0,0 +1,991 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import warnings
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import PIL.Image
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from PIL import Image, ImageFilter, ImageOps
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
27
+ # from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
28
+
29
+
30
+ PipelineImageInput = Union[
31
+ PIL.Image.Image,
32
+ np.ndarray,
33
+ torch.FloatTensor,
34
+ List[PIL.Image.Image],
35
+ List[np.ndarray],
36
+ List[torch.FloatTensor],
37
+ ]
38
+
39
+ PipelineDepthInput = PipelineImageInput
40
+
41
+
42
+ class VaeImageProcessor(ConfigMixin):
43
+ """
44
+ Image processor for VAE.
45
+
46
+ Args:
47
+ do_resize (`bool`, *optional*, defaults to `True`):
48
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
49
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
50
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
51
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
52
+ resample (`str`, *optional*, defaults to `lanczos`):
53
+ Resampling filter to use when resizing the image.
54
+ do_normalize (`bool`, *optional*, defaults to `True`):
55
+ Whether to normalize the image to [-1,1].
56
+ do_binarize (`bool`, *optional*, defaults to `False`):
57
+ Whether to binarize the image to 0/1.
58
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
59
+ Whether to convert the images to RGB format.
60
+ do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
61
+ Whether to convert the images to grayscale format.
62
+ """
63
+
64
+ config_name = CONFIG_NAME
65
+
66
+ @register_to_config
67
+ def __init__(
68
+ self,
69
+ do_resize: bool = True,
70
+ vae_scale_factor: int = 8,
71
+ resample: str = "lanczos",
72
+ do_normalize: bool = True,
73
+ do_binarize: bool = False,
74
+ do_convert_rgb: bool = False,
75
+ do_convert_grayscale: bool = False,
76
+ ):
77
+ super().__init__()
78
+ if do_convert_rgb and do_convert_grayscale:
79
+ raise ValueError(
80
+ "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
81
+ " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
82
+ " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
83
+ )
84
+ self.config.do_convert_rgb = False
85
+
86
+ @staticmethod
87
+ def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
88
+ """
89
+ Convert a numpy image or a batch of images to a PIL image.
90
+ """
91
+ if images.ndim == 3:
92
+ images = images[None, ...]
93
+ images = (images * 255).round().astype("uint8")
94
+ if images.shape[-1] == 1:
95
+ # special case for grayscale (single channel) images
96
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
97
+ else:
98
+ pil_images = [Image.fromarray(image) for image in images]
99
+
100
+ return pil_images
101
+
102
+ @staticmethod
103
+ def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
104
+ """
105
+ Convert a PIL image or a list of PIL images to NumPy arrays.
106
+ """
107
+ if not isinstance(images, list):
108
+ images = [images]
109
+ images = [np.array(image).astype(np.float32) / 255.0 for image in images]
110
+ images = np.stack(images, axis=0)
111
+
112
+ return images
113
+
114
+ @staticmethod
115
+ def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
116
+ """
117
+ Convert a NumPy image to a PyTorch tensor.
118
+ """
119
+ if images.ndim == 3:
120
+ images = images[..., None]
121
+
122
+ images = torch.from_numpy(images.transpose(0, 3, 1, 2))
123
+ return images
124
+
125
+ @staticmethod
126
+ def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
127
+ """
128
+ Convert a PyTorch tensor to a NumPy image.
129
+ """
130
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
131
+ return images
132
+
133
+ @staticmethod
134
+ def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
135
+ """
136
+ Normalize an image array to [-1,1].
137
+ """
138
+ return 2.0 * images - 1.0
139
+
140
+ @staticmethod
141
+ def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
142
+ """
143
+ Denormalize an image array to [0,1].
144
+ """
145
+ return (images / 2 + 0.5).clamp(0, 1)
146
+
147
+ @staticmethod
148
+ def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
149
+ """
150
+ Converts a PIL image to RGB format.
151
+ """
152
+ image = image.convert("RGB")
153
+
154
+ return image
155
+
156
+ @staticmethod
157
+ def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
158
+ """
159
+ Converts a PIL image to grayscale format.
160
+ """
161
+ image = image.convert("L")
162
+
163
+ return image
164
+
165
+ @staticmethod
166
+ def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
167
+ """
168
+ Applies Gaussian blur to an image.
169
+ """
170
+ image = image.filter(ImageFilter.GaussianBlur(blur_factor))
171
+
172
+ return image
173
+
174
+ @staticmethod
175
+ def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
176
+ """
177
+ Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect ratio of the original image;
178
+ for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128.
179
+
180
+ Args:
181
+ mask_image (PIL.Image.Image): Mask image.
182
+ width (int): Width of the image to be processed.
183
+ height (int): Height of the image to be processed.
184
+ pad (int, optional): Padding to be added to the crop region. Defaults to 0.
185
+
186
+ Returns:
187
+ tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and matches the original aspect ratio.
188
+ """
189
+
190
+ mask_image = mask_image.convert("L")
191
+ mask = np.array(mask_image)
192
+
193
+ # 1. find a rectangular region that contains all masked ares in an image
194
+ h, w = mask.shape
195
+ crop_left = 0
196
+ for i in range(w):
197
+ if not (mask[:, i] == 0).all():
198
+ break
199
+ crop_left += 1
200
+
201
+ crop_right = 0
202
+ for i in reversed(range(w)):
203
+ if not (mask[:, i] == 0).all():
204
+ break
205
+ crop_right += 1
206
+
207
+ crop_top = 0
208
+ for i in range(h):
209
+ if not (mask[i] == 0).all():
210
+ break
211
+ crop_top += 1
212
+
213
+ crop_bottom = 0
214
+ for i in reversed(range(h)):
215
+ if not (mask[i] == 0).all():
216
+ break
217
+ crop_bottom += 1
218
+
219
+ # 2. add padding to the crop region
220
+ x1, y1, x2, y2 = (
221
+ int(max(crop_left - pad, 0)),
222
+ int(max(crop_top - pad, 0)),
223
+ int(min(w - crop_right + pad, w)),
224
+ int(min(h - crop_bottom + pad, h)),
225
+ )
226
+
227
+ # 3. expands crop region to match the aspect ratio of the image to be processed
228
+ ratio_crop_region = (x2 - x1) / (y2 - y1)
229
+ ratio_processing = width / height
230
+
231
+ if ratio_crop_region > ratio_processing:
232
+ desired_height = (x2 - x1) / ratio_processing
233
+ desired_height_diff = int(desired_height - (y2 - y1))
234
+ y1 -= desired_height_diff // 2
235
+ y2 += desired_height_diff - desired_height_diff // 2
236
+ if y2 >= mask_image.height:
237
+ diff = y2 - mask_image.height
238
+ y2 -= diff
239
+ y1 -= diff
240
+ if y1 < 0:
241
+ y2 -= y1
242
+ y1 -= y1
243
+ if y2 >= mask_image.height:
244
+ y2 = mask_image.height
245
+ else:
246
+ desired_width = (y2 - y1) * ratio_processing
247
+ desired_width_diff = int(desired_width - (x2 - x1))
248
+ x1 -= desired_width_diff // 2
249
+ x2 += desired_width_diff - desired_width_diff // 2
250
+ if x2 >= mask_image.width:
251
+ diff = x2 - mask_image.width
252
+ x2 -= diff
253
+ x1 -= diff
254
+ if x1 < 0:
255
+ x2 -= x1
256
+ x1 -= x1
257
+ if x2 >= mask_image.width:
258
+ x2 = mask_image.width
259
+
260
+ return x1, y1, x2, y2
261
+
262
+ def _resize_and_fill(
263
+ self,
264
+ image: PIL.Image.Image,
265
+ width: int,
266
+ height: int,
267
+ ) -> PIL.Image.Image:
268
+ """
269
+ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
270
+
271
+ Args:
272
+ image: The image to resize.
273
+ width: The width to resize the image to.
274
+ height: The height to resize the image to.
275
+ """
276
+
277
+ ratio = width / height
278
+ src_ratio = image.width / image.height
279
+
280
+ src_w = width if ratio < src_ratio else image.width * height // image.height
281
+ src_h = height if ratio >= src_ratio else image.height * width // image.width
282
+
283
+ resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
284
+ res = Image.new("RGB", (width, height))
285
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
286
+
287
+ if ratio < src_ratio:
288
+ fill_height = height // 2 - src_h // 2
289
+ if fill_height > 0:
290
+ res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
291
+ res.paste(
292
+ resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
293
+ box=(0, fill_height + src_h),
294
+ )
295
+ elif ratio > src_ratio:
296
+ fill_width = width // 2 - src_w // 2
297
+ if fill_width > 0:
298
+ res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
299
+ res.paste(
300
+ resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
301
+ box=(fill_width + src_w, 0),
302
+ )
303
+
304
+ return res
305
+
306
+ def _resize_and_crop(
307
+ self,
308
+ image: PIL.Image.Image,
309
+ width: int,
310
+ height: int,
311
+ ) -> PIL.Image.Image:
312
+ """
313
+ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
314
+
315
+ Args:
316
+ image: The image to resize.
317
+ width: The width to resize the image to.
318
+ height: The height to resize the image to.
319
+ """
320
+ ratio = width / height
321
+ src_ratio = image.width / image.height
322
+
323
+ src_w = width if ratio > src_ratio else image.width * height // image.height
324
+ src_h = height if ratio <= src_ratio else image.height * width // image.width
325
+
326
+ resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
327
+ res = Image.new("RGB", (width, height))
328
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
329
+ return res
330
+
331
+ def resize(
332
+ self,
333
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
334
+ height: int,
335
+ width: int,
336
+ resize_mode: str = "default", # "default", "fill", "crop"
337
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
338
+ """
339
+ Resize image.
340
+
341
+ Args:
342
+ image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
343
+ The image input, can be a PIL image, numpy array or pytorch tensor.
344
+ height (`int`):
345
+ The height to resize to.
346
+ width (`int`):
347
+ The width to resize to.
348
+ resize_mode (`str`, *optional*, defaults to `default`):
349
+ The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
350
+ within the specified width and height, and it may not maintaining the original aspect ratio.
351
+ If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
352
+ within the dimensions, filling empty with data from image.
353
+ If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
354
+ within the dimensions, cropping the excess.
355
+ Note that resize_mode `fill` and `crop` are only supported for PIL image input.
356
+
357
+ Returns:
358
+ `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
359
+ The resized image.
360
+ """
361
+ if resize_mode != "default" and not isinstance(image, PIL.Image.Image):
362
+ raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}")
363
+ if isinstance(image, PIL.Image.Image):
364
+ if resize_mode == "default":
365
+ image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
366
+ elif resize_mode == "fill":
367
+ image = self._resize_and_fill(image, width, height)
368
+ elif resize_mode == "crop":
369
+ image = self._resize_and_crop(image, width, height)
370
+ else:
371
+ raise ValueError(f"resize_mode {resize_mode} is not supported")
372
+
373
+ elif isinstance(image, torch.Tensor):
374
+ image = torch.nn.functional.interpolate(
375
+ image,
376
+ size=(height, width),
377
+ )
378
+ elif isinstance(image, np.ndarray):
379
+ image = self.numpy_to_pt(image)
380
+ image = torch.nn.functional.interpolate(
381
+ image,
382
+ size=(height, width),
383
+ )
384
+ image = self.pt_to_numpy(image)
385
+ return image
386
+
387
+ def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
388
+ """
389
+ Create a mask.
390
+
391
+ Args:
392
+ image (`PIL.Image.Image`):
393
+ The image input, should be a PIL image.
394
+
395
+ Returns:
396
+ `PIL.Image.Image`:
397
+ The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
398
+ """
399
+ image[image < 0.5] = 0
400
+ image[image >= 0.5] = 1
401
+
402
+ return image
403
+
404
+ def get_default_height_width(
405
+ self,
406
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
407
+ height: Optional[int] = None,
408
+ width: Optional[int] = None,
409
+ ) -> Tuple[int, int]:
410
+ """
411
+ This function return the height and width that are downscaled to the next integer multiple of
412
+ `vae_scale_factor`.
413
+
414
+ Args:
415
+ image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
416
+ The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
417
+ shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
418
+ have shape `[batch, channel, height, width]`.
419
+ height (`int`, *optional*, defaults to `None`):
420
+ The height in preprocessed image. If `None`, will use the height of `image` input.
421
+ width (`int`, *optional*`, defaults to `None`):
422
+ The width in preprocessed. If `None`, will use the width of the `image` input.
423
+ """
424
+
425
+ if height is None:
426
+ if isinstance(image, PIL.Image.Image):
427
+ height = image.height
428
+ elif isinstance(image, torch.Tensor):
429
+ height = image.shape[2]
430
+ else:
431
+ height = image.shape[1]
432
+
433
+ if width is None:
434
+ if isinstance(image, PIL.Image.Image):
435
+ width = image.width
436
+ elif isinstance(image, torch.Tensor):
437
+ width = image.shape[3]
438
+ else:
439
+ width = image.shape[2]
440
+
441
+ width, height = (
442
+ x - x % self.config.vae_scale_factor for x in (width, height)
443
+ ) # resize to integer multiple of vae_scale_factor
444
+
445
+ return height, width
446
+
447
+ def preprocess(
448
+ self,
449
+ image: PipelineImageInput,
450
+ height: Optional[int] = None,
451
+ width: Optional[int] = None,
452
+ resize_mode: str = "default", # "default", "fill", "crop"
453
+ crops_coords: Optional[Tuple[int, int, int, int]] = None,
454
+ ) -> torch.Tensor:
455
+ """
456
+ Preprocess the image input.
457
+
458
+ Args:
459
+ image (`pipeline_image_input`):
460
+ The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of supported formats.
461
+ height (`int`, *optional*, defaults to `None`):
462
+ The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default height.
463
+ width (`int`, *optional*`, defaults to `None`):
464
+ The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
465
+ resize_mode (`str`, *optional*, defaults to `default`):
466
+ The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit
467
+ within the specified width and height, and it may not maintaining the original aspect ratio.
468
+ If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
469
+ within the dimensions, filling empty with data from image.
470
+ If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
471
+ within the dimensions, cropping the excess.
472
+ Note that resize_mode `fill` and `crop` are only supported for PIL image input.
473
+ crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
474
+ The crop coordinates for each image in the batch. If `None`, will not crop the image.
475
+ """
476
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
477
+
478
+ # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
479
+ if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
480
+ if isinstance(image, torch.Tensor):
481
+ # if image is a pytorch tensor could have 2 possible shapes:
482
+ # 1. batch x height x width: we should insert the channel dimension at position 1
483
+ # 2. channel x height x width: we should insert batch dimension at position 0,
484
+ # however, since both channel and batch dimension has same size 1, it is same to insert at position 1
485
+ # for simplicity, we insert a dimension of size 1 at position 1 for both cases
486
+ image = image.unsqueeze(1)
487
+ else:
488
+ # if it is a numpy array, it could have 2 possible shapes:
489
+ # 1. batch x height x width: insert channel dimension on last position
490
+ # 2. height x width x channel: insert batch dimension on first position
491
+ if image.shape[-1] == 1:
492
+ image = np.expand_dims(image, axis=0)
493
+ else:
494
+ image = np.expand_dims(image, axis=-1)
495
+
496
+ if isinstance(image, supported_formats):
497
+ image = [image]
498
+ elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
499
+ raise ValueError(
500
+ f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
501
+ )
502
+
503
+ if isinstance(image[0], PIL.Image.Image):
504
+ if crops_coords is not None:
505
+ image = [i.crop(crops_coords) for i in image]
506
+ if self.config.do_resize:
507
+ height, width = self.get_default_height_width(image[0], height, width)
508
+ image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
509
+ if self.config.do_convert_rgb:
510
+ image = [self.convert_to_rgb(i) for i in image]
511
+ elif self.config.do_convert_grayscale:
512
+ image = [self.convert_to_grayscale(i) for i in image]
513
+ image = self.pil_to_numpy(image) # to np
514
+ image = self.numpy_to_pt(image) # to pt
515
+
516
+ elif isinstance(image[0], np.ndarray):
517
+ image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
518
+
519
+ image = self.numpy_to_pt(image)
520
+
521
+ height, width = self.get_default_height_width(image, height, width)
522
+ if self.config.do_resize:
523
+ image = self.resize(image, height, width)
524
+
525
+ elif isinstance(image[0], torch.Tensor):
526
+ image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
527
+
528
+ if self.config.do_convert_grayscale and image.ndim == 3:
529
+ image = image.unsqueeze(1)
530
+
531
+ channel = image.shape[1]
532
+ # don't need any preprocess if the image is latents
533
+ if channel >= 4:
534
+ return image
535
+
536
+ height, width = self.get_default_height_width(image, height, width)
537
+ if self.config.do_resize:
538
+ image = self.resize(image, height, width)
539
+
540
+ # expected range [0,1], normalize to [-1,1]
541
+ do_normalize = self.config.do_normalize
542
+ if do_normalize and image.min() < 0:
543
+ warnings.warn(
544
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
545
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
546
+ FutureWarning,
547
+ )
548
+ do_normalize = False
549
+
550
+ if do_normalize:
551
+ image = self.normalize(image)
552
+
553
+ if self.config.do_binarize:
554
+ image = self.binarize(image)
555
+
556
+ return image
557
+
558
+ def postprocess(
559
+ self,
560
+ image: torch.FloatTensor,
561
+ output_type: str = "pil",
562
+ do_denormalize: Optional[List[bool]] = None,
563
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
564
+ """
565
+ Postprocess the image output from tensor to `output_type`.
566
+
567
+ Args:
568
+ image (`torch.FloatTensor`):
569
+ The image input, should be a pytorch tensor with shape `B x C x H x W`.
570
+ output_type (`str`, *optional*, defaults to `pil`):
571
+ The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
572
+ do_denormalize (`List[bool]`, *optional*, defaults to `None`):
573
+ Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
574
+ `VaeImageProcessor` config.
575
+
576
+ Returns:
577
+ `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
578
+ The postprocessed image.
579
+ """
580
+ if not isinstance(image, torch.Tensor):
581
+ raise ValueError(
582
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
583
+ )
584
+ if output_type not in ["latent", "pt", "np", "pil"]:
585
+ deprecation_message = (
586
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
587
+ "`pil`, `np`, `pt`, `latent`"
588
+ )
589
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
590
+ output_type = "np"
591
+
592
+ if output_type == "latent":
593
+ return image
594
+
595
+ if do_denormalize is None:
596
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
597
+
598
+ image = torch.stack(
599
+ [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
600
+ )
601
+
602
+ if output_type == "pt":
603
+ return image
604
+
605
+ image = self.pt_to_numpy(image)
606
+
607
+ if output_type == "np":
608
+ return image
609
+
610
+ if output_type == "pil":
611
+ return self.numpy_to_pil(image)
612
+
613
+ def apply_overlay(
614
+ self,
615
+ mask: PIL.Image.Image,
616
+ init_image: PIL.Image.Image,
617
+ image: PIL.Image.Image,
618
+ crop_coords: Optional[Tuple[int, int, int, int]] = None,
619
+ ) -> PIL.Image.Image:
620
+ """
621
+ overlay the inpaint output to the original image
622
+ """
623
+
624
+ width, height = image.width, image.height
625
+
626
+ init_image = self.resize(init_image, width=width, height=height)
627
+ mask = self.resize(mask, width=width, height=height)
628
+
629
+ init_image_masked = PIL.Image.new("RGBa", (width, height))
630
+ init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L")))
631
+ init_image_masked = init_image_masked.convert("RGBA")
632
+
633
+ if crop_coords is not None:
634
+ x, y, x2, y2 = crop_coords
635
+ w = x2 - x
636
+ h = y2 - y
637
+ base_image = PIL.Image.new("RGBA", (width, height))
638
+ image = self.resize(image, height=h, width=w, resize_mode="crop")
639
+ base_image.paste(image, (x, y))
640
+ image = base_image.convert("RGB")
641
+
642
+ image = image.convert("RGBA")
643
+ image.alpha_composite(init_image_masked)
644
+ image = image.convert("RGB")
645
+
646
+ return image
647
+
648
+
649
+ class VaeImageProcessorLDM3D(VaeImageProcessor):
650
+ """
651
+ Image processor for VAE LDM3D.
652
+
653
+ Args:
654
+ do_resize (`bool`, *optional*, defaults to `True`):
655
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
656
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
657
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
658
+ resample (`str`, *optional*, defaults to `lanczos`):
659
+ Resampling filter to use when resizing the image.
660
+ do_normalize (`bool`, *optional*, defaults to `True`):
661
+ Whether to normalize the image to [-1,1].
662
+ """
663
+
664
+ config_name = CONFIG_NAME
665
+
666
+ @register_to_config
667
+ def __init__(
668
+ self,
669
+ do_resize: bool = True,
670
+ vae_scale_factor: int = 8,
671
+ resample: str = "lanczos",
672
+ do_normalize: bool = True,
673
+ ):
674
+ super().__init__()
675
+
676
+ @staticmethod
677
+ def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
678
+ """
679
+ Convert a NumPy image or a batch of images to a PIL image.
680
+ """
681
+ if images.ndim == 3:
682
+ images = images[None, ...]
683
+ images = (images * 255).round().astype("uint8")
684
+ if images.shape[-1] == 1:
685
+ # special case for grayscale (single channel) images
686
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
687
+ else:
688
+ pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
689
+
690
+ return pil_images
691
+
692
+ @staticmethod
693
+ def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
694
+ """
695
+ Convert a PIL image or a list of PIL images to NumPy arrays.
696
+ """
697
+ if not isinstance(images, list):
698
+ images = [images]
699
+
700
+ images = [np.array(image).astype(np.float32) / (2**16 - 1) for image in images]
701
+ images = np.stack(images, axis=0)
702
+ return images
703
+
704
+ @staticmethod
705
+ def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
706
+ """
707
+ Args:
708
+ image: RGB-like depth image
709
+
710
+ Returns: depth map
711
+
712
+ """
713
+ return image[:, :, 1] * 2**8 + image[:, :, 2]
714
+
715
+ def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
716
+ """
717
+ Convert a NumPy depth image or a batch of images to a PIL image.
718
+ """
719
+ if images.ndim == 3:
720
+ images = images[None, ...]
721
+ images_depth = images[:, :, :, 3:]
722
+ if images.shape[-1] == 6:
723
+ images_depth = (images_depth * 255).round().astype("uint8")
724
+ pil_images = [
725
+ Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
726
+ ]
727
+ elif images.shape[-1] == 4:
728
+ images_depth = (images_depth * 65535.0).astype(np.uint16)
729
+ pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
730
+ else:
731
+ raise Exception("Not supported")
732
+
733
+ return pil_images
734
+
735
+ def postprocess(
736
+ self,
737
+ image: torch.FloatTensor,
738
+ output_type: str = "pil",
739
+ do_denormalize: Optional[List[bool]] = None,
740
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
741
+ """
742
+ Postprocess the image output from tensor to `output_type`.
743
+
744
+ Args:
745
+ image (`torch.FloatTensor`):
746
+ The image input, should be a pytorch tensor with shape `B x C x H x W`.
747
+ output_type (`str`, *optional*, defaults to `pil`):
748
+ The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
749
+ do_denormalize (`List[bool]`, *optional*, defaults to `None`):
750
+ Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
751
+ `VaeImageProcessor` config.
752
+
753
+ Returns:
754
+ `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
755
+ The postprocessed image.
756
+ """
757
+ if not isinstance(image, torch.Tensor):
758
+ raise ValueError(
759
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
760
+ )
761
+ if output_type not in ["latent", "pt", "np", "pil"]:
762
+ deprecation_message = (
763
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
764
+ "`pil`, `np`, `pt`, `latent`"
765
+ )
766
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
767
+ output_type = "np"
768
+
769
+ if do_denormalize is None:
770
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
771
+
772
+ image = torch.stack(
773
+ [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
774
+ )
775
+
776
+ image = self.pt_to_numpy(image)
777
+
778
+ if output_type == "np":
779
+ if image.shape[-1] == 6:
780
+ image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
781
+ else:
782
+ image_depth = image[:, :, :, 3:]
783
+ return image[:, :, :, :3], image_depth
784
+
785
+ if output_type == "pil":
786
+ return self.numpy_to_pil(image), self.numpy_to_depth(image)
787
+ else:
788
+ raise Exception(f"This type {output_type} is not supported")
789
+
790
+ def preprocess(
791
+ self,
792
+ rgb: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
793
+ depth: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
794
+ height: Optional[int] = None,
795
+ width: Optional[int] = None,
796
+ target_res: Optional[int] = None,
797
+ ) -> torch.Tensor:
798
+ """
799
+ Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
800
+ """
801
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
802
+
803
+ # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
804
+ if self.config.do_convert_grayscale and isinstance(rgb, (torch.Tensor, np.ndarray)) and rgb.ndim == 3:
805
+ raise Exception("This is not yet supported")
806
+
807
+ if isinstance(rgb, supported_formats):
808
+ rgb = [rgb]
809
+ depth = [depth]
810
+ elif not (isinstance(rgb, list) and all(isinstance(i, supported_formats) for i in rgb)):
811
+ raise ValueError(
812
+ f"Input is in incorrect format: {[type(i) for i in rgb]}. Currently, we only support {', '.join(supported_formats)}"
813
+ )
814
+
815
+ if isinstance(rgb[0], PIL.Image.Image):
816
+ if self.config.do_convert_rgb:
817
+ raise Exception("This is not yet supported")
818
+ # rgb = [self.convert_to_rgb(i) for i in rgb]
819
+ # depth = [self.convert_to_depth(i) for i in depth] #TODO define convert_to_depth
820
+ if self.config.do_resize or target_res:
821
+ height, width = self.get_default_height_width(rgb[0], height, width) if not target_res else target_res
822
+ rgb = [self.resize(i, height, width) for i in rgb]
823
+ depth = [self.resize(i, height, width) for i in depth]
824
+ rgb = self.pil_to_numpy(rgb) # to np
825
+ rgb = self.numpy_to_pt(rgb) # to pt
826
+
827
+ depth = self.depth_pil_to_numpy(depth) # to np
828
+ depth = self.numpy_to_pt(depth) # to pt
829
+
830
+ elif isinstance(rgb[0], np.ndarray):
831
+ rgb = np.concatenate(rgb, axis=0) if rgb[0].ndim == 4 else np.stack(rgb, axis=0)
832
+ rgb = self.numpy_to_pt(rgb)
833
+ height, width = self.get_default_height_width(rgb, height, width)
834
+ if self.config.do_resize:
835
+ rgb = self.resize(rgb, height, width)
836
+
837
+ depth = np.concatenate(depth, axis=0) if rgb[0].ndim == 4 else np.stack(depth, axis=0)
838
+ depth = self.numpy_to_pt(depth)
839
+ height, width = self.get_default_height_width(depth, height, width)
840
+ if self.config.do_resize:
841
+ depth = self.resize(depth, height, width)
842
+
843
+ elif isinstance(rgb[0], torch.Tensor):
844
+ raise Exception("This is not yet supported")
845
+ # rgb = torch.cat(rgb, axis=0) if rgb[0].ndim == 4 else torch.stack(rgb, axis=0)
846
+
847
+ # if self.config.do_convert_grayscale and rgb.ndim == 3:
848
+ # rgb = rgb.unsqueeze(1)
849
+
850
+ # channel = rgb.shape[1]
851
+
852
+ # height, width = self.get_default_height_width(rgb, height, width)
853
+ # if self.config.do_resize:
854
+ # rgb = self.resize(rgb, height, width)
855
+
856
+ # depth = torch.cat(depth, axis=0) if depth[0].ndim == 4 else torch.stack(depth, axis=0)
857
+
858
+ # if self.config.do_convert_grayscale and depth.ndim == 3:
859
+ # depth = depth.unsqueeze(1)
860
+
861
+ # channel = depth.shape[1]
862
+ # # don't need any preprocess if the image is latents
863
+ # if depth == 4:
864
+ # return rgb, depth
865
+
866
+ # height, width = self.get_default_height_width(depth, height, width)
867
+ # if self.config.do_resize:
868
+ # depth = self.resize(depth, height, width)
869
+ # expected range [0,1], normalize to [-1,1]
870
+ do_normalize = self.config.do_normalize
871
+ if rgb.min() < 0 and do_normalize:
872
+ warnings.warn(
873
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
874
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{rgb.min()},{rgb.max()}]",
875
+ FutureWarning,
876
+ )
877
+ do_normalize = False
878
+
879
+ if do_normalize:
880
+ rgb = self.normalize(rgb)
881
+ depth = self.normalize(depth)
882
+
883
+ if self.config.do_binarize:
884
+ rgb = self.binarize(rgb)
885
+ depth = self.binarize(depth)
886
+
887
+ return rgb, depth
888
+
889
+
890
+ class IPAdapterMaskProcessor(VaeImageProcessor):
891
+ """
892
+ Image processor for IP Adapter image masks.
893
+
894
+ Args:
895
+ do_resize (`bool`, *optional*, defaults to `True`):
896
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
897
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
898
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
899
+ resample (`str`, *optional*, defaults to `lanczos`):
900
+ Resampling filter to use when resizing the image.
901
+ do_normalize (`bool`, *optional*, defaults to `False`):
902
+ Whether to normalize the image to [-1,1].
903
+ do_binarize (`bool`, *optional*, defaults to `True`):
904
+ Whether to binarize the image to 0/1.
905
+ do_convert_grayscale (`bool`, *optional*, defaults to be `True`):
906
+ Whether to convert the images to grayscale format.
907
+
908
+ """
909
+
910
+ config_name = CONFIG_NAME
911
+
912
+ @register_to_config
913
+ def __init__(
914
+ self,
915
+ do_resize: bool = True,
916
+ vae_scale_factor: int = 8,
917
+ resample: str = "lanczos",
918
+ do_normalize: bool = False,
919
+ do_binarize: bool = True,
920
+ do_convert_grayscale: bool = True,
921
+ ):
922
+ super().__init__(
923
+ do_resize=do_resize,
924
+ vae_scale_factor=vae_scale_factor,
925
+ resample=resample,
926
+ do_normalize=do_normalize,
927
+ do_binarize=do_binarize,
928
+ do_convert_grayscale=do_convert_grayscale,
929
+ )
930
+
931
+ @staticmethod
932
+ def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value_embed_dim: int):
933
+ """
934
+ Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention.
935
+ If the aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
936
+
937
+ Args:
938
+ mask (`torch.FloatTensor`):
939
+ The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`.
940
+ batch_size (`int`):
941
+ The batch size.
942
+ num_queries (`int`):
943
+ The number of queries.
944
+ value_embed_dim (`int`):
945
+ The dimensionality of the value embeddings.
946
+
947
+ Returns:
948
+ `torch.FloatTensor`:
949
+ The downsampled mask tensor.
950
+
951
+ """
952
+ o_h = mask.shape[1]
953
+ o_w = mask.shape[2]
954
+ ratio = o_w / o_h
955
+ mask_h = int(math.sqrt(num_queries / ratio))
956
+ mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0)
957
+ mask_w = num_queries // mask_h
958
+
959
+ mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0)
960
+
961
+ # Repeat batch_size times
962
+ if mask_downsample.shape[0] < batch_size:
963
+ mask_downsample = mask_downsample.repeat(batch_size, 1, 1)
964
+
965
+ mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1)
966
+
967
+ downsampled_area = mask_h * mask_w
968
+ # If the output image and the mask do not have the same aspect ratio, tensor shapes will not match
969
+ # Pad tensor if downsampled_mask.shape[1] is smaller than num_queries
970
+ if downsampled_area < num_queries:
971
+ warnings.warn(
972
+ "The aspect ratio of the mask does not match the aspect ratio of the output image. "
973
+ "Please update your masks or adjust the output size for optimal performance.",
974
+ UserWarning,
975
+ )
976
+ mask_downsample = F.pad(mask_downsample, (0, num_queries - mask_downsample.shape[1]), value=0.0)
977
+ # Discard last embeddings if downsampled_mask.shape[1] is bigger than num_queries
978
+ if downsampled_area > num_queries:
979
+ warnings.warn(
980
+ "The aspect ratio of the mask does not match the aspect ratio of the output image. "
981
+ "Please update your masks or adjust the output size for optimal performance.",
982
+ UserWarning,
983
+ )
984
+ mask_downsample = mask_downsample[:, :num_queries]
985
+
986
+ # Repeat last dimension to match SDPA output shape
987
+ mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat(
988
+ 1, 1, value_embed_dim
989
+ )
990
+
991
+ return mask_downsample