Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
import glob | |
import numpy as np | |
from . import metric as metric_path | |
from . import predictor as predictor_path | |
class Evaluator(object): | |
""" | |
perform evaluation on a single (downstream) task. | |
make this both offline and online. | |
TODO(huxu) saving evaluation results. | |
""" | |
def __init__(self, config, eval_dataloader=None): | |
if config.metric is None: | |
raise ValueError("config.metric is", config.metric) | |
metric_cls = getattr(metric_path, config.metric) | |
self.metric = metric_cls(config) | |
if config.predictor is None: | |
raise ValueError("config.predictor is", config.predictor) | |
predictor_cls = getattr(predictor_path, config.predictor) | |
self.predictor = predictor_cls(config) | |
self.eval_dataloader = eval_dataloader | |
def __call__(self): | |
try: | |
print(self.predictor.pred_dir) | |
for pred_file in glob.glob( | |
self.predictor.pred_dir + "/*_merged.npy"): | |
outputs = np.load(pred_file) | |
results = self.metric.compute_metrics(outputs) | |
self.metric.print_computed_metrics(results) | |
outputs = np.load(os.path.join( | |
self.predictor.pred_dir, "merged.npy")) | |
results = self.metric.compute_metrics(outputs) | |
return {"results": results, "metric": self.metric} | |
except FileNotFoundError: | |
print("\n[missing]", self.predictor.pred_dir) | |
return {} | |
def evaluate(self, model, eval_dataloader=None, output_file="merged"): | |
if eval_dataloader is None: | |
eval_dataloader = self.eval_dataloader | |
outputs = self.predictor.predict_loop( | |
model, eval_dataloader, output_file) | |
results = self.metric.compute_metrics(**outputs) | |
return results | |