Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,404 Bytes
d9a2e19 1d117d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 |
import math
import numpy as np
import torch
from PIL import Image
def get_tiled_scale_steps(width: int, height: int, tile_x: int, tile_y: int, overlap: int) -> int:
"""#### Calculate the number of steps required for tiled scaling.
#### Args:
- `width` (int): The width of the image.
- `height` (int): The height of the image.
- `tile_x` (int): The width of each tile.
- `tile_y` (int): The height of each tile.
- `overlap` (int): The overlap between tiles.
#### Returns:
- `int`: The number of steps required for tiled scaling.
"""
return math.ceil((height / (tile_y - overlap))) * math.ceil(
(width / (tile_x - overlap))
)
@torch.inference_mode()
def tiled_scale(
samples: torch.Tensor,
function: callable,
tile_x: int = 64,
tile_y: int = 64,
overlap: int = 8,
upscale_amount: float = 4,
out_channels: int = 3,
pbar: any = None,
) -> torch.Tensor:
"""#### Perform tiled scaling on a batch of samples.
#### Args:
- `samples` (torch.Tensor): The input samples.
- `function` (callable): The function to apply to each tile.
- `tile_x` (int, optional): The width of each tile. Defaults to 64.
- `tile_y` (int, optional): The height of each tile. Defaults to 64.
- `overlap` (int, optional): The overlap between tiles. Defaults to 8.
- `upscale_amount` (float, optional): The upscale amount. Defaults to 4.
- `out_channels` (int, optional): The number of output channels. Defaults to 3.
- `pbar` (any, optional): The progress bar. Defaults to None.
#### Returns:
- `torch.Tensor`: The scaled output tensor.
"""
output = torch.empty(
(
samples.shape[0],
out_channels,
round(samples.shape[2] * upscale_amount),
round(samples.shape[3] * upscale_amount),
),
device="cpu",
)
for b in range(samples.shape[0]):
s = samples[b : b + 1]
out = torch.zeros(
(
s.shape[0],
out_channels,
round(s.shape[2] * upscale_amount),
round(s.shape[3] * upscale_amount),
),
device="cpu",
)
out_div = torch.zeros(
(
s.shape[0],
out_channels,
round(s.shape[2] * upscale_amount),
round(s.shape[3] * upscale_amount),
),
device="cpu",
)
for y in range(0, s.shape[2], tile_y - overlap):
for x in range(0, s.shape[3], tile_x - overlap):
s_in = s[:, :, y : y + tile_y, x : x + tile_x]
ps = function(s_in).cpu()
mask = torch.ones_like(ps)
feather = round(overlap * upscale_amount)
for t in range(feather):
mask[:, :, t : 1 + t, :] *= (1.0 / feather) * (t + 1)
mask[:, :, mask.shape[2] - 1 - t : mask.shape[2] - t, :] *= (
1.0 / feather
) * (t + 1)
mask[:, :, :, t : 1 + t] *= (1.0 / feather) * (t + 1)
mask[:, :, :, mask.shape[3] - 1 - t : mask.shape[3] - t] *= (
1.0 / feather
) * (t + 1)
out[
:,
:,
round(y * upscale_amount) : round((y + tile_y) * upscale_amount),
round(x * upscale_amount) : round((x + tile_x) * upscale_amount),
] += ps * mask
out_div[
:,
:,
round(y * upscale_amount) : round((y + tile_y) * upscale_amount),
round(x * upscale_amount) : round((x + tile_x) * upscale_amount),
] += mask
output[b : b + 1] = out / out_div
return output
def flatten(img: Image.Image, bgcolor: str) -> Image.Image:
"""#### Replace transparency with a background color.
#### Args:
- `img` (Image.Image): The input image.
- `bgcolor` (str): The background color.
#### Returns:
- `Image.Image`: The image with transparency replaced by the background color.
"""
if img.mode in ("RGB"):
return img
return Image.alpha_composite(Image.new("RGBA", img.size, bgcolor), img).convert(
"RGB"
)
BLUR_KERNEL_SIZE = 15
def tensor_to_pil(img_tensor: torch.Tensor, batch_index: int = 0) -> Image.Image:
"""#### Convert a tensor to a PIL image.
#### Args:
- `img_tensor` (torch.Tensor): The input tensor.
- `batch_index` (int, optional): The batch index. Defaults to 0.
#### Returns:
- `Image.Image`: The converted PIL image.
"""
img_tensor = img_tensor[batch_index].unsqueeze(0)
i = 255.0 * img_tensor.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8).squeeze())
return img
def pil_to_tensor(image: Image.Image) -> torch.Tensor:
"""#### Convert a PIL image to a tensor.
#### Args:
- `image` (Image.Image): The input PIL image.
#### Returns:
- `torch.Tensor`: The converted tensor.
"""
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image).unsqueeze(0)
return image
def get_crop_region(mask: Image.Image, pad: int = 0) -> tuple:
"""#### Get the coordinates of the white rectangular mask region.
#### Args:
- `mask` (Image.Image): The input mask image in 'L' mode.
- `pad` (int, optional): The padding to apply. Defaults to 0.
#### Returns:
- `tuple`: The coordinates of the crop region.
"""
coordinates = mask.getbbox()
if coordinates is not None:
x1, y1, x2, y2 = coordinates
else:
x1, y1, x2, y2 = mask.width, mask.height, 0, 0
# Apply padding
x1 = max(x1 - pad, 0)
y1 = max(y1 - pad, 0)
x2 = min(x2 + pad, mask.width)
y2 = min(y2 + pad, mask.height)
return fix_crop_region((x1, y1, x2, y2), (mask.width, mask.height))
def fix_crop_region(region: tuple, image_size: tuple) -> tuple:
"""#### Remove the extra pixel added by the get_crop_region function.
#### Args:
- `region` (tuple): The crop region coordinates.
- `image_size` (tuple): The size of the image.
#### Returns:
- `tuple`: The fixed crop region coordinates.
"""
image_width, image_height = image_size
x1, y1, x2, y2 = region
if x2 < image_width:
x2 -= 1
if y2 < image_height:
y2 -= 1
return x1, y1, x2, y2
def expand_crop(region: tuple, width: int, height: int, target_width: int, target_height: int) -> tuple:
"""#### Expand a crop region to a specified target size.
#### Args:
- `region` (tuple): The crop region coordinates.
- `width` (int): The width of the image.
- `height` (int): The height of the image.
- `target_width` (int): The desired width of the crop region.
- `target_height` (int): The desired height of the crop region.
#### Returns:
- `tuple`: The expanded crop region coordinates and the target size.
"""
x1, y1, x2, y2 = region
actual_width = x2 - x1
actual_height = y2 - y1
# Try to expand region to the right of half the difference
width_diff = target_width - actual_width
x2 = min(x2 + width_diff // 2, width)
# Expand region to the left of the difference including the pixels that could not be expanded to the right
width_diff = target_width - (x2 - x1)
x1 = max(x1 - width_diff, 0)
# Try the right again
width_diff = target_width - (x2 - x1)
x2 = min(x2 + width_diff, width)
# Try to expand region to the bottom of half the difference
height_diff = target_height - actual_height
y2 = min(y2 + height_diff // 2, height)
# Expand region to the top of the difference including the pixels that could not be expanded to the bottom
height_diff = target_height - (y2 - y1)
y1 = max(y1 - height_diff, 0)
# Try the bottom again
height_diff = target_height - (y2 - y1)
y2 = min(y2 + height_diff, height)
return (x1, y1, x2, y2), (target_width, target_height)
def crop_cond(cond: list, region: tuple, init_size: tuple, canvas_size: tuple, tile_size: tuple, w_pad: int = 0, h_pad: int = 0) -> list:
"""#### Crop conditioning data to match a specific region.
#### Args:
- `cond` (list): The conditioning data.
- `region` (tuple): The crop region coordinates.
- `init_size` (tuple): The initial size of the image.
- `canvas_size` (tuple): The size of the canvas.
- `tile_size` (tuple): The size of the tile.
- `w_pad` (int, optional): The width padding. Defaults to 0.
- `h_pad` (int, optional): The height padding. Defaults to 0.
#### Returns:
- `list`: The cropped conditioning data.
"""
cropped = []
for emb, x in cond:
cond_dict = x.copy()
n = [emb, cond_dict]
cropped.append(n)
return cropped |