Spaces:
Runtime error
Runtime error
| import math | |
| import warnings | |
| from typing import List, Optional, Tuple, Union | |
| import numpy as np | |
| import PIL.Image | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image, ImageFilter, ImageOps | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate | |
| from diffusers.image_processor import VaeImageProcessor | |
| class IPAdapterMaskProcessor(VaeImageProcessor): | |
| """ | |
| Image processor for IP Adapter image masks. | |
| Args: | |
| do_resize (`bool`, *optional*, defaults to `True`): | |
| Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. | |
| vae_scale_factor (`int`, *optional*, defaults to `8`): | |
| VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. | |
| resample (`str`, *optional*, defaults to `lanczos`): | |
| Resampling filter to use when resizing the image. | |
| do_normalize (`bool`, *optional*, defaults to `False`): | |
| Whether to normalize the image to [-1,1]. | |
| do_binarize (`bool`, *optional*, defaults to `True`): | |
| Whether to binarize the image to 0/1. | |
| do_convert_grayscale (`bool`, *optional*, defaults to be `True`): | |
| Whether to convert the images to grayscale format. | |
| """ | |
| config_name = CONFIG_NAME | |
| def __init__( | |
| self, | |
| do_resize: bool = True, | |
| vae_scale_factor: int = 8, | |
| resample: str = "lanczos", | |
| do_normalize: bool = False, | |
| do_binarize: bool = True, | |
| do_convert_grayscale: bool = True, | |
| ): | |
| super().__init__( | |
| do_resize=do_resize, | |
| vae_scale_factor=vae_scale_factor, | |
| resample=resample, | |
| do_normalize=do_normalize, | |
| do_binarize=do_binarize, | |
| do_convert_grayscale=do_convert_grayscale, | |
| ) | |
| def downsample(mask: torch.Tensor, batch_size: int, num_queries: int, value_embed_dim: int): | |
| """ | |
| Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the | |
| aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued. | |
| Args: | |
| mask (`torch.Tensor`): | |
| The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`. | |
| batch_size (`int`): | |
| The batch size. | |
| num_queries (`int`): | |
| The number of queries. | |
| value_embed_dim (`int`): | |
| The dimensionality of the value embeddings. | |
| Returns: | |
| `torch.Tensor`: | |
| The downsampled mask tensor. | |
| """ | |
| o_h = mask.shape[1] | |
| o_w = mask.shape[2] | |
| ratio = o_w / o_h | |
| mask_h = int(torch.sqrt(torch.FloatTensor([num_queries / ratio]))[0]) | |
| mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0) | |
| mask_w = num_queries // mask_h | |
| mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0) | |
| # Repeat batch_size times | |
| if mask_downsample.shape[0] < batch_size: | |
| mask_downsample = mask_downsample.repeat(batch_size, 1, 1) | |
| mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1) | |
| downsampled_area = mask_h * mask_w | |
| # If the output image and the mask do not have the same aspect ratio, tensor shapes will not match | |
| # Pad tensor if downsampled_mask.shape[1] is smaller than num_queries | |
| if downsampled_area < num_queries: | |
| warnings.warn( | |
| "The aspect ratio of the mask does not match the aspect ratio of the output image. " | |
| "Please update your masks or adjust the output size for optimal performance.", | |
| UserWarning, | |
| ) | |
| mask_downsample = F.pad(mask_downsample, (0, num_queries - mask_downsample.shape[1]), value=0.0) | |
| # Discard last embeddings if downsampled_mask.shape[1] is bigger than num_queries | |
| if downsampled_area > num_queries: | |
| warnings.warn( | |
| "The aspect ratio of the mask does not match the aspect ratio of the output image. " | |
| "Please update your masks or adjust the output size for optimal performance.", | |
| UserWarning, | |
| ) | |
| mask_downsample = mask_downsample[:, :num_queries] | |
| # Repeat last dimension to match SDPA output shape | |
| mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat( | |
| 1, 1, value_embed_dim | |
| ) | |
| return mask_downsample |