Spaces:
Sleeping
Sleeping
DCWIR-Offcial-Demo
/
textattack
/constraints
/semantics
/sentence_encoders
/infer_sent
/infer_sent.py
""" | |
infer sent for sentence similarity | |
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
""" | |
import os | |
import torch | |
from textattack.constraints.semantics.sentence_encoders import SentenceEncoder | |
from textattack.shared import utils | |
from .infer_sent_model import InferSentModel | |
class InferSent(SentenceEncoder): | |
"""Constraint using similarity between sentence encodings of x and x_adv | |
where the text embeddings are created using InferSent.""" | |
MODEL_PATH = "constraints/semantics/sentence-encoders/infersent-encoder" | |
WORD_EMBEDDING_PATH = "word_embeddings" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.model = self.get_infersent_model() | |
self.model.to(utils.device) | |
def get_infersent_model(self): | |
"""Retrieves the InferSent model. | |
Returns: | |
The pretrained InferSent model. | |
""" | |
infersent_version = 2 | |
model_folder_path = utils.download_from_s3(InferSent.MODEL_PATH) | |
model_path = os.path.join( | |
model_folder_path, f"infersent{infersent_version}.pkl" | |
) | |
params_model = { | |
"bsize": 64, | |
"word_emb_dim": 300, | |
"enc_lstm_dim": 2048, | |
"pool_type": "max", | |
"dpout_model": 0.0, | |
"version": infersent_version, | |
} | |
infersent = InferSentModel(params_model) | |
infersent.load_state_dict(torch.load(model_path)) | |
word_embedding_path = utils.download_from_s3(InferSent.WORD_EMBEDDING_PATH) | |
w2v_path = os.path.join(word_embedding_path, "fastText", "crawl-300d-2M.vec") | |
infersent.set_w2v_path(w2v_path) | |
infersent.build_vocab_k_words(K=100000) | |
return infersent | |
def encode(self, sentences): | |
return self.model.encode(sentences, tokenize=True) | |