Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,862 Bytes
1d117d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
from typing import List
import torch
def bislerp(samples: torch.Tensor, width: int, height: int) -> torch.Tensor:
"""#### Perform bilinear interpolation on samples.
#### Args:
- `samples` (torch.Tensor): The input samples.
- `width` (int): The target width.
- `height` (int): The target height.
#### Returns:
- `torch.Tensor`: The interpolated samples.
"""
def slerp(b1: torch.Tensor, b2: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
"""#### Perform spherical linear interpolation between two vectors.
#### Args:
- `b1` (torch.Tensor): The first vector.
- `b2` (torch.Tensor): The second vector.
- `r` (torch.Tensor): The interpolation ratio.
#### Returns:
- `torch.Tensor`: The interpolated vector.
"""
c = b1.shape[-1]
# norms
b1_norms = torch.norm(b1, dim=-1, keepdim=True)
b2_norms = torch.norm(b2, dim=-1, keepdim=True)
# normalize
b1_normalized = b1 / b1_norms
b2_normalized = b2 / b2_norms
# zero when norms are zero
b1_normalized[b1_norms.expand(-1, c) == 0.0] = 0.0
b2_normalized[b2_norms.expand(-1, c) == 0.0] = 0.0
# slerp
dot = (b1_normalized * b2_normalized).sum(1)
omega = torch.acos(dot)
so = torch.sin(omega)
# technically not mathematically correct, but more pleasing?
res = (torch.sin((1.0 - r.squeeze(1)) * omega) / so).unsqueeze(
1
) * b1_normalized + (torch.sin(r.squeeze(1) * omega) / so).unsqueeze(
1
) * b2_normalized
res *= (b1_norms * (1.0 - r) + b2_norms * r).expand(-1, c)
# edge cases for same or polar opposites
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
res[dot < 1e-5 - 1] = (b1 * (1.0 - r) + b2 * r)[dot < 1e-5 - 1]
return res
def generate_bilinear_data(
length_old: int, length_new: int, device: torch.device
) -> List[torch.Tensor]:
"""#### Generate bilinear data for interpolation.
#### Args:
- `length_old` (int): The old length.
- `length_new` (int): The new length.
- `device` (torch.device): The device to use.
#### Returns:
- `torch.Tensor`: The ratios.
- `torch.Tensor`: The first coordinates.
- `torch.Tensor`: The second coordinates.
"""
coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape(
(1, 1, 1, -1)
)
coords_1 = torch.nn.functional.interpolate(
coords_1, size=(1, length_new), mode="bilinear"
)
ratios = coords_1 - coords_1.floor()
coords_1 = coords_1.to(torch.int64)
coords_2 = (
torch.arange(length_old, dtype=torch.float32, device=device).reshape(
(1, 1, 1, -1)
)
+ 1
)
coords_2[:, :, :, -1] -= 1
coords_2 = torch.nn.functional.interpolate(
coords_2, size=(1, length_new), mode="bilinear"
)
coords_2 = coords_2.to(torch.int64)
return ratios, coords_1, coords_2
orig_dtype = samples.dtype
samples = samples.float()
n, c, h, w = samples.shape
h_new, w_new = (height, width)
# linear w
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device)
coords_1 = coords_1.expand((n, c, h, -1))
coords_2 = coords_2.expand((n, c, h, -1))
ratios = ratios.expand((n, 1, h, -1))
pass_1 = samples.gather(-1, coords_1).movedim(1, -1).reshape((-1, c))
pass_2 = samples.gather(-1, coords_2).movedim(1, -1).reshape((-1, c))
ratios = ratios.movedim(1, -1).reshape((-1, 1))
result = slerp(pass_1, pass_2, ratios)
result = result.reshape(n, h, w_new, c).movedim(-1, 1)
# linear h
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device)
coords_1 = coords_1.reshape((1, 1, -1, 1)).expand((n, c, -1, w_new))
coords_2 = coords_2.reshape((1, 1, -1, 1)).expand((n, c, -1, w_new))
ratios = ratios.reshape((1, 1, -1, 1)).expand((n, 1, -1, w_new))
pass_1 = result.gather(-2, coords_1).movedim(1, -1).reshape((-1, c))
pass_2 = result.gather(-2, coords_2).movedim(1, -1).reshape((-1, c))
ratios = ratios.movedim(1, -1).reshape((-1, 1))
result = slerp(pass_1, pass_2, ratios)
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
return result.to(orig_dtype)
def common_upscale(samples: List, width: int, height: int) -> torch.Tensor:
"""#### Upscales the given samples to the specified width and height using the specified method and crop settings.
#### Args:
- `samples` (list): The list of samples to be upscaled.
- `width` (int): The target width for the upscaled samples.
- `height` (int): The target height for the upscaled samples.
#### Returns:
- `torch.Tensor`: The upscaled samples.
"""
s = samples
return bislerp(s, width, height)
class LatentUpscale:
"""#### A class to upscale latent codes."""
def upscale(self, samples: dict, width: int, height: int) -> tuple:
"""#### Upscales the given latent codes.
#### Args:
- `samples` (dict): The latent codes to be upscaled.
- `width` (int): The target width for the upscaled samples.
- `height` (int): The target height for the upscaled samples.
#### Returns:
- `tuple`: The upscaled samples.
"""
if width == 0 and height == 0:
s = samples
else:
s = samples.copy()
width = max(64, width)
height = max(64, height)
s["samples"] = common_upscale(samples["samples"], width // 8, height // 8)
return (s,)
|