Aatricks's picture
Upload folder using huggingface_hub
d9a2e19 verified
import math
import numpy as np
import torch
from PIL import Image
def get_tiled_scale_steps(width: int, height: int, tile_x: int, tile_y: int, overlap: int) -> int:
"""#### Calculate the number of steps required for tiled scaling.
#### Args:
- `width` (int): The width of the image.
- `height` (int): The height of the image.
- `tile_x` (int): The width of each tile.
- `tile_y` (int): The height of each tile.
- `overlap` (int): The overlap between tiles.
#### Returns:
- `int`: The number of steps required for tiled scaling.
"""
return math.ceil((height / (tile_y - overlap))) * math.ceil(
(width / (tile_x - overlap))
)
@torch.inference_mode()
def tiled_scale(
samples: torch.Tensor,
function: callable,
tile_x: int = 64,
tile_y: int = 64,
overlap: int = 8,
upscale_amount: float = 4,
out_channels: int = 3,
pbar: any = None,
) -> torch.Tensor:
"""#### Perform tiled scaling on a batch of samples.
#### Args:
- `samples` (torch.Tensor): The input samples.
- `function` (callable): The function to apply to each tile.
- `tile_x` (int, optional): The width of each tile. Defaults to 64.
- `tile_y` (int, optional): The height of each tile. Defaults to 64.
- `overlap` (int, optional): The overlap between tiles. Defaults to 8.
- `upscale_amount` (float, optional): The upscale amount. Defaults to 4.
- `out_channels` (int, optional): The number of output channels. Defaults to 3.
- `pbar` (any, optional): The progress bar. Defaults to None.
#### Returns:
- `torch.Tensor`: The scaled output tensor.
"""
output = torch.empty(
(
samples.shape[0],
out_channels,
round(samples.shape[2] * upscale_amount),
round(samples.shape[3] * upscale_amount),
),
device="cpu",
)
for b in range(samples.shape[0]):
s = samples[b : b + 1]
out = torch.zeros(
(
s.shape[0],
out_channels,
round(s.shape[2] * upscale_amount),
round(s.shape[3] * upscale_amount),
),
device="cpu",
)
out_div = torch.zeros(
(
s.shape[0],
out_channels,
round(s.shape[2] * upscale_amount),
round(s.shape[3] * upscale_amount),
),
device="cpu",
)
for y in range(0, s.shape[2], tile_y - overlap):
for x in range(0, s.shape[3], tile_x - overlap):
s_in = s[:, :, y : y + tile_y, x : x + tile_x]
ps = function(s_in).cpu()
mask = torch.ones_like(ps)
feather = round(overlap * upscale_amount)
for t in range(feather):
mask[:, :, t : 1 + t, :] *= (1.0 / feather) * (t + 1)
mask[:, :, mask.shape[2] - 1 - t : mask.shape[2] - t, :] *= (
1.0 / feather
) * (t + 1)
mask[:, :, :, t : 1 + t] *= (1.0 / feather) * (t + 1)
mask[:, :, :, mask.shape[3] - 1 - t : mask.shape[3] - t] *= (
1.0 / feather
) * (t + 1)
out[
:,
:,
round(y * upscale_amount) : round((y + tile_y) * upscale_amount),
round(x * upscale_amount) : round((x + tile_x) * upscale_amount),
] += ps * mask
out_div[
:,
:,
round(y * upscale_amount) : round((y + tile_y) * upscale_amount),
round(x * upscale_amount) : round((x + tile_x) * upscale_amount),
] += mask
output[b : b + 1] = out / out_div
return output
def flatten(img: Image.Image, bgcolor: str) -> Image.Image:
"""#### Replace transparency with a background color.
#### Args:
- `img` (Image.Image): The input image.
- `bgcolor` (str): The background color.
#### Returns:
- `Image.Image`: The image with transparency replaced by the background color.
"""
if img.mode in ("RGB"):
return img
return Image.alpha_composite(Image.new("RGBA", img.size, bgcolor), img).convert(
"RGB"
)
BLUR_KERNEL_SIZE = 15
def tensor_to_pil(img_tensor: torch.Tensor, batch_index: int = 0) -> Image.Image:
"""#### Convert a tensor to a PIL image.
#### Args:
- `img_tensor` (torch.Tensor): The input tensor.
- `batch_index` (int, optional): The batch index. Defaults to 0.
#### Returns:
- `Image.Image`: The converted PIL image.
"""
img_tensor = img_tensor[batch_index].unsqueeze(0)
i = 255.0 * img_tensor.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8).squeeze())
return img
def pil_to_tensor(image: Image.Image) -> torch.Tensor:
"""#### Convert a PIL image to a tensor.
#### Args:
- `image` (Image.Image): The input PIL image.
#### Returns:
- `torch.Tensor`: The converted tensor.
"""
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image).unsqueeze(0)
return image
def get_crop_region(mask: Image.Image, pad: int = 0) -> tuple:
"""#### Get the coordinates of the white rectangular mask region.
#### Args:
- `mask` (Image.Image): The input mask image in 'L' mode.
- `pad` (int, optional): The padding to apply. Defaults to 0.
#### Returns:
- `tuple`: The coordinates of the crop region.
"""
coordinates = mask.getbbox()
if coordinates is not None:
x1, y1, x2, y2 = coordinates
else:
x1, y1, x2, y2 = mask.width, mask.height, 0, 0
# Apply padding
x1 = max(x1 - pad, 0)
y1 = max(y1 - pad, 0)
x2 = min(x2 + pad, mask.width)
y2 = min(y2 + pad, mask.height)
return fix_crop_region((x1, y1, x2, y2), (mask.width, mask.height))
def fix_crop_region(region: tuple, image_size: tuple) -> tuple:
"""#### Remove the extra pixel added by the get_crop_region function.
#### Args:
- `region` (tuple): The crop region coordinates.
- `image_size` (tuple): The size of the image.
#### Returns:
- `tuple`: The fixed crop region coordinates.
"""
image_width, image_height = image_size
x1, y1, x2, y2 = region
if x2 < image_width:
x2 -= 1
if y2 < image_height:
y2 -= 1
return x1, y1, x2, y2
def expand_crop(region: tuple, width: int, height: int, target_width: int, target_height: int) -> tuple:
"""#### Expand a crop region to a specified target size.
#### Args:
- `region` (tuple): The crop region coordinates.
- `width` (int): The width of the image.
- `height` (int): The height of the image.
- `target_width` (int): The desired width of the crop region.
- `target_height` (int): The desired height of the crop region.
#### Returns:
- `tuple`: The expanded crop region coordinates and the target size.
"""
x1, y1, x2, y2 = region
actual_width = x2 - x1
actual_height = y2 - y1
# Try to expand region to the right of half the difference
width_diff = target_width - actual_width
x2 = min(x2 + width_diff // 2, width)
# Expand region to the left of the difference including the pixels that could not be expanded to the right
width_diff = target_width - (x2 - x1)
x1 = max(x1 - width_diff, 0)
# Try the right again
width_diff = target_width - (x2 - x1)
x2 = min(x2 + width_diff, width)
# Try to expand region to the bottom of half the difference
height_diff = target_height - actual_height
y2 = min(y2 + height_diff // 2, height)
# Expand region to the top of the difference including the pixels that could not be expanded to the bottom
height_diff = target_height - (y2 - y1)
y1 = max(y1 - height_diff, 0)
# Try the bottom again
height_diff = target_height - (y2 - y1)
y2 = min(y2 + height_diff, height)
return (x1, y1, x2, y2), (target_width, target_height)
def crop_cond(cond: list, region: tuple, init_size: tuple, canvas_size: tuple, tile_size: tuple, w_pad: int = 0, h_pad: int = 0) -> list:
"""#### Crop conditioning data to match a specific region.
#### Args:
- `cond` (list): The conditioning data.
- `region` (tuple): The crop region coordinates.
- `init_size` (tuple): The initial size of the image.
- `canvas_size` (tuple): The size of the canvas.
- `tile_size` (tuple): The size of the tile.
- `w_pad` (int, optional): The width padding. Defaults to 0.
- `h_pad` (int, optional): The height padding. Defaults to 0.
#### Returns:
- `list`: The cropped conditioning data.
"""
cropped = []
for emb, x in cond:
cond_dict = x.copy()
n = [emb, cond_dict]
cropped.append(n)
return cropped