Spaces:
Sleeping
Sleeping
| import gc | |
| import traceback | |
| from legal_info_search_utils.rules_utils import use_rules | |
| from itertools import islice | |
| import os | |
| import torch | |
| import numpy as np | |
| from faiss import IndexFlatIP | |
| from datasets import Dataset as dataset | |
| from transformers import AutoTokenizer, AutoModel | |
| from legal_info_search_utils.utils import query_tokenization, query_embed_extraction | |
| import requests | |
| import re | |
| import json | |
| import pymorphy3 | |
| from torch.cuda.amp import autocast | |
| from elasticsearch_module import search_company | |
| import torch.nn.functional as F | |
| import pickle | |
| from llm.prompts import LLM_PROMPT_QE, LLM_PROMPT_OLYMPIC, LLM_PROMPT_KEYS | |
| from llm.vllm_api import LlmApi, LlmParams | |
| global_data_path = os.environ.get("GLOBAL_DATA_PATH", "./legal_info_search_data/") | |
| global_model_path = os.environ.get("GLOBAL_MODEL_PATH", "./models/20240202_204910_ep8") | |
| data_path_consult = global_data_path + "external_data" | |
| internal_docs_data_path = global_data_path + "nmd_full" | |
| spec_internal_docs_data_path = global_data_path + "nmd_short" | |
| accounting_data_path = global_data_path + "bu" | |
| companies_map_path = global_data_path + "companies_map/companies_map.json" | |
| dict_path = global_data_path + "dict/dict_20241030.pkl" | |
| general_nmd_path = global_data_path + "companies_map/general_nmd.json" | |
| consultations_dataset_path = global_data_path + "consult_data" | |
| explanations_dataset_path = global_data_path + "explanations" | |
| explanations_for_llm_path = global_data_path + "explanations_for_llm/explanations_for_llm.json" | |
| rules_list_path = global_data_path + "rules_list/terms.txt" | |
| db_data_types = ['НКРФ', 'ГКРФ', 'ТКРФ', 'Федеральный закон', 'Письмо Минфина', 'Письмо ФНС', | |
| 'Приказ ФНС', 'Постановление Правительства', 'Судебный документ', 'ВНД', 'Бухгалтерский документ'] | |
| device = os.environ.get("MODEL_DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu') | |
| # access token huggingface. Если задан, то используется модель с HF | |
| hf_token = os.environ.get("HF_TOKEN", "") | |
| hf_model_name = os.environ.get("HF_MODEL_NAME", "") | |
| llm_api_endpoint = os.environ.get("LLM_API_ENDPOINT", "") | |
| headers = {'Content-Type': 'application/json'} | |
| def_k = 15 | |
| class SemanticSearch: | |
| def __init__(self, do_normalization: bool = True): | |
| self.device = device | |
| self.do_normalization = do_normalization | |
| self.load_model() | |
| # Основная база | |
| self.full_base_search = True | |
| self.index_consult = IndexFlatIP(self.embedding_dim) | |
| self.index_explanations = IndexFlatIP(self.embedding_dim) | |
| self.index_all_docs_with_accounting = IndexFlatIP(self.embedding_dim) | |
| self.index_internal_docs = IndexFlatIP(self.embedding_dim) | |
| self.spec_index_internal_docs = IndexFlatIP(self.embedding_dim) | |
| self.index_teaser = IndexFlatIP(self.embedding_dim) | |
| self.load_data() | |
| # Обработка встраиваний | |
| def process_embeddings(docs): | |
| embeddings = torch.cat([torch.unsqueeze(torch.Tensor(x['doc_embedding']), 0) for x in docs], dim=0) | |
| if self.do_normalization: | |
| embeddings = F.normalize(embeddings, dim=-1).numpy() | |
| return embeddings | |
| # База ВНД | |
| self.internal_docs_embeddings = process_embeddings(self.internal_docs) | |
| self.index_internal_docs.add(self.internal_docs_embeddings) | |
| self.spec_internal_docs_embeddings = process_embeddings(self.spec_internal_docs) | |
| self.spec_index_internal_docs.add(self.spec_internal_docs_embeddings) | |
| self.all_docs_with_accounting_embeddings = process_embeddings(self.all_docs_with_accounting) | |
| self.index_all_docs_with_accounting.add(self.all_docs_with_accounting_embeddings) | |
| # База консультаций | |
| self.consult_embeddings = process_embeddings(self.all_consultations) | |
| self.index_consult.add(self.consult_embeddings) | |
| # База разъяснений | |
| self.explanations_embeddings = process_embeddings(self.all_explanations) | |
| self.index_explanations.add(self.explanations_embeddings) | |
| def get_main_info_with_llm(prompt: str): | |
| response = requests.post( | |
| url=llm_api_endpoint, | |
| json={'prompt': ' [INST] ' + prompt + ' [/INST]', | |
| 'temperature': 0.0, | |
| 'n_predict': 2500.0, | |
| 'top_p': 0.95, | |
| 'min_p': 0.05, | |
| 'repeat_penalty': 1.2, | |
| 'stop': []}) | |
| answer = response.json()['content'] | |
| return answer | |
| def rerank_by_avg_score(refs, scores_to_take=3): | |
| docs = {} | |
| regex = r'_(\d{1,3})$' | |
| refs = [(re.sub(regex, '', ref[0]), ref[1], float(ref[2])) for ref in refs] | |
| for ref in refs: | |
| if ref[0] not in docs.keys(): | |
| docs[ref[0]] = {'contents': [ref[1]], 'scores': [ref[2]]} | |
| elif len(docs[ref[0]]['scores']) < scores_to_take: | |
| docs[ref[0]]['contents'].append(ref[1]) | |
| docs[ref[0]]['scores'].append(ref[2]) | |
| for ref in docs: | |
| docs[ref]['avg_score'] = np.mean(docs[ref]['scores']) | |
| sorted_docs = sorted(docs.items(), key=lambda x: x[1]['avg_score'], reverse=True) | |
| result_refs = [ref[0] for ref in sorted_docs] | |
| return result_refs | |
| async def olymp_think(self, query, sources, llm_params: LlmParams = None): | |
| sources_text = '' | |
| res = '' | |
| for i, source in enumerate(sources): | |
| sources_text += f'Источник [{i + 1}]: {sources[source]}\n' | |
| # Если llm_params не переданы, значит используем микстраль по старому алгоритму | |
| # TODO: Сделать api для микстрали (надо ли?) | |
| if llm_params is None: | |
| step = LLM_PROMPT_OLYMPIC.format(query=query, sources=sources_text) | |
| res = self.get_main_info_with_llm(step) | |
| else: | |
| llm_api = LlmApi(llm_params) | |
| query_for_trim = LLM_PROMPT_OLYMPIC.format(query=query, sources='') | |
| trimmed_sources_result = await llm_api.trim_sources(sources_text, query_for_trim) | |
| prompt = LLM_PROMPT_OLYMPIC.format(query=query, sources=trimmed_sources_result["result"]) | |
| res = await llm_api.predict(prompt) | |
| return res | |
| def parse_step(text): | |
| step4_start = text.find('(4)') | |
| if step4_start != -1: | |
| step4_start = 0 | |
| step5_start = text.find('(5)') | |
| if step5_start == -1: | |
| step5_start = 0 | |
| if step4_start + 3 < step5_start: | |
| extracted_comment = text[step4_start + 3:step5_start] | |
| else: | |
| extracted_comment = '' | |
| if '$$' in text: | |
| extracted_comment = '' | |
| extracted_content = re.findall(r'\[(.*?)\]', text[step5_start:]) | |
| extracted_numbers = [] | |
| for item in extracted_content: | |
| if item.isdigit(): | |
| extracted_numbers.append(int(item)) | |
| return extracted_comment, extracted_numbers | |
| def lemmatize_query(text): | |
| morph = pymorphy3.MorphAnalyzer() | |
| signs = ',.<>?;\'\":}{!)(][-' | |
| words = text.split() | |
| lemmas = [] | |
| for word in words: | |
| if not word.isupper(): | |
| word = morph.parse(word)[0].normal_form | |
| lemmas.append(word) | |
| for i, lemma in enumerate(lemmas): | |
| while lemma[0] in signs and len(lemma) > 1: | |
| lemma = lemma[1:] | |
| lemmas[i] = lemma | |
| while lemma[-1] in signs and len(lemma) > 1: | |
| lemma = lemma[:-1] | |
| lemmas[i] = lemma | |
| return " ".join(lemmas) | |
| def mark_for_one_word_dict(lem_dict): | |
| terms_first_word = set() | |
| first_word_matching_names = {} | |
| first_word_names_to_remove = {} | |
| for name in lem_dict: | |
| first_word = name.split()[0] | |
| if first_word in terms_first_word: | |
| lem_dict[name]['one_word_searchable'] = False | |
| first_word_names_to_remove[first_word] = first_word_matching_names[first_word] | |
| else: | |
| terms_first_word.add(first_word) | |
| first_word_matching_names[first_word] = name | |
| for first_word in first_word_names_to_remove: | |
| name = first_word_names_to_remove[first_word] | |
| lem_dict[name]['one_word_searchable'] = False | |
| return lem_dict | |
| def lemmatize_dict(self, terms_dict): | |
| lem_dict = {} | |
| morph = pymorphy3.MorphAnalyzer() | |
| for name in terms_dict: | |
| if not name.isupper(): | |
| lem_name = morph.parse(name)[0].normal_form | |
| else: | |
| lem_name = name | |
| lem_dict[lem_name] = {} | |
| lem_dict[lem_name]['name'] = name | |
| lem_dict[lem_name]['definitions'] = terms_dict[name]['definitions'] | |
| lem_dict[lem_name]['titles'] = terms_dict[name]['titles'] | |
| lem_dict[lem_name]['sources'] = terms_dict[name]['sources'] | |
| lem_dict[lem_name]['is_multi_def'] = terms_dict[name]['is_multi_def'] | |
| lem_dict[lem_name]['one_word_searchable'] = True | |
| lem_dict = self.mark_for_one_word_dict(lem_dict) | |
| return lem_dict | |
| def separate_one_word_searchable_dict(lem_dict): | |
| lem_dict_fast = {} | |
| lem_dict_slow = {} | |
| for name in lem_dict: | |
| if lem_dict[name]['one_word_searchable']: | |
| lem_dict_fast[name] = {} | |
| lem_dict_fast[name]['name'] = lem_dict[name]['name'] | |
| lem_dict_fast[name]['definitions'] = lem_dict[name]['definitions'] | |
| lem_dict_fast[name]['titles'] = lem_dict[name]['titles'] | |
| lem_dict_fast[name]['sources'] = lem_dict[name]['sources'] | |
| lem_dict_fast[name]['is_multi_def'] = lem_dict[name]['is_multi_def'] | |
| else: | |
| lem_dict_slow[name] = {} | |
| lem_dict_slow[name]['name'] = lem_dict[name]['name'] | |
| lem_dict_slow[name]['definitions'] = lem_dict[name]['definitions'] | |
| lem_dict_slow[name]['titles'] = lem_dict[name]['titles'] | |
| lem_dict_slow[name]['sources'] = lem_dict[name]['sources'] | |
| lem_dict_slow[name]['is_multi_def'] = lem_dict[name]['is_multi_def'] | |
| return lem_dict_fast, lem_dict_slow | |
| def extract_original_phrase(original_text, lemmatized_text, lemmatized_phrase): | |
| words = original_text.split() | |
| words_lem = lemmatized_text.split() | |
| words_lem_phrase = lemmatized_phrase.split() | |
| for i, word in enumerate(words_lem): | |
| if word == words_lem_phrase[0]: | |
| words_full = ' '.join(words_lem[i:i + len(words_lem_phrase)]) | |
| if words_full == lemmatized_phrase: | |
| original_phrase = ' '.join(words[i:i + len(words_lem_phrase)]) | |
| return original_phrase | |
| return False | |
| def substitute_definitions(self, original_text, lem_dict, lem_dict_fast, lem_dict_slow, for_llm=False): | |
| lemmatized_text = self.lemmatize_query(original_text) | |
| found_phrases = set() | |
| phrases_to_add1 = [] | |
| phrases_to_add2 = [] | |
| words = lemmatized_text.split() | |
| sorted_lem_dict = sorted(lem_dict_slow.items(), key=lambda x: len(x[0]), | |
| reverse=True) # можно сэкономить милисекунды и вынести сортировку по длине куда-то наружу | |
| for lemmatized_phrase_tuple in sorted_lem_dict: | |
| lemmatized_phrase = lemmatized_phrase_tuple[0] | |
| is_new_phrase = True | |
| is_one_word = True | |
| lem_phrase_words = lemmatized_phrase.split() | |
| if len(lem_phrase_words) > 1: | |
| is_one_word = False | |
| if lemmatized_phrase in lemmatized_text and not is_one_word: | |
| if lemmatized_phrase in found_phrases: | |
| is_new_phrase = False | |
| else: | |
| found_phrases.add(lemmatized_phrase) | |
| original_phrase = self.extract_original_phrase(original_text, lemmatized_text, lemmatized_phrase) | |
| phrases_to_add2.append((lemmatized_phrase, original_phrase)) | |
| if is_one_word and lemmatized_phrase in words: | |
| for phrase in found_phrases: | |
| if lemmatized_phrase in phrase: | |
| is_new_phrase = False | |
| if is_new_phrase: | |
| found_phrases.add(lemmatized_phrase) | |
| original_phrase = self.extract_original_phrase(original_text, lemmatized_text, lemmatized_phrase) | |
| phrases_to_add2.append((lemmatized_phrase, original_phrase)) | |
| for word in words: | |
| is_new_phrase = True | |
| if word in lem_dict_fast: | |
| for phrase in found_phrases: | |
| if word in phrase: | |
| is_new_phrase = False | |
| break | |
| if is_new_phrase: | |
| found_phrases.add(word) | |
| original_phrase = self.extract_original_phrase(original_text, lemmatized_text, word) | |
| phrases_to_add1.append((word, original_phrase)) | |
| phrases_to_add = phrases_to_add1 + phrases_to_add2 | |
| definition_num = 0 | |
| definitions_info = [] | |
| substituted_text = original_text | |
| try: | |
| if for_llm: | |
| for term, original_phrase in phrases_to_add: | |
| if lem_dict[term]['is_multi_def']: | |
| definition_num = 0 # Здесь может быть логика контекстно-зависимого выбора нужного определения | |
| term_start = original_text.find(original_phrase) | |
| if type(lem_dict[term]['definitions']) is list: | |
| definitions_info.append(f"{term}-{lem_dict[term]['definitions'][definition_num]}") | |
| else: | |
| definitions_info.append(f"{term}-{lem_dict[term]['definitions']}") | |
| if definitions_info: | |
| definitions_str = ", ".join(definitions_info) | |
| substituted_text = f"{original_text}. Дополнительная информация: {definitions_str}" | |
| else: | |
| substituted_text = original_text | |
| else: | |
| for term, original_phrase in phrases_to_add: | |
| if lem_dict[term]['is_multi_def']: | |
| # Здесь может быть логика контекстно-зависимого выбора нужного определения | |
| definition_num = 0 | |
| term_start = substituted_text.find(original_phrase) | |
| if type(lem_dict[term]['definitions']) is list: | |
| substituted_text = substituted_text[:term_start + len( | |
| original_phrase)] + f" ({lem_dict[term]['definitions'][definition_num]})" + substituted_text[ | |
| term_start + len( | |
| original_phrase):] | |
| else: | |
| substituted_text = substituted_text[:term_start + len( | |
| original_phrase)] + f" ({lem_dict[term]['definitions']})" + substituted_text[ | |
| term_start + len( | |
| original_phrase):] | |
| except Exception as e: | |
| print(f'error processing\n {original_text}\n {term}: {e}') | |
| return substituted_text, phrases_to_add | |
| def filter_by_types(self, | |
| pred: list[str] = None, | |
| scores: list[float] = None, | |
| indexes: list[int] = None, | |
| docs_embeddings: list = None, | |
| ctgs: dict = None): | |
| ctgs = [ctg for ctg in ctgs.keys() if ctgs[ctg]] | |
| filtred_pred, filtred_scores, filtred_indexes, filtred_docs_embeddings = [], [], [], [] | |
| for doc_name, score, index, doc_embedding in zip(pred, scores, indexes, docs_embeddings): | |
| if ('ВНД' in doc_name and 'ВНД' in ctgs) or self.all_docs_with_accounting[index]['doc_type'] in ctgs: | |
| filtred_pred.append(doc_name) | |
| filtred_scores.append(score) | |
| filtred_indexes.append(index) | |
| filtred_docs_embeddings.append(doc_embedding) | |
| return filtred_pred, filtred_scores, filtred_indexes, filtred_docs_embeddings | |
| def get_types_of_docs(self, all_docs): | |
| def type_determiner(doc_name): | |
| names = ['НКРФ', 'ГКРФ', 'ТКРФ', 'Федеральный закон', 'Письмо Минфина', 'Письмо ФНС', 'Приказ ФНС', | |
| 'Постановление Правительства', 'Судебный документ', 'ВНД', 'Бухгалтерский документ'] | |
| for ctg in list(names): | |
| if ctg in doc_name: | |
| return ctg | |
| for doc in all_docs: | |
| doc_type = type_determiner(doc['doc_name']) | |
| doc['doc_type'] = doc_type | |
| return all_docs | |
| def load_model(self): | |
| if hf_token and hf_model_name: | |
| self.tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_auth_token=True) | |
| self.model = AutoModel.from_pretrained(hf_model_name, use_auth_token=True).to(self.device) | |
| else: | |
| self.tokenizer = AutoTokenizer.from_pretrained(global_model_path) | |
| self.model = AutoModel.from_pretrained(global_model_path).to(self.device) | |
| self.max_len = self.tokenizer.max_len_single_sentence | |
| self.embedding_dim = self.model.config.hidden_size | |
| def load_data(self): | |
| with open(dict_path, "rb") as f: | |
| self.terms_dict = pickle.load(f) | |
| with open(companies_map_path, "r", encoding='utf-8') as f: | |
| self.companies_map = json.load(f) | |
| with open(general_nmd_path, "r", encoding='utf-8') as f: | |
| self.general_nmd = json.load(f) | |
| with open(explanations_for_llm_path, "r", encoding='utf-8') as f: | |
| self.explanations_for_llm = json.load(f) | |
| with open(rules_list_path, 'r', encoding='utf-8') as f: | |
| self.rules_list = f.read().splitlines() | |
| self.all_docs_info = dataset.load_from_disk(data_path_consult).to_list() # ONLY EXTERNAL DOCS | |
| self.internal_docs = dataset.load_from_disk(internal_docs_data_path).to_list() | |
| self.accounting_docs = dataset.load_from_disk(accounting_data_path).to_list() | |
| self.spec_internal_docs = dataset.load_from_disk(spec_internal_docs_data_path).to_list() | |
| self.all_docs_with_accounting = self.all_docs_info + self.accounting_docs | |
| self.all_docs_with_accounting = self.get_types_of_docs(self.all_docs_with_accounting) | |
| self.type_weights_nu = {'НКРФ': 1, | |
| 'ТКРФ': 1, | |
| 'ГКРФ': 1, | |
| 'Письмо Минфина': 0.9, | |
| 'Письмо ФНС': 0.6, | |
| 'Приказ ФНС': 1, | |
| 'Постановление Правительства': 1, | |
| 'Федеральный закон': 0.9, | |
| 'Судебный документ': 0.2, | |
| 'ВНД': 0.2, | |
| 'Бухгалтерский документ': 0.7, | |
| 'Закон Красноярского края': 1.2, | |
| 'Правила заполнения': 1.2, | |
| 'Правила ведения': 1.2} | |
| self.all_consultations = dataset.load_from_disk(consultations_dataset_path).to_list() | |
| self.all_explanations = dataset.load_from_disk(explanations_dataset_path).to_list() | |
| def remove_duplicate_paragraphs(paragraphs): | |
| unique_paragraphs = [] | |
| seen = set() | |
| for paragraph in paragraphs: | |
| stripped_paragraph = paragraph.strip() | |
| if stripped_paragraph and stripped_paragraph not in seen: | |
| unique_paragraphs.append(paragraph) | |
| seen.add(stripped_paragraph) | |
| return '\n'.join(unique_paragraphs) | |
| def construct_base(idx_list, base): | |
| concatenated_text = "" | |
| seen_ids = set() | |
| pattern = re.compile(r'_(\d{1,3})') | |
| def find_overlap(a: str, b: str) -> int: | |
| max_overlap = min(len(a), len(b)) | |
| for i in range(max_overlap, 0, -1): | |
| if a[-i:] == b[:i]: | |
| return i | |
| return 0 | |
| def add_ellipsis(text: str) -> str: | |
| if not text: | |
| return text | |
| segments = text.split('\n\n') | |
| processed_segments = [] | |
| for segment in segments: | |
| if segment and not ( | |
| segment[0].isupper() or segment[0].isdigit() or segment[0] in ['•', '-', '—', '.']): | |
| segment = '...' + segment | |
| if segment and not (segment.endswith('.') or segment.endswith(';')): | |
| segment += '...' | |
| processed_segments.append(segment) | |
| return '\n\n'.join(processed_segments) | |
| for current_index in idx_list: | |
| if current_index in seen_ids: | |
| continue | |
| start_index = max(0, current_index - 2) | |
| end_index = min(len(base), current_index + 3) | |
| current_name_base = pattern.sub('', base[current_index]['doc_name']) | |
| current_doc_text = base[current_index]['doc_text'] | |
| texts_to_concatenate = [current_doc_text] | |
| for i in range(current_index - 1, start_index - 1, -1): | |
| if i in seen_ids: | |
| continue | |
| surrounding_name_base = pattern.sub('', base[i]['doc_name']) | |
| if current_name_base != surrounding_name_base: | |
| break | |
| surrounding_text = base[i]['doc_text'] | |
| overlap_length = find_overlap(surrounding_text, texts_to_concatenate[0]) | |
| if overlap_length == 0: | |
| break | |
| new_text = surrounding_text + texts_to_concatenate[0][overlap_length:] | |
| texts_to_concatenate[0] = new_text | |
| seen_ids.add(i) | |
| for i in range(current_index + 1, end_index): | |
| if i in seen_ids: | |
| continue | |
| surrounding_name_base = pattern.sub('', base[i]['doc_name']) | |
| if current_name_base != surrounding_name_base: | |
| break | |
| surrounding_text = base[i]['doc_text'] | |
| overlap_length = find_overlap(texts_to_concatenate[-1], surrounding_text) | |
| if overlap_length == 0: | |
| break | |
| new_text = texts_to_concatenate[-1] + surrounding_text[overlap_length:] | |
| texts_to_concatenate[-1] = new_text | |
| seen_ids.add(i) | |
| combined_text = ' '.join(texts_to_concatenate) | |
| concatenated_text += combined_text + '\n\n' | |
| seen_ids.add(current_index) | |
| concatenated_text = add_ellipsis(concatenated_text) | |
| return concatenated_text.rstrip('\n') | |
| def search_results_multiply_weights(self, | |
| pred: list[str] = None, | |
| scores: list[float] = None, | |
| indexes: list[int] = None, | |
| docs_embeddings: list = None) -> tuple[list[str], list[float], list[int], list]: | |
| if pred is None or scores is None or indexes is None or docs_embeddings is None: | |
| return [], [], [], [] | |
| weights = self.type_weights_nu | |
| weighted_scores = [(weights.get(ctg, 0) * score, prediction, idx, emb) | |
| for prediction, score, idx, emb in zip(pred, scores, indexes, docs_embeddings) | |
| for ctg in weights if ctg in prediction] | |
| weighted_scores.sort(reverse=True, key=lambda x: x[0]) | |
| if weighted_scores: | |
| sorted_scores, sorted_preds, sorted_indexes, sorted_docs_embeddings = zip(*weighted_scores) | |
| else: | |
| sorted_scores, sorted_preds, sorted_indexes, sorted_docs_embeddings = [], [], [], [] | |
| return list(sorted_preds), list(sorted_scores), list(sorted_indexes), list(sorted_docs_embeddings) | |
| def get_uniq_relevant_docs(self, | |
| top_k: int, | |
| query_refs_all: list[str], | |
| scores: list[float], | |
| indexes: list[int], | |
| docs_embeddings: list[list[float]] | |
| ) -> tuple[dict[str, list[str]], dict[str, list[float]], dict[str, list[int]], dict[str, list[list[float]]]]: | |
| regex = r'_\d{1,3}' | |
| base_ref_dict = {} | |
| for i, ref in enumerate(query_refs_all): | |
| base_ref = re.sub(regex, '', ref) | |
| base_ref = base_ref.strip() | |
| if base_ref not in base_ref_dict: | |
| if len(base_ref_dict) >= top_k: | |
| continue | |
| base_ref_dict[base_ref] = { | |
| 'refs': [], | |
| 'scores': [], | |
| 'indexes': [], | |
| 'embeddings': [] | |
| } | |
| base_ref_dict[base_ref]['refs'].append(ref) | |
| base_ref_dict[base_ref]['scores'].append(scores[i]) | |
| base_ref_dict[base_ref]['indexes'].append(indexes[i]) | |
| base_ref_dict[base_ref]['embeddings'].append(docs_embeddings[i]) | |
| def get_suffix_number(ref: str): | |
| match = re.findall(regex, ref) | |
| if match: | |
| match = re.findall(regex, ref)[0].replace('_', '') | |
| return int(match) | |
| return None | |
| for base_ref, data in base_ref_dict.items(): | |
| refs = data['refs'] | |
| scores_list = data['scores'] | |
| indexes_list = data['indexes'] | |
| embeddings_list = data['embeddings'] | |
| combined = list(zip(refs, scores_list, indexes_list, embeddings_list)) | |
| def sort_key(item): | |
| ref = item[0] | |
| suffix = get_suffix_number(ref) | |
| return (0 if suffix is None else 1, suffix if suffix is not None else -1) | |
| combined_sorted = sorted(combined, key=sort_key) | |
| sorted_refs, sorted_scores, sorted_indexes, sorted_embeddings = zip(*combined_sorted) | |
| base_ref_dict[base_ref]['refs'] = list(sorted_refs)[:20] | |
| base_ref_dict[base_ref]['scores'] = list(sorted_scores)[:20] | |
| base_ref_dict[base_ref]['indexes'] = list(sorted_indexes)[:20] | |
| base_ref_dict[base_ref]['embeddings'] = list(sorted_embeddings)[:20] | |
| unique_refs = {k: v['refs'] for k, v in base_ref_dict.items()} | |
| filtered_scores = {k: v['scores'] for k, v in base_ref_dict.items()} | |
| filtered_indexes = {k: v['indexes'] for k, v in base_ref_dict.items()} | |
| filtered_docs_embeddings = {k: v['embeddings'] for k, v in base_ref_dict.items()} | |
| return unique_refs, filtered_scores, filtered_indexes, filtered_docs_embeddings | |
| def filter_results(self, pred_internal, scores_internal, indices_internal, docs_embeddings_internal, companies_files): | |
| filt_pred_internal, filt_scores_internal, \ | |
| filt_indices_internal, filt_docs_embeddings_internal = list(), list(), list(), list() | |
| def add_data(pred, ind, score, emb): | |
| filt_pred_internal.append(pred) | |
| filt_indices_internal.append(ind) | |
| filt_scores_internal.append(score) | |
| filt_docs_embeddings_internal.append(emb) | |
| for pred, score, ind, emb in zip(pred_internal, scores_internal, indices_internal, docs_embeddings_internal): | |
| if [doc for doc in self.general_nmd if doc in pred]: | |
| add_data(pred, ind, score, emb) | |
| continue | |
| for company in companies_files: | |
| if company in pred: | |
| add_data(pred, ind, score, emb) | |
| return filt_pred_internal, filt_scores_internal, filt_indices_internal, filt_docs_embeddings_internal | |
| def merge_dictionaries(dicts: list = None): | |
| merged_dict = {} | |
| max_length = max(len(d) for d in dicts) | |
| for i in range(max_length): | |
| for d in dicts: | |
| keys = list(d.keys()) | |
| values = list(d.values()) | |
| if i < len(keys): | |
| merged_dict[keys[i]] = values[i] | |
| return merged_dict | |
| def check_specific_key(dictionary, key): | |
| if key in dictionary and dictionary[key] is True: | |
| for k, v in dictionary.items(): | |
| if k != key and v is True: | |
| return False | |
| return True | |
| return False | |
| def remove_duplicates(input_list): | |
| unique_dict = {} | |
| for item in input_list: | |
| unique_dict[item] = None | |
| return list(unique_dict.keys()) | |
| async def search_engine(self, | |
| query: str = None, | |
| use_qe: bool = False, | |
| categories: dict = None, | |
| llm_params: LlmParams = None): | |
| if True in list(categories.values()) and not all(categories.values()): | |
| self.full_base_search = False | |
| if self.check_specific_key(categories, 'ВНД'): | |
| nmd_chunks = 120 | |
| nmd_refs = 45 | |
| extra_chunks = 1 | |
| extra_refs = 1 | |
| elif not categories['ВНД']: | |
| extra_chunks = 120 | |
| extra_refs = 45 | |
| nmd_chunks = 1 | |
| nmd_refs = 1 | |
| else: | |
| nmd_chunks = 60 | |
| nmd_refs = 23 | |
| extra_chunks = 60 | |
| extra_refs = 23 | |
| else: | |
| self.full_base_search = True | |
| nmd_chunks = 50 | |
| nmd_refs = 15 | |
| extra_chunks = 75 | |
| extra_refs = 30 | |
| # Ответы от ллм для отправки на фронт | |
| llm_responses = [] | |
| # Токенизация и векторизация запроса | |
| query_tokens = query_tokenization(query, self.tokenizer) | |
| query_embeds = query_embed_extraction(query_tokens, self.model, self.do_normalization) | |
| # Поиск по базе документов внешней | |
| distances, indices = self.index_all_docs_with_accounting.search(query_embeds, len(self.all_docs_with_accounting)) | |
| pred = [self.all_docs_with_accounting[x]['doc_name'] for x in indices[0]] | |
| docs_embeddings = [self.all_docs_with_accounting[x]['doc_embedding'] for x in indices[0]] | |
| preds, scores, indexes, docs_embeddings = pred[:5000], list(distances[0])[:5000], \ | |
| list(indices[0])[:5000], docs_embeddings[:5000] | |
| if not re.search('[Кк]расноярск', query): | |
| self.type_weights_nu['Закон Красноярского края'] = 0 | |
| else: | |
| self.type_weights_nu['Закон Красноярского края'] = 1.2 | |
| if not use_rules(query, self.rules_list): | |
| self.type_weights_nu['Правила ведения'] = 0 | |
| self.type_weights_nu['Правила заполнения'] = 0 | |
| else: | |
| self.type_weights_nu['Правила ведения'] = 1.2 | |
| self.type_weights_nu['Правила заполнения'] = 1.2 | |
| preds, scores, indexes, docs_embeddings = pred[:5000], list(distances[0])[:5000], \ | |
| list(indices[0])[:5000], docs_embeddings[:5000] | |
| # Поиск по базе документов внутренних | |
| if self.full_base_search or categories['ВНД']: | |
| distances_internal, indices_internal = self.index_internal_docs.search(query_embeds, len(self.spec_internal_docs)) | |
| pred_internal = [self.spec_internal_docs[x]['doc_name'] for x in indices_internal[0]] | |
| docs_embeddings_internal = [self.spec_internal_docs[x]['doc_embedding'] for x in indices_internal[0]] | |
| indices_internal = indices_internal[0] | |
| scores_internal = [] | |
| for title, score in zip(pred_internal, distances_internal[0]): | |
| if 'КУП' in title: | |
| scores_internal.append(score*1.2) | |
| else: | |
| scores_internal.append(score) | |
| companies_files = search_company.find_nmd_docs(query, self.companies_map) | |
| pred_internal, scores_internal, indices_internal, docs_embeddings_internal = self.filter_results(pred_internal, | |
| scores_internal, | |
| indices_internal, | |
| docs_embeddings_internal, | |
| companies_files) | |
| combined = list(zip(pred_internal, scores_internal, indices_internal, docs_embeddings_internal)) | |
| sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True) | |
| top_nmd = sorted_combined[:nmd_chunks] | |
| if 'ЕГДС' in query: | |
| if not [x for x in top_nmd if 'п.5. Положение о КУП_262 (ВНД)' in x]: | |
| ch262 = self.internal_docs[22976] | |
| ch262 = (ch262['doc_name'], 1.0, 22976, ch262['chunks_embeddings'][0]) | |
| top_nmd.insert(0, ch262) | |
| if not [x for x in top_nmd if 'п.5. Положение о КУП_130 (ВНД)' in x]: | |
| ch130 = self.internal_docs[22844] | |
| ch130 = (ch130['doc_name'], 1.0, 22844, ch130['chunks_embeddings'][0]) | |
| top_nmd.insert(1, ch130) | |
| top_nmd = top_nmd[:nmd_chunks] | |
| preds_internal, scores_internal, indexes_internal, internal_docs_embeddings = zip(*top_nmd) | |
| preds_internal, scores_internal, indexes_internal, internal_docs_embeddings = list(preds_internal), \ | |
| list(scores_internal), \ | |
| list(indexes_internal), \ | |
| list(internal_docs_embeddings) | |
| # Сбор уникальных внутренних документов | |
| unique_preds_internal, unique_scores_internal, unique_indexes_internal, \ | |
| unique_docs_embeddings_internal = self.get_uniq_relevant_docs( | |
| top_k=nmd_refs, | |
| query_refs_all=preds_internal, | |
| scores=scores_internal, | |
| indexes=indexes_internal, | |
| docs_embeddings=internal_docs_embeddings) | |
| preds_internal, scores_internal, \ | |
| indexes_internal, internal_docs_embeddings = unique_preds_internal, unique_scores_internal,\ | |
| unique_indexes_internal, unique_docs_embeddings_internal | |
| # Фильтрация или не фильтрация по категориям по наличию отметок в чек-боксах | |
| if not self.full_base_search: | |
| preds, scores, indexes, docs_embeddings = self.filter_by_types(preds, scores, indexes, | |
| docs_embeddings, categories) | |
| # Использование весов поверх скоров | |
| sorted_preds, sorted_scores, sorted_indexes, sorted_docs_embeddings = self.search_results_multiply_weights( | |
| pred=preds, | |
| scores=scores, | |
| indexes=indexes, | |
| docs_embeddings=docs_embeddings) | |
| sorted_preds, sorted_scores, sorted_indexes, sorted_docs_embeddings = sorted_preds[:extra_chunks], \ | |
| sorted_scores[:extra_chunks], \ | |
| sorted_indexes[:extra_chunks], \ | |
| sorted_docs_embeddings[:extra_chunks] | |
| # Сбор уникальных документов внешних | |
| unique_preds, unique_scores, unique_indexes, unique_docs_embeddings = self.get_uniq_relevant_docs( | |
| top_k=extra_refs, | |
| query_refs_all=sorted_preds, | |
| scores=sorted_scores, | |
| indexes=sorted_indexes, | |
| docs_embeddings=sorted_docs_embeddings | |
| ) | |
| preds, scores, indexes, docs_embeddings = unique_preds, unique_scores, unique_indexes, unique_docs_embeddings | |
| if use_qe: | |
| try: | |
| prompt = LLM_PROMPT_KEYS.format(query=query) | |
| if llm_params is None: | |
| keyword_query = self.get_main_info_with_llm(prompt) | |
| else: | |
| llm_api = LlmApi(llm_params) | |
| keyword_query = await llm_api.predict(prompt) | |
| llm_responses.append(keyword_query) | |
| keyword_query = re.sub(r'\[1\].*?(?=\[\d+\]|$)', '', keyword_query, flags=re.DOTALL).replace(' [2]', '').replace('[3]', '').strip() | |
| keyword_query_tokens = query_tokenization(keyword_query, self.tokenizer) | |
| keyword_query_embeds = query_embed_extraction(keyword_query_tokens, | |
| self.model, | |
| self.do_normalization) | |
| keyword_distances, keyword_indices = self.index_all_docs_with_accounting.search( | |
| keyword_query_embeds, len(self.all_docs_with_accounting)) | |
| keyword_pred = [self.all_docs_with_accounting[x]['doc_name'] for x in keyword_indices[0]] | |
| keyword_docs_embeddings = [self.all_docs_with_accounting[x]['doc_embedding'] for x in | |
| keyword_indices[0]] | |
| if not self.full_base_search: | |
| keyword_preds, keyword_scores, \ | |
| keyword_indexes, keyword_docs_embeddings = self.filter_by_types(keyword_pred, | |
| keyword_distances[0], | |
| keyword_indices[0], | |
| keyword_docs_embeddings, | |
| categories) | |
| else: | |
| keyword_preds, keyword_scores, \ | |
| keyword_indexes, keyword_docs_embeddings = keyword_pred, keyword_distances[0], \ | |
| keyword_indices[0],keyword_docs_embeddings | |
| keyword_preds, keyword_scores, \ | |
| keyword_indexes, keyword_docs_embeddings = self.search_results_multiply_weights( | |
| pred=keyword_preds, scores=keyword_scores, | |
| indexes=keyword_indexes, docs_embeddings=keyword_docs_embeddings) | |
| keyword_unique_preds, keyword_unique_scores, \ | |
| keyword_unique_indexes, keyword_unique_docs_embeddings = self.get_uniq_relevant_docs( | |
| top_k=45, | |
| query_refs_all=keyword_preds, | |
| scores=keyword_scores, | |
| indexes=keyword_indexes, | |
| docs_embeddings=keyword_docs_embeddings) | |
| preds = dict(list(self.merge_dictionaries([preds, keyword_unique_preds]).items())[:30]) | |
| scores = dict(list(self.merge_dictionaries([scores, keyword_unique_scores]).items())[:30]) | |
| indexes = dict(list(self.merge_dictionaries([indexes, keyword_unique_indexes]).items())[:30]) | |
| except: | |
| traceback.print_exc() | |
| print(f"Error applying keys (possibly the LLM is not available)") | |
| if self.full_base_search or categories['ВНД']: | |
| # Внесение внутренних топ-10 документов в выдачу | |
| if self.full_base_search or categories['ВНД']: | |
| preds = self.merge_dictionaries([preds, preds_internal]) | |
| scores = self.merge_dictionaries([scores, scores_internal]) | |
| indexes = self.merge_dictionaries([indexes, indexes_internal]) | |
| # Красивая сборка чанков для LLM | |
| texts_for_llm, docs, teasers = [], [], [] | |
| for key, idx_list in indexes.items(): | |
| collected_text = [] | |
| if 'ВНД' in key: | |
| base = self.internal_docs | |
| else: | |
| base = self.all_docs_with_accounting | |
| if re.search('Минфин|Бухгалтерский документ|ФНС|Судебный документ|Постановление Правительства|Федеральный закон', key): | |
| text = self.construct_base(idx_list, base) | |
| collected_text.append(text) | |
| else: | |
| for idx in idx_list: | |
| if idx < len(base): | |
| for text in base[idx]['doc_text'].split('\n'): | |
| collected_text.append(text) | |
| collected_text = self.remove_duplicate_paragraphs(collected_text) | |
| texts_for_llm.append(collected_text) | |
| # Поиск релевантных консультаций | |
| distances_consult, indices_consult = self.index_consult.search(query_embeds, len(self.all_consultations)) | |
| predicted_consultations = {self.all_consultations[x]['doc_name']: self.all_consultations[x]['doc_text'] | |
| for x in indices_consult[0]} | |
| # Поиск релевантных разъяснений | |
| distances_explanations, indices_explanations = self.index_explanations.search(query_embeds, len(self.all_explanations)) | |
| predicted_explanations = {self.all_explanations[x]['doc_name']: self.all_explanations[x]['doc_text'] | |
| for x in indices_explanations[0]} | |
| results = list(zip(list(predicted_explanations.keys()), | |
| list(predicted_explanations.values()), | |
| distances_explanations[0])) | |
| explanation_titles = self.rerank_by_avg_score(results)[:3] | |
| try: | |
| predicted_explanation = {explanation_title: self.explanations_for_llm[explanation_title] for explanation_title in explanation_titles} | |
| except: | |
| predicted_explanation = {} | |
| print('The relevant document was not found in the system.') | |
| return query, [x.replace('ФЕДЕРАЛЬНЫЙ СТАНДАРТ БУХГАЛТЕРСКОГО УЧЕТА', 'Федеральный стандарт бухгалтерского учета ФСБУ') for x in list(preds.keys())], texts_for_llm, dict(list(predicted_consultations.items())[:def_k]), \ | |
| predicted_explanation, llm_responses | |
| async def olympic_branch(self, | |
| query: str = None, | |
| sources: dict = None, | |
| categories: dict = None, | |
| llm_params: LlmParams = None): | |
| # Собираем все ответы ллм для отправки на фронт | |
| llm_responses = [] | |
| text = await self.olymp_think(query, sources, llm_params) | |
| llm_responses.append(text) | |
| saved_sources = {} | |
| saved_step_by_step = [] | |
| comment1, sources_choice = self.parse_step(text) | |
| sources_choice = [source - 1 for source in sources_choice] | |
| for idx, ref in enumerate(sources): | |
| if idx in sources_choice and ref not in saved_sources.keys(): | |
| saved_sources.update({ref: sources[ref]}) | |
| should_continue = True | |
| if comment1 == '': | |
| count = 4 | |
| count = 0 | |
| while count < 4: | |
| query, preds, \ | |
| texts_for_llm, predicted_consultations, \ | |
| predicted_explanation, skip_llm_responses = await self.search_engine(query, use_qe=False, categories=categories) | |
| sources = dict(map(lambda i,j: (i,j), preds, texts_for_llm)) | |
| sources = dict(islice(sources.items(), 20)) | |
| text = await self.olymp_think(query, sources, llm_params) | |
| llm_responses.append(text) | |
| comment2, sources_choice = self.parse_step(text) | |
| sources_choice = [source - 1 for source in sources_choice] | |
| saved_step_by_step.append(sources_choice) | |
| for idx, ref in enumerate(sources): | |
| if idx in sources_choice and ref not in saved_sources.keys(): | |
| saved_sources.update({ref: sources[ref]}) | |
| if comment2 == '': | |
| break | |
| comment1 = comment2 | |
| count += 1 | |
| return saved_sources, saved_step_by_step, llm_responses | |
| async def search(self, | |
| query: str = None, | |
| use_qe: bool = False, | |
| use_olympic: bool = False, | |
| categories: dict = None, | |
| llm_params: LlmParams = None): | |
| # Преобразование запроса | |
| lem_dict = self.lemmatize_dict(self.terms_dict) | |
| lem_dict_fast, lem_dict_slow = self.separate_one_word_searchable_dict(lem_dict) | |
| query_for_llm, _ = self.substitute_definitions(query, lem_dict, lem_dict_fast, lem_dict_slow, for_llm=True) | |
| query, _ = self.substitute_definitions(query, lem_dict, lem_dict_fast, lem_dict_slow, for_llm=False) | |
| # Базовый поиск | |
| query, base_preds, base_texts_for_llm, \ | |
| predicted_consultations, predicted_explanation, llm_responses = await self.search_engine(query, use_qe, categories, llm_params) | |
| if use_olympic: | |
| sources = dict(map(lambda i,j: (i,j), base_preds, base_texts_for_llm)) | |
| sources = dict(islice(sources.items(), 20)) | |
| olymp_results, olymp_step_by_step, llm_responses = await self.olympic_branch(query, sources, categories, llm_params) | |
| olymp_preds, olymp_texts_for_llm = list(olymp_results.keys()), list(olymp_results.values()) | |
| if len(olymp_preds) <= 45: | |
| preds = olymp_preds + base_preds | |
| preds = self.remove_duplicates(preds)[:45] | |
| texts_for_llm = olymp_texts_for_llm + base_texts_for_llm | |
| texts_for_llm = self.remove_duplicates(texts_for_llm)[:45] | |
| return query_for_llm, preds, texts_for_llm, predicted_consultations, predicted_explanation, llm_responses | |
| else: | |
| olymp_results = self.merge_dictionaries(olymp_step_by_step)[:45] | |
| preds, texts_for_llm = list(olymp_results.keys()), list(olymp_results.values()) | |
| return query_for_llm, preds, texts_for_llm, predicted_consultations, predicted_explanation, llm_responses | |
| else: | |
| return query_for_llm, base_preds, base_texts_for_llm, predicted_consultations, predicted_explanation, llm_responses | |