Spaces:
Build error
Build error
import comet.utils.utils as utils | |
import comet.src.data.utils as data_utils | |
import comet.src.data.config as cfg | |
import pandas | |
import json | |
import random | |
import math | |
import torch | |
from tqdm import tqdm | |
def map_name(name): | |
if name == "train": | |
return "trn" | |
elif name == "test": | |
return "tst" | |
else: | |
return "dev" | |
class DataLoader(object): | |
def __init__(self, opt): | |
self.data = {} | |
self.data["train"] = {} | |
self.data["dev"] = {} | |
self.data["test"] = {} | |
self.sequences = {} | |
self.sequences["train"] = {} | |
self.sequences["dev"] = {} | |
self.sequences["test"] = {} | |
self.masks = {} | |
self.masks["train"] = {} | |
self.masks["dev"] = {} | |
self.masks["test"] = {} | |
self.offsets = {} | |
self.offsets["train"] = {} | |
self.offsets["dev"] = {} | |
self.offsets["test"] = {} | |
def offset_summary(self, split): | |
return self.offsets[split]["total"] | |
def do_take_partial_dataset(data_opts): | |
if data_opts.get("kr", None) is None: | |
return False | |
if data_opts.kr == 1: | |
return False | |
return True | |
def select_partial_dataset(data_opts, data): | |
num_selections = math.ceil(data_opts.kr * len(data)) | |
return random.sample(data, num_selections) | |
class GenerationDataLoader(DataLoader): | |
def __init__(self, opt, categories): | |
super(GenerationDataLoader, self).__init__(opt) | |
self.categories = categories | |
self.opt = opt | |
for split in self.data: | |
self.data[split] = {"total": []} | |
self.offsets[split] = {"total": 0} | |
self.vocab_encoder = None | |
self.vocab_decoder = None | |
self.special_chars = None | |
self.max_event = None | |
self.max_effect = None | |
def load_data(self, path): | |
if ".pickle" in path: | |
print("Loading data from: {}".format(path)) | |
data_utils.load_existing_data_loader(self, path) | |
return True | |
for split in self.data: | |
file_name = "v4_atomic_{}.csv".format(map_name(split)) | |
df = pandas.read_csv("{}/{}".format(path, file_name), index_col=0) | |
df.iloc[:, :9] = df.iloc[:, :9].apply( | |
lambda col: col.apply(json.loads)) | |
for cat in self.categories: | |
attr = df[cat] | |
self.data[split]["total"] += utils.zipped_flatten(zip( | |
attr.index, ["<{}>".format(cat)] * len(attr), attr.values)) | |
if do_take_partial_dataset(self.opt.data): | |
self.data["train"]["total"] = select_partial_dataset( | |
self.opt.data, self.data["train"]["total"]) | |
return False | |
def make_tensors(self, text_encoder, special, | |
splits=["train", "dev", "test"], test=False): | |
self.vocab_encoder = text_encoder.encoder | |
self.vocab_decoder = text_encoder.decoder | |
self.special_chars = special | |
sequences = {} | |
for split in splits: | |
sequences[split] = get_generation_sequences( | |
self.opt, self.data, split, text_encoder, test) | |
self.masks[split]["total"] = [(len(i[0]), len(i[1])) for | |
i in sequences[split]] | |
self.max_event = max([max([l[0] for l in self.masks[split]["total"]]) | |
for split in self.masks]) | |
self.max_effect = max([max([l[1] for l in self.masks[split]["total"]]) | |
for split in self.masks]) | |
print(self.max_event) | |
print(self.max_effect) | |
for split in splits: | |
num_elements = len(sequences[split]) | |
self.sequences[split]["total"] = torch.LongTensor( | |
num_elements, self.max_event + self.max_effect).fill_(0) | |
for i, seq in enumerate(sequences[split]): | |
# print(self.sequences[split]["total"][i, :len(seq[0])].size()) | |
# print(torch.FloatTensor(seq[0]).size()) | |
self.sequences[split]["total"][i, :len(seq[0])] = \ | |
torch.LongTensor(seq[0]) | |
self.sequences[split]["total"][i, self.max_event:self.max_event + len(seq[1])] = \ | |
torch.LongTensor(seq[1]) | |
def sample_batch(self, split, bs, idxs=None): | |
offset = self.offsets[split]["total"] | |
batch = {} | |
# Decided not to reduce computation on here because it's all parallel | |
# anyway and we don't want to run out of memory in cases where we | |
# don't see the longest version quickly enough | |
if idxs: | |
seqs = self.sequences[split]["total"].index_select( | |
0, torch.LongTensor(idxs).to( | |
self.sequences[split]["total"].device)) | |
else: | |
seqs = self.sequences[split]["total"][offset:offset + bs] | |
batch["sequences"] = seqs.to(cfg.device) | |
batch["attention_mask"] = make_attention_mask(seqs) | |
batch["loss_mask"] = make_loss_mask( | |
seqs, self.max_event, 1) | |
batch["key"] = ("total", offset, offset + bs) | |
offset += seqs.size(0) | |
self.offsets[split]["total"] = offset | |
if split == "train" and offset + bs > len(self.sequences[split]["total"]): | |
return batch, True | |
elif offset >= len(self.sequences[split]["total"]): | |
return batch, True | |
else: | |
return batch, False | |
def reset_offsets(self, splits=["train", "test", "dev"], | |
shuffle=True, keys=None): | |
if isinstance(splits, str): | |
splits = [splits] | |
for split in splits: | |
if keys is None: | |
keys = ["total"] | |
for key in keys: | |
self.offsets[split][key] = 0 | |
if shuffle: | |
self.shuffle_sequences(split, keys) | |
def shuffle_sequences(self, split="train", keys=None): | |
if keys is None: | |
# print(type(self.data)) | |
# print(type(self.data.keys())) | |
keys = self.data[split].keys() | |
for key in keys: | |
idxs = list(range(len(self.data[split][key]))) | |
random.shuffle(idxs) | |
self.sequences[split][key] = \ | |
self.sequences[split][key].index_select( | |
0, torch.LongTensor(idxs)) | |
temp = [self.data[split][key][i] for i in idxs] | |
self.data[split][key] = temp | |
temp = [self.masks[split][key][i] for i in idxs] | |
self.masks[split][key] = temp | |
def prune_data_for_evaluation(data_loader, categories, split): | |
indices = [] | |
for i, example in enumerate(data_loader.data[split]["total"]): | |
if example[1] in categories: | |
indices.append(i) | |
data_loader.masks[split]["total"] = [data_loader.masks[split]["total"][i] | |
for i in indices] | |
data_loader.sequences[split]["total"] = \ | |
data_loader.sequences[split]["total"].index_select( | |
0, torch.LongTensor(indices)) | |
data_loader.data[split]["total"] = [data_loader.data[split]["total"][i] | |
for i in indices] | |
def make_attention_mask(sequences): | |
return (sequences != 0).float().to(cfg.device) | |
def make_loss_mask(sequences, max_event, num_delim_tokens): | |
# print(num_delim_tokens) | |
# print(sequences.size()) | |
mask = (sequences != 0).float() | |
mask[:, :max_event + num_delim_tokens] = 0 | |
return mask[:, 1:].to(cfg.device) | |
def find_underscore_length(seq): | |
start = "_" | |
while start in seq: | |
start += "_" | |
return start[:-1] | |
def handle_underscores(suffix, text_encoder, prefix=False): | |
encoder = text_encoder.encoder | |
if prefix: | |
tok = "___" | |
else: | |
tok = find_underscore_length(suffix) | |
suffix_parts = [i.strip() for i in suffix.split("{}".format(tok))] | |
to_flatten = [] | |
for i, part in enumerate(suffix_parts): | |
if part: | |
to_flatten.append(text_encoder.encode([part], verbose=False)[0]) | |
if i != len(suffix_parts) - 1 and suffix_parts[i+1]: | |
to_flatten.append([encoder["<blank>"]]) | |
else: | |
to_flatten.append([encoder["<blank>"]]) | |
final_suffix = utils.flatten(to_flatten) | |
return final_suffix | |
def get_generation_sequences(opt, data, split, text_encoder, test): | |
sequences = [] | |
count = 0 | |
final_prefix = None | |
final_suffix = None | |
for prefix, category, suffix in tqdm(data[split]["total"]): | |
final_prefix, final_suffix = do_example( | |
text_encoder, prefix, suffix, True, True) | |
# if do_prefix: | |
# if "___" in prefix: | |
# final_prefix = handle_underscores(prefix, text_encoder, True) | |
# else: | |
# final_prefix = text_encoder.encode([prefix], verbose=False)[0] | |
# if do_suffix: | |
# if "_" in suffix: | |
# final_suffix = handle_underscores(suffix, text_encoder) | |
# else: | |
# final_suffix = text_encoder.encode([suffix], verbose=False)[0] | |
final = compile_final_sequence( | |
opt, final_prefix, final_suffix, category, text_encoder) | |
sequences.append(final) | |
count += 1 | |
if count > 10 and test: | |
break | |
return sequences | |
def do_example(text_encoder, prefix, suffix, do_prefix, do_suffix): | |
final_prefix = None | |
final_suffix = None | |
if do_prefix: | |
if "___" in prefix: | |
final_prefix = handle_underscores(prefix, text_encoder, True) | |
else: | |
final_prefix = text_encoder.encode([prefix], verbose=False)[0] | |
if do_suffix: | |
if "_" in suffix: | |
final_suffix = handle_underscores(suffix, text_encoder) | |
else: | |
final_suffix = text_encoder.encode([suffix], verbose=False)[0] | |
return final_prefix, final_suffix | |
def compile_final_sequence(opt, final_prefix, final_suffix, category, text_encoder): | |
final = [] | |
final.append(final_prefix) | |
final.append( | |
[text_encoder.encoder[category]] | |
+ final_suffix) | |
final[-1].append(text_encoder.encoder["<END>"]) | |
return final | |
num_delimiter_tokens = { | |
"category": 1, | |
"hierarchy": 3, | |
"hierarchy+label": 4, | |
"category+hierarchy": 4, | |
"category+hierarchy+label": 5 | |
} | |