|
import torch |
|
|
|
cached_multipier = None |
|
|
|
def get_multiplier(timesteps, num_timesteps=1000): |
|
global cached_multipier |
|
if cached_multipier is None: |
|
|
|
x = torch.arange(num_timesteps, dtype=torch.float32) |
|
y = torch.exp(-2 * ((x - num_timesteps / 2) / num_timesteps) ** 2) |
|
|
|
|
|
y_shifted = y - y.min() |
|
|
|
|
|
cached_multipier = y_shifted * (num_timesteps / y_shifted.sum()) |
|
|
|
scale_list = [] |
|
|
|
for i in range(timesteps.shape[0]): |
|
idx = min(int(timesteps[i].item()) - 1, 0) |
|
scale_list.append(cached_multipier[idx:idx + 1]) |
|
|
|
scales = torch.cat(scale_list, dim=0) |
|
|
|
batch_multiplier = scales.view(-1, 1, 1, 1) |
|
|
|
return batch_multiplier |
|
|
|
|
|
def get_blended_blur_noise(latents, noise, timestep): |
|
latent_chunks = torch.chunk(latents, latents.shape[0], dim=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
blurred_latent_chunks = [] |
|
for i in range(len(latent_chunks)): |
|
latent_chunk = latent_chunks[i] |
|
|
|
|
|
scaler1 = 0.25 |
|
scaler2 = scaler1 |
|
|
|
|
|
blur_latents = torch.nn.functional.interpolate( |
|
latent_chunk, |
|
size=(int(latents.shape[2] * scaler1), int(latents.shape[3] * scaler2)), |
|
mode='bilinear', |
|
align_corners=False |
|
) |
|
blur_latents = torch.nn.functional.interpolate( |
|
blur_latents, |
|
size=(latents.shape[2], latents.shape[3]), |
|
mode='bilinear', |
|
align_corners=False |
|
) |
|
|
|
blur_latents = blur_latents - latent_chunk |
|
blurred_latent_chunks.append(blur_latents) |
|
|
|
blur_latents = torch.cat(blurred_latent_chunks, dim=0) |
|
|
|
|
|
|
|
blur_strength = torch.rand((latents.shape[0], 1, 1, 1), device=latents.device, dtype=latents.dtype) * 2 |
|
|
|
blur_latents = blur_latents * blur_strength |
|
|
|
noise = noise + blur_latents |
|
return noise |
|
|