File size: 2,513 Bytes
1b2a9b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
from swapae.evaluation.base_evaluator import BaseEvaluator
import swapae.util as util

def find_evaluator_using_name(filename):
    target_class_name = filename
    module_name = 'swapae.evaluation.' + filename
    eval_class = util.find_class_in_module(target_class_name, module_name)

    assert issubclass(eval_class, BaseEvaluator), \
        "Class %s should be a subclass of BaseEvaluator" % eval_class

    return eval_class


def find_evaluator_classes(opt):
    if len(opt.evaluation_metrics) == 0:
        return []

    eval_metrics = opt.evaluation_metrics.split(",")

    all_classes = []
    target_phases = []
    for metric in eval_metrics:
        if metric.startswith("train"):
            target_phases.append("train")
            metric = metric[len("train"):]
        elif metric.startswith("test"):
            target_phases.append("test")
            metric = metric[len("test"):]
        else:
            target_phases.append("test")

        metric_class = find_evaluator_using_name("%s_evaluator" % metric)
        all_classes.append(metric_class)

    return all_classes, target_phases


class GroupEvaluator(BaseEvaluator):
    @staticmethod
    def modify_commandline_options(parser, is_train):
        parser.add_argument("--evaluation_metrics", default="structure_style_grid_generation")

        opt, _ = parser.parse_known_args()
        evaluator_classes, _ = find_evaluator_classes(opt)

        for eval_class in evaluator_classes:
            parser = eval_class.modify_commandline_options(parser, is_train)

        return parser

    def __init__(self, opt, target_phase=None):
        super().__init__(opt, target_phase=None)
        self.opt = opt
        evaluator_classes, target_phases = find_evaluator_classes(opt)
        self.evaluators = [cls(opt, target_phase=phs) for cls, phs in zip(evaluator_classes, target_phases)]

    def evaluate(self, model, dataset, nsteps=None):
        original_phase = dataset.phase
        metrics = {}
        for i, evaluator in enumerate(self.evaluators):
            print("Entering evaluation using %s on %s images" % (type(evaluator).__name__, evaluator.target_phase))
            dataset.set_phase(evaluator.target_phase)
            with torch.no_grad():
                new_metrics = evaluator.evaluate(model, dataset, nsteps)
                metrics.update(new_metrics)
            print("Finished evaluation of %s" % type(evaluator).__name__)
        dataset.set_phase(original_phase)
        return metrics