Spaces:
Build error
Build error
| import torch | |
| from comet.src.data.utils import TextEncoder | |
| import comet.src.data.config as cfg | |
| import comet.src.data.data as data | |
| import comet.src.models.models as models | |
| from comet.src.evaluate.sampler import BeamSampler, GreedySampler, TopKSampler | |
| import comet.utils.utils as utils | |
| def load_model_file(model_file): | |
| model_stuff = data.load_checkpoint(model_file) | |
| opt = model_stuff["opt"] | |
| state_dict = model_stuff["state_dict"] | |
| return opt, state_dict | |
| def load_data(dataset, opt): | |
| if dataset == "atomic": | |
| data_loader = load_atomic_data(opt) | |
| elif dataset == "conceptnet": | |
| data_loader = load_conceptnet_data(opt) | |
| # Initialize TextEncoder | |
| encoder_path = "comet/model/encoder_bpe_40000.json" | |
| bpe_path = "comet/model/vocab_40000.bpe" | |
| text_encoder = TextEncoder(encoder_path, bpe_path) | |
| text_encoder.encoder = data_loader.vocab_encoder | |
| text_encoder.decoder = data_loader.vocab_decoder | |
| return data_loader, text_encoder | |
| def load_atomic_data(opt): | |
| # Hacky workaround, you may have to change this | |
| # if your models use different pad lengths for e1, e2, r | |
| if opt.data.get("maxe1", None) is None: | |
| opt.data.maxe1 = 17 | |
| opt.data.maxe2 = 35 | |
| opt.data.maxr = 1 | |
| # path = "data/atomic/processed/generation/{}.pickle".format( | |
| # utils.make_name_string(opt.data)) | |
| path = "comet/data/atomic/processed/generation/categories_oEffect#oReact#oWant#xAttr#xEffect#xIntent#xNeed#xReact#xWant-maxe1_17-maxe2_35-maxr_1.pickle" | |
| data_loader = data.make_data_loader(opt, opt.data.categories) | |
| loaded = data_loader.load_data(path) | |
| return data_loader | |
| def load_conceptnet_data(opt): | |
| # Hacky workaround, you may have to change this | |
| # if your models use different pad lengths for r | |
| if opt.data.get("maxr", None) is None: | |
| if opt.data.rel == "language": | |
| opt.data.maxr = 5 | |
| else: | |
| opt.data.maxr = 1 | |
| path = "comet/data/conceptnet/processed/generation/{}.pickle".format( | |
| utils.make_name_string(opt.data)) | |
| data_loader = data.make_data_loader(opt) | |
| loaded = data_loader.load_data(path) | |
| return data_loader | |
| def make_model(opt, n_vocab, n_ctx, state_dict): | |
| model = models.make_model( | |
| opt, n_vocab, n_ctx, None, load=False, | |
| return_acts=True, return_probs=False) | |
| models.load_state_dict(model, state_dict) | |
| model.eval() | |
| return model | |
| def set_sampler(opt, sampling_algorithm, data_loader): | |
| if "beam" in sampling_algorithm: | |
| opt.eval.bs = int(sampling_algorithm.split("-")[1]) | |
| sampler = BeamSampler(opt, data_loader) | |
| elif "topk" in sampling_algorithm: | |
| # print("Still bugs in the topk sampler. Use beam or greedy instead") | |
| # raise NotImplementedError | |
| opt.eval.k = int(sampling_algorithm.split("-")[1]) | |
| sampler = TopKSampler(opt, data_loader) | |
| else: | |
| sampler = GreedySampler(opt, data_loader) | |
| return sampler | |
| def get_atomic_sequence(input_event, model, sampler, data_loader, text_encoder, category): | |
| if isinstance(category, list): | |
| outputs = {} | |
| for cat in category: | |
| new_outputs = get_atomic_sequence( | |
| input_event, model, sampler, data_loader, text_encoder, cat) | |
| outputs.update(new_outputs) | |
| return outputs | |
| elif category == "all": | |
| outputs = {} | |
| for category in data_loader.categories: | |
| new_outputs = get_atomic_sequence( | |
| input_event, model, sampler, data_loader, text_encoder, category) | |
| outputs.update(new_outputs) | |
| return outputs | |
| else: | |
| sequence_all = {} | |
| sequence_all["event"] = input_event | |
| sequence_all["effect_type"] = category | |
| with torch.no_grad(): | |
| batch = set_atomic_inputs( | |
| input_event, category, data_loader, text_encoder) | |
| sampling_result = sampler.generate_sequence( | |
| batch, model, data_loader, data_loader.max_event + | |
| data.atomic_data.num_delimiter_tokens["category"], | |
| data_loader.max_effect - | |
| data.atomic_data.num_delimiter_tokens["category"]) | |
| sequence_all['beams'] = sampling_result["beams"] | |
| # print_atomic_sequence(sequence_all) | |
| return {category: sequence_all} | |
| def print_atomic_sequence(sequence_object): | |
| input_event = sequence_object["event"] | |
| category = sequence_object["effect_type"] | |
| print("Input Event: {}".format(input_event)) | |
| print("Target Effect: {}".format(category)) | |
| print("") | |
| print("Candidate Sequences:") | |
| for beam in sequence_object["beams"]: | |
| print(beam) | |
| print("") | |
| print("====================================================") | |
| print("") | |
| def set_atomic_inputs(input_event, category, data_loader, text_encoder): | |
| XMB = torch.zeros(1, data_loader.max_event + 1).long().to(cfg.device) | |
| prefix, suffix = data.atomic_data.do_example(text_encoder, input_event, None, True, None) | |
| if len(prefix) > data_loader.max_event + 1: | |
| prefix = prefix[:data_loader.max_event + 1] | |
| XMB[:, :len(prefix)] = torch.LongTensor(prefix) | |
| XMB[:, -1] = torch.LongTensor([text_encoder.encoder["<{}>".format(category)]]) | |
| batch = {} | |
| batch["sequences"] = XMB | |
| batch["attention_mask"] = data.atomic_data.make_attention_mask(XMB) | |
| return batch | |
| def get_conceptnet_sequence(e1, model, sampler, data_loader, text_encoder, relation, force=False): | |
| if isinstance(relation, list): | |
| outputs = {} | |
| for rel in relation: | |
| new_outputs = get_conceptnet_sequence( | |
| e1, model, sampler, data_loader, text_encoder, rel) | |
| outputs.update(new_outputs) | |
| return outputs | |
| elif relation == "all": | |
| outputs = {} | |
| for relation in data.conceptnet_data.conceptnet_relations: | |
| new_outputs = get_conceptnet_sequence( | |
| e1, model, sampler, data_loader, text_encoder, relation) | |
| outputs.update(new_outputs) | |
| return outputs | |
| else: | |
| sequence_all = {} | |
| sequence_all["e1"] = e1 | |
| sequence_all["relation"] = relation | |
| with torch.no_grad(): | |
| if data_loader.max_r != 1: | |
| relation_sequence = data.conceptnet_data.split_into_words[relation] | |
| else: | |
| relation_sequence = "<{}>".format(relation) | |
| batch, abort = set_conceptnet_inputs( | |
| e1, relation_sequence, text_encoder, | |
| data_loader.max_e1, data_loader.max_r, force) | |
| if abort: | |
| return {relation: sequence_all} | |
| sampling_result = sampler.generate_sequence( | |
| batch, model, data_loader, | |
| data_loader.max_e1 + data_loader.max_r, | |
| data_loader.max_e2) | |
| sequence_all['beams'] = sampling_result["beams"] | |
| print_conceptnet_sequence(sequence_all) | |
| return {relation: sequence_all} | |
| def set_conceptnet_inputs(input_event, relation, text_encoder, max_e1, max_r, force): | |
| abort = False | |
| e1_tokens, rel_tokens, _ = data.conceptnet_data.do_example(text_encoder, input_event, relation, None) | |
| if len(e1_tokens) > max_e1: | |
| if force: | |
| XMB = torch.zeros(1, len(e1_tokens) + max_r).long().to(cfg.device) | |
| else: | |
| XMB = torch.zeros(1, max_e1 + max_r).long().to(cfg.device) | |
| return {}, True | |
| else: | |
| XMB = torch.zeros(1, max_e1 + max_r).long().to(cfg.device) | |
| XMB[:, :len(e1_tokens)] = torch.LongTensor(e1_tokens) | |
| XMB[:, max_e1:max_e1 + len(rel_tokens)] = torch.LongTensor(rel_tokens) | |
| batch = {} | |
| batch["sequences"] = XMB | |
| batch["attention_mask"] = data.conceptnet_data.make_attention_mask(XMB) | |
| return batch, abort | |
| def print_conceptnet_sequence(sequence_object): | |
| e1 = sequence_object["e1"] | |
| relation = sequence_object["relation"] | |
| print("Input Entity: {}".format(e1)) | |
| print("Target Relation: {}".format(relation)) | |
| print("") | |
| print("Candidate Sequences:") | |
| for beam in sequence_object["beams"]: | |
| print(beam) | |
| print("") | |
| print("====================================================") | |
| print("") | |
| def print_help(data): | |
| print("") | |
| if data == "atomic": | |
| print("Provide a seed event such as \"PersonX goes to the mall\"") | |
| print("Don't include names, instead replacing them with PersonX, PersonY, etc.") | |
| print("The event should always have PersonX included") | |
| if data == "conceptnet": | |
| print("Provide a seed entity such as \"go to the mall\"") | |
| print("Because the model was trained on lemmatized entities,") | |
| print("it works best if the input entities are also lemmatized") | |
| print("") | |
| def print_relation_help(data): | |
| print_category_help(data) | |
| def print_category_help(data): | |
| print("") | |
| if data == "atomic": | |
| print("Enter a possible effect type from the following effect types:") | |
| print("all - compute the output for all effect types {{oEffect, oReact, oWant, xAttr, xEffect, xIntent, xNeed, xReact, xWant}}") | |
| print("oEffect - generate the effect of the event on participants other than PersonX") | |
| print("oReact - generate the reactions of participants other than PersonX to the event") | |
| print("oEffect - generate what participants other than PersonX may want after the event") | |
| elif data == "conceptnet": | |
| print("Enter a possible relation from the following list:") | |
| print("") | |
| print('AtLocation') | |
| print('CapableOf') | |
| print('Causes') | |
| print('CausesDesire') | |
| print('CreatedBy') | |
| print('DefinedAs') | |
| print('DesireOf') | |
| print('Desires') | |
| print('HasA') | |
| print('HasFirstSubevent') | |
| print('HasLastSubevent') | |
| print('HasPainCharacter') | |
| print('HasPainIntensity') | |
| print('HasPrerequisite') | |
| print('HasProperty') | |
| print('HasSubevent') | |
| print('InheritsFrom') | |
| print('InstanceOf') | |
| print('IsA') | |
| print('LocatedNear') | |
| print('LocationOfAction') | |
| print('MadeOf') | |
| print('MotivatedByGoal') | |
| print('NotCapableOf') | |
| print('NotDesires') | |
| print('NotHasA') | |
| print('NotHasProperty') | |
| print('NotIsA') | |
| print('NotMadeOf') | |
| print('PartOf') | |
| print('ReceivesAction') | |
| print('RelatedTo') | |
| print('SymbolOf') | |
| print('UsedFor') | |
| print("") | |
| print("NOTE: Capitalization is important") | |
| else: | |
| raise | |
| print("") | |
| def print_sampling_help(): | |
| print("") | |
| print("Provide a sampling algorithm to produce the sequence with from the following:") | |
| print("") | |
| print("greedy") | |
| print("beam-# where # is the beam size") | |
| print("topk-# where # is k") | |
| print("") | |