File size: 3,707 Bytes
e2b1d98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
### Imports
from sentence_transformers import SentenceTransformer, util

### Classes and functions

##==========================================================================================================
class SentTransfUtilities:
    ##==========================================================================================================
    """
    Definition of attributes
    """
    model = None
    __model_name = None
    ##==========================================================================================================
    """
    Function: __init__
    Arguments:
        - model_name:
            Options:
                - 'all-MiniLM-L6-v2
                - 'nq-distilbert-base-v1'
                - 'paraphrase-multilingual-MiniLM-L12-v2'
    """
    def __init__(self, model_name="all-MiniLM-L6-v2"):
        self.__model_name = model_name
        if self.model == None:
            print("Initializing the Sentence Transformer model")
            self.model = SentenceTransformer(self.__model_name)
    ##==========================================================================================================
    """
    Function: get_embeddings()
    """
    def get_embeddings(self, src_data):
        return self.model.encode(src_data, convert_to_tensor=True, device='cpu')
    ##==========================================================================================================
    """
    Function: compute_cosine_similarity(query_embeddings, passage_embeddings)
    """
    def compute_cosine_similarity(self, query_embeddings, passage_embeddings):
        #Compute cosine-similarities
        cosine_scores = util.cos_sim(query_embeddings, passage_embeddings)
        return cosine_scores
    ##==========================================================================================================
    """
    Function: compute_dot_similarity(query_embeddings, passage_embeddings)
    Arguments:
        - query_embeddings
        - passage_embeddings
    """
    def compute_dot_similarity(self, query_embeddings, passage_embeddings):
        #Compute dot-similarities
        dot_scores = util.dot_score(query_embeddings, passage_embeddings)
        return dot_scores
    ##==========================================================================================================
    """
    Function: compute_semantic_search(query_embeddings, corpus_embeddings)
    Arguments:
        - query_embeddings
        - corpus_embeddings
    """
    def compute_semantic_search(self, query_embeddings, corpus_embeddings):
        #Compute dot-similarities
        dot_scores = util.semantic_search(query_embeddings, corpus_embeddings)
        return dot_scores
    ##==========================================================================================================
    """
    Function: compute_sentences_similarity(sentence_1, sentence_2, sim_func)
    Arguments:
        - sentence_1
        - sentence_2
        - sim_func: { "cosine", "dot" }
    """
    def compute_sentences_similarity(self, sentence_1, sentence_2, sim_func="cosine"):
        embeddings_1 = self.get_embeddings(sentence_1)
        embeddings_2 = self.get_embeddings(sentence_2)

        scores = None
        if sim_func == "cosine":
            scores = self.compute_cosine_similarity(embeddings_1, embeddings_2)
        elif sim_func == "dot":
            scores = self.compute_dot_similarity(embeddings_1, embeddings_2)
        return scores
    ##==========================================================================================================

##==========================================================================================================