Minh Q. Le
Pushed COSMIC code
a446b0b
import comet.utils.utils as utils
import comet.src.data.utils as data_utils
import comet.src.data.config as cfg
import pandas
import json
import random
import math
import torch
from tqdm import tqdm
def map_name(name):
if name == "train":
return "trn"
elif name == "test":
return "tst"
else:
return "dev"
class DataLoader(object):
def __init__(self, opt):
self.data = {}
self.data["train"] = {}
self.data["dev"] = {}
self.data["test"] = {}
self.sequences = {}
self.sequences["train"] = {}
self.sequences["dev"] = {}
self.sequences["test"] = {}
self.masks = {}
self.masks["train"] = {}
self.masks["dev"] = {}
self.masks["test"] = {}
self.offsets = {}
self.offsets["train"] = {}
self.offsets["dev"] = {}
self.offsets["test"] = {}
def offset_summary(self, split):
return self.offsets[split]["total"]
def do_take_partial_dataset(data_opts):
if data_opts.get("kr", None) is None:
return False
if data_opts.kr == 1:
return False
return True
def select_partial_dataset(data_opts, data):
num_selections = math.ceil(data_opts.kr * len(data))
return random.sample(data, num_selections)
class GenerationDataLoader(DataLoader):
def __init__(self, opt, categories):
super(GenerationDataLoader, self).__init__(opt)
self.categories = categories
self.opt = opt
for split in self.data:
self.data[split] = {"total": []}
self.offsets[split] = {"total": 0}
self.vocab_encoder = None
self.vocab_decoder = None
self.special_chars = None
self.max_event = None
self.max_effect = None
def load_data(self, path):
if ".pickle" in path:
print("Loading data from: {}".format(path))
data_utils.load_existing_data_loader(self, path)
return True
for split in self.data:
file_name = "v4_atomic_{}.csv".format(map_name(split))
df = pandas.read_csv("{}/{}".format(path, file_name), index_col=0)
df.iloc[:, :9] = df.iloc[:, :9].apply(
lambda col: col.apply(json.loads))
for cat in self.categories:
attr = df[cat]
self.data[split]["total"] += utils.zipped_flatten(zip(
attr.index, ["<{}>".format(cat)] * len(attr), attr.values))
if do_take_partial_dataset(self.opt.data):
self.data["train"]["total"] = select_partial_dataset(
self.opt.data, self.data["train"]["total"])
return False
def make_tensors(self, text_encoder, special,
splits=["train", "dev", "test"], test=False):
self.vocab_encoder = text_encoder.encoder
self.vocab_decoder = text_encoder.decoder
self.special_chars = special
sequences = {}
for split in splits:
sequences[split] = get_generation_sequences(
self.opt, self.data, split, text_encoder, test)
self.masks[split]["total"] = [(len(i[0]), len(i[1])) for
i in sequences[split]]
self.max_event = max([max([l[0] for l in self.masks[split]["total"]])
for split in self.masks])
self.max_effect = max([max([l[1] for l in self.masks[split]["total"]])
for split in self.masks])
print(self.max_event)
print(self.max_effect)
for split in splits:
num_elements = len(sequences[split])
self.sequences[split]["total"] = torch.LongTensor(
num_elements, self.max_event + self.max_effect).fill_(0)
for i, seq in enumerate(sequences[split]):
# print(self.sequences[split]["total"][i, :len(seq[0])].size())
# print(torch.FloatTensor(seq[0]).size())
self.sequences[split]["total"][i, :len(seq[0])] = \
torch.LongTensor(seq[0])
self.sequences[split]["total"][i, self.max_event:self.max_event + len(seq[1])] = \
torch.LongTensor(seq[1])
def sample_batch(self, split, bs, idxs=None):
offset = self.offsets[split]["total"]
batch = {}
# Decided not to reduce computation on here because it's all parallel
# anyway and we don't want to run out of memory in cases where we
# don't see the longest version quickly enough
if idxs:
seqs = self.sequences[split]["total"].index_select(
0, torch.LongTensor(idxs).to(
self.sequences[split]["total"].device))
else:
seqs = self.sequences[split]["total"][offset:offset + bs]
batch["sequences"] = seqs.to(cfg.device)
batch["attention_mask"] = make_attention_mask(seqs)
batch["loss_mask"] = make_loss_mask(
seqs, self.max_event, 1)
batch["key"] = ("total", offset, offset + bs)
offset += seqs.size(0)
self.offsets[split]["total"] = offset
if split == "train" and offset + bs > len(self.sequences[split]["total"]):
return batch, True
elif offset >= len(self.sequences[split]["total"]):
return batch, True
else:
return batch, False
def reset_offsets(self, splits=["train", "test", "dev"],
shuffle=True, keys=None):
if isinstance(splits, str):
splits = [splits]
for split in splits:
if keys is None:
keys = ["total"]
for key in keys:
self.offsets[split][key] = 0
if shuffle:
self.shuffle_sequences(split, keys)
def shuffle_sequences(self, split="train", keys=None):
if keys is None:
# print(type(self.data))
# print(type(self.data.keys()))
keys = self.data[split].keys()
for key in keys:
idxs = list(range(len(self.data[split][key])))
random.shuffle(idxs)
self.sequences[split][key] = \
self.sequences[split][key].index_select(
0, torch.LongTensor(idxs))
temp = [self.data[split][key][i] for i in idxs]
self.data[split][key] = temp
temp = [self.masks[split][key][i] for i in idxs]
self.masks[split][key] = temp
def prune_data_for_evaluation(data_loader, categories, split):
indices = []
for i, example in enumerate(data_loader.data[split]["total"]):
if example[1] in categories:
indices.append(i)
data_loader.masks[split]["total"] = [data_loader.masks[split]["total"][i]
for i in indices]
data_loader.sequences[split]["total"] = \
data_loader.sequences[split]["total"].index_select(
0, torch.LongTensor(indices))
data_loader.data[split]["total"] = [data_loader.data[split]["total"][i]
for i in indices]
def make_attention_mask(sequences):
return (sequences != 0).float().to(cfg.device)
def make_loss_mask(sequences, max_event, num_delim_tokens):
# print(num_delim_tokens)
# print(sequences.size())
mask = (sequences != 0).float()
mask[:, :max_event + num_delim_tokens] = 0
return mask[:, 1:].to(cfg.device)
def find_underscore_length(seq):
start = "_"
while start in seq:
start += "_"
return start[:-1]
def handle_underscores(suffix, text_encoder, prefix=False):
encoder = text_encoder.encoder
if prefix:
tok = "___"
else:
tok = find_underscore_length(suffix)
suffix_parts = [i.strip() for i in suffix.split("{}".format(tok))]
to_flatten = []
for i, part in enumerate(suffix_parts):
if part:
to_flatten.append(text_encoder.encode([part], verbose=False)[0])
if i != len(suffix_parts) - 1 and suffix_parts[i+1]:
to_flatten.append([encoder["<blank>"]])
else:
to_flatten.append([encoder["<blank>"]])
final_suffix = utils.flatten(to_flatten)
return final_suffix
def get_generation_sequences(opt, data, split, text_encoder, test):
sequences = []
count = 0
final_prefix = None
final_suffix = None
for prefix, category, suffix in tqdm(data[split]["total"]):
final_prefix, final_suffix = do_example(
text_encoder, prefix, suffix, True, True)
# if do_prefix:
# if "___" in prefix:
# final_prefix = handle_underscores(prefix, text_encoder, True)
# else:
# final_prefix = text_encoder.encode([prefix], verbose=False)[0]
# if do_suffix:
# if "_" in suffix:
# final_suffix = handle_underscores(suffix, text_encoder)
# else:
# final_suffix = text_encoder.encode([suffix], verbose=False)[0]
final = compile_final_sequence(
opt, final_prefix, final_suffix, category, text_encoder)
sequences.append(final)
count += 1
if count > 10 and test:
break
return sequences
def do_example(text_encoder, prefix, suffix, do_prefix, do_suffix):
final_prefix = None
final_suffix = None
if do_prefix:
if "___" in prefix:
final_prefix = handle_underscores(prefix, text_encoder, True)
else:
final_prefix = text_encoder.encode([prefix], verbose=False)[0]
if do_suffix:
if "_" in suffix:
final_suffix = handle_underscores(suffix, text_encoder)
else:
final_suffix = text_encoder.encode([suffix], verbose=False)[0]
return final_prefix, final_suffix
def compile_final_sequence(opt, final_prefix, final_suffix, category, text_encoder):
final = []
final.append(final_prefix)
final.append(
[text_encoder.encoder[category]]
+ final_suffix)
final[-1].append(text_encoder.encoder["<END>"])
return final
num_delimiter_tokens = {
"category": 1,
"hierarchy": 3,
"hierarchy+label": 4,
"category+hierarchy": 4,
"category+hierarchy+label": 5
}