Spaces:
Running
Running
import torch | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
encoder_model_name = 'MPA/sambert' | |
class TextEmbedder: | |
def __init__(self): | |
""" | |
Initialize the Hebrew text embedder using dictabert-large-heq model | |
""" | |
# self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = SentenceTransformer(encoder_model_name) | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model.to(self.device) | |
self.model.eval() | |
def encode(self, text) -> np.ndarray: | |
""" | |
Encode Hebrew text using LaBSE model with handling for texts longer than max_seq_length. | |
Args: | |
text (str): Hebrew text to encode | |
model_name (str): Name of the model to use | |
# max_seq_length (int): Maximum sequence length for the model | |
strategy (str): Strategy for combining sentence embeddings ('mean' or 'concat') | |
Returns: | |
numpy.ndarray: Text embedding | |
""" | |
# Get embeddings for the text | |
embeddings = [float(x) for x in self.model.encode([text])[0]] | |
return embeddings | |
# def encode_many(self, texts: List[str]) -> np.ndarray: | |
# """ | |
# Encode Hebrew text using LaBSE model with handling for texts longer than max_seq_length. | |
# | |
# Args: | |
# text (str): Hebrew text to encode | |
# model_name (str): Name of the model to use | |
# # max_seq_length (int): Maximum sequence length for the model | |
# strategy (str): Strategy for combining sentence embeddings ('mean' or 'concat') | |
# | |
# Returns: | |
# numpy.ndarray: Text embedding | |
# """ | |
# # Get embeddings for the text | |
# embeddings = self.model.encode(texts) | |
# embeddings = [[float(x) for x in embedding] for embedding in embeddings] | |
# | |
# return embeddings | |