hebrew-dentsit / text_embedder_encoder.py
borodache's picture
Change the retrieval and reranking into two steps search with two different indexes - which is supposed to make the latency much lower (faster)
a983ce0 verified
import torch
import numpy as np
from sentence_transformers import SentenceTransformer
encoder_model_name = 'MPA/sambert'
class TextEmbedder:
def __init__(self):
"""
Initialize the Hebrew text embedder using dictabert-large-heq model
"""
# self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = SentenceTransformer(encoder_model_name)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.model.eval()
def encode(self, text) -> np.ndarray:
"""
Encode Hebrew text using LaBSE model with handling for texts longer than max_seq_length.
Args:
text (str): Hebrew text to encode
model_name (str): Name of the model to use
# max_seq_length (int): Maximum sequence length for the model
strategy (str): Strategy for combining sentence embeddings ('mean' or 'concat')
Returns:
numpy.ndarray: Text embedding
"""
# Get embeddings for the text
embeddings = [float(x) for x in self.model.encode([text])[0]]
return embeddings
# def encode_many(self, texts: List[str]) -> np.ndarray:
# """
# Encode Hebrew text using LaBSE model with handling for texts longer than max_seq_length.
#
# Args:
# text (str): Hebrew text to encode
# model_name (str): Name of the model to use
# # max_seq_length (int): Maximum sequence length for the model
# strategy (str): Strategy for combining sentence embeddings ('mean' or 'concat')
#
# Returns:
# numpy.ndarray: Text embedding
# """
# # Get embeddings for the text
# embeddings = self.model.encode(texts)
# embeddings = [[float(x) for x in embedding] for embedding in embeddings]
#
# return embeddings