|
import cv2 |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
|
|
def set_requires_grad(module: nn.Module, requires_grad: bool): |
|
for p in module.parameters(): |
|
p.requires_grad_(requires_grad) |
|
|
|
|
|
def compute_distance_transform(mask: torch.Tensor): |
|
image_size = mask.shape[-1] |
|
distance_transform = torch.stack([ |
|
torch.from_numpy(cv2.distanceTransform( |
|
(1 - m), distanceType=cv2.DIST_L2, maskSize=cv2.DIST_MASK_3 |
|
) / (image_size / 2)) |
|
for m in mask.squeeze(1).detach().cpu().numpy().astype(np.uint8) |
|
]).unsqueeze(1).clip(0, 1).to(mask.device) |
|
return distance_transform |
|
|
|
|
|
def default(x, d): |
|
return d if x is None else x |
|
|
|
def get_custom_betas(beta_start: float, beta_end: float, warmup_frac: float = 0.3, num_train_timesteps: int = 1000): |
|
"""Custom beta schedule""" |
|
betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) |
|
warmup_frac = 0.3 |
|
warmup_time = int(num_train_timesteps * warmup_frac) |
|
warmup_steps = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) |
|
warmup_time = min(warmup_time, num_train_timesteps) |
|
betas[:warmup_time] = warmup_steps[:warmup_time] |
|
return betas |
|
|