Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import warnings | |
from typing import List, Optional, Tuple, Union | |
import numpy as np | |
import PIL.Image | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image, ImageFilter, ImageOps | |
from diffusers.configuration_utils import ConfigMixin, register_to_config | |
from diffusers.utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate | |
from diffusers.image_processor import VaeImageProcessor | |
class IPAdapterMaskProcessor(VaeImageProcessor): | |
""" | |
Image processor for IP Adapter image masks. | |
Args: | |
do_resize (`bool`, *optional*, defaults to `True`): | |
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. | |
vae_scale_factor (`int`, *optional*, defaults to `8`): | |
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. | |
resample (`str`, *optional*, defaults to `lanczos`): | |
Resampling filter to use when resizing the image. | |
do_normalize (`bool`, *optional*, defaults to `False`): | |
Whether to normalize the image to [-1,1]. | |
do_binarize (`bool`, *optional*, defaults to `True`): | |
Whether to binarize the image to 0/1. | |
do_convert_grayscale (`bool`, *optional*, defaults to be `True`): | |
Whether to convert the images to grayscale format. | |
""" | |
config_name = CONFIG_NAME | |
def __init__( | |
self, | |
do_resize: bool = True, | |
vae_scale_factor: int = 8, | |
resample: str = "lanczos", | |
do_normalize: bool = False, | |
do_binarize: bool = True, | |
do_convert_grayscale: bool = True, | |
): | |
super().__init__( | |
do_resize=do_resize, | |
vae_scale_factor=vae_scale_factor, | |
resample=resample, | |
do_normalize=do_normalize, | |
do_binarize=do_binarize, | |
do_convert_grayscale=do_convert_grayscale, | |
) | |
def downsample(mask: torch.Tensor, batch_size: int, num_queries: int, value_embed_dim: int): | |
""" | |
Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the | |
aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued. | |
Args: | |
mask (`torch.Tensor`): | |
The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`. | |
batch_size (`int`): | |
The batch size. | |
num_queries (`int`): | |
The number of queries. | |
value_embed_dim (`int`): | |
The dimensionality of the value embeddings. | |
Returns: | |
`torch.Tensor`: | |
The downsampled mask tensor. | |
""" | |
o_h = mask.shape[1] | |
o_w = mask.shape[2] | |
ratio = o_w / o_h | |
mask_h = int(torch.sqrt(torch.FloatTensor([num_queries / ratio]))[0]) | |
mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0) | |
mask_w = num_queries // mask_h | |
mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0) | |
# Repeat batch_size times | |
if mask_downsample.shape[0] < batch_size: | |
mask_downsample = mask_downsample.repeat(batch_size, 1, 1) | |
mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1) | |
downsampled_area = mask_h * mask_w | |
# If the output image and the mask do not have the same aspect ratio, tensor shapes will not match | |
# Pad tensor if downsampled_mask.shape[1] is smaller than num_queries | |
if downsampled_area < num_queries: | |
warnings.warn( | |
"The aspect ratio of the mask does not match the aspect ratio of the output image. " | |
"Please update your masks or adjust the output size for optimal performance.", | |
UserWarning, | |
) | |
mask_downsample = F.pad(mask_downsample, (0, num_queries - mask_downsample.shape[1]), value=0.0) | |
# Discard last embeddings if downsampled_mask.shape[1] is bigger than num_queries | |
if downsampled_area > num_queries: | |
warnings.warn( | |
"The aspect ratio of the mask does not match the aspect ratio of the output image. " | |
"Please update your masks or adjust the output size for optimal performance.", | |
UserWarning, | |
) | |
mask_downsample = mask_downsample[:, :num_queries] | |
# Repeat last dimension to match SDPA output shape | |
mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat( | |
1, 1, value_embed_dim | |
) | |
return mask_downsample |