File size: 2,071 Bytes
fb0495b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a983ce0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import numpy as np
from sentence_transformers import SentenceTransformer


encoder_model_name = 'MPA/sambert'


class TextEmbedder:
    def __init__(self):
        """

        Initialize the Hebrew text embedder using dictabert-large-heq model

        """
        # self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = SentenceTransformer(encoder_model_name)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()

    def encode(self, text) -> np.ndarray:
        """

            Encode Hebrew text using LaBSE model with handling for texts longer than max_seq_length.



            Args:

                text (str): Hebrew text to encode

                model_name (str): Name of the model to use

                # max_seq_length (int): Maximum sequence length for the model

                strategy (str): Strategy for combining sentence embeddings ('mean' or 'concat')



            Returns:

                numpy.ndarray: Text embedding

            """
        # Get embeddings for the text
        embeddings = [float(x) for x in self.model.encode([text])[0]]

        return embeddings

    # def encode_many(self, texts: List[str]) -> np.ndarray:
    #     """
    #         Encode Hebrew text using LaBSE model with handling for texts longer than max_seq_length.
    #
    #         Args:
    #             text (str): Hebrew text to encode
    #             model_name (str): Name of the model to use
    #             # max_seq_length (int): Maximum sequence length for the model
    #             strategy (str): Strategy for combining sentence embeddings ('mean' or 'concat')
    #
    #         Returns:
    #             numpy.ndarray: Text embedding
    #         """
    #     # Get embeddings for the text
    #     embeddings = self.model.encode(texts)
    #     embeddings = [[float(x) for x in embedding] for embedding in embeddings]
    #
    #     return embeddings