ArturG9 commited on
Commit
97134b4
·
verified ·
1 Parent(s): 05127d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -96
app.py CHANGED
@@ -16,59 +16,68 @@ from utills import load_txt_documents, split_docs, load_uploaded_documents, retr
16
  from langchain.text_splitter import TokenTextSplitter, RecursiveCharacterTextSplitter
17
  from langchain_community.document_loaders.directory import DirectoryLoader
18
 
19
- script_dir = os.path.dirname(os.path.abspath(__file__))
20
- data_path = os.path.join(script_dir, "data/")
21
- model_path = os.path.join(script_dir, 'qwen2-0_5b-instruct-q4_0.gguf')
22
- store = {}
23
 
24
- model_name = "sentence-transformers/all-mpnet-base-v2"
25
- model_kwargs = {'device': 'cpu'}
26
- encode_kwargs = {'normalize_embeddings': True}
27
 
28
- hf = HuggingFaceEmbeddings(
29
- model_name=model_name,
30
- model_kwargs=model_kwargs,
31
- encode_kwargs=encode_kwargs
32
- )
33
 
34
- def get_vectorstore(text_chunks):
35
- model_name = "sentence-transformers/all-mpnet-base-v2"
36
- model_kwargs = {'device': 'cpu'}
37
- encode_kwargs = {'normalize_embeddings': True}
38
- hf = HuggingFaceEmbeddings(
39
- model_name=model_name,
40
- model_kwargs=model_kwargs,
41
- encode_kwargs=encode_kwargs
42
- )
 
 
 
 
43
 
44
- vectorstore = Chroma.from_documents(documents=text_chunks, embedding=hf, persist_directory="docs/chroma/")
45
- return vectorstore
 
46
 
47
  def get_pdf_text(pdf_docs):
48
- document_loader = DirectoryLoader(pdf_docs)
49
- return document_loader.load()
 
 
 
 
50
 
51
  def get_text_chunks(text):
52
  text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
53
- separator="\n",
54
- chunk_size=1000,
55
- chunk_overlap=200,
56
- length_function=len
57
  )
58
  chunks = text_splitter.split_text(text)
59
  return chunks
60
 
61
- def create_conversational_rag_chain(vectorstore):
62
-
63
- script_dir = os.path.dirname(os.path.abspath(__file__))
64
- model_path = os.path.join(script_dir, 'qwen2-0_5b-instruct-q4_0.gguf')
65
-
66
- retriever = vectorstore.as_retriever(search_type='mmr', search_kwargs={"k": 7})
 
 
 
 
 
 
67
 
 
68
  callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
69
-
70
  llm = llamacpp.LlamaCpp(
71
- model_path=os.path.join(model_path),
72
  n_gpu_layers=1,
73
  temperature=0.1,
74
  top_p=0.9,
@@ -79,14 +88,26 @@ def create_conversational_rag_chain(vectorstore):
79
  verbose=False,
80
  )
81
 
82
- contextualize_q_system_prompt = """Given a context, chat history and the latest user question
83
- which maybe reference context in the chat history, formulate a standalone question
 
84
  which can be understood without the chat history. Do NOT answer the question,
85
  just reformulate it if needed and otherwise return it as is."""
86
 
87
- ha_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_system_prompt)
 
 
 
 
 
 
 
 
88
 
