Spaces:
Build error
Build error
import time | |
import torch | |
import comet.utils.utils as utils | |
import comet.src.data.config as cfg | |
class Evaluator(object): | |
def __init__(self, opt, model, data_loader): | |
super(Evaluator, self).__init__() | |
self.data_loader = data_loader | |
self.model = model | |
self.batch_variables = { | |
"model": model, | |
"data": data_loader | |
} | |
self.opt = opt | |
def validate(self, l, split="dev", losses={}, keyset=None): | |
self.batch_variables["split"] = split | |
print("Evaluating {}".format(split)) | |
epoch_losses = self.epoch( | |
self.opt, self.model, self.data_loader, split, keyset) | |
self.print_result(split, epoch_losses) | |
for loss_name, loss_val in epoch_losses.items(): | |
losses.setdefault(loss_name, {}) | |
losses[loss_name][l] = loss_val | |
def epoch(self, opt, model, data_loader, split, keyset=None): | |
average_loss, nums = self.initialize_losses() | |
data_loader.reset_offsets(splits=split, shuffle=False) | |
# Set evaluation mode | |
model.eval() | |
start = time.time() | |
# Initialize progress bar | |
bar = utils.set_progress_bar( | |
data_loader.total_size[split]) | |
reset = False | |
with torch.no_grad(): | |
while not reset: | |
start = data_loader.offset_summary(split) | |
outputs = self.batch( | |
opt, nums, average_loss, | |
self.batch_variables, eval_mode=True) | |
end = data_loader.offset_summary(split) | |
reset = outputs["reset"] | |
if not reset: | |
bar.update(end - start) | |
else: | |
print(end) | |
if cfg.toy and self.counter(nums) > 100: | |
break | |
if (opt.eval.es != "full" and | |
(self.counter(nums) > opt.eval.es)): | |
break | |
nums = outputs["nums"] | |
torch.cuda.synchronize() | |
print("{} evaluation completed in: {} s".format( | |
split.capitalize(), time.time() - start)) | |
average_loss = self.compute_final_scores( | |
average_loss, nums) | |
return average_loss | |