Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,258 Bytes
5004324 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import math
import torch
from tqdm import tqdm
from .utils import get_tensor_items
import torch.nn.functional as F
def get_named_beta_schedule(schedule_name, timesteps):
if schedule_name == "linear":
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(
beta_start, beta_end, timesteps, dtype=torch.float32
)
elif schedule_name == "cosine":
alpha_bar = lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(timesteps):
t1 = i / timesteps
t2 = (i + 1) / timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), 0.999))
return torch.tensor(betas, dtype=torch.float32)
class BaseDiffusion:
def __init__(self, betas, percentile=None, gen_noise=torch.randn_like):
self.betas = betas
self.num_timesteps = betas.shape[0]
alphas = 1. - betas
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
self.alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=betas.dtype), self.alphas_cumprod[:-1]])
# calculate q(x_t | x_{t-1})
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
self.time_scale = 1000 // self.num_timesteps
self.gen_noise = gen_noise
def get_x_start(self, x, t, noise):
sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, noise.shape)
sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, noise.shape)
pred_x_start = (x - sqrt_one_minus_alphas_cumprod * noise) / sqrt_alphas_cumprod
return pred_x_start
def q_sample(self, x_start, t, noise=None):
if noise is None:
noise = self.gen_noise(x_start)
sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, noise.shape)
x_t = sqrt_alphas_cumprod * x_start + sqrt_one_minus_alphas_cumprod * noise
return x_t
@torch.no_grad()
def refine(self, model, img, context, context_mask):
# for time in tqdm([479, 229]):
for time in [229]:
time = torch.tensor([time,] * img.shape[0], device=img.device)
x_t = self.q_sample(img, time)
pred_noise = model(x_t, time.type(x_t.dtype), context, context_mask.bool())
img = self.get_x_start(x_t, time, pred_noise)
return img
def blend_v(
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent):
b[ :, :, y, :] = a[ :, :, -blend_extent + y, :] * (
1 - y / blend_extent
) + b[ :, :, y, :] * (y / blend_extent)
return b
def blend_h(
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent):
b[ :, :, :, x] = a[ :, :, :, -blend_extent + x] * (
1 - x / blend_extent
) + b[ :, :, :, x] * (x / blend_extent)
return b
def refine_tiled(self, model, img, context, context_mask):
tile_sample_min_size = 352
tile_overlap_factor = 0.25
overlap_size = int(tile_sample_min_size * (1 - tile_overlap_factor))
tile_latent_min_size = int(tile_sample_min_size)
blend_extent = int(tile_latent_min_size * tile_overlap_factor)
row_limit = tile_latent_min_size - blend_extent
# Split the image into tiles and encode them separately.
rows = []
for i in tqdm(range(0, img.shape[2], overlap_size)):
row = []
for j in range(0, img.shape[3], overlap_size):
tile = img[
:,
:,
i : i + tile_sample_min_size,
j : j + tile_sample_min_size,
]
tile = self.refine(model, tile, context, context_mask)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[ :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
refine_img = torch.cat(result_rows, dim=2)
return refine_img
def get_diffusion(conf):
betas = get_named_beta_schedule(**conf.schedule_params)
base_diffusion = BaseDiffusion(betas, **conf.diffusion_params)
return base_diffusion |