import os import torch from comet.src.data.utils import TextEncoder import comet.src.data.config as cfg import comet.src.data.data as data import comet.src.models.models as models from comet.src.evaluate.sampler import BeamSampler, GreedySampler, TopKSampler import comet.utils.utils as utils def load_model_file(model_file): model_stuff = data.load_checkpoint(model_file) opt = model_stuff["opt"] state_dict = model_stuff["state_dict"] return opt, state_dict def load_data(dataset, opt, dir="."): if dataset == "atomic": data_loader = load_atomic_data(opt, dir) elif dataset == "conceptnet": data_loader = load_conceptnet_data(opt, dir) # Initialize TextEncoder encoder_path = os.path.join(dir, "comet/model/encoder_bpe_40000.json") bpe_path = os.path.join(dir, "comet/model/vocab_40000.bpe") text_encoder = TextEncoder(encoder_path, bpe_path) text_encoder.encoder = data_loader.vocab_encoder text_encoder.decoder = data_loader.vocab_decoder return data_loader, text_encoder def load_atomic_data(opt, dir="."): # Hacky workaround, you may have to change this # if your models use different pad lengths for e1, e2, r if opt.data.get("maxe1", None) is None: opt.data.maxe1 = 17 opt.data.maxe2 = 35 opt.data.maxr = 1 # temporarily change to the target directory current_dir = os.getcwd() os.chdir(dir) path = "comet/data/atomic/processed/generation/categories_oEffect#oReact#oWant#xAttr#xEffect#xIntent#xNeed#xReact#xWant-maxe1_17-maxe2_35-maxr_1.pickle" data_loader = data.make_data_loader(opt, opt.data.categories) loaded = data_loader.load_data(path) # go back to the original working directory os.chdir(current_dir) return data_loader def load_conceptnet_data(opt, dir="."): # Hacky workaround, you may have to change this # if your models use different pad lengths for r if opt.data.get("maxr", None) is None: if opt.data.rel == "language": opt.data.maxr = 5 else: opt.data.maxr = 1 # temporarily change to the target directory current_dir = os.getcwd() os.chdir(dir) path = "comet/data/conceptnet/processed/generation/{}.pickle".format( utils.make_name_string(opt.data) ) data_loader = data.make_data_loader(opt) loaded = data_loader.load_data(path) # go back to the original working directory os.chdir(current_dir) return data_loader def make_model(opt, n_vocab, n_ctx, state_dict): model = models.make_model( opt, n_vocab, n_ctx, None, load=False, return_acts=True, return_probs=False ) models.load_state_dict(model, state_dict) model.eval() return model def set_sampler(opt, sampling_algorithm, data_loader): if "beam" in sampling_algorithm: opt.eval.bs = int(sampling_algorithm.split("-")[1]) sampler = BeamSampler(opt, data_loader) elif "topk" in sampling_algorithm: # print("Still bugs in the topk sampler. Use beam or greedy instead") # raise NotImplementedError opt.eval.k = int(sampling_algorithm.split("-")[1]) sampler = TopKSampler(opt, data_loader) else: sampler = GreedySampler(opt, data_loader) return sampler def get_atomic_sequence( input_event, model, sampler, data_loader, text_encoder, category ): if isinstance(category, list): outputs = {} for cat in category: new_outputs = get_atomic_sequence( input_event, model, sampler, data_loader, text_encoder, cat ) outputs.update(new_outputs) return outputs elif category == "all": outputs = {} for category in data_loader.categories: new_outputs = get_atomic_sequence( input_event, model, sampler, data_loader, text_encoder, category ) outputs.update(new_outputs) return outputs else: sequence_all = {} sequence_all["event"] = input_event sequence_all["effect_type"] = category with torch.no_grad(): batch = set_atomic_inputs(input_event, category, data_loader, text_encoder) sampling_result = sampler.generate_sequence( batch, model, data_loader, data_loader.max_event + data.atomic_data.num_delimiter_tokens["category"], data_loader.max_effect - data.atomic_data.num_delimiter_tokens["category"], ) sequence_all["beams"] = sampling_result["beams"] # print_atomic_sequence(sequence_all) return {category: sequence_all} def print_atomic_sequence(sequence_object): input_event = sequence_object["event"] category = sequence_object["effect_type"] print("Input Event: {}".format(input_event)) print("Target Effect: {}".format(category)) print("") print("Candidate Sequences:") for beam in sequence_object["beams"]: print(beam) print("") print("====================================================") print("") def set_atomic_inputs(input_event, category, data_loader, text_encoder): XMB = torch.zeros(1, data_loader.max_event + 1).long().to(cfg.device) prefix, suffix = data.atomic_data.do_example( text_encoder, input_event, None, True, None ) if len(prefix) > data_loader.max_event + 1: prefix = prefix[: data_loader.max_event + 1] XMB[:, : len(prefix)] = torch.LongTensor(prefix) XMB[:, -1] = torch.LongTensor([text_encoder.encoder["<{}>".format(category)]]) batch = {} batch["sequences"] = XMB batch["attention_mask"] = data.atomic_data.make_attention_mask(XMB) return batch def get_conceptnet_sequence( e1, model, sampler, data_loader, text_encoder, relation, force=False ): if isinstance(relation, list): outputs = {} for rel in relation: new_outputs = get_conceptnet_sequence( e1, model, sampler, data_loader, text_encoder, rel ) outputs.update(new_outputs) return outputs elif relation == "all": outputs = {} for relation in data.conceptnet_data.conceptnet_relations: new_outputs = get_conceptnet_sequence( e1, model, sampler, data_loader, text_encoder, relation ) outputs.update(new_outputs) return outputs else: sequence_all = {} sequence_all["e1"] = e1 sequence_all["relation"] = relation with torch.no_grad(): if data_loader.max_r != 1: relation_sequence = data.conceptnet_data.split_into_words[relation] else: relation_sequence = "<{}>".format(relation) batch, abort = set_conceptnet_inputs( e1, relation_sequence, text_encoder, data_loader.max_e1, data_loader.max_r, force, ) if abort: return {relation: sequence_all} sampling_result = sampler.generate_sequence( batch, model, data_loader, data_loader.max_e1 + data_loader.max_r, data_loader.max_e2, ) sequence_all["beams"] = sampling_result["beams"] print_conceptnet_sequence(sequence_all) return {relation: sequence_all} def set_conceptnet_inputs(input_event, relation, text_encoder, max_e1, max_r, force): abort = False e1_tokens, rel_tokens, _ = data.conceptnet_data.do_example( text_encoder, input_event, relation, None ) if len(e1_tokens) > max_e1: if force: XMB = torch.zeros(1, len(e1_tokens) + max_r).long().to(cfg.device) else: XMB = torch.zeros(1, max_e1 + max_r).long().to(cfg.device) return {}, True else: XMB = torch.zeros(1, max_e1 + max_r).long().to(cfg.device) XMB[:, : len(e1_tokens)] = torch.LongTensor(e1_tokens) XMB[:, max_e1 : max_e1 + len(rel_tokens)] = torch.LongTensor(rel_tokens) batch = {} batch["sequences"] = XMB batch["attention_mask"] = data.conceptnet_data.make_attention_mask(XMB) return batch, abort def print_conceptnet_sequence(sequence_object): e1 = sequence_object["e1"] relation = sequence_object["relation"] print("Input Entity: {}".format(e1)) print("Target Relation: {}".format(relation)) print("") print("Candidate Sequences:") for beam in sequence_object["beams"]: print(beam) print("") print("====================================================") print("") def print_help(data): print("") if data == "atomic": print('Provide a seed event such as "PersonX goes to the mall"') print("Don't include names, instead replacing them with PersonX, PersonY, etc.") print("The event should always have PersonX included") if data == "conceptnet": print('Provide a seed entity such as "go to the mall"') print("Because the model was trained on lemmatized entities,") print("it works best if the input entities are also lemmatized") print("") def print_relation_help(data): print_category_help(data) def print_category_help(data): print("") if data == "atomic": print("Enter a possible effect type from the following effect types:") print( "all - compute the output for all effect types {{oEffect, oReact, oWant, xAttr, xEffect, xIntent, xNeed, xReact, xWant}}" ) print( "oEffect - generate the effect of the event on participants other than PersonX" ) print( "oReact - generate the reactions of participants other than PersonX to the event" ) print( "oEffect - generate what participants other than PersonX may want after the event" ) elif data == "conceptnet": print("Enter a possible relation from the following list:") print("") print("AtLocation") print("CapableOf") print("Causes") print("CausesDesire") print("CreatedBy") print("DefinedAs") print("DesireOf") print("Desires") print("HasA") print("HasFirstSubevent") print("HasLastSubevent") print("HasPainCharacter") print("HasPainIntensity") print("HasPrerequisite") print("HasProperty") print("HasSubevent") print("InheritsFrom") print("InstanceOf") print("IsA") print("LocatedNear") print("LocationOfAction") print("MadeOf") print("MotivatedByGoal") print("NotCapableOf") print("NotDesires") print("NotHasA") print("NotHasProperty") print("NotIsA") print("NotMadeOf") print("PartOf") print("ReceivesAction") print("RelatedTo") print("SymbolOf") print("UsedFor") print("") print("NOTE: Capitalization is important") else: raise print("") def print_sampling_help(): print("") print( "Provide a sampling algorithm to produce the sequence with from the following:" ) print("") print("greedy") print("beam-# where # is the beam size") print("topk-# where # is k") print("")