File size: 3,029 Bytes
117183e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import logging
from torchmetrics import PeakSignalNoiseRatio as PSNR
from torchmetrics import StructuralSimilarityIndexMeasure as SSIM
from lpips import LPIPS
from utils.deltaE import deltaEab, deltaE00

class Evaluator():
    def __init__(self, dataloader, metrics, split_name, log_dirpath, best_metric):
        self.dataloader = dataloader
        self._create_metrics(metrics)
        self.split_name = split_name
        self.log_dirpath = log_dirpath
        self.best_metric = best_metric
        self.best_value = 0

    def _create_metrics(self, metrics):
        self.metrics = {}
        self.cumulative_values = {}
        for metric in metrics:
            if metric.type == 'PSNR':
                self.metrics['PSNR'] = PSNR(**metric.params).cuda()
                self.cumulative_values['PSNR'] = 0
            elif metric.type == 'SSIM':
                self.metrics['SSIM'] = SSIM(**metric.params).cuda()
                self.cumulative_values['SSIM'] = 0
            elif metric.type == 'LPIPS':
                self.metrics['LPIPS'] = LPIPS(**metric.params).cuda()
                self.cumulative_values['LPIPS'] = 0
            elif metric.type == 'deltaEab':
                self.metrics['deltaEab'] = deltaEab()
                self.cumulative_values['deltaEab'] = 0
            elif metric.type == 'deltaE00':
                self.metrics['deltaE00'] = deltaE00()
                self.cumulative_values['deltaE00'] = 0
            else:
                raise NotImplementedError(f"Metric {metric.type} not implemented")

    def _compute_metrics(self, input_image, target_image):
        for name, metric in self.metrics.items():
            self.cumulative_values[name] += metric(input_image, target_image)

    def _compute_average_metrics(self):
        avg_metrics = {}
        for name, value in self.cumulative_values.items():
            avg_metrics[name] = float(value / len(self.dataloader))
        return avg_metrics

    def _reset_metrics(self):
        for metric in self.metrics:
            self.cumulative_values[metric] = 0

    def __call__(self, model, save_results=True):
        model.eval()
        self._reset_metrics()
        with torch.no_grad():
            for data in self.dataloader:
                input_image, target_image, name = data['input_image'], data['target_image'], data['name']

                self._compute_metrics(input_image.cuda(), target_image.cuda())

        avg_metrics = self._compute_average_metrics()
        logging.info(f"{self.split_name} metrics: " + ", ".join([f'{key}: {value:.4f}' for key, value in avg_metrics.items()]))

        if (avg_metrics[self.best_metric] > self.best_value) and save_results:
            self.best_value = avg_metrics[self.best_metric]
            torch.save({**{'model_state_dict': model.state_dict()}, **avg_metrics},
                       f"{self.log_dirpath}/{self.split_name}_best_model.pth")
            logging.info(f"New best model saved at {self.log_dirpath}/{self.split_name}_best_model.pth")