Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import numpy as np | |
import torch | |
from PIL import Image | |
def get_tiled_scale_steps(width: int, height: int, tile_x: int, tile_y: int, overlap: int) -> int: | |
"""#### Calculate the number of steps required for tiled scaling. | |
#### Args: | |
- `width` (int): The width of the image. | |
- `height` (int): The height of the image. | |
- `tile_x` (int): The width of each tile. | |
- `tile_y` (int): The height of each tile. | |
- `overlap` (int): The overlap between tiles. | |
#### Returns: | |
- `int`: The number of steps required for tiled scaling. | |
""" | |
return math.ceil((height / (tile_y - overlap))) * math.ceil( | |
(width / (tile_x - overlap)) | |
) | |
def tiled_scale( | |
samples: torch.Tensor, | |
function: callable, | |
tile_x: int = 64, | |
tile_y: int = 64, | |
overlap: int = 8, | |
upscale_amount: float = 4, | |
out_channels: int = 3, | |
pbar: any = None, | |
) -> torch.Tensor: | |
"""#### Perform tiled scaling on a batch of samples. | |
#### Args: | |
- `samples` (torch.Tensor): The input samples. | |
- `function` (callable): The function to apply to each tile. | |
- `tile_x` (int, optional): The width of each tile. Defaults to 64. | |
- `tile_y` (int, optional): The height of each tile. Defaults to 64. | |
- `overlap` (int, optional): The overlap between tiles. Defaults to 8. | |
- `upscale_amount` (float, optional): The upscale amount. Defaults to 4. | |
- `out_channels` (int, optional): The number of output channels. Defaults to 3. | |
- `pbar` (any, optional): The progress bar. Defaults to None. | |
#### Returns: | |
- `torch.Tensor`: The scaled output tensor. | |
""" | |
output = torch.empty( | |
( | |
samples.shape[0], | |
out_channels, | |
round(samples.shape[2] * upscale_amount), | |
round(samples.shape[3] * upscale_amount), | |
), | |
device="cpu", | |
) | |
for b in range(samples.shape[0]): | |
s = samples[b : b + 1] | |
out = torch.zeros( | |
( | |
s.shape[0], | |
out_channels, | |
round(s.shape[2] * upscale_amount), | |
round(s.shape[3] * upscale_amount), | |
), | |
device="cpu", | |
) | |
out_div = torch.zeros( | |
( | |
s.shape[0], | |
out_channels, | |
round(s.shape[2] * upscale_amount), | |
round(s.shape[3] * upscale_amount), | |
), | |
device="cpu", | |
) | |
for y in range(0, s.shape[2], tile_y - overlap): | |
for x in range(0, s.shape[3], tile_x - overlap): | |
s_in = s[:, :, y : y + tile_y, x : x + tile_x] | |
ps = function(s_in).cpu() | |
mask = torch.ones_like(ps) | |
feather = round(overlap * upscale_amount) | |
for t in range(feather): | |
mask[:, :, t : 1 + t, :] *= (1.0 / feather) * (t + 1) | |
mask[:, :, mask.shape[2] - 1 - t : mask.shape[2] - t, :] *= ( | |
1.0 / feather | |
) * (t + 1) | |
mask[:, :, :, t : 1 + t] *= (1.0 / feather) * (t + 1) | |
mask[:, :, :, mask.shape[3] - 1 - t : mask.shape[3] - t] *= ( | |
1.0 / feather | |
) * (t + 1) | |
out[ | |
:, | |
:, | |
round(y * upscale_amount) : round((y + tile_y) * upscale_amount), | |
round(x * upscale_amount) : round((x + tile_x) * upscale_amount), | |
] += ps * mask | |
out_div[ | |
:, | |
:, | |
round(y * upscale_amount) : round((y + tile_y) * upscale_amount), | |
round(x * upscale_amount) : round((x + tile_x) * upscale_amount), | |
] += mask | |
output[b : b + 1] = out / out_div | |
return output | |
def flatten(img: Image.Image, bgcolor: str) -> Image.Image: | |
"""#### Replace transparency with a background color. | |
#### Args: | |
- `img` (Image.Image): The input image. | |
- `bgcolor` (str): The background color. | |
#### Returns: | |
- `Image.Image`: The image with transparency replaced by the background color. | |
""" | |
if img.mode in ("RGB"): | |
return img | |
return Image.alpha_composite(Image.new("RGBA", img.size, bgcolor), img).convert( | |
"RGB" | |
) | |
BLUR_KERNEL_SIZE = 15 | |
def tensor_to_pil(img_tensor: torch.Tensor, batch_index: int = 0) -> Image.Image: | |
"""#### Convert a tensor to a PIL image. | |
#### Args: | |
- `img_tensor` (torch.Tensor): The input tensor. | |
- `batch_index` (int, optional): The batch index. Defaults to 0. | |
#### Returns: | |
- `Image.Image`: The converted PIL image. | |
""" | |
img_tensor = img_tensor[batch_index].unsqueeze(0) | |
i = 255.0 * img_tensor.cpu().numpy() | |
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8).squeeze()) | |
return img | |
def pil_to_tensor(image: Image.Image) -> torch.Tensor: | |
"""#### Convert a PIL image to a tensor. | |
#### Args: | |
- `image` (Image.Image): The input PIL image. | |
#### Returns: | |
- `torch.Tensor`: The converted tensor. | |
""" | |
image = np.array(image).astype(np.float32) / 255.0 | |
image = torch.from_numpy(image).unsqueeze(0) | |
return image | |
def get_crop_region(mask: Image.Image, pad: int = 0) -> tuple: | |
"""#### Get the coordinates of the white rectangular mask region. | |
#### Args: | |
- `mask` (Image.Image): The input mask image in 'L' mode. | |
- `pad` (int, optional): The padding to apply. Defaults to 0. | |
#### Returns: | |
- `tuple`: The coordinates of the crop region. | |
""" | |
coordinates = mask.getbbox() | |
if coordinates is not None: | |
x1, y1, x2, y2 = coordinates | |
else: | |
x1, y1, x2, y2 = mask.width, mask.height, 0, 0 | |
# Apply padding | |
x1 = max(x1 - pad, 0) | |
y1 = max(y1 - pad, 0) | |
x2 = min(x2 + pad, mask.width) | |
y2 = min(y2 + pad, mask.height) | |
return fix_crop_region((x1, y1, x2, y2), (mask.width, mask.height)) | |
def fix_crop_region(region: tuple, image_size: tuple) -> tuple: | |
"""#### Remove the extra pixel added by the get_crop_region function. | |
#### Args: | |
- `region` (tuple): The crop region coordinates. | |
- `image_size` (tuple): The size of the image. | |
#### Returns: | |
- `tuple`: The fixed crop region coordinates. | |
""" | |
image_width, image_height = image_size | |
x1, y1, x2, y2 = region | |
if x2 < image_width: | |
x2 -= 1 | |
if y2 < image_height: | |
y2 -= 1 | |
return x1, y1, x2, y2 | |
def expand_crop(region: tuple, width: int, height: int, target_width: int, target_height: int) -> tuple: | |
"""#### Expand a crop region to a specified target size. | |
#### Args: | |
- `region` (tuple): The crop region coordinates. | |
- `width` (int): The width of the image. | |
- `height` (int): The height of the image. | |
- `target_width` (int): The desired width of the crop region. | |
- `target_height` (int): The desired height of the crop region. | |
#### Returns: | |
- `tuple`: The expanded crop region coordinates and the target size. | |
""" | |
x1, y1, x2, y2 = region | |
actual_width = x2 - x1 | |
actual_height = y2 - y1 | |
# Try to expand region to the right of half the difference | |
width_diff = target_width - actual_width | |
x2 = min(x2 + width_diff // 2, width) | |
# Expand region to the left of the difference including the pixels that could not be expanded to the right | |
width_diff = target_width - (x2 - x1) | |
x1 = max(x1 - width_diff, 0) | |
# Try the right again | |
width_diff = target_width - (x2 - x1) | |
x2 = min(x2 + width_diff, width) | |
# Try to expand region to the bottom of half the difference | |
height_diff = target_height - actual_height | |
y2 = min(y2 + height_diff // 2, height) | |
# Expand region to the top of the difference including the pixels that could not be expanded to the bottom | |
height_diff = target_height - (y2 - y1) | |
y1 = max(y1 - height_diff, 0) | |
# Try the bottom again | |
height_diff = target_height - (y2 - y1) | |
y2 = min(y2 + height_diff, height) | |
return (x1, y1, x2, y2), (target_width, target_height) | |
def crop_cond(cond: list, region: tuple, init_size: tuple, canvas_size: tuple, tile_size: tuple, w_pad: int = 0, h_pad: int = 0) -> list: | |
"""#### Crop conditioning data to match a specific region. | |
#### Args: | |
- `cond` (list): The conditioning data. | |
- `region` (tuple): The crop region coordinates. | |
- `init_size` (tuple): The initial size of the image. | |
- `canvas_size` (tuple): The size of the canvas. | |
- `tile_size` (tuple): The size of the tile. | |
- `w_pad` (int, optional): The width padding. Defaults to 0. | |
- `h_pad` (int, optional): The height padding. Defaults to 0. | |
#### Returns: | |
- `list`: The cropped conditioning data. | |
""" | |
cropped = [] | |
for emb, x in cond: | |
cond_dict = x.copy() | |
n = [emb, cond_dict] | |
cropped.append(n) | |
return cropped |