Spaces:
Build error
Build error
| import comet.src.data.data as data | |
| import comet.src.data.config as cfg | |
| import comet.src.evaluate.sampler as sampling | |
| def do_gen_run(opt, generator, l, split="dev", scores={}): | |
| # Generate sequences for examples in evaluation set using | |
| # current trained model | |
| if opt.eval.gs == "full": | |
| sequences, avg_scores, indiv_scores = generator.generate(split) | |
| else: | |
| sequences, avg_scores, indiv_scores = generator.generate_some(split) | |
| if avg_scores is not None: | |
| # Record scores from generated sequences | |
| for score_name, score_val in avg_scores.items(): | |
| scores.setdefault(score_name, {}) | |
| scores[score_name].setdefault(l, []) | |
| scores[score_name][l] += [score_val] | |
| # Save generated sequences | |
| save_sequences(opt, sequences, avg_scores, indiv_scores, | |
| l, split, opt.eval.gs == "full", | |
| generator.data_loader) | |
| def save_sequences(opt, sequences, avg_scores, indiv_scores, | |
| l, split, full, data_loader): | |
| # This seems a bit roundabout since l = opt.train.dynamic in train.py | |
| # But it's in case we start checkpointing outside of epoch boundaries | |
| opt.train.dynamic.epoch = l | |
| if cfg.save: | |
| if full: | |
| names = {"gens": "gens", "scores": "scores", | |
| "indiv": "indiv.scores"} | |
| else: | |
| names = {"gens": "gens.small", "scores": "scores.small", | |
| "indiv": "indiv.scores.small"} | |
| # Save generated sequences | |
| data.save_eval_file(opt, sequences, names["gens"], split) | |
| if avg_scores is not None: | |
| # Save average scores over evaluation set for generated sequences | |
| # Scores computed are the ones the generator was initialized with | |
| data.save_eval_file(opt, avg_scores, names["scores"], split) | |
| if split == "dev": | |
| # Save individual scores | |
| data.save_eval_file( | |
| opt, indiv_scores, names["indiv"], split) | |
| class Generator(object): | |
| def __init__(self, opt, model, data_loader, scorers, reward_function=None): | |
| super(Generator, self).__init__() | |
| self.opt = opt | |
| self.model = model | |
| self.data_loader = data_loader | |
| self.sampler = sampling.make_sampler( | |
| opt.eval.sample, opt, data_loader) | |
| def generate(self, split="dev"): | |
| pass | |
| def generate_batch(self, sequences, split, verbose=False, bs=32): | |
| pass | |