Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import torch | |
from PIL import Image | |
import torchvision | |
from modules.Device import Device | |
def _tensor_check_image(image: torch.Tensor) -> None: | |
"""#### Check if the input is a valid tensor image. | |
#### Args: | |
- `image` (torch.Tensor): The input tensor image. | |
""" | |
return | |
def tensor2pil(image: torch.Tensor) -> Image.Image: | |
"""#### Convert a tensor to a PIL image. | |
#### Args: | |
- `image` (torch.Tensor): The input tensor. | |
#### Returns: | |
- `Image.Image`: The converted PIL image. | |
""" | |
_tensor_check_image(image) | |
return Image.fromarray( | |
np.clip(255.0 * image.cpu().numpy().squeeze(0), 0, 255).astype(np.uint8) | |
) | |
def general_tensor_resize(image: torch.Tensor, w: int, h: int) -> torch.Tensor: | |
"""#### Resize a tensor image using bilinear interpolation. | |
#### Args: | |
- `image` (torch.Tensor): The input tensor image. | |
- `w` (int): The target width. | |
- `h` (int): The target height. | |
#### Returns: | |
- `torch.Tensor`: The resized tensor image. | |
""" | |
_tensor_check_image(image) | |
image = image.permute(0, 3, 1, 2) | |
image = torch.nn.functional.interpolate(image, size=(h, w), mode="bilinear") | |
image = image.permute(0, 2, 3, 1) | |
return image | |
def pil2tensor(image: Image.Image) -> torch.Tensor: | |
"""#### Convert a PIL image to a tensor. | |
#### Args: | |
- `image` (Image.Image): The input PIL image. | |
#### Returns: | |
- `torch.Tensor`: The converted tensor. | |
""" | |
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) | |
class TensorBatchBuilder: | |
"""#### Class for building a batch of tensors.""" | |
def __init__(self): | |
self.tensor: torch.Tensor | None = None | |
def concat(self, new_tensor: torch.Tensor) -> None: | |
"""#### Concatenate a new tensor to the batch. | |
#### Args: | |
- `new_tensor` (torch.Tensor): The new tensor to concatenate. | |
""" | |
self.tensor = new_tensor | |
LANCZOS = Image.Resampling.LANCZOS if hasattr(Image, "Resampling") else Image.LANCZOS | |
def tensor_resize(image: torch.Tensor, w: int, h: int) -> torch.Tensor: | |
"""#### Resize a tensor image. | |
#### Args: | |
- `image` (torch.Tensor): The input tensor image. | |
- `w` (int): The target width. | |
- `h` (int): The target height. | |
#### Returns: | |
- `torch.Tensor`: The resized tensor image. | |
""" | |
_tensor_check_image(image) | |
if image.shape[3] >= 3: | |
scaled_images = TensorBatchBuilder() | |
for single_image in image: | |
single_image = single_image.unsqueeze(0) | |
single_pil = tensor2pil(single_image) | |
scaled_pil = single_pil.resize((w, h), resample=LANCZOS) | |
single_image = pil2tensor(scaled_pil) | |
scaled_images.concat(single_image) | |
return scaled_images.tensor | |
else: | |
return general_tensor_resize(image, w, h) | |
def tensor_paste( | |
image1: torch.Tensor, | |
image2: torch.Tensor, | |
left_top: tuple[int, int], | |
mask: torch.Tensor, | |
) -> None: | |
"""#### Paste one tensor image onto another using a mask. | |
#### Args: | |
- `image1` (torch.Tensor): The base tensor image. | |
- `image2` (torch.Tensor): The tensor image to paste. | |
- `left_top` (tuple[int, int]): The top-left corner where the image2 will be pasted. | |
- `mask` (torch.Tensor): The mask tensor. | |
""" | |
_tensor_check_image(image1) | |
_tensor_check_image(image2) | |
_tensor_check_mask(mask) | |
x, y = left_top | |
_, h1, w1, _ = image1.shape | |
_, h2, w2, _ = image2.shape | |
# calculate image patch size | |
w = min(w1, x + w2) - x | |
h = min(h1, y + h2) - y | |
mask = mask[:, :h, :w, :] | |
image1[:, y : y + h, x : x + w, :] = (1 - mask) * image1[ | |
:, y : y + h, x : x + w, : | |
] + mask * image2[:, :h, :w, :] | |
return | |
def tensor_convert_rgba(image: torch.Tensor, prefer_copy: bool = True) -> torch.Tensor: | |
"""#### Convert a tensor image to RGBA format. | |
#### Args: | |
- `image` (torch.Tensor): The input tensor image. | |
- `prefer_copy` (bool, optional): Whether to prefer copying the tensor. Defaults to True. | |
#### Returns: | |
- `torch.Tensor`: The converted RGBA tensor image. | |
""" | |
_tensor_check_image(image) | |
alpha = torch.ones((*image.shape[:-1], 1)) | |
return torch.cat((image, alpha), axis=-1) | |
def tensor_convert_rgb(image: torch.Tensor, prefer_copy: bool = True) -> torch.Tensor: | |
"""#### Convert a tensor image to RGB format. | |
#### Args: | |
- `image` (torch.Tensor): The input tensor image. | |
- `prefer_copy` (bool, optional): Whether to prefer copying the tensor. Defaults to True. | |
#### Returns: | |
- `torch.Tensor`: The converted RGB tensor image. | |
""" | |
_tensor_check_image(image) | |
return image | |
def tensor_get_size(image: torch.Tensor) -> tuple[int, int]: | |
"""#### Get the size of a tensor image. | |
#### Args: | |
- `image` (torch.Tensor): The input tensor image. | |
#### Returns: | |
- `tuple[int, int]`: The width and height of the tensor image. | |
""" | |
_tensor_check_image(image) | |
_, h, w, _ = image.shape | |
return (w, h) | |
def tensor_putalpha(image: torch.Tensor, mask: torch.Tensor) -> None: | |
"""#### Add an alpha channel to a tensor image using a mask. | |
#### Args: | |
- `image` (torch.Tensor): The input tensor image. | |
- `mask` (torch.Tensor): The mask tensor. | |
""" | |
_tensor_check_image(image) | |
_tensor_check_mask(mask) | |
image[..., -1] = mask[..., 0] | |
def _tensor_check_mask(mask: torch.Tensor) -> None: | |
"""#### Check if the input is a valid tensor mask. | |
#### Args: | |
- `mask` (torch.Tensor): The input tensor mask. | |
""" | |
return | |
def tensor_gaussian_blur_mask( | |
mask: torch.Tensor | np.ndarray, kernel_size: int, sigma: float = 10.0 | |
) -> torch.Tensor: | |
"""#### Apply Gaussian blur to a tensor mask. | |
#### Args: | |
- `mask` (torch.Tensor | np.ndarray): The input tensor mask. | |
- `kernel_size` (int): The size of the Gaussian kernel. | |
- `sigma` (float, optional): The standard deviation of the Gaussian kernel. Defaults to 10.0. | |
#### Returns: | |
- `torch.Tensor`: The blurred tensor mask. | |
""" | |
if isinstance(mask, np.ndarray): | |
mask = torch.from_numpy(mask) | |
if mask.ndim == 2: | |
mask = mask[None, ..., None] | |
_tensor_check_mask(mask) | |
kernel_size = kernel_size * 2 + 1 | |
prev_device = mask.device | |
device = Device.get_torch_device() | |
mask.to(device) | |
# apply gaussian blur | |
mask = mask[:, None, ..., 0] | |
blurred_mask = torchvision.transforms.GaussianBlur( | |
kernel_size=kernel_size, sigma=sigma | |
)(mask) | |
blurred_mask = blurred_mask[:, 0, ..., None] | |
blurred_mask.to(prev_device) | |
return blurred_mask | |
def to_tensor(image: np.ndarray) -> torch.Tensor: | |
"""#### Convert a numpy array to a tensor. | |
#### Args: | |
- `image` (np.ndarray): The input numpy array. | |
#### Returns: | |
- `torch.Tensor`: The converted tensor. | |
""" | |
return torch.from_numpy(image) | |