Spaces:
Sleeping
Sleeping
# utils/helpers.py | |
"""Helper functions for model loading and embedding generation""" | |
import torch | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer, AutoModel, RobertaTokenizer, RobertaModel | |
from typing import List, Dict, Optional | |
import gc | |
import os | |
def load_models() -> Dict: | |
""" | |
Load both embedding models with memory optimization | |
Returns: | |
Dict containing loaded models and tokenizers | |
""" | |
models_cache = {} | |
# Set device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
try: | |
# Load Jina model | |
print("Loading Jina embeddings model...") | |
jina_tokenizer = AutoTokenizer.from_pretrained( | |
'jinaai/jina-embeddings-v2-base-es', | |
trust_remote_code=True | |
) | |
jina_model = AutoModel.from_pretrained( | |
'jinaai/jina-embeddings-v2-base-es', | |
trust_remote_code=True, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
).to(device) | |
jina_model.eval() | |
# Load RoBERTalex model | |
print("Loading RoBERTalex model...") | |
robertalex_tokenizer = RobertaTokenizer.from_pretrained('PlanTL-GOB-ES/RoBERTalex') | |
robertalex_model = RobertaModel.from_pretrained( | |
'PlanTL-GOB-ES/RoBERTalex', | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
).to(device) | |
robertalex_model.eval() | |
models_cache = { | |
'jina': { | |
'tokenizer': jina_tokenizer, | |
'model': jina_model, | |
'device': device | |
}, | |
'robertalex': { | |
'tokenizer': robertalex_tokenizer, | |
'model': robertalex_model, | |
'device': device | |
} | |
} | |
# Force garbage collection after loading | |
gc.collect() | |
return models_cache | |
except Exception as e: | |
print(f"Error loading models: {str(e)}") | |
raise | |
def mean_pooling(model_output, attention_mask): | |
""" | |
Apply mean pooling to get sentence embeddings | |
Args: | |
model_output: Model output containing token embeddings | |
attention_mask: Attention mask for valid tokens | |
Returns: | |
Pooled embeddings | |
""" | |
token_embeddings = model_output[0] | |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
def get_embeddings( | |
texts: List[str], | |
model_name: str, | |
models_cache: Dict, | |
normalize: bool = True, | |
max_length: Optional[int] = None | |
) -> List[List[float]]: | |
""" | |
Generate embeddings for texts using specified model | |
Args: | |
texts: List of texts to embed | |
model_name: Name of model to use ('jina' or 'robertalex') | |
models_cache: Dictionary containing loaded models | |
normalize: Whether to normalize embeddings | |
max_length: Maximum sequence length | |
Returns: | |
List of embedding vectors | |
""" | |
if model_name not in models_cache: | |
raise ValueError(f"Model {model_name} not available. Choose 'jina' or 'robertalex'") | |
tokenizer = models_cache[model_name]['tokenizer'] | |
model = models_cache[model_name]['model'] | |
device = models_cache[model_name]['device'] | |
# Set max length based on model capabilities | |
if max_length is None: | |
max_length = 8192 if model_name == 'jina' else 512 | |
# Process in batches for memory efficiency | |
batch_size = 8 if len(texts) > 8 else len(texts) | |
all_embeddings = [] | |
for i in range(0, len(texts), batch_size): | |
batch_texts = texts[i:i + batch_size] | |
# Tokenize inputs | |
encoded_input = tokenizer( | |
batch_texts, | |
padding=True, | |
truncation=True, | |
max_length=max_length, | |
return_tensors='pt' | |
).to(device) | |
# Generate embeddings | |
with torch.no_grad(): | |
model_output = model(**encoded_input) | |
if model_name == 'jina': | |
# Jina models require mean pooling | |
embeddings = mean_pooling(model_output, encoded_input['attention_mask']) | |
else: | |
# RoBERTalex: use [CLS] token embedding | |
embeddings = model_output.last_hidden_state[:, 0, :] | |
# Normalize if requested | |
if normalize: | |
embeddings = F.normalize(embeddings, p=2, dim=1) | |
# Convert to CPU and list | |
batch_embeddings = embeddings.cpu().numpy().tolist() | |
all_embeddings.extend(batch_embeddings) | |
return all_embeddings | |
def cleanup_memory(): | |
"""Force garbage collection and clear cache""" | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
def validate_input_texts(texts: List[str]) -> List[str]: | |
""" | |
Validate and clean input texts | |
Args: | |
texts: List of input texts | |
Returns: | |
Cleaned texts | |
""" | |
cleaned_texts = [] | |
for text in texts: | |
# Remove excess whitespace | |
text = ' '.join(text.split()) | |
# Skip empty texts | |
if text: | |
cleaned_texts.append(text) | |
if not cleaned_texts: | |
raise ValueError("No valid texts provided after cleaning") | |
return cleaned_texts | |
def get_model_info(model_name: str) -> Dict: | |
""" | |
Get detailed information about a model | |
Args: | |
model_name: Model identifier | |
Returns: | |
Dictionary with model information | |
""" | |
model_info = { | |
'jina': { | |
'full_name': 'jinaai/jina-embeddings-v2-base-es', | |
'dimensions': 768, | |
'max_length': 8192, | |
'pooling': 'mean', | |
'languages': ['Spanish', 'English'] | |
}, | |
'robertalex': { | |
'full_name': 'PlanTL-GOB-ES/RoBERTalex', | |
'dimensions': 768, | |
'max_length': 512, | |
'pooling': 'cls', | |
'languages': ['Spanish'] | |
} | |
} | |
return model_info.get(model_name, {}) |