Spaces:
Build error
Build error
import os | |
import torch | |
from comet.src.data.utils import TextEncoder | |
import comet.src.data.config as cfg | |
import comet.src.data.data as data | |
import comet.src.models.models as models | |
from comet.src.evaluate.sampler import BeamSampler, GreedySampler, TopKSampler | |
import comet.utils.utils as utils | |
def load_model_file(model_file): | |
model_stuff = data.load_checkpoint(model_file) | |
opt = model_stuff["opt"] | |
state_dict = model_stuff["state_dict"] | |
return opt, state_dict | |
def load_data(dataset, opt, dir="."): | |
if dataset == "atomic": | |
data_loader = load_atomic_data(opt, dir) | |
elif dataset == "conceptnet": | |
data_loader = load_conceptnet_data(opt, dir) | |
# Initialize TextEncoder | |
encoder_path = os.path.join(dir, "comet/model/encoder_bpe_40000.json") | |
bpe_path = os.path.join(dir, "comet/model/vocab_40000.bpe") | |
text_encoder = TextEncoder(encoder_path, bpe_path) | |
text_encoder.encoder = data_loader.vocab_encoder | |
text_encoder.decoder = data_loader.vocab_decoder | |
return data_loader, text_encoder | |
def load_atomic_data(opt, dir="."): | |
# Hacky workaround, you may have to change this | |
# if your models use different pad lengths for e1, e2, r | |
if opt.data.get("maxe1", None) is None: | |
opt.data.maxe1 = 17 | |
opt.data.maxe2 = 35 | |
opt.data.maxr = 1 | |
# temporarily change to the target directory | |
current_dir = os.getcwd() | |
os.chdir(dir) | |
path = "comet/data/atomic/processed/generation/categories_oEffect#oReact#oWant#xAttr#xEffect#xIntent#xNeed#xReact#xWant-maxe1_17-maxe2_35-maxr_1.pickle" | |
data_loader = data.make_data_loader(opt, opt.data.categories) | |
loaded = data_loader.load_data(path) | |
# go back to the original working directory | |
os.chdir(current_dir) | |
return data_loader | |
def load_conceptnet_data(opt, dir="."): | |
# Hacky workaround, you may have to change this | |
# if your models use different pad lengths for r | |
if opt.data.get("maxr", None) is None: | |
if opt.data.rel == "language": | |
opt.data.maxr = 5 | |
else: | |
opt.data.maxr = 1 | |
# temporarily change to the target directory | |
current_dir = os.getcwd() | |
os.chdir(dir) | |
path = "comet/data/conceptnet/processed/generation/{}.pickle".format( | |
utils.make_name_string(opt.data) | |
) | |
data_loader = data.make_data_loader(opt) | |
loaded = data_loader.load_data(path) | |
# go back to the original working directory | |
os.chdir(current_dir) | |
return data_loader | |
def make_model(opt, n_vocab, n_ctx, state_dict): | |
model = models.make_model( | |
opt, n_vocab, n_ctx, None, load=False, return_acts=True, return_probs=False | |
) | |
models.load_state_dict(model, state_dict) | |
model.eval() | |
return model | |
def set_sampler(opt, sampling_algorithm, data_loader): | |
if "beam" in sampling_algorithm: | |
opt.eval.bs = int(sampling_algorithm.split("-")[1]) | |
sampler = BeamSampler(opt, data_loader) | |
elif "topk" in sampling_algorithm: | |
# print("Still bugs in the topk sampler. Use beam or greedy instead") | |
# raise NotImplementedError | |
opt.eval.k = int(sampling_algorithm.split("-")[1]) | |
sampler = TopKSampler(opt, data_loader) | |
else: | |
sampler = GreedySampler(opt, data_loader) | |
return sampler | |
def get_atomic_sequence( | |
input_event, model, sampler, data_loader, text_encoder, category | |
): | |
if isinstance(category, list): | |
outputs = {} | |
for cat in category: | |
new_outputs = get_atomic_sequence( | |
input_event, model, sampler, data_loader, text_encoder, cat | |
) | |
outputs.update(new_outputs) | |
return outputs | |
elif category == "all": | |
outputs = {} | |
for category in data_loader.categories: | |
new_outputs = get_atomic_sequence( | |
input_event, model, sampler, data_loader, text_encoder, category | |
) | |
outputs.update(new_outputs) | |
return outputs | |
else: | |
sequence_all = {} | |
sequence_all["event"] = input_event | |
sequence_all["effect_type"] = category | |
with torch.no_grad(): | |
batch = set_atomic_inputs(input_event, category, data_loader, text_encoder) | |
sampling_result = sampler.generate_sequence( | |
batch, | |
model, | |
data_loader, | |
data_loader.max_event | |
+ data.atomic_data.num_delimiter_tokens["category"], | |
data_loader.max_effect | |
- data.atomic_data.num_delimiter_tokens["category"], | |
) | |
sequence_all["beams"] = sampling_result["beams"] | |
# print_atomic_sequence(sequence_all) | |
return {category: sequence_all} | |
def print_atomic_sequence(sequence_object): | |
input_event = sequence_object["event"] | |
category = sequence_object["effect_type"] | |
print("Input Event: {}".format(input_event)) | |
print("Target Effect: {}".format(category)) | |
print("") | |
print("Candidate Sequences:") | |
for beam in sequence_object["beams"]: | |
print(beam) | |
print("") | |
print("====================================================") | |
print("") | |
def set_atomic_inputs(input_event, category, data_loader, text_encoder): | |
XMB = torch.zeros(1, data_loader.max_event + 1).long().to(cfg.device) | |
prefix, suffix = data.atomic_data.do_example( | |
text_encoder, input_event, None, True, None | |
) | |
if len(prefix) > data_loader.max_event + 1: | |
prefix = prefix[: data_loader.max_event + 1] | |
XMB[:, : len(prefix)] = torch.LongTensor(prefix) | |
XMB[:, -1] = torch.LongTensor([text_encoder.encoder["<{}>".format(category)]]) | |
batch = {} | |
batch["sequences"] = XMB | |
batch["attention_mask"] = data.atomic_data.make_attention_mask(XMB) | |
return batch | |
def get_conceptnet_sequence( | |
e1, model, sampler, data_loader, text_encoder, relation, force=False | |
): | |
if isinstance(relation, list): | |
outputs = {} | |
for rel in relation: | |
new_outputs = get_conceptnet_sequence( | |
e1, model, sampler, data_loader, text_encoder, rel | |
) | |
outputs.update(new_outputs) | |
return outputs | |
elif relation == "all": | |
outputs = {} | |
for relation in data.conceptnet_data.conceptnet_relations: | |
new_outputs = get_conceptnet_sequence( | |
e1, model, sampler, data_loader, text_encoder, relation | |
) | |
outputs.update(new_outputs) | |
return outputs | |
else: | |
sequence_all = {} | |
sequence_all["e1"] = e1 | |
sequence_all["relation"] = relation | |
with torch.no_grad(): | |
if data_loader.max_r != 1: | |
relation_sequence = data.conceptnet_data.split_into_words[relation] | |
else: | |
relation_sequence = "<{}>".format(relation) | |
batch, abort = set_conceptnet_inputs( | |
e1, | |
relation_sequence, | |
text_encoder, | |
data_loader.max_e1, | |
data_loader.max_r, | |
force, | |
) | |
if abort: | |
return {relation: sequence_all} | |
sampling_result = sampler.generate_sequence( | |
batch, | |
model, | |
data_loader, | |
data_loader.max_e1 + data_loader.max_r, | |
data_loader.max_e2, | |
) | |
sequence_all["beams"] = sampling_result["beams"] | |
print_conceptnet_sequence(sequence_all) | |
return {relation: sequence_all} | |
def set_conceptnet_inputs(input_event, relation, text_encoder, max_e1, max_r, force): | |
abort = False | |
e1_tokens, rel_tokens, _ = data.conceptnet_data.do_example( | |
text_encoder, input_event, relation, None | |
) | |
if len(e1_tokens) > max_e1: | |
if force: | |
XMB = torch.zeros(1, len(e1_tokens) + max_r).long().to(cfg.device) | |
else: | |
XMB = torch.zeros(1, max_e1 + max_r).long().to(cfg.device) | |
return {}, True | |
else: | |
XMB = torch.zeros(1, max_e1 + max_r).long().to(cfg.device) | |
XMB[:, : len(e1_tokens)] = torch.LongTensor(e1_tokens) | |
XMB[:, max_e1 : max_e1 + len(rel_tokens)] = torch.LongTensor(rel_tokens) | |
batch = {} | |
batch["sequences"] = XMB | |
batch["attention_mask"] = data.conceptnet_data.make_attention_mask(XMB) | |
return batch, abort | |
def print_conceptnet_sequence(sequence_object): | |
e1 = sequence_object["e1"] | |
relation = sequence_object["relation"] | |
print("Input Entity: {}".format(e1)) | |
print("Target Relation: {}".format(relation)) | |
print("") | |
print("Candidate Sequences:") | |
for beam in sequence_object["beams"]: | |
print(beam) | |
print("") | |
print("====================================================") | |
print("") | |
def print_help(data): | |
print("") | |
if data == "atomic": | |
print('Provide a seed event such as "PersonX goes to the mall"') | |
print("Don't include names, instead replacing them with PersonX, PersonY, etc.") | |
print("The event should always have PersonX included") | |
if data == "conceptnet": | |
print('Provide a seed entity such as "go to the mall"') | |
print("Because the model was trained on lemmatized entities,") | |
print("it works best if the input entities are also lemmatized") | |
print("") | |
def print_relation_help(data): | |
print_category_help(data) | |
def print_category_help(data): | |
print("") | |
if data == "atomic": | |
print("Enter a possible effect type from the following effect types:") | |
print( | |
"all - compute the output for all effect types {{oEffect, oReact, oWant, xAttr, xEffect, xIntent, xNeed, xReact, xWant}}" | |
) | |
print( | |
"oEffect - generate the effect of the event on participants other than PersonX" | |
) | |
print( | |
"oReact - generate the reactions of participants other than PersonX to the event" | |
) | |
print( | |
"oEffect - generate what participants other than PersonX may want after the event" | |
) | |
elif data == "conceptnet": | |
print("Enter a possible relation from the following list:") | |
print("") | |
print("AtLocation") | |
print("CapableOf") | |
print("Causes") | |
print("CausesDesire") | |
print("CreatedBy") | |
print("DefinedAs") | |
print("DesireOf") | |
print("Desires") | |
print("HasA") | |
print("HasFirstSubevent") | |
print("HasLastSubevent") | |
print("HasPainCharacter") | |
print("HasPainIntensity") | |
print("HasPrerequisite") | |
print("HasProperty") | |
print("HasSubevent") | |
print("InheritsFrom") | |
print("InstanceOf") | |
print("IsA") | |
print("LocatedNear") | |
print("LocationOfAction") | |
print("MadeOf") | |
print("MotivatedByGoal") | |
print("NotCapableOf") | |
print("NotDesires") | |
print("NotHasA") | |
print("NotHasProperty") | |
print("NotIsA") | |
print("NotMadeOf") | |
print("PartOf") | |
print("ReceivesAction") | |
print("RelatedTo") | |
print("SymbolOf") | |
print("UsedFor") | |
print("") | |
print("NOTE: Capitalization is important") | |
else: | |
raise | |
print("") | |
def print_sampling_help(): | |
print("") | |
print( | |
"Provide a sampling algorithm to produce the sequence with from the following:" | |
) | |
print("") | |
print("greedy") | |
print("beam-# where # is the beam size") | |
print("topk-# where # is k") | |
print("") | |