import os from tqdm import tqdm from nltk import tokenize import numpy as np import pickle, torch import comet.src.data.data as data import comet.src.data.config as cfg import comet.src.models.utils as model_utils import comet.src.interactive.functions as interactive class CSKFeatureExtractor: def __init__(self, dir=".", device=0): super(CSKFeatureExtractor, self).__init__() model_file = os.path.join( dir, "comet/pretrained_models/atomic_pretrained_model.pickle" ) sampling_algorithm = "beam-5" category = "all" opt, state_dict = interactive.load_model_file(model_file) data_loader, text_encoder = interactive.load_data("atomic", opt, dir) self.opt = opt self.data_loader = data_loader self.text_encoder = text_encoder n_ctx = data_loader.max_event + data_loader.max_effect n_vocab = len(text_encoder.encoder) + n_ctx self.model = interactive.make_model(opt, n_vocab, n_ctx, state_dict) self.model.eval() if device != "cpu": cfg.device = int(device) cfg.do_gpu = True torch.cuda.set_device(cfg.device) self.model.cuda(cfg.device) else: cfg.device = "cpu" def set_atomic_inputs(self, input_event, category, data_loader, text_encoder): XMB = torch.zeros(1, data_loader.max_event + 1).long().to(cfg.device) prefix, suffix = data.atomic_data.do_example( text_encoder, input_event, None, True, None ) if len(prefix) > data_loader.max_event + 1: prefix = prefix[: data_loader.max_event + 1] XMB[:, : len(prefix)] = torch.LongTensor(prefix) XMB[:, -1] = torch.LongTensor([text_encoder.encoder["<{}>".format(category)]]) batch = {} batch["sequences"] = XMB batch["attention_mask"] = data.atomic_data.make_attention_mask(XMB) return batch def extract(self, sentence): atomic_keys = [ "xIntent", "xAttr", "xNeed", "xWant", "xEffect", "xReact", "oWant", "oEffect", "oReact", ] map1 = [{}, {}, {}, {}, {}, {}, {}, {}, {}] all_keys = list(sentence.keys()) for i in tqdm(range(len(all_keys))): item = all_keys[i] list1 = [[], [], [], [], [], [], [], [], []] for x in sentence[item]: input_event = x.encode("ascii", errors="ignore").decode("utf-8") m1 = [] for sent in tokenize.sent_tokenize(input_event): seqs = [] masks = [] for category in atomic_keys: batch = self.set_atomic_inputs( sent, category, self.data_loader, self.text_encoder ) seqs.append(batch["sequences"]) masks.append(batch["attention_mask"]) XMB = torch.cat(seqs) MMB = torch.cat(masks) XMB = model_utils.prepare_position_embeddings( self.opt, self.data_loader.vocab_encoder, XMB.unsqueeze(-1) ) h, _ = self.model(XMB.unsqueeze(1), sequence_mask=MMB) last_index = MMB[0][:-1].nonzero()[-1].cpu().numpy()[0] + 1 m1.append(h[:, -1, :].detach().cpu().numpy()) m1 = np.mean(np.array(m1), axis=0) for k, l1 in enumerate(list1): l1.append(m1[k]) for k, v1 in enumerate(map1): v1[item] = list1[k] return map1