Minh Q. Le
Pushed COSMIC code
a446b0b
import torch
from comet.src.data.utils import TextEncoder
import comet.src.data.config as cfg
import comet.src.data.data as data
import comet.src.models.models as models
from comet.src.evaluate.sampler import BeamSampler, GreedySampler, TopKSampler
import comet.utils.utils as utils
def load_model_file(model_file):
model_stuff = data.load_checkpoint(model_file)
opt = model_stuff["opt"]
state_dict = model_stuff["state_dict"]
return opt, state_dict
def load_data(dataset, opt):
if dataset == "atomic":
data_loader = load_atomic_data(opt)
elif dataset == "conceptnet":
data_loader = load_conceptnet_data(opt)
# Initialize TextEncoder
encoder_path = "comet/model/encoder_bpe_40000.json"
bpe_path = "comet/model/vocab_40000.bpe"
text_encoder = TextEncoder(encoder_path, bpe_path)
text_encoder.encoder = data_loader.vocab_encoder
text_encoder.decoder = data_loader.vocab_decoder
return data_loader, text_encoder
def load_atomic_data(opt):
# Hacky workaround, you may have to change this
# if your models use different pad lengths for e1, e2, r
if opt.data.get("maxe1", None) is None:
opt.data.maxe1 = 17
opt.data.maxe2 = 35
opt.data.maxr = 1
# path = "data/atomic/processed/generation/{}.pickle".format(
# utils.make_name_string(opt.data))
path = "comet/data/atomic/processed/generation/categories_oEffect#oReact#oWant#xAttr#xEffect#xIntent#xNeed#xReact#xWant-maxe1_17-maxe2_35-maxr_1.pickle"
data_loader = data.make_data_loader(opt, opt.data.categories)
loaded = data_loader.load_data(path)
return data_loader
def load_conceptnet_data(opt):
# Hacky workaround, you may have to change this
# if your models use different pad lengths for r
if opt.data.get("maxr", None) is None:
if opt.data.rel == "language":
opt.data.maxr = 5
else:
opt.data.maxr = 1
path = "comet/data/conceptnet/processed/generation/{}.pickle".format(
utils.make_name_string(opt.data))
data_loader = data.make_data_loader(opt)
loaded = data_loader.load_data(path)
return data_loader
def make_model(opt, n_vocab, n_ctx, state_dict):
model = models.make_model(
opt, n_vocab, n_ctx, None, load=False,
return_acts=True, return_probs=False)
models.load_state_dict(model, state_dict)
model.eval()
return model
def set_sampler(opt, sampling_algorithm, data_loader):
if "beam" in sampling_algorithm:
opt.eval.bs = int(sampling_algorithm.split("-")[1])
sampler = BeamSampler(opt, data_loader)
elif "topk" in sampling_algorithm:
# print("Still bugs in the topk sampler. Use beam or greedy instead")
# raise NotImplementedError
opt.eval.k = int(sampling_algorithm.split("-")[1])
sampler = TopKSampler(opt, data_loader)
else:
sampler = GreedySampler(opt, data_loader)
return sampler
def get_atomic_sequence(input_event, model, sampler, data_loader, text_encoder, category):
if isinstance(category, list):
outputs = {}
for cat in category:
new_outputs = get_atomic_sequence(
input_event, model, sampler, data_loader, text_encoder, cat)
outputs.update(new_outputs)
return outputs
elif category == "all":
outputs = {}
for category in data_loader.categories:
new_outputs = get_atomic_sequence(
input_event, model, sampler, data_loader, text_encoder, category)
outputs.update(new_outputs)
return outputs
else:
sequence_all = {}
sequence_all["event"] = input_event
sequence_all["effect_type"] = category
with torch.no_grad():
batch = set_atomic_inputs(
input_event, category, data_loader, text_encoder)
sampling_result = sampler.generate_sequence(
batch, model, data_loader, data_loader.max_event +
data.atomic_data.num_delimiter_tokens["category"],
data_loader.max_effect -
data.atomic_data.num_delimiter_tokens["category"])
sequence_all['beams'] = sampling_result["beams"]
# print_atomic_sequence(sequence_all)
return {category: sequence_all}
def print_atomic_sequence(sequence_object):
input_event = sequence_object["event"]
category = sequence_object["effect_type"]
print("Input Event: {}".format(input_event))
print("Target Effect: {}".format(category))
print("")
print("Candidate Sequences:")
for beam in sequence_object["beams"]:
print(beam)
print("")
print("====================================================")
print("")
def set_atomic_inputs(input_event, category, data_loader, text_encoder):
XMB = torch.zeros(1, data_loader.max_event + 1).long().to(cfg.device)
prefix, suffix = data.atomic_data.do_example(text_encoder, input_event, None, True, None)
if len(prefix) > data_loader.max_event + 1:
prefix = prefix[:data_loader.max_event + 1]
XMB[:, :len(prefix)] = torch.LongTensor(prefix)
XMB[:, -1] = torch.LongTensor([text_encoder.encoder["<{}>".format(category)]])
batch = {}
batch["sequences"] = XMB
batch["attention_mask"] = data.atomic_data.make_attention_mask(XMB)
return batch
def get_conceptnet_sequence(e1, model, sampler, data_loader, text_encoder, relation, force=False):
if isinstance(relation, list):
outputs = {}
for rel in relation:
new_outputs = get_conceptnet_sequence(
e1, model, sampler, data_loader, text_encoder, rel)
outputs.update(new_outputs)
return outputs
elif relation == "all":
outputs = {}
for relation in data.conceptnet_data.conceptnet_relations:
new_outputs = get_conceptnet_sequence(
e1, model, sampler, data_loader, text_encoder, relation)
outputs.update(new_outputs)
return outputs
else:
sequence_all = {}
sequence_all["e1"] = e1
sequence_all["relation"] = relation
with torch.no_grad():
if data_loader.max_r != 1:
relation_sequence = data.conceptnet_data.split_into_words[relation]
else:
relation_sequence = "<{}>".format(relation)
batch, abort = set_conceptnet_inputs(
e1, relation_sequence, text_encoder,
data_loader.max_e1, data_loader.max_r, force)
if abort:
return {relation: sequence_all}
sampling_result = sampler.generate_sequence(
batch, model, data_loader,
data_loader.max_e1 + data_loader.max_r,
data_loader.max_e2)
sequence_all['beams'] = sampling_result["beams"]
print_conceptnet_sequence(sequence_all)
return {relation: sequence_all}
def set_conceptnet_inputs(input_event, relation, text_encoder, max_e1, max_r, force):
abort = False
e1_tokens, rel_tokens, _ = data.conceptnet_data.do_example(text_encoder, input_event, relation, None)
if len(e1_tokens) > max_e1:
if force:
XMB = torch.zeros(1, len(e1_tokens) + max_r).long().to(cfg.device)
else:
XMB = torch.zeros(1, max_e1 + max_r).long().to(cfg.device)
return {}, True
else:
XMB = torch.zeros(1, max_e1 + max_r).long().to(cfg.device)
XMB[:, :len(e1_tokens)] = torch.LongTensor(e1_tokens)
XMB[:, max_e1:max_e1 + len(rel_tokens)] = torch.LongTensor(rel_tokens)
batch = {}
batch["sequences"] = XMB
batch["attention_mask"] = data.conceptnet_data.make_attention_mask(XMB)
return batch, abort
def print_conceptnet_sequence(sequence_object):
e1 = sequence_object["e1"]
relation = sequence_object["relation"]
print("Input Entity: {}".format(e1))
print("Target Relation: {}".format(relation))
print("")
print("Candidate Sequences:")
for beam in sequence_object["beams"]:
print(beam)
print("")
print("====================================================")
print("")
def print_help(data):
print("")
if data == "atomic":
print("Provide a seed event such as \"PersonX goes to the mall\"")
print("Don't include names, instead replacing them with PersonX, PersonY, etc.")
print("The event should always have PersonX included")
if data == "conceptnet":
print("Provide a seed entity such as \"go to the mall\"")
print("Because the model was trained on lemmatized entities,")
print("it works best if the input entities are also lemmatized")
print("")
def print_relation_help(data):
print_category_help(data)
def print_category_help(data):
print("")
if data == "atomic":
print("Enter a possible effect type from the following effect types:")
print("all - compute the output for all effect types {{oEffect, oReact, oWant, xAttr, xEffect, xIntent, xNeed, xReact, xWant}}")
print("oEffect - generate the effect of the event on participants other than PersonX")
print("oReact - generate the reactions of participants other than PersonX to the event")
print("oEffect - generate what participants other than PersonX may want after the event")
elif data == "conceptnet":
print("Enter a possible relation from the following list:")
print("")
print('AtLocation')
print('CapableOf')
print('Causes')
print('CausesDesire')
print('CreatedBy')
print('DefinedAs')
print('DesireOf')
print('Desires')
print('HasA')
print('HasFirstSubevent')
print('HasLastSubevent')
print('HasPainCharacter')
print('HasPainIntensity')
print('HasPrerequisite')
print('HasProperty')
print('HasSubevent')
print('InheritsFrom')
print('InstanceOf')
print('IsA')
print('LocatedNear')
print('LocationOfAction')
print('MadeOf')
print('MotivatedByGoal')
print('NotCapableOf')
print('NotDesires')
print('NotHasA')
print('NotHasProperty')
print('NotIsA')
print('NotMadeOf')
print('PartOf')
print('ReceivesAction')
print('RelatedTo')
print('SymbolOf')
print('UsedFor')
print("")
print("NOTE: Capitalization is important")
else:
raise
print("")
def print_sampling_help():
print("")
print("Provide a sampling algorithm to produce the sequence with from the following:")
print("")
print("greedy")
print("beam-# where # is the beam size")
print("topk-# where # is k")
print("")