Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| from torch.cuda.amp import autocast | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModel | |
| class EmbeddingExtractor: | |
| """Класс обрабатывает текст вопроса и возвращает embedding""" | |
| def __init__(self, model_id: str, device: str): | |
| self.model_id = model_id | |
| self.device = device | |
| self.__init_model() | |
| def __init_model(self): | |
| """Инициализация моделей""" | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) | |
| self.model = AutoModel.from_pretrained(self.model_id).to(self.device) | |
| self.model.eval() | |
| def _average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Усреднение векторов по всем токенам | |
| Args: | |
| last_hidden_states: Вектор с последнего скрытого слоя модели | |
| attention_mask: Маска, чтобы не усреднять пустые токены | |
| Returns: | |
| Vector Embeddings | |
| """ | |
| last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
| return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
| def query_embed_extraction(self, tokens, do_normalization: bool = True) -> np.array: | |
| """ | |
| Функция преобразует токены в вектор embedding | |
| Args: | |
| tokens: Tokens | |
| do_normalization: Нормировать ли вектора embedding | |
| Returns: | |
| Возвращает вектор embedding | |
| """ | |
| with torch.no_grad(): | |
| with autocast(): | |
| inputs = {k: v[:, :].to(self.device) for k, v in tokens.items()} | |
| outputs = self.model(**inputs) | |
| # Использование эмбеддинга первого токена для представления всего текста | |
| # embedding = outputs.last_hidden_state[:, 0] | |
| embedding = self._average_pool(outputs.last_hidden_state, inputs['attention_mask']) | |
| if do_normalization: | |
| embedding = F.normalize(embedding, dim=-1) | |
| return embedding.cpu().numpy() | |
| def query_tokenization(self, text: str, max_len: int = 510): | |
| """ | |
| Преобразует текст в токены. | |
| Args: | |
| text: Текст. | |
| max_len: Максимальная длина текст (510 токенов) | |
| Returns: | |
| Массив токенов | |
| """ | |
| return self.tokenizer(text, | |
| return_tensors="pt", | |
| padding='max_length', | |
| truncation=True, | |
| max_length=max_len) | |