File size: 1,815 Bytes
4943752 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
"""
infer sent for sentence similarity
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
"""
import os
import torch
from textattack.constraints.semantics.sentence_encoders import SentenceEncoder
from textattack.shared import utils
from .infer_sent_model import InferSentModel
class InferSent(SentenceEncoder):
"""Constraint using similarity between sentence encodings of x and x_adv
where the text embeddings are created using InferSent."""
MODEL_PATH = "constraints/semantics/sentence-encoders/infersent-encoder"
WORD_EMBEDDING_PATH = "word_embeddings"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model = self.get_infersent_model()
self.model.to(utils.device)
def get_infersent_model(self):
"""Retrieves the InferSent model.
Returns:
The pretrained InferSent model.
"""
infersent_version = 2
model_folder_path = utils.download_from_s3(InferSent.MODEL_PATH)
model_path = os.path.join(
model_folder_path, f"infersent{infersent_version}.pkl"
)
params_model = {
"bsize": 64,
"word_emb_dim": 300,
"enc_lstm_dim": 2048,
"pool_type": "max",
"dpout_model": 0.0,
"version": infersent_version,
}
infersent = InferSentModel(params_model)
infersent.load_state_dict(torch.load(model_path))
word_embedding_path = utils.download_from_s3(InferSent.WORD_EMBEDDING_PATH)
w2v_path = os.path.join(word_embedding_path, "fastText", "crawl-300d-2M.vec")
infersent.set_w2v_path(w2v_path)
infersent.build_vocab_k_words(K=100000)
return infersent
def encode(self, sentences):
return self.model.encode(sentences, tokenize=True)
|