Spaces:
Sleeping
Sleeping
| # utils/helpers.py | |
| """Helper functions for model loading and embedding generation""" | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModel, RobertaTokenizer, RobertaModel | |
| from typing import List, Dict, Optional | |
| import gc | |
| import os | |
| def load_models() -> Dict: | |
| """ | |
| Load both embedding models with memory optimization | |
| Returns: | |
| Dict containing loaded models and tokenizers | |
| """ | |
| models_cache = {} | |
| # Set device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| try: | |
| # Load Jina model | |
| print("Loading Jina embeddings model...") | |
| jina_tokenizer = AutoTokenizer.from_pretrained( | |
| 'jinaai/jina-embeddings-v2-base-es', | |
| trust_remote_code=True | |
| ) | |
| jina_model = AutoModel.from_pretrained( | |
| 'jinaai/jina-embeddings-v2-base-es', | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
| ).to(device) | |
| jina_model.eval() | |
| # Load RoBERTalex model | |
| print("Loading RoBERTalex model...") | |
| robertalex_tokenizer = RobertaTokenizer.from_pretrained('PlanTL-GOB-ES/RoBERTalex') | |
| robertalex_model = RobertaModel.from_pretrained( | |
| 'PlanTL-GOB-ES/RoBERTalex', | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
| ).to(device) | |
| robertalex_model.eval() | |
| models_cache = { | |
| 'jina': { | |
| 'tokenizer': jina_tokenizer, | |
| 'model': jina_model, | |
| 'device': device | |
| }, | |
| 'robertalex': { | |
| 'tokenizer': robertalex_tokenizer, | |
| 'model': robertalex_model, | |
| 'device': device | |
| } | |
| } | |
| # Force garbage collection after loading | |
| gc.collect() | |
| return models_cache | |
| except Exception as e: | |
| print(f"Error loading models: {str(e)}") | |
| raise | |
| def mean_pooling(model_output, attention_mask): | |
| """ | |
| Apply mean pooling to get sentence embeddings | |
| Args: | |
| model_output: Model output containing token embeddings | |
| attention_mask: Attention mask for valid tokens | |
| Returns: | |
| Pooled embeddings | |
| """ | |
| token_embeddings = model_output[0] | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
| def get_embeddings( | |
| texts: List[str], | |
| model_name: str, | |
| models_cache: Dict, | |
| normalize: bool = True, | |
| max_length: Optional[int] = None | |
| ) -> List[List[float]]: | |
| """ | |
| Generate embeddings for texts using specified model | |
| Args: | |
| texts: List of texts to embed | |
| model_name: Name of model to use ('jina' or 'robertalex') | |
| models_cache: Dictionary containing loaded models | |
| normalize: Whether to normalize embeddings | |
| max_length: Maximum sequence length | |
| Returns: | |
| List of embedding vectors | |
| """ | |
| if model_name not in models_cache: | |
| raise ValueError(f"Model {model_name} not available. Choose 'jina' or 'robertalex'") | |
| tokenizer = models_cache[model_name]['tokenizer'] | |
| model = models_cache[model_name]['model'] | |
| device = models_cache[model_name]['device'] | |
| # Set max length based on model capabilities | |
| if max_length is None: | |
| max_length = 8192 if model_name == 'jina' else 512 | |
| # Process in batches for memory efficiency | |
| batch_size = 8 if len(texts) > 8 else len(texts) | |
| all_embeddings = [] | |
| for i in range(0, len(texts), batch_size): | |
| batch_texts = texts[i:i + batch_size] | |
| # Tokenize inputs | |
| encoded_input = tokenizer( | |
| batch_texts, | |
| padding=True, | |
| truncation=True, | |
| max_length=max_length, | |
| return_tensors='pt' | |
| ).to(device) | |
| # Generate embeddings | |
| with torch.no_grad(): | |
| model_output = model(**encoded_input) | |
| if model_name == 'jina': | |
| # Jina models require mean pooling | |
| embeddings = mean_pooling(model_output, encoded_input['attention_mask']) | |
| else: | |
| # RoBERTalex: use [CLS] token embedding | |
| embeddings = model_output.last_hidden_state[:, 0, :] | |
| # Normalize if requested | |
| if normalize: | |
| embeddings = F.normalize(embeddings, p=2, dim=1) | |
| # Convert to CPU and list | |
| batch_embeddings = embeddings.cpu().numpy().tolist() | |
| all_embeddings.extend(batch_embeddings) | |
| return all_embeddings | |
| def cleanup_memory(): | |
| """Force garbage collection and clear cache""" | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def validate_input_texts(texts: List[str]) -> List[str]: | |
| """ | |
| Validate and clean input texts | |
| Args: | |
| texts: List of input texts | |
| Returns: | |
| Cleaned texts | |
| """ | |
| cleaned_texts = [] | |
| for text in texts: | |
| # Remove excess whitespace | |
| text = ' '.join(text.split()) | |
| # Skip empty texts | |
| if text: | |
| cleaned_texts.append(text) | |
| if not cleaned_texts: | |
| raise ValueError("No valid texts provided after cleaning") | |
| return cleaned_texts | |
| def get_model_info(model_name: str) -> Dict: | |
| """ | |
| Get detailed information about a model | |
| Args: | |
| model_name: Model identifier | |
| Returns: | |
| Dictionary with model information | |
| """ | |
| model_info = { | |
| 'jina': { | |
| 'full_name': 'jinaai/jina-embeddings-v2-base-es', | |
| 'dimensions': 768, | |
| 'max_length': 8192, | |
| 'pooling': 'mean', | |
| 'languages': ['Spanish', 'English'] | |
| }, | |
| 'robertalex': { | |
| 'full_name': 'PlanTL-GOB-ES/RoBERTalex', | |
| 'dimensions': 768, | |
| 'max_length': 512, | |
| 'pooling': 'cls', | |
| 'languages': ['Spanish'] | |
| } | |
| } | |
| return model_info.get(model_name, {}) |