File size: 2,810 Bytes
17cd746 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
# Copyright (c) 2023-2024, Zexin He
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from einops import rearrange
__all__ = ['LPIPSLoss']
class LPIPSLoss(nn.Module):
"""
Compute LPIPS loss between two images.
"""
def __init__(self, device, prefech: bool = False):
super().__init__()
self.device = device
self.cached_models = {}
if prefech:
self.prefetch_models()
def _get_model(self, model_name: str):
if model_name not in self.cached_models:
import warnings
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=UserWarning)
import lpips
_model = lpips.LPIPS(net=model_name, eval_mode=True, verbose=False).to(self.device)
_model = torch.compile(_model)
self.cached_models[model_name] = _model
return self.cached_models[model_name]
def prefetch_models(self):
_model_names = ['alex', 'vgg']
for model_name in _model_names:
self._get_model(model_name)
def forward(self, x, y, is_training: bool = True, conf_sigma=None, only_sym_conf=False):
"""
Assume images are 0-1 scaled and channel first.
Args:
x: [N, M, C, H, W]
y: [N, M, C, H, W]
is_training: whether to use VGG or AlexNet.
Returns:
Mean-reduced LPIPS loss across batch.
"""
model_name = 'vgg' if is_training else 'alex'
loss_fn = self._get_model(model_name)
EPS = 1e-7
if len(x.shape) == 5:
N, M, C, H, W = x.shape
x = x.reshape(N*M, C, H, W)
y = y.reshape(N*M, C, H, W)
image_loss = loss_fn(x, y, normalize=True)
image_loss = image_loss.mean(dim=[1, 2, 3])
batch_loss = image_loss.reshape(N, M).mean(dim=1)
all_loss = batch_loss.mean()
else:
image_loss = loss_fn(x, y, normalize=True)
if conf_sigma is not None:
image_loss = image_loss / (2*conf_sigma**2 +EPS) + (conf_sigma +EPS).log()
image_loss = image_loss.mean(dim=[1, 2, 3])
all_loss = image_loss.mean()
return all_loss
|