from skimage.metrics import structural_similarity, peak_signal_noise_ratio | |
import numpy as np | |
import lpips | |
import torch | |
from pytorch_fid.fid_score import calculate_frechet_distance | |
from pytorch_fid.inception import InceptionV3 | |
import torch.nn as nn | |
import cv2 | |
from scipy import stats | |
import os | |
def calc_ssim(pred_image, gt_image): | |
''' | |
Structural Similarity Index (SSIM) is a perceptual metric that quantifies the image quality degradation that is | |
caused by processing such as data compression or by losses in data transmission. | |
# Arguments | |
img1: PIL.Image | |
img2: PIL.Image | |
# Returns | |
ssim: float (-1.0, 1.0) | |
''' | |
pred_image = np.array(pred_image.convert('RGB')).astype(np.float32) | |
gt_image = np.array(gt_image.convert('RGB')).astype(np.float32) | |
ssim = structural_similarity(pred_image, gt_image, channel_axis=2, data_range=255.) | |
return ssim | |
def calc_psnr(pred_image, gt_image): | |
''' | |
Peak Signal-to-Noise Ratio (PSNR) is an expression for the ratio between the maximum possible value (power) of a signal | |
and the power of distorting noise that affects the quality of its representation. | |
# Arguments | |
img1: PIL.Image | |
img2: PIL.Image | |
# Returns | |
psnr: float | |
''' | |
pred_image = np.array(pred_image.convert('RGB')).astype(np.float32) | |
gt_image = np.array(gt_image.convert('RGB')).astype(np.float32) | |
psnr = peak_signal_noise_ratio(gt_image, pred_image, data_range=255.) | |
return psnr | |
class LPIPS_utils: | |
def __init__(self, device = 'cuda'): | |
self.loss_fn = lpips.LPIPS(net='vgg', spatial=True) # Can set net = 'squeeze' or 'vgg'or 'alex' | |
self.loss_fn = | |
self.device = device | |
def compare_lpips(self,img_fake, img_real, data_range=255.): # input: torch 1 c h w / h w c | |
img_fake = torch.from_numpy(np.array(img_fake).astype(np.float32)/data_range) | |
img_real = torch.from_numpy(np.array(img_real).astype(np.float32)/data_range) | |
if img_fake.ndim==3: | |
img_fake = img_fake.permute(2,0,1).unsqueeze(0) | |
img_real = img_real.permute(2,0,1).unsqueeze(0) | |
img_fake = | |
img_real = | |
dist = self.loss_fn.forward(img_fake,img_real) | |
return dist.mean().item() | |
class FID_utils(nn.Module): | |
"""Class for computing the Fréchet Inception Distance (FID) metric score. | |
It is implemented as a class in order to hold the inception model instance | |
in its state. | |
Parameters | |
---------- | |
resize_input : bool (optional) | |
Whether or not to resize the input images to the image size (299, 299) | |
on which the inception model was trained. Since the model is fully | |
convolutional, the score also works without resizing. In literature | |
and when working with GANs people tend to set this value to True, | |
however, for internal evaluation this is not necessary. | |
device : str or torch.device | |
The device on which to run the inception model. | |
""" | |
def __init__(self, resize_input=True, device="cuda"): | |
super(FID_utils, self).__init__() | |
self.device = device | |
if self.device is None: | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
#self.model = InceptionV3(resize_input=resize_input).to(device) | |
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] | |
self.model = InceptionV3([block_idx]).to(device) | |
self.model = self.model.eval() | |
def get_activations(self,batch): # 1 c h w | |
with torch.no_grad(): | |
pred = self.model(batch)[0] | |
# If model output is not scalar, apply global spatial average pooling. | |
# This happens if you choose a dimensionality not equal 2048. | |
if pred.size(2) != 1 or pred.size(3) != 1: | |
#pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) | |
print("error in get activations!") | |
#pred = pred.squeeze(3).squeeze(2).cpu().numpy() | |
return pred | |
def _get_mu_sigma(self, batch,data_range): | |
"""Compute the inception mu and sigma for a batch of images. | |
Parameters | |
---------- | |
images : np.ndarray | |
A batch of images with shape (n_images,3, width, height). | |
Returns | |
------- | |
mu : np.ndarray | |
The array of mean activations with shape (2048,). | |
sigma : np.ndarray | |
The covariance matrix of activations with shape (2048, 2048). | |
""" | |
# forward pass | |
if batch.ndim ==3 and batch.shape[2]==3: | |
batch=batch.permute(2,0,1).unsqueeze(0) | |
batch /= data_range | |
#batch = torch.tensor(batch)#.unsqueeze(1).repeat((1, 3, 1, 1)) | |
batch =, torch.float32) | |
#(activations,) = self.model(batch) | |
activations = self.get_activations(batch) | |
activations = activations.detach().cpu().numpy().squeeze(3).squeeze(2) | |
# compute statistics | |
mu = np.mean(activations,axis=0) | |
sigma = np.cov(activations, rowvar=False) | |
return mu, sigma | |
def score(self, images_1, images_2, data_range=255.): | |
"""Compute the FID score. | |
The input batches should have the shape (n_images,3, width, height). or (h,w,3) | |
Parameters | |
---------- | |
images_1 : np.ndarray | |
First batch of images. | |
images_2 : np.ndarray | |
Section batch of images. | |
Returns | |
------- | |
score : float | |
The FID score. | |
""" | |
images_1 = torch.from_numpy(np.array(images_1).astype(np.float32)) | |
images_2 = torch.from_numpy(np.array(images_2).astype(np.float32)) | |
images_1 = | |
images_2 = | |
mu_1, sigma_1 = self._get_mu_sigma(images_1,data_range) | |
mu_2, sigma_2 = self._get_mu_sigma(images_2,data_range) | |
score = calculate_frechet_distance(mu_1, sigma_1, mu_2, sigma_2) | |
return score | |
def JS_divergence(p, q): | |
M = (p + q) / 2 | |
return 0.5 * stats.entropy(p, M) + 0.5 * stats.entropy(q, M) | |
def compute_JS_bgr(input_dir, dilation=1): | |
input_img_list = os.listdir(input_dir) | |
input_img_list.sort() | |
# print(input_img_list) | |
hist_b_list = [] # [img1_histb, img2_histb, ...] | |
hist_g_list = [] | |
hist_r_list = [] | |
for img_name in input_img_list: | |
# print(os.path.join(input_dir, img_name)) | |
img_in = cv2.imread(os.path.join(input_dir, img_name)) | |
H, W, C = img_in.shape | |
hist_b = cv2.calcHist([img_in], [0], None, [256], [0,256]) # B | |
hist_g = cv2.calcHist([img_in], [1], None, [256], [0,256]) # G | |
hist_r = cv2.calcHist([img_in], [2], None, [256], [0,256]) # R | |
hist_b = hist_b / (H * W) | |
hist_g = hist_g / (H * W) | |
hist_r = hist_r / (H * W) | |
hist_b_list.append(hist_b) | |
hist_g_list.append(hist_g) | |
hist_r_list.append(hist_r) | |
JS_b_list = [] | |
JS_g_list = [] | |
JS_r_list = [] | |
for i in range(len(hist_b_list)): | |
if i + dilation > len(hist_b_list) - 1: | |
break | |
hist_b_img1 = hist_b_list[i] | |
hist_b_img2 = hist_b_list[i + dilation] | |
JS_b = JS_divergence(hist_b_img1, hist_b_img2) | |
JS_b_list.append(JS_b) | |
hist_g_img1 = hist_g_list[i] | |
hist_g_img2 = hist_g_list[i+dilation] | |
JS_g = JS_divergence(hist_g_img1, hist_g_img2) | |
JS_g_list.append(JS_g) | |
hist_r_img1 = hist_r_list[i] | |
hist_r_img2 = hist_r_list[i+dilation] | |
JS_r = JS_divergence(hist_r_img1, hist_r_img2) | |
JS_r_list.append(JS_r) | |
return JS_b_list, JS_g_list, JS_r_list | |
def calc_cdc(vid_folder, dilation=[1, 2, 4], weight=[1/3, 1/3, 1/3]): | |
mean_b, mean_g, mean_r = 0, 0, 0 | |
for d, w in zip(dilation, weight): | |
JS_b_list_one, JS_g_list_one, JS_r_list_one = compute_JS_bgr(vid_folder, d) | |
mean_b += w * np.mean(JS_b_list_one) | |
mean_g += w * np.mean(JS_g_list_one) | |
mean_r += w * np.mean(JS_r_list_one) | |
cdc = np.mean([mean_b, mean_g, mean_r]) | |
return cdc | |