Spaces:
Build error
Build error
import random | |
import comet.src.train.train as base_train | |
import comet.src.train.batch as batch | |
import comet.src.evaluate.atomic_evaluate as evaluate | |
# import comet.src.evaluate.atomic_generate as gen | |
def make_trainer(opt, *args): | |
return AtomicGenerationIteratorTrainer(opt, *args) | |
class AtomicGenerationIteratorTrainer(base_train.IteratorTrainer): | |
def __init__(self, opt, *args): | |
super(AtomicGenerationIteratorTrainer, self).__init__(opt, *args) | |
self.initialize_losses(opt.data.get("categories", [])) | |
def set_evaluator(self, opt, model, data_loader): | |
self.evaluator = evaluate.make_evaluator( | |
opt, model, data_loader) | |
# def set_generator(self, opt, model, data_loader, scores, reward=None): | |
# self.generator = gen.make_generator( | |
# opt, model, data_loader, scores, reward) | |
def set_sampler(self, opt): | |
if opt.train.static.samp not in self.samplers: | |
self.samplers[opt.train.static.samp] = sampling.make_sampler( | |
opt.train.static.samp, opt, self.data_loader, batch_mode=True) | |
self.batch_variables["sampler"] = self.samplers | |
def batch(self, opt, *args): | |
outputs = batch.batch_atomic_generate(opt, *args) | |
token_loss = outputs["loss"] | |
nums = outputs["nums"] | |
reset = outputs["reset"] | |
return token_loss, nums, reset | |
def initialize_losses(self, categories): | |
self.losses["train"] = { | |
"total_micro": [0], | |
"total_macro": [0] | |
} | |
nums = {"total_micro": 0, "total_macro": 0} | |
for category in categories: | |
micro_name = "{}_micro".format(category) | |
macro_name = "{}_macro".format(category) | |
self.losses["train"][micro_name] = [0] | |
self.losses["train"][macro_name] = [0] | |
nums[micro_name] = 0 | |
nums[macro_name] = 0 | |
return nums | |
def update_top_score(self, opt): | |
print(self.top_score) | |
if self.top_score is None: | |
self.top_score = (self.opt.train.dynamic.epoch, | |
self.get_tracked_score()) | |
elif self.get_tracked_score() < self.top_score[-1]: | |
self.top_score = (self.opt.train.dynamic.epoch, | |
self.get_tracked_score()) | |
print(self.top_score) | |
def get_tracked_score(self): | |
return self.losses["dev"]["total_micro"][self.opt.train.dynamic.epoch] | |
def counter(self, nums): | |
return nums["total_macro"] | |