Update app.py
Browse files
app.py
CHANGED
@@ -16,59 +16,68 @@ from utills import load_txt_documents, split_docs, load_uploaded_documents, retr
|
|
16 |
from langchain.text_splitter import TokenTextSplitter, RecursiveCharacterTextSplitter
|
17 |
from langchain_community.document_loaders.directory import DirectoryLoader
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
store = {}
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
encode_kwargs = {'normalize_embeddings': True}
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
model_kwargs=model_kwargs,
|
31 |
-
encode_kwargs=encode_kwargs
|
32 |
-
)
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
43 |
|
44 |
-
|
45 |
-
|
|
|
46 |
|
47 |
def get_pdf_text(pdf_docs):
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
50 |
|
51 |
def get_text_chunks(text):
|
52 |
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
53 |
-
|
54 |
-
|
55 |
-
chunk_overlap=200,
|
56 |
-
length_function=len
|
57 |
)
|
58 |
chunks = text_splitter.split_text(text)
|
59 |
return chunks
|
60 |
|
61 |
-
def
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
|
|
68 |
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
69 |
-
|
70 |
llm = llamacpp.LlamaCpp(
|
71 |
-
model_path=os.path.join(
|
72 |
n_gpu_layers=1,
|
73 |
temperature=0.1,
|
74 |
top_p=0.9,
|
@@ -79,14 +88,26 @@ def create_conversational_rag_chain(vectorstore):
|
|
79 |
verbose=False,
|
80 |
)
|
81 |
|
82 |
-
|
83 |
-
|
|
|
84 |
which can be understood without the chat history. Do NOT answer the question,
|
85 |
just reformulate it if needed and otherwise return it as is."""
|
86 |
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
-
qa_system_prompt = """
|
|
|
|
|
|
|
90 |
|
91 |
qa_prompt = ChatPromptTemplate.from_messages(
|
92 |
[
|
@@ -95,12 +116,12 @@ def create_conversational_rag_chain(vectorstore):
|
|
95 |
("human", "{input}"),
|
96 |
]
|
97 |
)
|
98 |
-
|
99 |
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
|
100 |
|
101 |
-
rag_chain = create_retrieval_chain(
|
102 |
msgs = StreamlitChatMessageHistory(key="special_app_key")
|
103 |
-
|
104 |
conversation_chain = RunnableWithMessageHistory(
|
105 |
rag_chain,
|
106 |
lambda session_id: msgs,
|
@@ -110,59 +131,17 @@ def create_conversational_rag_chain(vectorstore):
|
|
110 |
)
|
111 |
return conversation_chain
|
112 |
|
113 |
-
def
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
script_dir = os.path.dirname(os.path.abspath(__file__))
|
120 |
-
data_path = os.path.join(script_dir, "data/")
|
121 |
-
for filename in os.listdir(data_path):
|
122 |
-
|
123 |
-
if filename.endswith('.txt'):
|
124 |
-
|
125 |
-
file_path = os.path.join(data_path, filename)
|
126 |
-
|
127 |
-
documents = TextLoader(file_path).load()
|
128 |
-
|
129 |
-
documents.extend(documents)
|
130 |
-
|
131 |
-
|
132 |
-
docs = split_docs(documents, 350, 40)
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
vectorstore = get_vectorstore(docs)
|
141 |
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
st.title("Conversational RAG Chatbot")
|
146 |
-
|
147 |
-
if prompt := st.chat_input():
|
148 |
-
st.chat_message("human").write(prompt)
|
149 |
-
|
150 |
-
# Prepare the input dictionary with the correct keys
|
151 |
-
input_dict = {"input": prompt, "chat_history": msgs.messages}
|
152 |
-
config = {"configurable": {"session_id": "any"}}
|
153 |
-
|
154 |
-
# Process user input and handle response
|
155 |
-
response = chain_with_history.invoke(input_dict, config)
|
156 |
st.chat_message("ai").write(response["answer"])
|
|
|
|
|
157 |
|
158 |
-
|
159 |
-
if "docs" in response and response["documents"]:
|
160 |
-
for index, doc in enumerate(response["documents"]):
|
161 |
-
with st.expander(f"Document {index + 1}"):
|
162 |
-
st.write(doc)
|
163 |
-
|
164 |
-
# Update chat history in session state
|
165 |
-
st.session_state["chat_history"] = msgs
|
166 |
-
|
167 |
-
if __name__ == "__main__":
|
168 |
main()
|
|
|
16 |
from langchain.text_splitter import TokenTextSplitter, RecursiveCharacterTextSplitter
|
17 |
from langchain_community.document_loaders.directory import DirectoryLoader
|
18 |
|
19 |
+
def main():
|
20 |
+
st.set_page_config(page_title="Chat with multiple PDFs", page_icon=":books:")
|
21 |
+
st.header("Chat with multiple PDFs :books:")
|
|
|
22 |
|
23 |
+
if "pdf_docs" not in st.session_state:
|
24 |
+
st.session_state.pdf_docs = []
|
|
|
25 |
|
26 |
+
if "conversation_chain" not in st.session_state:
|
27 |
+
st.session_state.conversation_chain = None
|
|
|
|
|
|
|
28 |
|
29 |
+
with st.sidebar:
|
30 |
+
st.subheader("Your documents")
|
31 |
+
pdf_docs = st.file_uploader("Upload your PDFs here and click on 'Process'", accept_multiple_files=True)
|
32 |
+
if pdf_docs:
|
33 |
+
st.session_state.pdf_docs.extend(pdf_docs)
|
34 |
+
|
35 |
+
if st.button("Process"):
|
36 |
+
with st.spinner("Processing"):
|
37 |
+
raw_text = get_pdf_text(st.session_state.pdf_docs)
|
38 |
+
text_chunks = get_text_chunks(raw_text)
|
39 |
+
vectorstore = get_vectorstore(text_chunks)
|
40 |
+
st.session_state.conversation_chain = get_conversation_chain(vectorstore)
|
41 |
+
st.success("Documents processed and conversation chain created successfully.")
|
42 |
|
43 |
+
user_question = st.text_input("Ask a question about your documents:")
|
44 |
+
if user_question:
|
45 |
+
handle_userinput(st.session_state.conversation_chain, user_question)
|
46 |
|
47 |
def get_pdf_text(pdf_docs):
|
48 |
+
text = ""
|
49 |
+
for pdf in pdf_docs:
|
50 |
+
pdf_reader = PdfReader(pdf)
|
51 |
+
for page in pdf_reader.pages:
|
52 |
+
text += page.extract_text()
|
53 |
+
return text
|
54 |
|
55 |
def get_text_chunks(text):
|
56 |
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
57 |
+
chunk_size=600, chunk_overlap=50,
|
58 |
+
separators=["\n \n \n", "\n \n", "\n1" , "(?<=\. )", " ", ""],
|
|
|
|
|
59 |
)
|
60 |
chunks = text_splitter.split_text(text)
|
61 |
return chunks
|
62 |
|
63 |
+
def get_vectorstore(text_chunks):
|
64 |
+
model_name = "sentence-transformers/all-mpnet-base-v2"
|
65 |
+
model_kwargs = {'device': 'cpu'}
|
66 |
+
encode_kwargs = {'normalize_embeddings': True}
|
67 |
+
embeddings = HuggingFaceEmbeddings(
|
68 |
+
model_name=model_name,
|
69 |
+
model_kwargs=model_kwargs,
|
70 |
+
encode_kwargs=encode_kwargs
|
71 |
+
)
|
72 |
+
vectorstore = Chroma.from_texts(
|
73 |
+
texts=text_chunks, embedding=embeddings, persist_directory="docs/chroma/")
|
74 |
+
return vectorstore
|
75 |
|
76 |
+
def get_conversation_chain(vectorstore):
|
77 |
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
78 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
79 |
llm = llamacpp.LlamaCpp(
|
80 |
+
model_path = os.path.join(script_dir, 'qwen2-0_5b-instruct-q4_0.gguf'),
|
81 |
n_gpu_layers=1,
|
82 |
temperature=0.1,
|
83 |
top_p=0.9,
|
|
|
88 |
verbose=False,
|
89 |
)
|
90 |
|
91 |
+
retriever = vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 7})
|
92 |
+
|
93 |
+
contextualize_q_system_prompt = """Given a context, chat history and the latest user question, formulate a standalone question
|
94 |
which can be understood without the chat history. Do NOT answer the question,
|
95 |
just reformulate it if needed and otherwise return it as is."""
|
96 |
|
97 |
+
contextualize_q_prompt = ChatPromptTemplate.from_messages(
|
98 |
+
[
|
99 |
+
("system", contextualize_q_system_prompt),
|
100 |
+
MessagesPlaceholder("chat_history"),
|
101 |
+
("human", "{input}"),
|
102 |
+
]
|
103 |
+
)
|
104 |
+
|
105 |
+
history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)
|
106 |
|
107 |
+
qa_system_prompt = """Use the following pieces of retrieved context to answer the question{input}. \
|
108 |
+
Be informative but don't make too long answers, be polite and formal. \
|
109 |
+
If you don't know the answer, say "I don't know the answer." \
|
110 |
+
{context}"""
|
111 |
|
112 |
qa_prompt = ChatPromptTemplate.from_messages(
|
113 |
[
|
|
|
116 |
("human", "{input}"),
|
117 |
]
|
118 |
)
|
119 |
+
|
120 |
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
|
121 |
|
122 |
+
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
|
123 |
msgs = StreamlitChatMessageHistory(key="special_app_key")
|
124 |
+
|
125 |
conversation_chain = RunnableWithMessageHistory(
|
126 |
rag_chain,
|
127 |
lambda session_id: msgs,
|
|
|
131 |
)
|
132 |
return conversation_chain
|
133 |
|
134 |
+
def handle_userinput(conversation_chain, prompt):
|
135 |
+
msgs = StreamlitChatMessageHistory(key="special_app_key")
|
136 |
+
st.chat_message("human").write(prompt)
|
137 |
+
input_dict = {"input": prompt, "chat_history": msgs.messages}
|
138 |
+
config = {"configurable": {"session_id": 1}}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
+
try:
|
141 |
+
response = conversation_chain.invoke(input_dict, config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
st.chat_message("ai").write(response["answer"])
|
143 |
+
except Exception as e:
|
144 |
+
st.error(f"Error invoking conversation chain: {e}")
|
145 |
|
146 |
+
if __name__ == '__main__':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
main()
|