Epik / Model /COSMIC /feature_extraction /comet /csk_feature_extract.py
Minh Q. Le
Fixed running on CPU
0935f1b
import os
from tqdm import tqdm
from nltk import tokenize
import numpy as np
import pickle, torch
import comet.src.data.data as data
import comet.src.data.config as cfg
import comet.src.models.utils as model_utils
import comet.src.interactive.functions as interactive
class CSKFeatureExtractor:
def __init__(self, dir=".", device=0):
super(CSKFeatureExtractor, self).__init__()
model_file = os.path.join(
dir, "comet/pretrained_models/atomic_pretrained_model.pickle"
)
sampling_algorithm = "beam-5"
category = "all"
opt, state_dict = interactive.load_model_file(model_file)
data_loader, text_encoder = interactive.load_data("atomic", opt, dir)
self.opt = opt
self.data_loader = data_loader
self.text_encoder = text_encoder
n_ctx = data_loader.max_event + data_loader.max_effect
n_vocab = len(text_encoder.encoder) + n_ctx
self.model = interactive.make_model(opt, n_vocab, n_ctx, state_dict)
self.model.eval()
if device != "cpu":
cfg.device = int(device)
cfg.do_gpu = True
torch.cuda.set_device(cfg.device)
self.model.cuda(cfg.device)
else:
cfg.device = "cpu"
def set_atomic_inputs(self, input_event, category, data_loader, text_encoder):
XMB = torch.zeros(1, data_loader.max_event + 1).long().to(cfg.device)
prefix, suffix = data.atomic_data.do_example(
text_encoder, input_event, None, True, None
)
if len(prefix) > data_loader.max_event + 1:
prefix = prefix[: data_loader.max_event + 1]
XMB[:, : len(prefix)] = torch.LongTensor(prefix)
XMB[:, -1] = torch.LongTensor([text_encoder.encoder["<{}>".format(category)]])
batch = {}
batch["sequences"] = XMB
batch["attention_mask"] = data.atomic_data.make_attention_mask(XMB)
return batch
def extract(self, sentence):
atomic_keys = [
"xIntent",
"xAttr",
"xNeed",
"xWant",
"xEffect",
"xReact",
"oWant",
"oEffect",
"oReact",
]
map1 = [{}, {}, {}, {}, {}, {}, {}, {}, {}]
all_keys = list(sentence.keys())
for i in tqdm(range(len(all_keys))):
item = all_keys[i]
list1 = [[], [], [], [], [], [], [], [], []]
for x in sentence[item]:
input_event = x.encode("ascii", errors="ignore").decode("utf-8")
m1 = []
for sent in tokenize.sent_tokenize(input_event):
seqs = []
masks = []
for category in atomic_keys:
batch = self.set_atomic_inputs(
sent, category, self.data_loader, self.text_encoder
)
seqs.append(batch["sequences"])
masks.append(batch["attention_mask"])
XMB = torch.cat(seqs)
MMB = torch.cat(masks)
XMB = model_utils.prepare_position_embeddings(
self.opt, self.data_loader.vocab_encoder, XMB.unsqueeze(-1)
)
h, _ = self.model(XMB.unsqueeze(1), sequence_mask=MMB)
last_index = MMB[0][:-1].nonzero()[-1].cpu().numpy()[0] + 1
m1.append(h[:, -1, :].detach().cpu().numpy())
m1 = np.mean(np.array(m1), axis=0)
for k, l1 in enumerate(list1):
l1.append(m1[k])
for k, v1 in enumerate(map1):
v1[item] = list1[k]
return map1