|
import random |
|
|
|
import numpy as np |
|
|
|
from rich import get_console |
|
from rich.table import Table |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
def set_seed(seed: int) -> None: |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
|
|
|
|
def print_table(title: str, metrics: dict) -> None: |
|
table = Table(title=title) |
|
|
|
table.add_column("Metrics", style="cyan", no_wrap=True) |
|
table.add_column("Value", style="magenta") |
|
|
|
for key, value in metrics.items(): |
|
table.add_row(key, str(value)) |
|
|
|
console = get_console() |
|
console.print(table, justify="center") |
|
|
|
|
|
def move_batch_to_device(batch: dict, device: torch.device) -> dict: |
|
for key in batch.keys(): |
|
if isinstance(batch[key], torch.Tensor): |
|
batch[key] = batch[key].to(device) |
|
return batch |
|
|
|
|
|
def count_parameters(module: nn.Module) -> float: |
|
num_params = sum(p.numel() for p in module.parameters()) |
|
return round(num_params / 1e6, 3) |
|
|
|
|
|
def get_guidance_scale_embedding(w: torch.Tensor, embedding_dim: int = 512, |
|
dtype: torch.dtype = torch.float32) -> torch.Tensor: |
|
assert len(w.shape) == 1 |
|
w = w * 1000.0 |
|
half_dim = embedding_dim // 2 |
|
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) |
|
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) |
|
emb = w.to(dtype)[:, None] * emb[None, :] |
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) |
|
if embedding_dim % 2 == 1: |
|
emb = torch.nn.functional.pad(emb, (0, 1)) |
|
assert emb.shape == (w.shape[0], embedding_dim) |
|
return emb |
|
|
|
|
|
def extract_into_tensor(a: torch.Tensor, t: torch.Tensor, x_shape: torch.Size) -> torch.Tensor: |
|
b, *_ = t.shape |
|
out = a.gather(-1, t) |
|
return out.reshape(b, *((1,) * (len(x_shape) - 1))) |
|
|
|
|
|
def sum_flat(tensor: torch.Tensor) -> torch.Tensor: |
|
return tensor.sum(dim=list(range(1, len(tensor.shape)))) |
|
|
|
|
|
def control_loss_calculate( |
|
vaeloss_type: str, loss_func: str, src: torch.Tensor, |
|
tgt: torch.Tensor, mask: torch.Tensor |
|
) -> torch.Tensor: |
|
|
|
if loss_func == 'l1': |
|
loss = F.l1_loss(src, tgt, reduction='none') |
|
elif loss_func == 'l1_smooth': |
|
loss = F.smooth_l1_loss(src, tgt, reduction='none') |
|
elif loss_func == 'l2': |
|
loss = F.mse_loss(src, tgt, reduction='none') |
|
else: |
|
raise ValueError(f'Unknown loss func: {loss_func}') |
|
|
|
if vaeloss_type == 'sum': |
|
loss = loss.sum(-1, keepdims=True) * mask |
|
loss = loss.sum() / mask.sum() |
|
elif vaeloss_type == 'sum_mask': |
|
loss = loss.sum(-1, keepdims=True) * mask |
|
loss = sum_flat(loss) / sum_flat(mask) |
|
loss = loss.mean() |
|
elif vaeloss_type == 'mask': |
|
loss = sum_flat(loss * mask) |
|
n_entries = src.shape[-1] |
|
non_zero_elements = sum_flat(mask) * n_entries |
|
loss = loss / non_zero_elements |
|
loss = loss.mean() |
|
else: |
|
raise ValueError(f'Unsupported vaeloss_type: {vaeloss_type}') |
|
|
|
return loss |
|
|