import comet.src.data.utils as data_utils import comet.src.data.atomic as adata import comet.src.data.config as cfg import torch import random from tqdm import tqdm def map_name(name, opt): if name == "train": return "train{}k.txt".format(opt.trainsize) elif name == "test": return "test.txt" else: return "dev{}.txt".format(opt.devversion) conceptnet_relations = [ 'AtLocation', 'CapableOf', 'Causes', 'CausesDesire', 'CreatedBy', 'DefinedAs', 'DesireOf', 'Desires', 'HasA', 'HasFirstSubevent', 'HasLastSubevent', 'HasPainCharacter', 'HasPainIntensity', 'HasPrerequisite', 'HasProperty', 'HasSubevent', 'InheritsFrom', 'InstanceOf', 'IsA', 'LocatedNear', 'LocationOfAction', 'MadeOf', 'MotivatedByGoal', 'NotCapableOf', 'NotDesires', 'NotHasA', 'NotHasProperty', 'NotIsA', 'NotMadeOf', 'PartOf', 'ReceivesAction', 'RelatedTo', 'SymbolOf', 'UsedFor' ] split_into_words = { 'AtLocation': "at location", 'CapableOf': "capable of", 'Causes': "causes", 'CausesDesire': "causes desire", 'CreatedBy': "created by", 'DefinedAs': "defined as", 'DesireOf': "desire of", 'Desires': "desires", 'HasA': "has a", 'HasFirstSubevent': "has first subevent", 'HasLastSubevent': "has last subevent", 'HasPainCharacter': "has pain character", 'HasPainIntensity': "has pain intensity", 'HasPrerequisite': "has prequisite", 'HasProperty': "has property", 'HasSubevent': "has subevent", 'InheritsFrom': "inherits from", 'InstanceOf': 'instance of', 'IsA': "is a", 'LocatedNear': "located near", 'LocationOfAction': "location of action", 'MadeOf': "made of", 'MotivatedByGoal': "motivated by goal", 'NotCapableOf': "not capable of", 'NotDesires': "not desires", 'NotHasA': "not has a", 'NotHasProperty': "not has property", 'NotIsA': "not is a", 'NotMadeOf': "not made of", 'PartOf': "part of", 'ReceivesAction': "receives action", 'RelatedTo': "related to", 'SymbolOf': "symbol of", 'UsedFor': "used for" } class GenerationDataLoader(adata.DataLoader): def __init__(self, opt, categories=None): super(GenerationDataLoader, self).__init__(opt) self.opt = opt for split in self.data: self.data[split] = {"total": []} self.offsets[split] = {"total": 0} self.vocab_encoder = None self.vocab_decoder = None self.special_chars = None self.max_e1 = None self.max_e2 = None self.max_r = None def offset_summary(self, split): return sum(self.offsets[split].values()) def load_data(self, path): if ".pickle" in path: print("Loading data from: {}".format(path)) data_utils.load_existing_data_loader(self, path) return True for split in self.data: file_name = map_name(split, self.opt.data) if split != "dev" or self.opt.data.devversion != "12": string_tuples = open("{}/{}".format( path, file_name), "r").read().split("\n") tuples = [x.split("\t") for x in string_tuples if x] else: string_tuples = open("{}/{}".format( path, "dev1.txt"), "r").read().split("\n") tuples = [x.split("\t") for x in string_tuples if x] string_tuples = open("{}/{}".format( path, "dev2.txt"), "r").read().split("\n") tuples += [x.split("\t") for x in string_tuples if x] if split in ["dev", "test"]: if self.opt.data.rel == "language": self.data[split]["total"] = \ [(i[1].lower().strip(), split_into_words[i[0]], i[2].lower().strip(), int(i[3])) for i in tuples] self.data[split]["positive"] = \ [(i[1].lower().strip(), split_into_words[i[0]], i[2].lower().strip(), int(i[3])) for i in tuples if int(i[3])] self.data[split]["negative"] = \ [(i[1].lower().strip(), split_into_words[i[0]], i[2].lower().strip(), int(i[3])) for i in tuples if not int(i[3])] elif self.opt.data.rel == "relation": self.data[split]["total"] = \ [(i[1].lower().strip(), "<{}>".format(i[0]), i[2].lower().strip(), int(i[3])) for i in tuples] self.data[split]["positive"] = \ [(i[1].lower().strip(), "<{}>".format(i[0]), i[2].lower().strip(), int(i[3])) for i in tuples if int(i[3])] self.data[split]["negative"] = \ [(i[1].lower().strip(), "<{}>".format(i[0]), i[2].lower().strip(), int(i[3])) for i in tuples if not int(i[3])] else: if self.opt.data.rel == "language": self.data[split]["total"] = \ [(i[1].lower().strip(), split_into_words[i[0]], i[2].lower().strip(), i[3]) for i in tuples] elif self.opt.data.rel == "relation": self.data[split]["total"] = \ [(i[1].lower().strip(), "<{}>".format(i[0]), i[2].lower().strip(), i[3]) for i in tuples] return False def make_tensors(self, text_encoder, special, splits=["train", "dev", "test"], test=False): self.vocab_encoder = text_encoder.encoder self.vocab_decoder = text_encoder.decoder self.special_chars = special sequences = {} for split in splits: sequences[split], discarded = get_generation_sequences( self.data, split, text_encoder, test, self.opt.data.maxe1, self.opt.data.maxe2) if split == "train": self.data[split]["total"] = [j for i, j in enumerate( self.data[split]["total"]) if i not in set(discarded)] self.masks[split]["total"] = [(len(i[0]), len(i[1]), len(i[2])) for i in sequences[split]] self.max_e1 = max([max([l[0] for l in self.masks[split]["total"]]) for split in self.masks]) self.max_r = max([max([l[1] for l in self.masks[split]["total"]]) for split in self.masks]) self.max_e2 = max([max([l[2] for l in self.masks[split]["total"]]) for split in self.masks]) print(self.max_e1) print(self.max_r) print(self.max_e2) for split in splits: num_elements = len(sequences[split]) self.sequences[split]["total"] = torch.LongTensor( num_elements, self.max_e1 + self.max_e2 + self.max_r).fill_(0) for i, seq in enumerate(sequences[split]): # print(self.sequences[split]["total"][i, :len(seq[0])].size()) # print(torch.FloatTensor(seq[0]).size()) self.sequences[split]["total"][i, :len(seq[0])] = \ torch.LongTensor(seq[0]) start_r = self.max_e1 end_r = self.max_e1 + len(seq[1]) self.sequences[split]["total"][i, start_r:end_r] = \ torch.LongTensor(seq[1]) start_e2 = self.max_e1 + self.max_r end_e2 = self.max_e1 + self.max_r + len(seq[2]) self.sequences[split]["total"][i, start_e2:end_e2] = \ torch.LongTensor(seq[2]) if split in ["test", "dev"]: print(split) self.sequences[split]["negative"] = \ self.sequences[split]["total"].index_select( 0, torch.LongTensor([i for i, j in enumerate( self.data[split]['total']) if not j[3]])) # self.data[split]['total'][:self.sequences[split]["total"].size(0)]) if not j[3]])) self.sequences[split]["positive"] = \ self.sequences[split]["total"].index_select( 0, torch.LongTensor([i for i, j in enumerate( self.data[split]['total']) if j[3]])) # self.data[split]['total'][:self.sequences[split]["total"].size(0)]) if j[3]])) def sample_batch(self, split, bs, cat="total", idxs=None): offset = self.offsets[split][cat] batch = {} # Decided not to reduce computation on here because it's all parallel # anyway and we don't want to run out of memory in cases where we # don't see the longest version quickly enough if idxs: seqs = self.sequences[split][cat].index_select( 0, torch.LongTensor(idxs).to( self.sequences[split][cat].device)) else: seqs = self.sequences[split][cat][offset:offset + bs] batch["sequences"] = seqs.to(cfg.device) batch["attention_mask"] = make_attention_mask(seqs) batch["loss_mask"] = make_loss_mask(seqs, self.max_e1 + self.max_r) batch["key"] = (cat, offset, offset + bs) offset += seqs.size(0) self.offsets[split][cat] = offset if split == "train" and offset + bs > len(self.sequences[split][cat]): return batch, True elif offset >= len(self.sequences[split][cat]): return batch, True else: return batch, False def reset_offsets(self, splits=["train", "test", "dev"], shuffle=True, keys=None): if isinstance(splits, str): splits = [splits] for split in splits: if keys is None: keys = ["total", "positive", "negative"] for key in keys: self.offsets[split][key] = 0 if shuffle: self.shuffle_sequences(split, keys) def shuffle_sequences(self, split="train", keys=None): if keys is None: # print(type(self.data)) # print(type(self.data.keys())) keys = self.data[split].keys() for key in keys: if key in ["positive", "negative"]: continue idxs = list(range(len(self.data[split][key]))) random.shuffle(idxs) self.sequences[split][key] = \ self.sequences[split][key].index_select( 0, torch.LongTensor(idxs)) temp = [self.data[split][key][i] for i in idxs] self.data[split][key] = temp temp = [self.masks[split][key][i] for i in idxs] self.masks[split][key] = temp def make_attention_mask(sequences): return (sequences != 0).float().to(cfg.device) def make_loss_mask(sequences, max_event): # print(sequences.size()) mask = (sequences != 0).float() mask[:, :max_event] = 0 return mask[:, 1:].to(cfg.device) def get_generation_sequences(data, split, text_encoder, test, max_e1=10, max_e2=15): sequences = [] count = 0 final_event1 = None final_event2 = None final_relation = None discarded = [] for event1, relation, event2, _ in tqdm(data[split]["total"]): e1, r, e2 = do_example(text_encoder, event1, relation, event2) if (split == "train" and len(e1) > max_e1 or len(e2) > max_e2): discarded.append(count) count += 1 continue final = compile_final_sequence( e1, e2, r, text_encoder) sequences.append(final) count += 1 if count > 10 and test: break return sequences, discarded def do_example(text_encoder, event1, relation, event2): final_event1 = text_encoder.encode([event1], verbose=False)[0] if relation.lower() != relation: final_relation = [text_encoder.encoder[relation]] else: final_relation = text_encoder.encode( [relation], verbose=False)[0] if event2 is not None: final_event2 = text_encoder.encode([event2], verbose=False)[0] else: final_event2 = None return final_event1, final_relation, final_event2 def compile_final_sequence(final_event1, final_event2, final_relation, text_encoder): final = [] final.append(final_event1) final.append(final_relation) final.append(final_event2) final[-1].append(text_encoder.encoder[""]) return final