import os
from _utils.langchain_utils.LLM_class import LLM
from typing import Any, List, Dict, Tuple, Optional, Union, cast
from anthropic import Anthropic, AsyncAnthropic
import logging
from langchain.schema import Document
from llama_index import Document as Llama_Index_Document
import asyncio
from typing import List
from dataclasses import dataclass
from _utils.gerar_documento_utils.llm_calls import (
aclaude_answer,
agemini_answer,
agpt_answer,
)
from _utils.gerar_documento_utils.prompts import contextual_prompt
from _utils.models.gerar_documento import (
ContextualizedChunk,
DocumentChunk,
RetrievalConfig,
)
from langchain_core.messages import HumanMessage
from gerar_documento.serializer import (
GerarDocumentoComPDFProprioSerializerData,
GerarDocumentoSerializerData,
)
from setup.logging import Axiom
import re
class ContextualRetriever:
def __init__(
self,
serializer: Union[
GerarDocumentoSerializerData, GerarDocumentoComPDFProprioSerializerData, Any
],
):
self.lista_contador = []
self.contextual_retriever_utils = ContextualRetrieverUtils()
self.config = RetrievalConfig(
num_chunks=serializer.num_chunks_retrieval,
embedding_weight=serializer.embedding_weight,
bm25_weight=serializer.bm25_weight,
context_window=serializer.context_window,
chunk_overlap=serializer.chunk_overlap,
)
self.logger = logging.getLogger(__name__)
self.bm25 = None
self.claude_context_model = serializer.claude_context_model
self.claude_api_key = os.environ.get("CLAUDE_API_KEY", "")
self.claude_client = AsyncAnthropic(api_key=self.claude_api_key)
async def llm_call_uma_lista_de_20_chunks(
self,
lista_com_20_chunks: List[DocumentChunk],
resumo_auxiliar,
axiom_instance: Axiom,
) -> List[List[str]]:
all_chunks_contents, all_document_ids = (
self.contextual_retriever_utils.get_all_document_ids_and_contents(
lista_com_20_chunks
)
)
send_axiom = axiom_instance.send_axiom
utils = self.contextual_retriever_utils
try:
prompt = contextual_prompt(
resumo_auxiliar, all_chunks_contents, len(lista_com_20_chunks)
)
result = None
for attempt in range(4):
if attempt != 0:
send_axiom(
f"------------- FORMATAÇÃO DO CONTEXTUAL INCORRETA - TENTANDO NOVAMENTE (TENTATIVA: {attempt + 1}) -------------"
)
send_axiom(
f"COMEÇANDO UMA REQUISIÇÃO DO CONTEXTUAL - TENTATIVA {attempt + 1}"
)
raw_response = await agemini_answer(prompt, "gemini-2.0-flash-lite")
response = cast(str, raw_response)
send_axiom(
f"TERMINOU UMA REQUISIÇÃO DO CONTEXTUAL - TENTATIVA {attempt + 1}"
)
matches = utils.validate_many_chunks_in_one_request(
response, all_document_ids
)
if matches:
send_axiom(
f"VALIDAÇÃO DO CONTEXTUAL FUNCIONOU NA TENTATIVA {attempt + 1} (ou seja, a função validate_many_chunks_in_one_request)"
)
result = utils.get_info_from_validated_chunks(matches)
break
if result is None:
axiom_instance.send_axiom(
f"-------------- UMA LISTA DE 20 CHUNKS FALHOU AS 4x NA FORMATAÇÃO ------------- ÚLTIMO RETORNO ERRADO: {response}"
)
result = [[""]] # default value if no iteration succeeded
return result
except Exception as e:
self.logger.error(f"Context generation failed for chunks .... : {str(e)}")
return [[""]]
async def contextualize_uma_lista_de_20_chunks(
self,
lista_com_20_chunks: List[DocumentChunk],
response_auxiliar_summary,
axiom_instance: Axiom,
):
self.lista_contador.append(0)
print("contador: ", len(self.lista_contador))
result = await self.llm_call_uma_lista_de_20_chunks(
lista_com_20_chunks, response_auxiliar_summary, axiom_instance
)
lista_chunks: List[ContextualizedChunk] = []
try:
for index, chunk in enumerate(lista_com_20_chunks):
lista_chunks.append(
ContextualizedChunk(
contextual_summary=result[index][2],
content=chunk.content,
page_number=chunk.page_number,
id_do_processo=int(result[index][0]),
chunk_id=chunk.chunk_id,
start_char=chunk.start_char,
end_char=chunk.end_char,
context=result[index][1],
)
)
except BaseException as e:
axiom_instance.send_axiom(
f"ERRO EM UMA LISTA COM 20 CHUNKS CONTEXTUALS --------- lista: {lista_com_20_chunks} ------------ ERRO: {e}"
)
return lista_chunks
async def contextualize_all_chunks(
self,
all_PDFs_chunks: List[DocumentChunk],
response_auxiliar_summary,
axiom_instance: Axiom,
) -> List[ContextualizedChunk]:
"""Add context to all chunks"""
lista_de_listas_cada_com_20_chunks = (
self.contextual_retriever_utils.get_lista_de_listas_cada_com_20_chunks(
all_PDFs_chunks
)
)
async with asyncio.TaskGroup() as tg:
def processa_uma_lista_de_20_chunks(
lista_com_20_chunks: List[DocumentChunk],
):
coroutine = self.contextualize_uma_lista_de_20_chunks(
lista_com_20_chunks, response_auxiliar_summary, axiom_instance
)
return tg.create_task(coroutine)
tasks = [
processa_uma_lista_de_20_chunks(lista_com_20_chunks)
for lista_com_20_chunks in lista_de_listas_cada_com_20_chunks
]
contextualized_chunks: List[ContextualizedChunk] = []
for task in tasks:
contextualized_chunks = contextualized_chunks + task.result()
axiom_instance.send_axiom(
"TERMINOU COM SUCESSO DE PROCESSAR TUDO DOS CONTEXTUALS"
)
return contextualized_chunks
@dataclass
class ContextualRetrieverUtils:
def get_all_document_ids_and_contents(
self, lista_com_20_chunks: List[DocumentChunk]
):
contador = 1
all_chunks_contents = ""
all_document_ids = []
for chunk in lista_com_20_chunks:
all_chunks_contents += f"\n\nCHUNK {contador}:\n"
all_chunks_contents += chunk.content
pattern = r"Num\. (\d+)"
import re
match = re.search(pattern, chunk.content)
if match:
number = match.group(1) # Extract the number
else:
number = 0
all_document_ids.append(int(number))
contador += 1
return all_chunks_contents, all_document_ids
def get_info_from_validated_chunks(self, matches):
result = [
[int(doc_id), title.strip(), content.strip()]
for doc_id, title, content in matches
]
return result
def get_lista_de_listas_cada_com_20_chunks(
self, all_PDFs_chunks: List[DocumentChunk]
):
return [all_PDFs_chunks[i : i + 20] for i in range(0, len(all_PDFs_chunks), 20)]
def validate_many_chunks_in_one_request(
self, response: str, lista_de_document_ids: List[int]
):
context = (
response.replace("document_id: ", "")
.replace("document_id:", "")
.replace("DOCUMENT_ID: ", "")
.replace("DOCUMENT_ID: ", "")
)
# pattern = r"\[(\d+|[-.]+)\] --- (.+?) --- (.+?)" # Funciona para quando a resposta do LLM não vem com "document_id" escrito
matches = self.check_regex_patterns(context, lista_de_document_ids)
if not matches or len(matches) == 0:
print(
"----------- ERROU NA TENTATIVA ATUAL DE FORMATAR O CONTEXTUAL -----------"
)
return False
matches_as_list = []
for index, match in enumerate(list(matches)):
if index >= 20:
break
resultado = match[0].replace(".", "").replace("-", "")
resultado = lista_de_document_ids[index]
matches_as_list.append((resultado, match[1], match[2]))
return matches_as_list
def check_regex_patterns(self, context: str, lista_de_document_ids: List[int]):
patterns = [
r"\[(.*?)\] --- \[(.*?)\] --- \[(.*?)\](?=\n|\s*$)",
r"\[([^\[\]]+?)\]\s*---\s*\[([^\[\]]+?)\]\s*---\s*(.*?)",
r"\s*(\d+)(?:\s*-\s*Pág\.\s*\d+)?\s*---\s*([^-\n]+)\s*---\s*([^<]+)",
r"\s*(?:\[*([\d]+)\]*\s*[-–]*\s*(?:Pág\.\s*\d+\s*[-–]*)?)?\s*\[*([^\]]+)\]*\s*[-–]*\s*\[*([^\]]+)\]*\s*[-–]*\s*\[*([^\]]+)\]*\s*",
r"\s*(.*?)\s*---\s*(.*?)\s*---\s*(.*?)\s*",
# -------------- ABAIXO SÃO OS ANTIGOS
# r"\[*([\d.\-]+)\]*\s*---\s*\[*([^]]+)\]*\s*---\s*\[*([^]]+)\]*\s*", # PRIMEIRO DE TODOS
# r"\s*([\d.\-]+)\s*---\s*([^<]+)\s*---\s*([^<]+)\s*",
# r"\[([\d.\-]+)\]\s*---\s*\[([^]]+)\]\s*---\s*\[([^]]+)\]\s*",
# r"\s*\[?([\d.\-]+)\]?\s*---\s*\[?([^\]\[]+?)\]?\s*---\s*\[?([^<]+?)\]?\s*",
# r"\s*\[([\d.\-]+)\]\s*---\s*\[([^\]]+)\]\s*---\s*\[([^\]]+)\]\s*"
# r"\s*\[?([\d.\-\s]+)\]?\s*---\s*\[?([^\]\[]+?)\]?\s*---\s*\[?([\s\S]+?)\]?\s*",
]
resultado = None
for pattern in patterns:
matches: List[str] = re.findall(pattern, context, re.DOTALL)
condition_tuples_3_items = all(len(m) == 3 for m in matches)
if len(matches) == len(lista_de_document_ids) and condition_tuples_3_items:
print("\n--------------- REGEX DO CONTEXTUAL FUNCIONOU")
resultado = []
for m in matches:
regex = r"Num\.\s*(\d+)\s*-"
page_id = re.search(regex, m[0])
if page_id:
first_item = page_id.group(1)
else:
first_item = "0"
resultado.append((first_item, m[1], m[2]))
break
if not resultado:
context = (
context.replace("", "")
.replace("", "")
.strip()
)
raw_chunks = context.split("")[0:20]
resultado_temporario = []
for r in raw_chunks:
lista_3_itens = r.split("---")
page_id = re.search(r"Num\.\s*(\d+)\s*-", lista_3_itens[0].strip())
page_id_tentativa_2 = re.search(
r"\d+\.\s+(\d+)\s+-\s+Pág\.", lista_3_itens[0].strip()
)
if page_id:
first_item = page_id.group(1)
elif page_id_tentativa_2:
first_item = page_id_tentativa_2.group(1)
else:
first_item = "0"
resultado_temporario.append(
(first_item, lista_3_itens[1], lista_3_itens[2])
)
condition_tuples_3_items = all(len(m) == 3 for m in resultado_temporario)
if (
len(resultado_temporario) == len(lista_de_document_ids)
and condition_tuples_3_items
):
resultado = resultado_temporario
return resultado
# Código comentado abaixo é para ler as páginas ao redor da página atual do chunk
# page_content = ""
# for i in range(
# max(0, chunk.page_number - 1),
# min(len(single_page_text), chunk.page_number + 2),
# ):
# page_content += single_page_text[i].page_content if single_page_text[i] else ""