SamuelM0422 commited on
Commit
46a68f8
·
verified ·
1 Parent(s): 2812ca8

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +103 -0
  2. graph.py +81 -0
  3. utils.py +33 -0
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain_community.document_loaders import PyPDFLoader
3
+ from langchain_core.messages import HumanMessage, AIMessageChunk, AIMessage
4
+ from langchain_huggingface import HuggingFaceEmbeddings
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain_core.vectorstores import InMemoryVectorStore
7
+ import os
8
+ from langchain_core.chat_history import InMemoryChatMessageHistory, BaseChatMessageHistory
9
+ import time
10
+ from graph import get_graph
11
+
12
+ if 'read_file' not in st.session_state:
13
+ st.session_state.read_file = False
14
+ st.session_state.retriever = None
15
+
16
+ if 'chat_history' not in st.session_state:
17
+ st.session_state.chat_history = {}
18
+ st.session_state.first_msg = True
19
+
20
+ def get_session_by_id(session_id: str) -> BaseChatMessageHistory:
21
+ if session_id not in st.session_state.chat_history:
22
+ st.session_state.chat_history[session_id] = InMemoryChatMessageHistory()
23
+ return st.session_state.chat_history[session_id]
24
+ return st.session_state.chat_history[session_id]
25
+
26
+ if not st.session_state.read_file:
27
+ st.title('🤓 Upload your PDF to talk with it', anchor=False)
28
+ file = st.file_uploader('Upload a PDF file', type='pdf')
29
+ if file:
30
+ with st.status('🤗 Booting up the things!', expanded=True):
31
+ with st.spinner('📁 Uploading the PDF...', show_time=True):
32
+ with open('file.pdf', 'wb') as f:
33
+ f.write(file.read())
34
+ loader = PyPDFLoader('file.pdf')
35
+ documents = loader.load_and_split(RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=200))
36
+ st.success('📁 File uploaded successfully!!!')
37
+ with st.spinner('🧐 Reading the file...', show_time=True):
38
+ vstore = InMemoryVectorStore.from_documents(documents, HuggingFaceEmbeddings(model_name='all-MiniLM-L6-v2'))
39
+ st.session_state.retriever = vstore.as_retriever()
40
+ st.success('🧐 File read successfully!!!')
41
+ os.remove('file.pdf')
42
+ with st.spinner('😴 Waking up the LLM...', show_time=True):
43
+ st.session_state.graph = get_graph(st.session_state.retriever)
44
+ st.success('😁 LLM awakened!!!')
45
+ st.balloons()
46
+ placeholder = st.empty()
47
+ for _ in range(5, -1, -1):
48
+ placeholder.write(f'⏳ Chat starting in 0{_} sec.')
49
+ time.sleep(1)
50
+ st.session_state.read_file = True
51
+ st.rerun()
52
+
53
+ if st.session_state.read_file:
54
+
55
+ st.title('🤗 DocAI', anchor=False)
56
+ st.subheader('Chat with your document!', anchor=False)
57
+
58
+ if st.session_state.first_msg:
59
+ st.session_state.first_msg = False
60
+ get_session_by_id('chat42').add_message(AIMessage(content='Hello, how are you? How about we talk about the '
61
+ 'document you sent me to read?'))
62
+
63
+ for msg in get_session_by_id('chat42').messages:
64
+ with st.chat_message(name='user' if isinstance(msg, HumanMessage) else 'ai'):
65
+ st.write(msg.content)
66
+
67
+ prompt = st.chat_input('Try to ask something about your file!')
68
+ if prompt:
69
+ with st.chat_message(name='user'):
70
+ st.write(prompt)
71
+
72
+ response = st.session_state.graph.stream(
73
+ {
74
+ 'question': prompt,
75
+ 'scratchpad': None,
76
+ 'answer': None,
77
+ 'next_node': None,
78
+ 'history': get_session_by_id('chat42').messages,
79
+ },
80
+ stream_mode='messages'
81
+ )
82
+
83
+ get_session_by_id('chat42').add_message(HumanMessage(content=prompt))
84
+
85
+ def get_message():
86
+ for chunk, _ in response:
87
+ if chunk.content and isinstance(chunk, AIMessageChunk):
88
+ yield chunk.content
89
+
90
+ with st.chat_message(name='ai'):
91
+ full_response = ''
92
+ placeholder = st.empty()
93
+
94
+ for msg in get_message():
95
+ full_response += msg
96
+ if '</tool>' in full_response:
97
+ full_response = ''
98
+ continue
99
+ if '<tool>' in full_response:
100
+ continue
101
+ placeholder.write(full_response)
102
+ print(full_response)
103
+ get_session_by_id('chat42').add_message(AIMessage(content=full_response))
graph.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import MainState, generate_uuid, llm
2
+ from langchain_core.messages import AIMessage, ToolMessage, HumanMessage
3
+ from langchain_core.prompts import ChatPromptTemplate
4
+ from langgraph.graph import StateGraph, START, END
5
+ import re
6
+
7
+ def get_graph(retriever):
8
+ def retriever_node(state: MainState):
9
+ return {
10
+ 'question': state['question'],
11
+ 'scratchpad': state['scratchpad'] + [ToolMessage(content=retriever.invoke(state['question'].content),
12
+ tool_call_id=state['scratchpad'][-1].tool_call_id)],
13
+ 'answer': state['answer'],
14
+ 'next_node': 'model_node',
15
+ 'history': state['history']
16
+ }
17
+
18
+ import re
19
+
20
+ def model_node(state: MainState):
21
+ prompt = ChatPromptTemplate.from_template(
22
+ """
23
+ Você é um assistente de IA. Responda à pergunta abaixo da forma mais precisa possível.
24
+
25
+ Caso não tenha informações para responder à pergunte **retorne apenas** uma resposta no seguinte formato:
26
+ <tool>retriever</tool>,
27
+ ao fazer isso a task será repassada para um agente que irá complementar as informações.
28
+ Se a pergunta puder ser respondida sem acessar documentos enviados, forneça uma resposta **concisa e objetiva**, com no máximo três sentenças.
29
+
30
+ ### Contexto:
31
+ - Bloco de Notas: {scratchpad}
32
+ - Histórico de Conversas: {chat_history}
33
+
34
+ **Pergunta:** {question}
35
+ """
36
+ )
37
+
38
+ if isinstance(state['question'], str):
39
+ state['question'] = HumanMessage(content=state['question'])
40
+
41
+ qa_chain = prompt | llm
42
+
43
+ response = qa_chain.invoke({'question': state['question'].content,
44
+ 'scratchpad': state['scratchpad'],
45
+ 'chat_history': [
46
+ f'AI: {msg.content}' if isinstance(msg, AIMessage) else f'Human: {msg.content}'
47
+ for msg in state['history']],
48
+ })
49
+
50
+ if '<tool>' in response.content:
51
+ return {
52
+ 'question': state['question'],
53
+ 'scratchpad': state['scratchpad'] + [AIMessage(content='', tool_call_id=generate_uuid())] if state[
54
+ 'scratchpad'] else [AIMessage(content='', tool_call_id=generate_uuid())],
55
+ 'answer': state['answer'],
56
+ 'next_node': 'retriever',
57
+ 'history': state['history']
58
+ }
59
+
60
+ # print(state['scratchpad'])
61
+ return {
62
+ 'question': state['question'],
63
+ 'scratchpad': state['scratchpad'],
64
+ 'answer': response,
65
+ 'next_node': END,
66
+ 'history': state['history'] + [HumanMessage(content=state['question'].content), response]
67
+ }
68
+
69
+ def next_node(state: MainState):
70
+ return state['next_node']
71
+
72
+ graph = StateGraph(MainState)
73
+ graph.add_node('model', model_node)
74
+ graph.add_node('retriever', retriever_node)
75
+ graph.add_edge(START, 'model')
76
+ graph.add_edge('retriever', 'model')
77
+ graph.add_conditional_edges('model', next_node)
78
+
79
+ chain = graph.compile()
80
+
81
+ return chain
utils.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ from typing import TypedDict
3
+ from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
4
+ from langchain_openai import ChatOpenAI
5
+ import os
6
+ import re
7
+ from dotenv import load_dotenv
8
+
9
+ load_dotenv()
10
+
11
+ class MainState(TypedDict):
12
+ question: HumanMessage | str| None
13
+ scratchpad: list[AIMessage | ToolMessage] | None
14
+ answer: AIMessage | None
15
+ next_node: str | None
16
+ history: list[HumanMessage | AIMessage]
17
+
18
+ llm = ChatOpenAI(
19
+ model="gpt-4o-mini",
20
+ temperature=0,
21
+ api_key=os.environ.get('OPENAI_API_KEY'), # Insira sua chave aqui
22
+ )
23
+
24
+ def generate_uuid():
25
+ return str(uuid.uuid4())
26
+
27
+ def post_process(message: AIMessage) -> AIMessage:
28
+ matches = re.findall(r"\[SOT\](.*)\[EOT\]", message.content, re.DOTALL)
29
+ matches = matches[0] if matches else None
30
+
31
+ if matches:
32
+ return AIMessage(content='', additional_kwargs={'custom_tool_call': matches} ,tool_call_id=generate_uuid())
33
+ return AIMessage(content=message.content)