import os from _utils.langchain_utils.LLM_class import LLM from typing import Any, List, Dict, Tuple, Optional, Union, cast from anthropic import Anthropic, AsyncAnthropic import logging from langchain.schema import Document from llama_index import Document as Llama_Index_Document import asyncio from typing import List from dataclasses import dataclass from _utils.gerar_documento_utils.llm_calls import ( aclaude_answer, agemini_answer, agpt_answer, ) from _utils.gerar_documento_utils.prompts import contextual_prompt from _utils.models.gerar_documento import ( ContextualizedChunk, DocumentChunk, RetrievalConfig, ) from langchain_core.messages import HumanMessage from gerar_documento.serializer import ( GerarDocumentoComPDFProprioSerializerData, GerarDocumentoSerializerData, ) from setup.logging import Axiom import re class ContextualRetriever: def __init__( self, serializer: Union[ GerarDocumentoSerializerData, GerarDocumentoComPDFProprioSerializerData, Any ], ): self.lista_contador = [] self.contextual_retriever_utils = ContextualRetrieverUtils() self.config = RetrievalConfig( num_chunks=serializer.num_chunks_retrieval, embedding_weight=serializer.embedding_weight, bm25_weight=serializer.bm25_weight, context_window=serializer.context_window, chunk_overlap=serializer.chunk_overlap, ) self.logger = logging.getLogger(__name__) self.bm25 = None self.claude_context_model = serializer.claude_context_model self.claude_api_key = os.environ.get("CLAUDE_API_KEY", "") self.claude_client = AsyncAnthropic(api_key=self.claude_api_key) async def llm_call_uma_lista_de_20_chunks( self, lista_com_20_chunks: List[DocumentChunk], resumo_auxiliar, axiom_instance: Axiom, ) -> List[List[str]]: all_chunks_contents, all_document_ids = ( self.contextual_retriever_utils.get_all_document_ids_and_contents( lista_com_20_chunks ) ) send_axiom = axiom_instance.send_axiom utils = self.contextual_retriever_utils try: prompt = contextual_prompt( resumo_auxiliar, all_chunks_contents, len(lista_com_20_chunks) ) result = None for attempt in range(4): if attempt != 0: send_axiom( f"------------- FORMATAÇÃO DO CONTEXTUAL INCORRETA - TENTANDO NOVAMENTE (TENTATIVA: {attempt + 1}) -------------" ) send_axiom( f"COMEÇANDO UMA REQUISIÇÃO DO CONTEXTUAL - TENTATIVA {attempt + 1}" ) raw_response = await agemini_answer(prompt, "gemini-2.0-flash-lite") response = cast(str, raw_response) send_axiom( f"TERMINOU UMA REQUISIÇÃO DO CONTEXTUAL - TENTATIVA {attempt + 1}" ) matches = utils.validate_many_chunks_in_one_request( response, all_document_ids ) if matches: send_axiom( f"VALIDAÇÃO DO CONTEXTUAL FUNCIONOU NA TENTATIVA {attempt + 1} (ou seja, a função validate_many_chunks_in_one_request)" ) result = utils.get_info_from_validated_chunks(matches) break if result is None: axiom_instance.send_axiom( f"-------------- UMA LISTA DE 20 CHUNKS FALHOU AS 4x NA FORMATAÇÃO ------------- ÚLTIMO RETORNO ERRADO: {response}" ) result = [[""]] # default value if no iteration succeeded return result except Exception as e: self.logger.error(f"Context generation failed for chunks .... : {str(e)}") return [[""]] async def contextualize_uma_lista_de_20_chunks( self, lista_com_20_chunks: List[DocumentChunk], response_auxiliar_summary, axiom_instance: Axiom, ): self.lista_contador.append(0) print("contador: ", len(self.lista_contador)) result = await self.llm_call_uma_lista_de_20_chunks( lista_com_20_chunks, response_auxiliar_summary, axiom_instance ) lista_chunks: List[ContextualizedChunk] = [] try: for index, chunk in enumerate(lista_com_20_chunks): lista_chunks.append( ContextualizedChunk( contextual_summary=result[index][2], content=chunk.content, page_number=chunk.page_number, id_do_processo=int(result[index][0]), chunk_id=chunk.chunk_id, start_char=chunk.start_char, end_char=chunk.end_char, context=result[index][1], ) ) except BaseException as e: axiom_instance.send_axiom( f"ERRO EM UMA LISTA COM 20 CHUNKS CONTEXTUALS --------- lista: {lista_com_20_chunks} ------------ ERRO: {e}" ) return lista_chunks async def contextualize_all_chunks( self, all_PDFs_chunks: List[DocumentChunk], response_auxiliar_summary, axiom_instance: Axiom, ) -> List[ContextualizedChunk]: """Add context to all chunks""" lista_de_listas_cada_com_20_chunks = ( self.contextual_retriever_utils.get_lista_de_listas_cada_com_20_chunks( all_PDFs_chunks ) ) async with asyncio.TaskGroup() as tg: def processa_uma_lista_de_20_chunks( lista_com_20_chunks: List[DocumentChunk], ): coroutine = self.contextualize_uma_lista_de_20_chunks( lista_com_20_chunks, response_auxiliar_summary, axiom_instance ) return tg.create_task(coroutine) tasks = [ processa_uma_lista_de_20_chunks(lista_com_20_chunks) for lista_com_20_chunks in lista_de_listas_cada_com_20_chunks ] contextualized_chunks: List[ContextualizedChunk] = [] for task in tasks: contextualized_chunks = contextualized_chunks + task.result() axiom_instance.send_axiom( "TERMINOU COM SUCESSO DE PROCESSAR TUDO DOS CONTEXTUALS" ) return contextualized_chunks @dataclass class ContextualRetrieverUtils: def get_all_document_ids_and_contents( self, lista_com_20_chunks: List[DocumentChunk] ): contador = 1 all_chunks_contents = "" all_document_ids = [] for chunk in lista_com_20_chunks: all_chunks_contents += f"\n\nCHUNK {contador}:\n" all_chunks_contents += chunk.content pattern = r"Num\. (\d+)" import re match = re.search(pattern, chunk.content) if match: number = match.group(1) # Extract the number else: number = 0 all_document_ids.append(int(number)) contador += 1 return all_chunks_contents, all_document_ids def get_info_from_validated_chunks(self, matches): result = [ [int(doc_id), title.strip(), content.strip()] for doc_id, title, content in matches ] return result def get_lista_de_listas_cada_com_20_chunks( self, all_PDFs_chunks: List[DocumentChunk] ): return [all_PDFs_chunks[i : i + 20] for i in range(0, len(all_PDFs_chunks), 20)] def validate_many_chunks_in_one_request( self, response: str, lista_de_document_ids: List[int] ): context = ( response.replace("document_id: ", "") .replace("document_id:", "") .replace("DOCUMENT_ID: ", "") .replace("DOCUMENT_ID: ", "") ) # pattern = r"\[(\d+|[-.]+)\] --- (.+?) --- (.+?)" # Funciona para quando a resposta do LLM não vem com "document_id" escrito matches = self.check_regex_patterns(context, lista_de_document_ids) if not matches or len(matches) == 0: print( "----------- ERROU NA TENTATIVA ATUAL DE FORMATAR O CONTEXTUAL -----------" ) return False matches_as_list = [] for index, match in enumerate(list(matches)): if index >= 20: break resultado = match[0].replace(".", "").replace("-", "") resultado = lista_de_document_ids[index] matches_as_list.append((resultado, match[1], match[2])) return matches_as_list def check_regex_patterns(self, context: str, lista_de_document_ids: List[int]): patterns = [ r"\[(.*?)\] --- \[(.*?)\] --- \[(.*?)\](?=\n|\s*$)", r"\[([^\[\]]+?)\]\s*---\s*\[([^\[\]]+?)\]\s*---\s*(.*?)", r"\s*(\d+)(?:\s*-\s*Pág\.\s*\d+)?\s*---\s*([^-\n]+)\s*---\s*([^<]+)", r"\s*(?:\[*([\d]+)\]*\s*[-–]*\s*(?:Pág\.\s*\d+\s*[-–]*)?)?\s*\[*([^\]]+)\]*\s*[-–]*\s*\[*([^\]]+)\]*\s*[-–]*\s*\[*([^\]]+)\]*\s*", r"\s*(.*?)\s*---\s*(.*?)\s*---\s*(.*?)\s*", # -------------- ABAIXO SÃO OS ANTIGOS # r"\[*([\d.\-]+)\]*\s*---\s*\[*([^]]+)\]*\s*---\s*\[*([^]]+)\]*\s*", # PRIMEIRO DE TODOS # r"\s*([\d.\-]+)\s*---\s*([^<]+)\s*---\s*([^<]+)\s*", # r"\[([\d.\-]+)\]\s*---\s*\[([^]]+)\]\s*---\s*\[([^]]+)\]\s*", # r"\s*\[?([\d.\-]+)\]?\s*---\s*\[?([^\]\[]+?)\]?\s*---\s*\[?([^<]+?)\]?\s*", # r"\s*\[([\d.\-]+)\]\s*---\s*\[([^\]]+)\]\s*---\s*\[([^\]]+)\]\s*" # r"\s*\[?([\d.\-\s]+)\]?\s*---\s*\[?([^\]\[]+?)\]?\s*---\s*\[?([\s\S]+?)\]?\s*", ] resultado = None for pattern in patterns: matches: List[str] = re.findall(pattern, context, re.DOTALL) condition_tuples_3_items = all(len(m) == 3 for m in matches) if len(matches) == len(lista_de_document_ids) and condition_tuples_3_items: print("\n--------------- REGEX DO CONTEXTUAL FUNCIONOU") resultado = [] for m in matches: regex = r"Num\.\s*(\d+)\s*-" page_id = re.search(regex, m[0]) if page_id: first_item = page_id.group(1) else: first_item = "0" resultado.append((first_item, m[1], m[2])) break if not resultado: context = ( context.replace("", "") .replace("", "") .strip() ) raw_chunks = context.split("")[0:20] resultado_temporario = [] for r in raw_chunks: lista_3_itens = r.split("---") page_id = re.search(r"Num\.\s*(\d+)\s*-", lista_3_itens[0].strip()) page_id_tentativa_2 = re.search( r"\d+\.\s+(\d+)\s+-\s+Pág\.", lista_3_itens[0].strip() ) if page_id: first_item = page_id.group(1) elif page_id_tentativa_2: first_item = page_id_tentativa_2.group(1) else: first_item = "0" resultado_temporario.append( (first_item, lista_3_itens[1], lista_3_itens[2]) ) condition_tuples_3_items = all(len(m) == 3 for m in resultado_temporario) if ( len(resultado_temporario) == len(lista_de_document_ids) and condition_tuples_3_items ): resultado = resultado_temporario return resultado # Código comentado abaixo é para ler as páginas ao redor da página atual do chunk # page_content = "" # for i in range( # max(0, chunk.page_number - 1), # min(len(single_page_text), chunk.page_number + 2), # ): # page_content += single_page_text[i].page_content if single_page_text[i] else ""