from utils import MainState, generate_uuid, llm
from langchain_core.messages import AIMessage, ToolMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import StateGraph, START, END
import re
def get_graph(retriever):
def retriever_node(state: MainState):
return {
'question': state['question'],
'scratchpad': state['scratchpad'] + [ToolMessage(content=retriever.invoke(state['question'].content),
tool_call_id=state['scratchpad'][-1].tool_call_id)],
'answer': state['answer'],
'next_node': 'model_node',
'history': state['history']
}
import re
def model_node(state: MainState):
prompt = ChatPromptTemplate.from_template(
"""
Você é um assistente de IA chamado DocAI. Responda à pergunta abaixo da forma mais precisa possível.
Caso não tenha informações para responder à pergunte **retorne apenas** uma resposta no seguinte formato:
retriever,
ao fazer isso a task será repassada para um agente que irá complementar as informações.
Se a pergunta puder ser respondida sem acessar documentos enviados, forneça uma resposta **concisa e objetiva**, com no máximo três sentenças.
### Contexto:
- Bloco de Notas: {scratchpad}
- Histórico de Conversas: {chat_history}
**Pergunta:** {question}
"""
)
if isinstance(state['question'], str):
state['question'] = HumanMessage(content=state['question'])
qa_chain = prompt | llm
response = qa_chain.invoke({'question': state['question'].content,
'scratchpad': state['scratchpad'],
'chat_history': [
f'AI: {msg.content}' if isinstance(msg, AIMessage) else f'Human: {msg.content}'
for msg in state['history']],
})
if '' in response.content:
return {
'question': state['question'],
'scratchpad': state['scratchpad'] + [AIMessage(content='', tool_call_id=generate_uuid())] if state[
'scratchpad'] else [AIMessage(content='', tool_call_id=generate_uuid())],
'answer': state['answer'],
'next_node': 'retriever',
'history': state['history']
}
# print(state['scratchpad'])
return {
'question': state['question'],
'scratchpad': state['scratchpad'],
'answer': response,
'next_node': END,
'history': state['history'] + [HumanMessage(content=state['question'].content), response]
}
def next_node(state: MainState):
return state['next_node']
graph = StateGraph(MainState)
graph.add_node('model', model_node)
graph.add_node('retriever', retriever_node)
graph.add_edge(START, 'model')
graph.add_edge('retriever', 'model')
graph.add_conditional_edges('model', next_node)
chain = graph.compile()
return chain