muryshev's picture
init
b24d496
raw
history blame
2.96 kB
import numpy as np
import torch
from torch.cuda.amp import autocast
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
class EmbeddingExtractor:
"""Класс обрабатывает текст вопроса и возвращает embedding"""
def __init__(self, model_id: str, device: str):
self.model_id = model_id
self.device = device
self.__init_model()
def __init_model(self):
"""Инициализация моделей"""
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.model = AutoModel.from_pretrained(self.model_id).to(self.device)
self.model.eval()
@staticmethod
def _average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""
Усреднение векторов по всем токенам
Args:
last_hidden_states: Вектор с последнего скрытого слоя модели
attention_mask: Маска, чтобы не усреднять пустые токены
Returns:
Vector Embeddings
"""
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
def query_embed_extraction(self, tokens, do_normalization: bool = True) -> np.array:
"""
Функция преобразует токены в вектор embedding
Args:
tokens: Tokens
do_normalization: Нормировать ли вектора embedding
Returns:
Возвращает вектор embedding
"""
with torch.no_grad():
with autocast():
inputs = {k: v[:, :].to(self.device) for k, v in tokens.items()}
outputs = self.model(**inputs)
# Использование эмбеддинга первого токена для представления всего текста
# embedding = outputs.last_hidden_state[:, 0]
embedding = self._average_pool(outputs.last_hidden_state, inputs['attention_mask'])
if do_normalization:
embedding = F.normalize(embedding, dim=-1)
return embedding.cpu().numpy()
def query_tokenization(self, text: str, max_len: int = 510):
"""
Преобразует текст в токены.
Args:
text: Текст.
max_len: Максимальная длина текст (510 токенов)
Returns:
Массив токенов
"""
return self.tokenizer(text,
return_tensors="pt",
padding='max_length',
truncation=True,
max_length=max_len)