Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import torchvision | |
| from modules.Device import Device | |
| def _tensor_check_image(image: torch.Tensor) -> None: | |
| """#### Check if the input is a valid tensor image. | |
| #### Args: | |
| - `image` (torch.Tensor): The input tensor image. | |
| """ | |
| return | |
| def tensor2pil(image: torch.Tensor) -> Image.Image: | |
| """#### Convert a tensor to a PIL image. | |
| #### Args: | |
| - `image` (torch.Tensor): The input tensor. | |
| #### Returns: | |
| - `Image.Image`: The converted PIL image. | |
| """ | |
| _tensor_check_image(image) | |
| return Image.fromarray( | |
| np.clip(255.0 * image.cpu().numpy().squeeze(0), 0, 255).astype(np.uint8) | |
| ) | |
| def general_tensor_resize(image: torch.Tensor, w: int, h: int) -> torch.Tensor: | |
| """#### Resize a tensor image using bilinear interpolation. | |
| #### Args: | |
| - `image` (torch.Tensor): The input tensor image. | |
| - `w` (int): The target width. | |
| - `h` (int): The target height. | |
| #### Returns: | |
| - `torch.Tensor`: The resized tensor image. | |
| """ | |
| _tensor_check_image(image) | |
| image = image.permute(0, 3, 1, 2) | |
| image = torch.nn.functional.interpolate(image, size=(h, w), mode="bilinear") | |
| image = image.permute(0, 2, 3, 1) | |
| return image | |
| def pil2tensor(image: Image.Image) -> torch.Tensor: | |
| """#### Convert a PIL image to a tensor. | |
| #### Args: | |
| - `image` (Image.Image): The input PIL image. | |
| #### Returns: | |
| - `torch.Tensor`: The converted tensor. | |
| """ | |
| return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) | |
| class TensorBatchBuilder: | |
| """#### Class for building a batch of tensors.""" | |
| def __init__(self): | |
| self.tensor: torch.Tensor | None = None | |
| def concat(self, new_tensor: torch.Tensor) -> None: | |
| """#### Concatenate a new tensor to the batch. | |
| #### Args: | |
| - `new_tensor` (torch.Tensor): The new tensor to concatenate. | |
| """ | |
| self.tensor = new_tensor | |
| LANCZOS = Image.Resampling.LANCZOS if hasattr(Image, "Resampling") else Image.LANCZOS | |
| def tensor_resize(image: torch.Tensor, w: int, h: int) -> torch.Tensor: | |
| """#### Resize a tensor image. | |
| #### Args: | |
| - `image` (torch.Tensor): The input tensor image. | |
| - `w` (int): The target width. | |
| - `h` (int): The target height. | |
| #### Returns: | |
| - `torch.Tensor`: The resized tensor image. | |
| """ | |
| _tensor_check_image(image) | |
| if image.shape[3] >= 3: | |
| scaled_images = TensorBatchBuilder() | |
| for single_image in image: | |
| single_image = single_image.unsqueeze(0) | |
| single_pil = tensor2pil(single_image) | |
| scaled_pil = single_pil.resize((w, h), resample=LANCZOS) | |
| single_image = pil2tensor(scaled_pil) | |
| scaled_images.concat(single_image) | |
| return scaled_images.tensor | |
| else: | |
| return general_tensor_resize(image, w, h) | |
| def tensor_paste( | |
| image1: torch.Tensor, | |
| image2: torch.Tensor, | |
| left_top: tuple[int, int], | |
| mask: torch.Tensor, | |
| ) -> None: | |
| """#### Paste one tensor image onto another using a mask. | |
| #### Args: | |
| - `image1` (torch.Tensor): The base tensor image. | |
| - `image2` (torch.Tensor): The tensor image to paste. | |
| - `left_top` (tuple[int, int]): The top-left corner where the image2 will be pasted. | |
| - `mask` (torch.Tensor): The mask tensor. | |
| """ | |
| _tensor_check_image(image1) | |
| _tensor_check_image(image2) | |
| _tensor_check_mask(mask) | |
| x, y = left_top | |
| _, h1, w1, _ = image1.shape | |
| _, h2, w2, _ = image2.shape | |
| # calculate image patch size | |
| w = min(w1, x + w2) - x | |
| h = min(h1, y + h2) - y | |
| mask = mask[:, :h, :w, :] | |
| image1[:, y : y + h, x : x + w, :] = (1 - mask) * image1[ | |
| :, y : y + h, x : x + w, : | |
| ] + mask * image2[:, :h, :w, :] | |
| return | |
| def tensor_convert_rgba(image: torch.Tensor, prefer_copy: bool = True) -> torch.Tensor: | |
| """#### Convert a tensor image to RGBA format. | |
| #### Args: | |
| - `image` (torch.Tensor): The input tensor image. | |
| - `prefer_copy` (bool, optional): Whether to prefer copying the tensor. Defaults to True. | |
| #### Returns: | |
| - `torch.Tensor`: The converted RGBA tensor image. | |
| """ | |
| _tensor_check_image(image) | |
| alpha = torch.ones((*image.shape[:-1], 1)) | |
| return torch.cat((image, alpha), axis=-1) | |
| def tensor_convert_rgb(image: torch.Tensor, prefer_copy: bool = True) -> torch.Tensor: | |
| """#### Convert a tensor image to RGB format. | |
| #### Args: | |
| - `image` (torch.Tensor): The input tensor image. | |
| - `prefer_copy` (bool, optional): Whether to prefer copying the tensor. Defaults to True. | |
| #### Returns: | |
| - `torch.Tensor`: The converted RGB tensor image. | |
| """ | |
| _tensor_check_image(image) | |
| return image | |
| def tensor_get_size(image: torch.Tensor) -> tuple[int, int]: | |
| """#### Get the size of a tensor image. | |
| #### Args: | |
| - `image` (torch.Tensor): The input tensor image. | |
| #### Returns: | |
| - `tuple[int, int]`: The width and height of the tensor image. | |
| """ | |
| _tensor_check_image(image) | |
| _, h, w, _ = image.shape | |
| return (w, h) | |
| def tensor_putalpha(image: torch.Tensor, mask: torch.Tensor) -> None: | |
| """#### Add an alpha channel to a tensor image using a mask. | |
| #### Args: | |
| - `image` (torch.Tensor): The input tensor image. | |
| - `mask` (torch.Tensor): The mask tensor. | |
| """ | |
| _tensor_check_image(image) | |
| _tensor_check_mask(mask) | |
| image[..., -1] = mask[..., 0] | |
| def _tensor_check_mask(mask: torch.Tensor) -> None: | |
| """#### Check if the input is a valid tensor mask. | |
| #### Args: | |
| - `mask` (torch.Tensor): The input tensor mask. | |
| """ | |
| return | |
| def tensor_gaussian_blur_mask( | |
| mask: torch.Tensor | np.ndarray, kernel_size: int, sigma: float = 10.0 | |
| ) -> torch.Tensor: | |
| """#### Apply Gaussian blur to a tensor mask. | |
| #### Args: | |
| - `mask` (torch.Tensor | np.ndarray): The input tensor mask. | |
| - `kernel_size` (int): The size of the Gaussian kernel. | |
| - `sigma` (float, optional): The standard deviation of the Gaussian kernel. Defaults to 10.0. | |
| #### Returns: | |
| - `torch.Tensor`: The blurred tensor mask. | |
| """ | |
| if isinstance(mask, np.ndarray): | |
| mask = torch.from_numpy(mask) | |
| if mask.ndim == 2: | |
| mask = mask[None, ..., None] | |
| _tensor_check_mask(mask) | |
| kernel_size = kernel_size * 2 + 1 | |
| prev_device = mask.device | |
| device = Device.get_torch_device() | |
| mask.to(device) | |
| # apply gaussian blur | |
| mask = mask[:, None, ..., 0] | |
| blurred_mask = torchvision.transforms.GaussianBlur( | |
| kernel_size=kernel_size, sigma=sigma | |
| )(mask) | |
| blurred_mask = blurred_mask[:, 0, ..., None] | |
| blurred_mask.to(prev_device) | |
| return blurred_mask | |
| def to_tensor(image: np.ndarray) -> torch.Tensor: | |
| """#### Convert a numpy array to a tensor. | |
| #### Args: | |
| - `image` (np.ndarray): The input numpy array. | |
| #### Returns: | |
| - `torch.Tensor`: The converted tensor. | |
| """ | |
| return torch.from_numpy(image) | |