import logging from typing import Callable import numpy as np import torch import torch.nn.functional as F from torch.utils.data import DataLoader from transformers import AutoModel, AutoTokenizer, BatchEncoding, XLMRobertaModel from transformers.modeling_outputs import ( BaseModelOutputWithPoolingAndCrossAttentions as EncoderOutput, ) logger = logging.getLogger(__name__) class EmbeddingExtractor: """Класс обрабатывает текст вопроса и возвращает embedding""" def __init__( self, model_id: str, device: str | torch.device | None = None, batch_size: int = 1, do_normalization: bool = True, max_len: int = 510, ): """ Класс, соединяющий в себе модель, токенизатор и параметры векторизации. Args: model_id: Идентификатор модели. device: Устройство для вычислений (по умолчанию - GPU, если доступен). batch_size: Размер батча (по умолчанию - 1). do_normalization: Нормировать ли вектора (по умолчанию - True). max_len: Максимальная длина текста в токенах (по умолчанию - 510). """ if device is None: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: device = torch.device(device) self.device = device # Инициализация модели self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.model: XLMRobertaModel = AutoModel.from_pretrained(model_id).to( self.device ) self.model.eval() self.model.share_memory() self.batch_size = batch_size if device.type != 'cpu' else 1 self.do_normalization = do_normalization self.max_len = max_len @staticmethod def _average_pool( last_hidden_states: torch.Tensor, attention_mask: torch.Tensor ) -> torch.Tensor: """ Расчёт усредненного эмбеддинга по всем токенам Args: last_hidden_states: Матрица эмбеддингов отдельных токенов размерности (batch_size, seq_len, embedding_size) - последний скрытый слой attention_mask: Маска, чтобы не учитывать при усреднении пустые токены Returns: torch.Tensor - Усредненный эмбеддинг размерности (batch_size, embedding_size) """ last_hidden = last_hidden_states.masked_fill( ~attention_mask[..., None].bool(), 0.0 ) return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] def _query_tokenization(self, text: str | list[str]) -> BatchEncoding: """ Преобразует текст в токены. Args: text: Текст. max_len: Максимальная длина текста (510 токенов) Returns: BatchEncoding - Словарь с ключами "input_ids", "attention_mask" и т.п. """ if isinstance(text, str): cleaned_text = text.replace('\n', ' ') else: cleaned_text = [t.replace('\n', ' ') for t in text] return self.tokenizer( cleaned_text, return_tensors='pt', padding=True, truncation=True, max_length=self.max_len, ) @torch.no_grad() def query_embed_extraction( self, text: str, do_normalization: bool = True, ) -> np.ndarray: """ Функция преобразует один текст в эмбеддинг размерности (1, embedding_size) Args: text: Текст. do_normalization: Нормировать ли вектора embedding Returns: np.array - Эмбеддинг размерности (1, embedding_size) """ inputs = self._query_tokenization(text).to(self.device) outputs = self.model(**inputs) mask = inputs["attention_mask"] embedding = self._average_pool(outputs.last_hidden_state, mask) if do_normalization: embedding = F.normalize(embedding, dim=-1) return embedding.cpu().numpy() # TODO: В будущем стоит объединить vectorize и query_embed_extraction def vectorize( self, texts: list[str] | str, progress_callback: Callable[[int, int], None] | None = None, ) -> np.ndarray: """ Векторизует все тексты в списке. Во многом аналогичен методу query_embed_extraction, в будущем стоит объединить их. Args: texts: Список текстов или один текст. progress_callback: Функция, которая будет вызываться при каждом шаге векторизации. Принимает два аргумента: current и total. current - текущий шаг векторизации. total - общее количество шагов векторизации. Returns: np.array - Матрица эмбеддингов размерности (texts_count, embedding_size) """ if isinstance(texts, str): texts = [texts] loader = DataLoader(texts, batch_size=self.batch_size) embeddings = [] logger.info( 'Vectorizing texts with batch size %d on %s', self.batch_size, self.device ) for i, batch in enumerate(loader): embeddings.append(self._vectorize_batch(batch)) if progress_callback is not None: progress_callback(i * self.batch_size, len(texts)) else: logger.info('Vectorized batch %d', i) logger.info('Vectorized all %d batches', len(embeddings)) return torch.cat(embeddings).numpy() @torch.no_grad() def _vectorize_batch( self, texts: list[str], ) -> torch.Tensor: """ Векторизует один батч текстов. Args: texts: Список текстов. Returns: torch.Tensor - Матрица эмбеддингов размерности (batch_size, embedding_size) """ tokenized = self._query_tokenization(texts).to(self.device) outputs: EncoderOutput = self.model(**tokenized) mask = tokenized["attention_mask"] embedding = self._average_pool(outputs.last_hidden_state, mask) if self.do_normalization: embedding = F.normalize(embedding, dim=-1) return embedding.cpu() def get_dim(self) -> int: """ Возвращает размерность эмбеддинга. """ return self.model.config.hidden_size