mini-gte / mteb_eval.py
prdev's picture
Upload 14 files
634c39b verified
raw
history blame
2.34 kB
import mteb
from mteb.encoder_interface import PromptType
from sentence_transformers import SentenceTransformer, models
import numpy as np
import torch
import os
import math
model_save_path = "./" #REPLACE WITH YOUR PATH
# Reload the prepared SentenceTransformer model
model = SentenceTransformer(model_save_path)
# -------- Step 3: Define Custom Model Interface for MTEB --------
class CustomModel:
def __init__(self, model):
self.model = model
def encode(
self,
sentences,
task_name: str,
prompt_type = None,
max_batch_size: int = 32, # Set default max batch size
**kwargs
) -> np.ndarray:
"""
Encodes the given sentences using the model with a maximum batch size.
Args:
sentences (List[str]): The sentences to encode.
task_name (str): The name of the task.
prompt_type (Optional[PromptType]): The prompt type to use.
max_batch_size (int): The maximum number of sentences to process in a single batch.
**kwargs: Additional arguments to pass to the encoder.
Returns:
np.ndarray: Encoded sentences as a numpy array.
"""
sentences = [str(sentence) for sentence in sentences]
total_sentences = len(sentences)
num_batches = math.ceil(total_sentences / max_batch_size)
embeddings_list = []
for batch_idx in range(num_batches):
start_idx = batch_idx * max_batch_size
end_idx = min(start_idx + max_batch_size, total_sentences)
batch_sentences = sentences[start_idx:end_idx]
batch_embeddings = self.model.encode(batch_sentences, convert_to_tensor=True)
if not isinstance(batch_embeddings, torch.Tensor):
batch_embeddings = torch.tensor(batch_embeddings)
embeddings_list.append(batch_embeddings.cpu().numpy())
return np.vstack(embeddings_list)
# Wrap the SentenceTransformer model in the CustomModel class
custom_model = CustomModel(model)
# Select the MTEB tasks to evaluate
tasks = mteb.get_benchmark("MTEB(eng, classic)")
# Initialize the evaluation framework
evaluation = mteb.MTEB(tasks=tasks)
# Run evaluation and save results
results = evaluation.run(custom_model, output_folder="results/model_results")