Portrait-Animation / src /utils /mask_processer.py
xiaozhongji
init spaces
79d88c4
import math
import warnings
from typing import List, Optional, Tuple, Union
import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F
from PIL import Image, ImageFilter, ImageOps
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
from diffusers.image_processor import VaeImageProcessor
class IPAdapterMaskProcessor(VaeImageProcessor):
"""
Image processor for IP Adapter image masks.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
vae_scale_factor (`int`, *optional*, defaults to `8`):
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
resample (`str`, *optional*, defaults to `lanczos`):
Resampling filter to use when resizing the image.
do_normalize (`bool`, *optional*, defaults to `False`):
Whether to normalize the image to [-1,1].
do_binarize (`bool`, *optional*, defaults to `True`):
Whether to binarize the image to 0/1.
do_convert_grayscale (`bool`, *optional*, defaults to be `True`):
Whether to convert the images to grayscale format.
"""
config_name = CONFIG_NAME
@register_to_config
def __init__(
self,
do_resize: bool = True,
vae_scale_factor: int = 8,
resample: str = "lanczos",
do_normalize: bool = False,
do_binarize: bool = True,
do_convert_grayscale: bool = True,
):
super().__init__(
do_resize=do_resize,
vae_scale_factor=vae_scale_factor,
resample=resample,
do_normalize=do_normalize,
do_binarize=do_binarize,
do_convert_grayscale=do_convert_grayscale,
)
@staticmethod
def downsample(mask: torch.Tensor, batch_size: int, num_queries: int, value_embed_dim: int):
"""
Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the
aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
Args:
mask (`torch.Tensor`):
The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`.
batch_size (`int`):
The batch size.
num_queries (`int`):
The number of queries.
value_embed_dim (`int`):
The dimensionality of the value embeddings.
Returns:
`torch.Tensor`:
The downsampled mask tensor.
"""
o_h = mask.shape[1]
o_w = mask.shape[2]
ratio = o_w / o_h
mask_h = int(torch.sqrt(torch.FloatTensor([num_queries / ratio]))[0])
mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0)
mask_w = num_queries // mask_h
mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0)
# Repeat batch_size times
if mask_downsample.shape[0] < batch_size:
mask_downsample = mask_downsample.repeat(batch_size, 1, 1)
mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1)
downsampled_area = mask_h * mask_w
# If the output image and the mask do not have the same aspect ratio, tensor shapes will not match
# Pad tensor if downsampled_mask.shape[1] is smaller than num_queries
if downsampled_area < num_queries:
warnings.warn(
"The aspect ratio of the mask does not match the aspect ratio of the output image. "
"Please update your masks or adjust the output size for optimal performance.",
UserWarning,
)
mask_downsample = F.pad(mask_downsample, (0, num_queries - mask_downsample.shape[1]), value=0.0)
# Discard last embeddings if downsampled_mask.shape[1] is bigger than num_queries
if downsampled_area > num_queries:
warnings.warn(
"The aspect ratio of the mask does not match the aspect ratio of the output image. "
"Please update your masks or adjust the output size for optimal performance.",
UserWarning,
)
mask_downsample = mask_downsample[:, :num_queries]
# Repeat last dimension to match SDPA output shape
mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat(
1, 1, value_embed_dim
)
return mask_downsample