import time import torch import comet.src.evaluate.generate as base_generate import comet.src.evaluate.sampler as sampling import comet.utils.utils as utils import comet.src.data.config as cfg def make_generator(opt, *args): return ConceptNetGenerator(opt, *args) class ConceptNetGenerator(base_generate.Generator): def __init__(self, opt, model, data_loader): self.opt = opt self.model = model self.data_loader = data_loader self.sampler = sampling.make_sampler( opt.eval.sample, opt, data_loader) def reset_sequences(self): return [] def generate(self, split="dev"): print("Generating Sequences") # Set evaluation mode self.model.eval() # Reset evaluation set for dataset split self.data_loader.reset_offsets(splits=split, shuffle=False) start = time.time() count = 0 sequences = None # Reset generated sequence buffer sequences = self.reset_sequences() # Initialize progress bar bar = utils.set_progress_bar( self.data_loader.total_size[split] / 2) reset = False with torch.no_grad(): # Cycle through development set while not reset: start = len(sequences) # Generate a single batch reset = self.generate_batch(sequences, split, bs=1) end = len(sequences) if not reset: bar.update(end - start) else: print(end) count += 1 if cfg.toy and count > 10: break if (self.opt.eval.gs != "full" and (count > opt.eval.gs)): break torch.cuda.synchronize() print("{} generations completed in: {} s".format( split, time.time() - start)) # Compute scores for sequences (e.g., BLEU, ROUGE) # Computes scores that the generator is initialized with # Change define_scorers to add more scorers as possibilities # avg_scores, indiv_scores = self.compute_sequence_scores( # sequences, split) avg_scores, indiv_scores = None, None return sequences, avg_scores, indiv_scores def generate_batch(self, sequences, split, verbose=False, bs=1): # Sample batch from data loader batch, reset = self.data_loader.sample_batch( split, bs=bs, cat="positive") start_idx = self.data_loader.max_e1 + self.data_loader.max_r max_end_len = self.data_loader.max_e2 context = batch["sequences"][:, :start_idx] reference = batch["sequences"][:, start_idx:] init = "".join([self.data_loader.vocab_decoder[i].replace( '', ' ') for i in context[:, :self.data_loader.max_e1].squeeze().tolist() if i]).strip() start = self.data_loader.max_e1 end = self.data_loader.max_e1 + self.data_loader.max_r attr = "".join([self.data_loader.vocab_decoder[i].replace( '', ' ') for i in context[:, start:end].squeeze(0).tolist() if i]).strip() # Decode sequence sampling_result = self.sampler.generate_sequence( batch, self.model, self.data_loader, start_idx, max_end_len) sampling_result["key"] = batch["key"] sampling_result["e1"] = init sampling_result["r"] = attr sequences.append(sampling_result) return reset