generic-chatbot-backend / components /embedding_extraction.py
muryshev's picture
update
0dffae9
raw
history blame
8.2 kB
import logging
from typing import Callable
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import (AutoModel, AutoTokenizer, BatchEncoding,
XLMRobertaModel, PreTrainedTokenizer, PreTrainedTokenizerFast)
from transformers.modeling_outputs import \
BaseModelOutputWithPoolingAndCrossAttentions as EncoderOutput
from common.decorators import singleton
logger = logging.getLogger(__name__)
@singleton
class EmbeddingExtractor:
"""Класс обрабатывает текст вопроса и возвращает embedding"""
def __init__(
self,
model_id: str | None,
device: str | torch.device | None = None,
batch_size: int = 1,
do_normalization: bool = True,
max_len: int = 510,
model: XLMRobertaModel = None,
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast = None
):
"""
Класс, соединяющий в себе модель, токенизатор и параметры векторизации.
Args:
model_id: Идентификатор модели.
device: Устройство для вычислений (по умолчанию - GPU, если доступен).
batch_size: Размер батча (по умолчанию - 1).
do_normalization: Нормировать ли вектора (по умолчанию - True).
max_len: Максимальная длина текста в токенах (по умолчанию - 510).
model: Экземпляр загруженной модели.
tokenizer: Экземпляр загруженного токенизатора.
"""
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
device = torch.device(device)
self.device = device
# Инициализация модели
if model is not None and tokenizer is not None:
self.tokenizer = tokenizer
self.model = model
elif model_id is not None:
print('EmbeddingExtractor: model loading '+model_id+' to '+str(self.device))
self.tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
self.model: XLMRobertaModel = AutoModel.from_pretrained(model_id, local_files_only=True).to(
self.device
)
print('EmbeddingExtractor: model loaded')
self.model.eval()
self.model.share_memory()
self.batch_size = batch_size if device.type != 'cpu' else 1
self.do_normalization = do_normalization
self.max_len = max_len
@staticmethod
def _average_pool(
last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""
Расчёт усредненного эмбеддинга по всем токенам
Args:
last_hidden_states: Матрица эмбеддингов отдельных токенов размерности (batch_size, seq_len, embedding_size) - последний скрытый слой
attention_mask: Маска, чтобы не учитывать при усреднении пустые токены
Returns:
torch.Tensor - Усредненный эмбеддинг размерности (batch_size, embedding_size)
"""
last_hidden = last_hidden_states.masked_fill(
~attention_mask[..., None].bool(), 0.0
)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
def _query_tokenization(self, text: str | list[str]) -> BatchEncoding:
"""
Преобразует текст в токены.
Args:
text: Текст.
max_len: Максимальная длина текста (510 токенов)
Returns:
BatchEncoding - Словарь с ключами "input_ids", "attention_mask" и т.п.
"""
if isinstance(text, str):
cleaned_text = text.replace('\n', ' ')
else:
cleaned_text = [t.replace('\n', ' ') for t in text]
return self.tokenizer(
cleaned_text,
return_tensors='pt',
padding=True,
truncation=True,
max_length=self.max_len,
)
@torch.no_grad()
def query_embed_extraction(
self,
text: str,
do_normalization: bool = True,
) -> np.ndarray:
"""
Функция преобразует один текст в эмбеддинг размерности (1, embedding_size)
Args:
text: Текст.
do_normalization: Нормировать ли вектора embedding
Returns:
np.array - Эмбеддинг размерности (1, embedding_size)
"""
inputs = self._query_tokenization(text).to(self.device)
outputs = self.model(**inputs)
mask = inputs["attention_mask"]
embedding = self._average_pool(outputs.last_hidden_state, mask)
if do_normalization:
embedding = F.normalize(embedding, dim=-1)
return embedding.cpu().numpy()
def vectorize(
self,
texts: list[str] | str,
progress_callback: Callable[[int, int], None] | None = None,
) -> np.ndarray:
"""
Векторизует все тексты в списке.
Во многом аналогичен методу query_embed_extraction, в будущем стоит объединить их.
Args:
texts: Список текстов или один текст.
progress_callback: Функция, которая будет вызываться при каждом шаге векторизации.
Принимает два аргумента: current и total.
current - текущий шаг векторизации.
total - общее количество шагов векторизации.
Returns:
np.array - Матрица эмбеддингов размерности (texts_count, embedding_size)
"""
if isinstance(texts, str):
texts = [texts]
loader = DataLoader(texts, batch_size=self.batch_size)
embeddings = []
logger.info(
'Vectorizing texts with batch size %d on %s', self.batch_size, self.device
)
for i, batch in enumerate(loader):
embeddings.append(self._vectorize_batch(batch))
if progress_callback is not None:
progress_callback(i * self.batch_size, len(texts))
else:
logger.info('Vectorized batch %d', i)
logger.info('Vectorized all %d batches', len(embeddings))
result = torch.cat(embeddings).numpy()
# Всегда возвращаем двумерный массив
if result.ndim == 1:
result = result.reshape(1, -1)
return result
@torch.no_grad()
def _vectorize_batch(
self,
texts: list[str],
) -> torch.Tensor:
"""
Векторизует один батч текстов.
Args:
texts: Список текстов.
Returns:
torch.Tensor - Матрица эмбеддингов размерности (batch_size, embedding_size)
"""
tokenized = self._query_tokenization(texts).to(self.device)
outputs: EncoderOutput = self.model(**tokenized)
mask = tokenized["attention_mask"]
embedding = self._average_pool(outputs.last_hidden_state, mask)
if self.do_normalization:
embedding = F.normalize(embedding, dim=-1)
return embedding.cpu()
def get_dim(self) -> int:
"""
Возвращает размерность эмбеддинга.
"""
return self.model.config.hidden_size