Jordi Catafal
Move files to root for HF Spaces
734683c
raw
history blame
6.27 kB
# utils/helpers.py
"""Helper functions for model loading and embedding generation"""
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel, RobertaTokenizer, RobertaModel
from typing import List, Dict, Optional
import gc
import os
def load_models() -> Dict:
"""
Load both embedding models with memory optimization
Returns:
Dict containing loaded models and tokenizers
"""
models_cache = {}
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
# Load Jina model
print("Loading Jina embeddings model...")
jina_tokenizer = AutoTokenizer.from_pretrained(
'jinaai/jina-embeddings-v2-base-es',
trust_remote_code=True
)
jina_model = AutoModel.from_pretrained(
'jinaai/jina-embeddings-v2-base-es',
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(device)
jina_model.eval()
# Load RoBERTalex model
print("Loading RoBERTalex model...")
robertalex_tokenizer = RobertaTokenizer.from_pretrained('PlanTL-GOB-ES/RoBERTalex')
robertalex_model = RobertaModel.from_pretrained(
'PlanTL-GOB-ES/RoBERTalex',
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(device)
robertalex_model.eval()
models_cache = {
'jina': {
'tokenizer': jina_tokenizer,
'model': jina_model,
'device': device
},
'robertalex': {
'tokenizer': robertalex_tokenizer,
'model': robertalex_model,
'device': device
}
}
# Force garbage collection after loading
gc.collect()
return models_cache
except Exception as e:
print(f"Error loading models: {str(e)}")
raise
def mean_pooling(model_output, attention_mask):
"""
Apply mean pooling to get sentence embeddings
Args:
model_output: Model output containing token embeddings
attention_mask: Attention mask for valid tokens
Returns:
Pooled embeddings
"""
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def get_embeddings(
texts: List[str],
model_name: str,
models_cache: Dict,
normalize: bool = True,
max_length: Optional[int] = None
) -> List[List[float]]:
"""
Generate embeddings for texts using specified model
Args:
texts: List of texts to embed
model_name: Name of model to use ('jina' or 'robertalex')
models_cache: Dictionary containing loaded models
normalize: Whether to normalize embeddings
max_length: Maximum sequence length
Returns:
List of embedding vectors
"""
if model_name not in models_cache:
raise ValueError(f"Model {model_name} not available. Choose 'jina' or 'robertalex'")
tokenizer = models_cache[model_name]['tokenizer']
model = models_cache[model_name]['model']
device = models_cache[model_name]['device']
# Set max length based on model capabilities
if max_length is None:
max_length = 8192 if model_name == 'jina' else 512
# Process in batches for memory efficiency
batch_size = 8 if len(texts) > 8 else len(texts)
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
# Tokenize inputs
encoded_input = tokenizer(
batch_texts,
padding=True,
truncation=True,
max_length=max_length,
return_tensors='pt'
).to(device)
# Generate embeddings
with torch.no_grad():
model_output = model(**encoded_input)
if model_name == 'jina':
# Jina models require mean pooling
embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
else:
# RoBERTalex: use [CLS] token embedding
embeddings = model_output.last_hidden_state[:, 0, :]
# Normalize if requested
if normalize:
embeddings = F.normalize(embeddings, p=2, dim=1)
# Convert to CPU and list
batch_embeddings = embeddings.cpu().numpy().tolist()
all_embeddings.extend(batch_embeddings)
return all_embeddings
def cleanup_memory():
"""Force garbage collection and clear cache"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def validate_input_texts(texts: List[str]) -> List[str]:
"""
Validate and clean input texts
Args:
texts: List of input texts
Returns:
Cleaned texts
"""
cleaned_texts = []
for text in texts:
# Remove excess whitespace
text = ' '.join(text.split())
# Skip empty texts
if text:
cleaned_texts.append(text)
if not cleaned_texts:
raise ValueError("No valid texts provided after cleaning")
return cleaned_texts
def get_model_info(model_name: str) -> Dict:
"""
Get detailed information about a model
Args:
model_name: Model identifier
Returns:
Dictionary with model information
"""
model_info = {
'jina': {
'full_name': 'jinaai/jina-embeddings-v2-base-es',
'dimensions': 768,
'max_length': 8192,
'pooling': 'mean',
'languages': ['Spanish', 'English']
},
'robertalex': {
'full_name': 'PlanTL-GOB-ES/RoBERTalex',
'dimensions': 768,
'max_length': 512,
'pooling': 'cls',
'languages': ['Spanish']
}
}
return model_info.get(model_name, {})