Spaces:
Running
on
Zero
Running
on
Zero
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary | |
# | |
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual | |
# property and proprietary rights in and to this material, related | |
# documentation and any modifications thereto. Any use, reproduction, | |
# disclosure or distribution of this material and related documentation | |
# without an express license agreement from NVIDIA CORPORATION or | |
# its affiliates is strictly prohibited. | |
"""Loss functions.""" | |
import torch | |
from torch_utils import training_stats | |
from torch_utils.ops import conv2d_gradfix | |
from torch_utils.ops import upfirdn2d | |
# Distortion loss in MipNeRF-360: https://github.com/google-research/multinerf | |
def lossfun_distortion(t, w, reduction='mean'): | |
"""Compute iint w[i] w[j] |t[i] - t[j]| di dj.""" | |
# The loss incurred between all pairs of intervals. | |
# print("t:",t.shape) | |
# print("w:",w.shape) | |
ut = (t[..., 1:] + t[..., :-1]) / 2 | |
dut = torch.abs(ut[..., :, None] - ut[..., None, :]) | |
loss_inter = torch.sum(w * torch.sum(w[..., None, :] * dut, dim=-1), dim=-1) | |
# The loss incurred within each individual interval with itself. | |
loss_intra = torch.sum(w**2 * (t[..., 1:] - t[..., :-1]), dim=-1) / 3 | |
if reduction is None: | |
return loss_inter + loss_intra | |
return (loss_inter + loss_intra).mean() |