Minh Q. Le
Pushed COSMIC code
a446b0b
import time
import torch
import comet.utils.utils as utils
import comet.src.data.config as cfg
class Evaluator(object):
def __init__(self, opt, model, data_loader):
super(Evaluator, self).__init__()
self.data_loader = data_loader
self.model = model
self.batch_variables = {
"model": model,
"data": data_loader
}
self.opt = opt
def validate(self, l, split="dev", losses={}, keyset=None):
self.batch_variables["split"] = split
print("Evaluating {}".format(split))
epoch_losses = self.epoch(
self.opt, self.model, self.data_loader, split, keyset)
self.print_result(split, epoch_losses)
for loss_name, loss_val in epoch_losses.items():
losses.setdefault(loss_name, {})
losses[loss_name][l] = loss_val
def epoch(self, opt, model, data_loader, split, keyset=None):
average_loss, nums = self.initialize_losses()
data_loader.reset_offsets(splits=split, shuffle=False)
# Set evaluation mode
model.eval()
start = time.time()
# Initialize progress bar
bar = utils.set_progress_bar(
data_loader.total_size[split])
reset = False
with torch.no_grad():
while not reset:
start = data_loader.offset_summary(split)
outputs = self.batch(
opt, nums, average_loss,
self.batch_variables, eval_mode=True)
end = data_loader.offset_summary(split)
reset = outputs["reset"]
if not reset:
bar.update(end - start)
else:
print(end)
if cfg.toy and self.counter(nums) > 100:
break
if (opt.eval.es != "full" and
(self.counter(nums) > opt.eval.es)):
break
nums = outputs["nums"]
torch.cuda.synchronize()
print("{} evaluation completed in: {} s".format(
split.capitalize(), time.time() - start))
average_loss = self.compute_final_scores(
average_loss, nums)
return average_loss