Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
from torch.cuda.amp import autocast | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer, AutoModel | |
class EmbeddingExtractor: | |
"""Класс обрабатывает текст вопроса и возвращает embedding""" | |
def __init__(self, model_id: str, device: str): | |
self.model_id = model_id | |
self.device = device | |
self.__init_model() | |
def __init_model(self): | |
"""Инициализация моделей""" | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) | |
self.model = AutoModel.from_pretrained(self.model_id).to(self.device) | |
self.model.eval() | |
def _average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: | |
""" | |
Усреднение векторов по всем токенам | |
Args: | |
last_hidden_states: Вектор с последнего скрытого слоя модели | |
attention_mask: Маска, чтобы не усреднять пустые токены | |
Returns: | |
Vector Embeddings | |
""" | |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
def query_embed_extraction(self, tokens, do_normalization: bool = True) -> np.array: | |
""" | |
Функция преобразует токены в вектор embedding | |
Args: | |
tokens: Tokens | |
do_normalization: Нормировать ли вектора embedding | |
Returns: | |
Возвращает вектор embedding | |
""" | |
with torch.no_grad(): | |
with autocast(): | |
inputs = {k: v[:, :].to(self.device) for k, v in tokens.items()} | |
outputs = self.model(**inputs) | |
# Использование эмбеддинга первого токена для представления всего текста | |
# embedding = outputs.last_hidden_state[:, 0] | |
embedding = self._average_pool(outputs.last_hidden_state, inputs['attention_mask']) | |
if do_normalization: | |
embedding = F.normalize(embedding, dim=-1) | |
return embedding.cpu().numpy() | |
def query_tokenization(self, text: str, max_len: int = 510): | |
""" | |
Преобразует текст в токены. | |
Args: | |
text: Текст. | |
max_len: Максимальная длина текст (510 токенов) | |
Returns: | |
Массив токенов | |
""" | |
return self.tokenizer(text, | |
return_tensors="pt", | |
padding='max_length', | |
truncation=True, | |
max_length=max_len) | |