hebrew-dentsit / text_embedder_encoder.py
borodache's picture
Upload 6 files
fb0495b verified
raw
history blame
2.05 kB
import torch
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List
encoder_model_name = 'MPA/sambert'
class TextEmbedder:
def __init__(self):
"""
Initialize the Hebrew text embedder using dictabert-large-heq model
"""
# self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = SentenceTransformer(encoder_model_name)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.model.eval()
def encode(self, text) -> np.ndarray:
"""
Encode Hebrew text using LaBSE model with handling for texts longer than max_seq_length.
Args:
text (str): Hebrew text to encode
model_name (str): Name of the model to use
# max_seq_length (int): Maximum sequence length for the model
strategy (str): Strategy for combining sentence embeddings ('mean' or 'concat')
Returns:
numpy.ndarray: Text embedding
"""
# Get embeddings for the text
embeddings = [float(x) for x in self.model.encode([text])[0]]
return embeddings
def encode_many(self, texts: List[str]) -> np.ndarray:
"""
Encode Hebrew text using LaBSE model with handling for texts longer than max_seq_length.
Args:
text (str): Hebrew text to encode
model_name (str): Name of the model to use
# max_seq_length (int): Maximum sequence length for the model
strategy (str): Strategy for combining sentence embeddings ('mean' or 'concat')
Returns:
numpy.ndarray: Text embedding
"""
# Get embeddings for the text
embeddings = self.model.encode(texts)
embeddings = [[float(x) for x in embedding] for embedding in embeddings]
return embeddings