Spaces:
Runtime error
Runtime error
File size: 3,707 Bytes
e2b1d98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
### Imports
from sentence_transformers import SentenceTransformer, util
### Classes and functions
##==========================================================================================================
class SentTransfUtilities:
##==========================================================================================================
"""
Definition of attributes
"""
model = None
__model_name = None
##==========================================================================================================
"""
Function: __init__
Arguments:
- model_name:
Options:
- 'all-MiniLM-L6-v2
- 'nq-distilbert-base-v1'
- 'paraphrase-multilingual-MiniLM-L12-v2'
"""
def __init__(self, model_name="all-MiniLM-L6-v2"):
self.__model_name = model_name
if self.model == None:
print("Initializing the Sentence Transformer model")
self.model = SentenceTransformer(self.__model_name)
##==========================================================================================================
"""
Function: get_embeddings()
"""
def get_embeddings(self, src_data):
return self.model.encode(src_data, convert_to_tensor=True, device='cpu')
##==========================================================================================================
"""
Function: compute_cosine_similarity(query_embeddings, passage_embeddings)
"""
def compute_cosine_similarity(self, query_embeddings, passage_embeddings):
#Compute cosine-similarities
cosine_scores = util.cos_sim(query_embeddings, passage_embeddings)
return cosine_scores
##==========================================================================================================
"""
Function: compute_dot_similarity(query_embeddings, passage_embeddings)
Arguments:
- query_embeddings
- passage_embeddings
"""
def compute_dot_similarity(self, query_embeddings, passage_embeddings):
#Compute dot-similarities
dot_scores = util.dot_score(query_embeddings, passage_embeddings)
return dot_scores
##==========================================================================================================
"""
Function: compute_semantic_search(query_embeddings, corpus_embeddings)
Arguments:
- query_embeddings
- corpus_embeddings
"""
def compute_semantic_search(self, query_embeddings, corpus_embeddings):
#Compute dot-similarities
dot_scores = util.semantic_search(query_embeddings, corpus_embeddings)
return dot_scores
##==========================================================================================================
"""
Function: compute_sentences_similarity(sentence_1, sentence_2, sim_func)
Arguments:
- sentence_1
- sentence_2
- sim_func: { "cosine", "dot" }
"""
def compute_sentences_similarity(self, sentence_1, sentence_2, sim_func="cosine"):
embeddings_1 = self.get_embeddings(sentence_1)
embeddings_2 = self.get_embeddings(sentence_2)
scores = None
if sim_func == "cosine":
scores = self.compute_cosine_similarity(embeddings_1, embeddings_2)
elif sim_func == "dot":
scores = self.compute_dot_similarity(embeddings_1, embeddings_2)
return scores
##==========================================================================================================
##==========================================================================================================
|