89
- qa_system_prompt = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Be as informative as possible, be polite and formal.\n{context}"""
 
 
 
90
 
91
  qa_prompt = ChatPromptTemplate.from_messages(
92
  [
@@ -95,12 +116,12 @@ def create_conversational_rag_chain(vectorstore):
95
  ("human", "{input}"),
96
  ]
97
  )
98
-
99
  question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
100
 
101
- rag_chain = create_retrieval_chain(ha_retriever, question_answer_chain)
102
  msgs = StreamlitChatMessageHistory(key="special_app_key")
103
-
104
  conversation_chain = RunnableWithMessageHistory(
105
  rag_chain,
106
  lambda session_id: msgs,
@@ -110,59 +131,17 @@ def create_conversational_rag_chain(vectorstore):
110
  )
111
  return conversation_chain
112
 
113
- def main():
114
- """Main function for the Streamlit app."""
115
- # Initialize chat history if not already present in session state
116
-
117
- documents = []
118
-
119
- script_dir = os.path.dirname(os.path.abspath(__file__))
120
- data_path = os.path.join(script_dir, "data/")
121
- for filename in os.listdir(data_path):
122
-
123
- if filename.endswith('.txt'):
124
-
125
- file_path = os.path.join(data_path, filename)
126
-
127
- documents = TextLoader(file_path).load()
128
-
129
- documents.extend(documents)
130
-
131
-
132
- docs = split_docs(documents, 350, 40)
133
-
134
-
135
-
136
-
137
-
138
-
139
-
140
- vectorstore = get_vectorstore(docs)
141
 
142
- msgs = st.session_state.get("chat_history", StreamlitChatMessageHistory(key="special_app_key"))
143
- chain_with_history = create_conversational_rag_chain(vectorstore)
144
-
145
- st.title("Conversational RAG Chatbot")
146
-
147
- if prompt := st.chat_input():
148
- st.chat_message("human").write(prompt)
149
-
150
- # Prepare the input dictionary with the correct keys
151
- input_dict = {"input": prompt, "chat_history": msgs.messages}
152
- config = {"configurable": {"session_id": "any"}}
153
-
154
- # Process user input and handle response
155
- response = chain_with_history.invoke(input_dict, config)
156
  st.chat_message("ai").write(response["answer"])
 
 
157
 
158
- # Display retrieved documents (if any and present in response)
159
- if "docs" in response and response["documents"]:
160
- for index, doc in enumerate(response["documents"]):
161
- with st.expander(f"Document {index + 1}"):
162
- st.write(doc)
163
-
164
- # Update chat history in session state
165
- st.session_state["chat_history"] = msgs
166
-
167
- if __name__ == "__main__":
168
  main()
 
16
  from langchain.text_splitter import TokenTextSplitter, RecursiveCharacterTextSplitter
17
  from langchain_community.document_loaders.directory import DirectoryLoader
18
 
19
+ def main():
20
+ st.set_page_config(page_title="Chat with multiple PDFs", page_icon=":books:")
21
+ st.header("Chat with multiple PDFs :books:")
 
22
 
23
+ if "pdf_docs" not in st.session_state:
24
+ st.session_state.pdf_docs = []
 
25
 
26
+ if "conversation_chain" not in st.session_state:
27
+ st.session_state.conversation_chain = None
 
 
 
28
 
29
+ with st.sidebar:
30
+ st.subheader("Your documents")
31
+ pdf_docs = st.file_uploader("Upload your PDFs here and click on 'Process'", accept_multiple_files=True)
32
+ if pdf_docs:
33
+ st.session_state.pdf_docs.extend(pdf_docs)
34
+
35
+ if st.button("Process"):
36
+ with st.spinner("Processing"):
37
+ raw_text = get_pdf_text(st.session_state.pdf_docs)
38
+ text_chunks = get_text_chunks(raw_text)
39
+ vectorstore = get_vectorstore(text_chunks)
40
+ st.session_state.conversation_chain = get_conversation_chain(vectorstore)
41
+ st.success("Documents processed and conversation chain created successfully.")
42
 
43
+ user_question = st.text_input("Ask a question about your documents:")
44
+ if user_question:
45
+ handle_userinput(st.session_state.conversation_chain, user_question)
46
 
47
  def get_pdf_text(pdf_docs):
48
+ text = ""
49
+ for pdf in pdf_docs:
50
+ pdf_reader = PdfReader(pdf)
51
+ for page in pdf_reader.pages:
52
+ text += page.extract_text()
53
+ return text
54
 
55
  def get_text_chunks(text):
56
  text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
57
+ chunk_size=600, chunk_overlap=50,
58
+ separators=["\n \n \n", "\n \n", "\n1" , "(?<=\. )", " ", ""],
 
 
59
  )
60
  chunks = text_splitter.split_text(text)
61
  return chunks
62
 
63
+ def get_vectorstore(text_chunks):
64
+ model_name = "sentence-transformers/all-mpnet-base-v2"
65
+ model_kwargs = {'device': 'cpu'}
66
+ encode_kwargs = {'normalize_embeddings': True}
67
+ embeddings = HuggingFaceEmbeddings(
68
+ model_name=model_name,
69
+ model_kwargs=model_kwargs,
70
+ encode_kwargs=encode_kwargs
71
+ )
72
+ vectorstore = Chroma.from_texts(
73
+ texts=text_chunks, embedding=embeddings, persist_directory="docs/chroma/")
74
+ return vectorstore
75
 
76
+ def get_conversation_chain(vectorstore):
77
  callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
78
+ script_dir = os.path.dirname(os.path.abspath(__file__))
79
  llm = llamacpp.LlamaCpp(
80
+ model_path = os.path.join(script_dir, 'qwen2-0_5b-instruct-q4_0.gguf'),
81
  n_gpu_layers=1,
82
  temperature=0.1,
83
  top_p=0.9,
 
88
  verbose=False,
89
  )
90
 
91
+ retriever = vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 7})
92
+
93
+ contextualize_q_system_prompt = """Given a context, chat history and the latest user question, formulate a standalone question
94
  which can be understood without the chat history. Do NOT answer the question,
95
  just reformulate it if needed and otherwise return it as is."""
96
 
97
+ contextualize_q_prompt = ChatPromptTemplate.from_messages(
98
+ [
99
+ ("system", contextualize_q_system_prompt),
100
+ MessagesPlaceholder("chat_history"),
101
+ ("human", "{input}"),
102
+ ]
103
+ )
104
+
105
+ history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)
106
 
107
+ qa_system_prompt = """Use the following pieces of retrieved context to answer the question{input}. \
108
+ Be informative but don't make too long answers, be polite and formal. \
109
+ If you don't know the answer, say "I don't know the answer." \
110
+ {context}"""
111
 
112
  qa_prompt = ChatPromptTemplate.from_messages(
113
  [
 
116
  ("human", "{input}"),
117
  ]
118
  )
119
+
120
  question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
121
 
122
+ rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
123
  msgs = StreamlitChatMessageHistory(key="special_app_key")
124
+
125
  conversation_chain = RunnableWithMessageHistory(
126
  rag_chain,
127
  lambda session_id: msgs,
 
131
  )
132
  return conversation_chain
133
 
134
+ def handle_userinput(conversation_chain, prompt):
135
+ msgs = StreamlitChatMessageHistory(key="special_app_key")
136
+ st.chat_message("human").write(prompt)
137
+ input_dict = {"input": prompt, "chat_history": msgs.messages}
138
+ config = {"configurable": {"session_id": 1}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
+ try:
141
+ response = conversation_chain.invoke(input_dict, config)
 
 
 
 
 
 
 
 
 
 
 
 
142
  st.chat_message("ai").write(response["answer"])
143
+ except Exception as e:
144
+ st.error(f"Error invoking conversation chain: {e}")
145
 
146
+ if __name__ == '__main__':
 
 
 
 
 
 
 
 
 
147
  main()