import torch from comet.src.data.utils import TextEncoder import comet.src.data.config as cfg import comet.src.data.data as data import comet.src.models.models as models from comet.src.evaluate.sampler import BeamSampler, GreedySampler, TopKSampler import comet.utils.utils as utils def load_model_file(model_file): model_stuff = data.load_checkpoint(model_file) opt = model_stuff["opt"] state_dict = model_stuff["state_dict"] return opt, state_dict def load_data(dataset, opt): if dataset == "atomic": data_loader = load_atomic_data(opt) elif dataset == "conceptnet": data_loader = load_conceptnet_data(opt) # Initialize TextEncoder encoder_path = "comet/model/encoder_bpe_40000.json" bpe_path = "comet/model/vocab_40000.bpe" text_encoder = TextEncoder(encoder_path, bpe_path) text_encoder.encoder = data_loader.vocab_encoder text_encoder.decoder = data_loader.vocab_decoder return data_loader, text_encoder def load_atomic_data(opt): # Hacky workaround, you may have to change this # if your models use different pad lengths for e1, e2, r if opt.data.get("maxe1", None) is None: opt.data.maxe1 = 17 opt.data.maxe2 = 35 opt.data.maxr = 1 # path = "data/atomic/processed/generation/{}.pickle".format( # utils.make_name_string(opt.data)) path = "comet/data/atomic/processed/generation/categories_oEffect#oReact#oWant#xAttr#xEffect#xIntent#xNeed#xReact#xWant-maxe1_17-maxe2_35-maxr_1.pickle" data_loader = data.make_data_loader(opt, opt.data.categories) loaded = data_loader.load_data(path) return data_loader def load_conceptnet_data(opt): # Hacky workaround, you may have to change this # if your models use different pad lengths for r if opt.data.get("maxr", None) is None: if opt.data.rel == "language": opt.data.maxr = 5 else: opt.data.maxr = 1 path = "comet/data/conceptnet/processed/generation/{}.pickle".format( utils.make_name_string(opt.data)) data_loader = data.make_data_loader(opt) loaded = data_loader.load_data(path) return data_loader def make_model(opt, n_vocab, n_ctx, state_dict): model = models.make_model( opt, n_vocab, n_ctx, None, load=False, return_acts=True, return_probs=False) models.load_state_dict(model, state_dict) model.eval() return model def set_sampler(opt, sampling_algorithm, data_loader): if "beam" in sampling_algorithm: opt.eval.bs = int(sampling_algorithm.split("-")[1]) sampler = BeamSampler(opt, data_loader) elif "topk" in sampling_algorithm: # print("Still bugs in the topk sampler. Use beam or greedy instead") # raise NotImplementedError opt.eval.k = int(sampling_algorithm.split("-")[1]) sampler = TopKSampler(opt, data_loader) else: sampler = GreedySampler(opt, data_loader) return sampler def get_atomic_sequence(input_event, model, sampler, data_loader, text_encoder, category): if isinstance(category, list): outputs = {} for cat in category: new_outputs = get_atomic_sequence( input_event, model, sampler, data_loader, text_encoder, cat) outputs.update(new_outputs) return outputs elif category == "all": outputs = {} for category in data_loader.categories: new_outputs = get_atomic_sequence( input_event, model, sampler, data_loader, text_encoder, category) outputs.update(new_outputs) return outputs else: sequence_all = {} sequence_all["event"] = input_event sequence_all["effect_type"] = category with torch.no_grad(): batch = set_atomic_inputs( input_event, category, data_loader, text_encoder) sampling_result = sampler.generate_sequence( batch, model, data_loader, data_loader.max_event + data.atomic_data.num_delimiter_tokens["category"], data_loader.max_effect - data.atomic_data.num_delimiter_tokens["category"]) sequence_all['beams'] = sampling_result["beams"] # print_atomic_sequence(sequence_all) return {category: sequence_all} def print_atomic_sequence(sequence_object): input_event = sequence_object["event"] category = sequence_object["effect_type"] print("Input Event: {}".format(input_event)) print("Target Effect: {}".format(category)) print("") print("Candidate Sequences:") for beam in sequence_object["beams"]: print(beam) print("") print("====================================================") print("") def set_atomic_inputs(input_event, category, data_loader, text_encoder): XMB = torch.zeros(1, data_loader.max_event + 1).long().to(cfg.device) prefix, suffix = data.atomic_data.do_example(text_encoder, input_event, None, True, None) if len(prefix) > data_loader.max_event + 1: prefix = prefix[:data_loader.max_event + 1] XMB[:, :len(prefix)] = torch.LongTensor(prefix) XMB[:, -1] = torch.LongTensor([text_encoder.encoder["<{}>".format(category)]]) batch = {} batch["sequences"] = XMB batch["attention_mask"] = data.atomic_data.make_attention_mask(XMB) return batch def get_conceptnet_sequence(e1, model, sampler, data_loader, text_encoder, relation, force=False): if isinstance(relation, list): outputs = {} for rel in relation: new_outputs = get_conceptnet_sequence( e1, model, sampler, data_loader, text_encoder, rel) outputs.update(new_outputs) return outputs elif relation == "all": outputs = {} for relation in data.conceptnet_data.conceptnet_relations: new_outputs = get_conceptnet_sequence( e1, model, sampler, data_loader, text_encoder, relation) outputs.update(new_outputs) return outputs else: sequence_all = {} sequence_all["e1"] = e1 sequence_all["relation"] = relation with torch.no_grad(): if data_loader.max_r != 1: relation_sequence = data.conceptnet_data.split_into_words[relation] else: relation_sequence = "<{}>".format(relation) batch, abort = set_conceptnet_inputs( e1, relation_sequence, text_encoder, data_loader.max_e1, data_loader.max_r, force) if abort: return {relation: sequence_all} sampling_result = sampler.generate_sequence( batch, model, data_loader, data_loader.max_e1 + data_loader.max_r, data_loader.max_e2) sequence_all['beams'] = sampling_result["beams"] print_conceptnet_sequence(sequence_all) return {relation: sequence_all} def set_conceptnet_inputs(input_event, relation, text_encoder, max_e1, max_r, force): abort = False e1_tokens, rel_tokens, _ = data.conceptnet_data.do_example(text_encoder, input_event, relation, None) if len(e1_tokens) > max_e1: if force: XMB = torch.zeros(1, len(e1_tokens) + max_r).long().to(cfg.device) else: XMB = torch.zeros(1, max_e1 + max_r).long().to(cfg.device) return {}, True else: XMB = torch.zeros(1, max_e1 + max_r).long().to(cfg.device) XMB[:, :len(e1_tokens)] = torch.LongTensor(e1_tokens) XMB[:, max_e1:max_e1 + len(rel_tokens)] = torch.LongTensor(rel_tokens) batch = {} batch["sequences"] = XMB batch["attention_mask"] = data.conceptnet_data.make_attention_mask(XMB) return batch, abort def print_conceptnet_sequence(sequence_object): e1 = sequence_object["e1"] relation = sequence_object["relation"] print("Input Entity: {}".format(e1)) print("Target Relation: {}".format(relation)) print("") print("Candidate Sequences:") for beam in sequence_object["beams"]: print(beam) print("") print("====================================================") print("") def print_help(data): print("") if data == "atomic": print("Provide a seed event such as \"PersonX goes to the mall\"") print("Don't include names, instead replacing them with PersonX, PersonY, etc.") print("The event should always have PersonX included") if data == "conceptnet": print("Provide a seed entity such as \"go to the mall\"") print("Because the model was trained on lemmatized entities,") print("it works best if the input entities are also lemmatized") print("") def print_relation_help(data): print_category_help(data) def print_category_help(data): print("") if data == "atomic": print("Enter a possible effect type from the following effect types:") print("all - compute the output for all effect types {{oEffect, oReact, oWant, xAttr, xEffect, xIntent, xNeed, xReact, xWant}}") print("oEffect - generate the effect of the event on participants other than PersonX") print("oReact - generate the reactions of participants other than PersonX to the event") print("oEffect - generate what participants other than PersonX may want after the event") elif data == "conceptnet": print("Enter a possible relation from the following list:") print("") print('AtLocation') print('CapableOf') print('Causes') print('CausesDesire') print('CreatedBy') print('DefinedAs') print('DesireOf') print('Desires') print('HasA') print('HasFirstSubevent') print('HasLastSubevent') print('HasPainCharacter') print('HasPainIntensity') print('HasPrerequisite') print('HasProperty') print('HasSubevent') print('InheritsFrom') print('InstanceOf') print('IsA') print('LocatedNear') print('LocationOfAction') print('MadeOf') print('MotivatedByGoal') print('NotCapableOf') print('NotDesires') print('NotHasA') print('NotHasProperty') print('NotIsA') print('NotMadeOf') print('PartOf') print('ReceivesAction') print('RelatedTo') print('SymbolOf') print('UsedFor') print("") print("NOTE: Capitalization is important") else: raise print("") def print_sampling_help(): print("") print("Provide a sampling algorithm to produce the sequence with from the following:") print("") print("greedy") print("beam-# where # is the beam size") print("topk-# where # is k") print("")