Aatricks's picture
Upload folder using huggingface_hub
d9a2e19 verified
import numpy as np
import torch
from PIL import Image
import torchvision
from modules.Device import Device
def _tensor_check_image(image: torch.Tensor) -> None:
"""#### Check if the input is a valid tensor image.
#### Args:
- `image` (torch.Tensor): The input tensor image.
"""
return
def tensor2pil(image: torch.Tensor) -> Image.Image:
"""#### Convert a tensor to a PIL image.
#### Args:
- `image` (torch.Tensor): The input tensor.
#### Returns:
- `Image.Image`: The converted PIL image.
"""
_tensor_check_image(image)
return Image.fromarray(
np.clip(255.0 * image.cpu().numpy().squeeze(0), 0, 255).astype(np.uint8)
)
def general_tensor_resize(image: torch.Tensor, w: int, h: int) -> torch.Tensor:
"""#### Resize a tensor image using bilinear interpolation.
#### Args:
- `image` (torch.Tensor): The input tensor image.
- `w` (int): The target width.
- `h` (int): The target height.
#### Returns:
- `torch.Tensor`: The resized tensor image.
"""
_tensor_check_image(image)
image = image.permute(0, 3, 1, 2)
image = torch.nn.functional.interpolate(image, size=(h, w), mode="bilinear")
image = image.permute(0, 2, 3, 1)
return image
def pil2tensor(image: Image.Image) -> torch.Tensor:
"""#### Convert a PIL image to a tensor.
#### Args:
- `image` (Image.Image): The input PIL image.
#### Returns:
- `torch.Tensor`: The converted tensor.
"""
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
class TensorBatchBuilder:
"""#### Class for building a batch of tensors."""
def __init__(self):
self.tensor: torch.Tensor | None = None
def concat(self, new_tensor: torch.Tensor) -> None:
"""#### Concatenate a new tensor to the batch.
#### Args:
- `new_tensor` (torch.Tensor): The new tensor to concatenate.
"""
self.tensor = new_tensor
LANCZOS = Image.Resampling.LANCZOS if hasattr(Image, "Resampling") else Image.LANCZOS
def tensor_resize(image: torch.Tensor, w: int, h: int) -> torch.Tensor:
"""#### Resize a tensor image.
#### Args:
- `image` (torch.Tensor): The input tensor image.
- `w` (int): The target width.
- `h` (int): The target height.
#### Returns:
- `torch.Tensor`: The resized tensor image.
"""
_tensor_check_image(image)
if image.shape[3] >= 3:
scaled_images = TensorBatchBuilder()
for single_image in image:
single_image = single_image.unsqueeze(0)
single_pil = tensor2pil(single_image)
scaled_pil = single_pil.resize((w, h), resample=LANCZOS)
single_image = pil2tensor(scaled_pil)
scaled_images.concat(single_image)
return scaled_images.tensor
else:
return general_tensor_resize(image, w, h)
def tensor_paste(
image1: torch.Tensor,
image2: torch.Tensor,
left_top: tuple[int, int],
mask: torch.Tensor,
) -> None:
"""#### Paste one tensor image onto another using a mask.
#### Args:
- `image1` (torch.Tensor): The base tensor image.
- `image2` (torch.Tensor): The tensor image to paste.
- `left_top` (tuple[int, int]): The top-left corner where the image2 will be pasted.
- `mask` (torch.Tensor): The mask tensor.
"""
_tensor_check_image(image1)
_tensor_check_image(image2)
_tensor_check_mask(mask)
x, y = left_top
_, h1, w1, _ = image1.shape
_, h2, w2, _ = image2.shape
# calculate image patch size
w = min(w1, x + w2) - x
h = min(h1, y + h2) - y
mask = mask[:, :h, :w, :]
image1[:, y : y + h, x : x + w, :] = (1 - mask) * image1[
:, y : y + h, x : x + w, :
] + mask * image2[:, :h, :w, :]
return
def tensor_convert_rgba(image: torch.Tensor, prefer_copy: bool = True) -> torch.Tensor:
"""#### Convert a tensor image to RGBA format.
#### Args:
- `image` (torch.Tensor): The input tensor image.
- `prefer_copy` (bool, optional): Whether to prefer copying the tensor. Defaults to True.
#### Returns:
- `torch.Tensor`: The converted RGBA tensor image.
"""
_tensor_check_image(image)
alpha = torch.ones((*image.shape[:-1], 1))
return torch.cat((image, alpha), axis=-1)
def tensor_convert_rgb(image: torch.Tensor, prefer_copy: bool = True) -> torch.Tensor:
"""#### Convert a tensor image to RGB format.
#### Args:
- `image` (torch.Tensor): The input tensor image.
- `prefer_copy` (bool, optional): Whether to prefer copying the tensor. Defaults to True.
#### Returns:
- `torch.Tensor`: The converted RGB tensor image.
"""
_tensor_check_image(image)
return image
def tensor_get_size(image: torch.Tensor) -> tuple[int, int]:
"""#### Get the size of a tensor image.
#### Args:
- `image` (torch.Tensor): The input tensor image.
#### Returns:
- `tuple[int, int]`: The width and height of the tensor image.
"""
_tensor_check_image(image)
_, h, w, _ = image.shape
return (w, h)
def tensor_putalpha(image: torch.Tensor, mask: torch.Tensor) -> None:
"""#### Add an alpha channel to a tensor image using a mask.
#### Args:
- `image` (torch.Tensor): The input tensor image.
- `mask` (torch.Tensor): The mask tensor.
"""
_tensor_check_image(image)
_tensor_check_mask(mask)
image[..., -1] = mask[..., 0]
def _tensor_check_mask(mask: torch.Tensor) -> None:
"""#### Check if the input is a valid tensor mask.
#### Args:
- `mask` (torch.Tensor): The input tensor mask.
"""
return
def tensor_gaussian_blur_mask(
mask: torch.Tensor | np.ndarray, kernel_size: int, sigma: float = 10.0
) -> torch.Tensor:
"""#### Apply Gaussian blur to a tensor mask.
#### Args:
- `mask` (torch.Tensor | np.ndarray): The input tensor mask.
- `kernel_size` (int): The size of the Gaussian kernel.
- `sigma` (float, optional): The standard deviation of the Gaussian kernel. Defaults to 10.0.
#### Returns:
- `torch.Tensor`: The blurred tensor mask.
"""
if isinstance(mask, np.ndarray):
mask = torch.from_numpy(mask)
if mask.ndim == 2:
mask = mask[None, ..., None]
_tensor_check_mask(mask)
kernel_size = kernel_size * 2 + 1
prev_device = mask.device
device = Device.get_torch_device()
mask.to(device)
# apply gaussian blur
mask = mask[:, None, ..., 0]
blurred_mask = torchvision.transforms.GaussianBlur(
kernel_size=kernel_size, sigma=sigma
)(mask)
blurred_mask = blurred_mask[:, 0, ..., None]
blurred_mask.to(prev_device)
return blurred_mask
def to_tensor(image: np.ndarray) -> torch.Tensor:
"""#### Convert a numpy array to a tensor.
#### Args:
- `image` (np.ndarray): The input numpy array.
#### Returns:
- `torch.Tensor`: The converted tensor.
"""
return torch.from_numpy(image)