Spaces:
Sleeping
Sleeping
import logging | |
import re | |
from logging import Logger | |
from pathlib import Path | |
from typing import Dict, List, Tuple | |
import pandas as pd | |
from elasticsearch.exceptions import ConnectionError | |
from natasha import Doc, MorphVocab, NewsEmbedding, NewsMorphTagger, Segmenter | |
from common.common import ( | |
get_elastic_abbreviation_query, | |
get_elastic_group_query, | |
get_elastic_people_query, | |
get_elastic_query, | |
get_elastic_rocks_nn_query, | |
get_elastic_segmentation_query, | |
) | |
from common.configuration import Configuration, Query, SummaryChunks | |
from common.constants import PROMPT, PROMPT_CLASSIFICATION | |
from components.elastic import create_index_elastic_chunks | |
from components.elastic.elasticsearch_client import ElasticsearchClient | |
from components.embedding_extraction import EmbeddingExtractor | |
from components.nmd.aggregate_answers import aggregate_answers | |
from components.nmd.faiss_vector_search import FaissVectorSearch | |
from components.nmd.llm_chunk_search import LLMChunkSearch | |
from components.nmd.metadata_manager import MetadataManager | |
from components.nmd.query_classification import QueryClassification | |
from components.nmd.rancker import DocumentRanking | |
from components.services.dataset import DatasetService | |
logger = logging.getLogger(__name__) | |
class Dispatcher: | |
def __init__( | |
self, | |
embedding_model: EmbeddingExtractor, | |
config: Configuration, | |
logger: Logger, | |
dataset_service: DatasetService | |
): | |
self.dataset_service = dataset_service | |
self.config = config | |
self.embedder = embedding_model | |
self.dataset_id = None | |
self.try_load_default_dataset() | |
self.llm_search = LLMChunkSearch(config.llm_config, PROMPT, logger) | |
if self.config.db_config.elastic.use_elastic: | |
self.elastic_search = ElasticsearchClient( | |
host=f'{config.db_config.elastic.es_host}', | |
port=config.db_config.elastic.es_port, | |
) | |
self.query_classification = QueryClassification( | |
config.llm_config, PROMPT_CLASSIFICATION, logger | |
) | |
self.segmenter = Segmenter() | |
self.morph_tagger = NewsMorphTagger(NewsEmbedding()) | |
self.morph_vocab = MorphVocab() | |
def try_load_default_dataset(self): | |
default_dataset = self.dataset_service.get_default_dataset() | |
if default_dataset is not None and default_dataset.id is not None and default_dataset.id != self.dataset_id: | |
logger.info(f'Reloading dataset {default_dataset.id}') | |
self.reset_dataset(default_dataset.id) | |
else: | |
self.faiss_search = None | |
self.meta_database = None | |
def reset_dataset(self, dataset_id: int): | |
logger.info(f'Reset dataset to dataset_id: {dataset_id}') | |
data_path = Path(self.config.db_config.faiss.path_to_metadata) | |
df = pd.read_pickle(data_path / str(dataset_id) / 'dataset.pkl') | |
logger.info(f'Dataset loaded from {data_path / str(dataset_id) / "dataset.pkl"}') | |
logger.info(f'Dataset shape: {df.shape}') | |
self.faiss_search = FaissVectorSearch(self.embedder, df, self.config.db_config) | |
logger.info(f'Faiss search initialized') | |
self.meta_database = MetadataManager(df, logger) | |
logger.info(f'Meta database initialized') | |
if self.config.db_config.elastic.use_elastic: | |
create_index_elastic_chunks(df, logger) | |
logger.info(f'Elastic index created') | |
self.document_ranking = DocumentRanking(df, self.config) | |
logger.info(f'Document ranking initialized') | |
def __vector_search(self, query: str) -> Dict[int, Dict]: | |
""" | |
Метод для поиска ближайших векторов по векторной базе Faiss. | |
Args: | |
query: Запрос пользователя. | |
Returns: | |
возвращает словарь chunks. | |
""" | |
query_embeds, scores, indexes = self.faiss_search.search_vectors(query) | |
if self.config.db_config.ranker.use_ranging: | |
indexes = self.document_ranking.doc_ranking(query_embeds, scores, indexes) | |
return self.meta_database.search(indexes) | |
def __elastic_search( | |
self, query: str, index_name: str, search_function, size: int | |
) -> Dict: | |
""" | |
Метод для полнотекстового поиска. | |
Args: | |
query: Запрос пользователя. | |
index_name: Наименование индекса. | |
search_function: Функция запроса, зависит от индекса по которому нужно искать. | |
size: Количество ближайших соседей, или размер выборки. | |
Returns: | |
Возвращает словарь c ответами. | |
""" | |
self.elastic_search.set_index(index_name) | |
return self.elastic_search.search(query=search_function(query), size=size) | |
def _get_indexes_full_text_elastic_search(elastic_answer: Dict) -> List: | |
""" | |
Метод позволяет получить индексы чанков, которые нашел elastic. | |
Args: | |
elastic_answer: Результаты полнотекстового поиска по чанкам. | |
Returns: | |
Возвращает список индексов. | |
""" | |
answer = [] | |
for answer_dict in elastic_answer: | |
answer.append(answer_dict['_source']['index']) | |
return answer | |
def _lemmatization_text(self, text: str): | |
doc = Doc(text) | |
doc.segment(self.segmenter) | |
doc.tag_morph(self.morph_tagger) | |
for token in doc.tokens: | |
token.lemmatize(self.morph_vocab) | |
return ' '.join([token.lemma for token in doc.tokens]) | |
def _get_abbreviations(self, query: Query): | |
query_abbreviation = query.query_abbreviation | |
abbreviations_replaced = query.abbreviations_replaced | |
try: | |
if self.config.db_config.elastic.use_elastic: | |
if ( | |
self.config.db_config.search.abbreviation_search.use_abbreviation_search | |
): | |
abbreviation_answer = self.__elastic_search( | |
query=query.query, | |
index_name=self.config.db_config.search.abbreviation_search.index_name, | |
search_function=get_elastic_abbreviation_query, | |
size=self.config.db_config.search.abbreviation_search.k_neighbors, | |
) | |
if len(abbreviation_answer) > 0: | |
query_lemmatization = self._lemmatization_text(query.query) | |
for abbreviation in abbreviation_answer: | |
abbreviation_lemmatization = self._lemmatization_text( | |
abbreviation['_source']['text'].lower() | |
) | |
if abbreviation_lemmatization in query_lemmatization: | |
query_abbreviation_lemmatization = ( | |
self._lemmatization_text(query_abbreviation) | |
) | |
index = re.search( | |
abbreviation_lemmatization, | |
query_abbreviation_lemmatization, | |
).span()[1] | |
space_index = query_abbreviation.find(' ', index) | |
if space_index != -1: | |
query_abbreviation = '{} ({}) {}'.format( | |
query_abbreviation[:space_index], | |
abbreviation["_source"]["abbreviation"], | |
query_abbreviation[space_index:], | |
) | |
else: | |
query_abbreviation = '{} ({})'.format( | |
query_abbreviation, | |
abbreviation["_source"]["abbreviation"], | |
) | |
except ConnectionError: | |
logger.info("Connection Error Elasticsearch") | |
return Query( | |
query=query.query, | |
query_abbreviation=query_abbreviation, | |
abbreviations_replaced=abbreviations_replaced, | |
) | |
def search_answer(self, query: Query) -> SummaryChunks: | |
""" | |
Метод для поиска чанков отвечающих на вопрос пользователя в разных типах поиска. | |
Args: | |
query: Запрос пользователя. | |
Returns: | |
Возвращает чанки найденные на запрос пользователя. | |
""" | |
self.try_load_default_dataset() | |
query = self._get_abbreviations(query) | |
logger.info(f'Start search for {query.query_abbreviation}') | |
logger.info(f'Use elastic search: {self.config.db_config.elastic.use_elastic}') | |
answer = {} | |
if self.config.db_config.search.vector_search.use_vector_search: | |
logger.info('Start vector search.') | |
answer['vector_answer'] = self.__vector_search(query.query_abbreviation) | |
logger.info(f'Vector search found {len(answer["vector_answer"])} chunks') | |
try: | |
if self.config.db_config.elastic.use_elastic: | |
if self.config.db_config.search.people_elastic_search.use_people_search: | |
logger.info('Start people search.') | |
people_answer = self.__elastic_search( | |
query.query, | |
index_name=self.config.db_config.search.people_elastic_search.index_name, | |
search_function=get_elastic_people_query, | |
size=self.config.db_config.search.people_elastic_search.k_neighbors, | |
) | |
logger.info(f'People search found {len(people_answer)} chunks') | |
answer['people_answer'] = people_answer | |
if self.config.db_config.search.chunks_elastic_search.use_chunks_search: | |
logger.info('Start full text chunks search.') | |
chunks_answer = self.__elastic_search( | |
query.query, | |
index_name=self.config.db_config.search.chunks_elastic_search.index_name, | |
search_function=get_elastic_query, | |
size=self.config.db_config.search.chunks_elastic_search.k_neighbors, | |
) | |
indexes = self._get_indexes_full_text_elastic_search(chunks_answer) | |
chunks_answer = self.meta_database.search(indexes) | |
logger.info( | |
f'Full text chunks search found {len(chunks_answer)} chunks' | |
) | |
answer['chunks_answer'] = chunks_answer | |
if self.config.db_config.search.groups_elastic_search.use_groups_search: | |
logger.info('Start groups search.') | |
groups_answer = self.__elastic_search( | |
query.query, | |
index_name=self.config.db_config.search.groups_elastic_search.index_name, | |
search_function=get_elastic_group_query, | |
size=self.config.db_config.search.groups_elastic_search.k_neighbors, | |
) | |
if len(groups_answer) != 0: | |
logger.info(f'Groups search found {len(groups_answer)} chunks') | |
answer['groups_answer'] = groups_answer | |
if ( | |
self.config.db_config.search.rocks_nn_elastic_search.use_rocks_nn_search | |
): | |
logger.info('Start Rocks NN search.') | |
rocks_nn_answer = self.__elastic_search( | |
query.query, | |
index_name=self.config.db_config.search.rocks_nn_elastic_search.index_name, | |
search_function=get_elastic_rocks_nn_query, | |
size=self.config.db_config.search.rocks_nn_elastic_search.k_neighbors, | |
) | |
if len(rocks_nn_answer) != 0: | |
logger.info( | |
f'Rocks NN search found {len(rocks_nn_answer)} chunks' | |
) | |
answer['rocks_nn_answer'] = rocks_nn_answer | |
if ( | |
self.config.db_config.search.segmentation_elastic_search.use_segmentation_search | |
): | |
logger.info('Start Segmentation search.') | |
segmentation_answer = self.__elastic_search( | |
query.query, | |
index_name=self.config.db_config.search.segmentation_elastic_search.index_name, | |
search_function=get_elastic_segmentation_query, | |
size=self.config.db_config.search.segmentation_elastic_search.k_neighbors, | |
) | |
if len(segmentation_answer) != 0: | |
logger.info( | |
f'Segmentation search found {len(segmentation_answer)} chunks' | |
) | |
answer['segmentation_answer'] = segmentation_answer | |
except ConnectionError: | |
logger.info("Connection Error Elasticsearch") | |
final_answer = aggregate_answers(**answer) | |
logger.info(f'Final answer found {len(final_answer)} chunks') | |
return SummaryChunks(**final_answer) | |
def llm_classification(self, query: str) -> str: | |
type_query = self.query_classification.classification(query) | |
return type_query | |
def llm_answer( | |
self, query: str, answer_chunks: SummaryChunks | |
) -> Tuple[str, str, str, int]: | |
""" | |
Метод для поиска правильного ответа с помощью LLM. | |
Args: | |
query: Запрос. | |
answer_chunks: Ответы векторного поиска и elastic. | |
Returns: | |
Возвращает исходные chunks из поисков, и chunk который выбрала модель. | |
""" | |
prompt = PROMPT | |
return self.llm_search.llm_chunk_search(query, answer_chunks, prompt) | |