|
|
|
import torch |
|
|
|
|
|
_dwt = None |
|
|
|
|
|
def _get_wavelet_loss(device, dtype): |
|
global _dwt |
|
if _dwt is not None: |
|
return _dwt |
|
|
|
|
|
from pytorch_wavelets import DWTForward |
|
|
|
dwt = DWTForward(J=1, mode='zero', wave='haar').to( |
|
device=device, dtype=dtype) |
|
_dwt = dwt |
|
return dwt |
|
|
|
|
|
def wavelet_loss(model_pred, latents, noise): |
|
model_pred = model_pred.float() |
|
latents = latents.float() |
|
noise = noise.float() |
|
dwt = _get_wavelet_loss(model_pred.device, model_pred.dtype) |
|
with torch.no_grad(): |
|
model_input_xll, model_input_xh = dwt(latents) |
|
model_input_xlh, model_input_xhl, model_input_xhh = torch.unbind(model_input_xh[0], dim=2) |
|
model_input = torch.cat([model_input_xll, model_input_xlh, model_input_xhl, model_input_xhh], dim=1) |
|
|
|
|
|
model_pred = noise - model_pred |
|
|
|
model_pred_xll, model_pred_xh = dwt(model_pred) |
|
model_pred_xlh, model_pred_xhl, model_pred_xhh = torch.unbind(model_pred_xh[0], dim=2) |
|
model_pred = torch.cat([model_pred_xll, model_pred_xlh, model_pred_xhl, model_pred_xhh], dim=1) |
|
|
|
return torch.nn.functional.mse_loss(model_pred, model_input, reduction="none") |