Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	DCWIR-Offcial-Demo
				/
					textattack
								
							/constraints
								
							/semantics
								
							/sentence_encoders
								
							/infer_sent
								
							/infer_sent.py
					
			| """ | |
| infer sent for sentence similarity | |
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
| """ | |
| import os | |
| import torch | |
| from textattack.constraints.semantics.sentence_encoders import SentenceEncoder | |
| from textattack.shared import utils | |
| from .infer_sent_model import InferSentModel | |
| class InferSent(SentenceEncoder): | |
| """Constraint using similarity between sentence encodings of x and x_adv | |
| where the text embeddings are created using InferSent.""" | |
| MODEL_PATH = "constraints/semantics/sentence-encoders/infersent-encoder" | |
| WORD_EMBEDDING_PATH = "word_embeddings" | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.model = self.get_infersent_model() | |
| self.model.to(utils.device) | |
| def get_infersent_model(self): | |
| """Retrieves the InferSent model. | |
| Returns: | |
| The pretrained InferSent model. | |
| """ | |
| infersent_version = 2 | |
| model_folder_path = utils.download_from_s3(InferSent.MODEL_PATH) | |
| model_path = os.path.join( | |
| model_folder_path, f"infersent{infersent_version}.pkl" | |
| ) | |
| params_model = { | |
| "bsize": 64, | |
| "word_emb_dim": 300, | |
| "enc_lstm_dim": 2048, | |
| "pool_type": "max", | |
| "dpout_model": 0.0, | |
| "version": infersent_version, | |
| } | |
| infersent = InferSentModel(params_model) | |
| infersent.load_state_dict(torch.load(model_path)) | |
| word_embedding_path = utils.download_from_s3(InferSent.WORD_EMBEDDING_PATH) | |
| w2v_path = os.path.join(word_embedding_path, "fastText", "crawl-300d-2M.vec") | |
| infersent.set_w2v_path(w2v_path) | |
| infersent.build_vocab_k_words(K=100000) | |
| return infersent | |
| def encode(self, sentences): | |
| return self.model.encode(sentences, tokenize=True) | |