Update app.py
Browse files
app.py
CHANGED
@@ -161,59 +161,67 @@ if "messages" not in st.session_state:
|
|
161 |
# return res['result']
|
162 |
|
163 |
def generate_response(prompt_input):
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
# Format the input for the chain with the retrieved context
|
177 |
-
formatted_input = (
|
178 |
-
f"Context:\n"
|
179 |
-
f"{' '.join([doc.page_content for doc in relevant_context])}\n\n"
|
180 |
-
f"Conversation:\n{conversation_history}"
|
181 |
-
)
|
182 |
-
|
183 |
-
# Invoke the RetrievalQA chain directly with the formatted input
|
184 |
-
res = st.session_state.chain.invoke({"query": formatted_input})
|
185 |
-
|
186 |
-
# Process the response text
|
187 |
-
result_text = res['result']
|
188 |
-
|
189 |
-
# Clean up prefixing phrases and capitalize the first letter
|
190 |
-
if result_text.startswith('According to the provided context, '):
|
191 |
-
result_text = result_text[35:].strip()
|
192 |
-
elif result_text.startswith('Based on the provided context, '):
|
193 |
-
result_text = result_text[31:].strip()
|
194 |
-
elif result_text.startswith('According to the provided text, '):
|
195 |
-
result_text = result_text[34:].strip()
|
196 |
-
elif result_text.startswith('According to the context, '):
|
197 |
-
result_text = result_text[26:].strip()
|
198 |
-
|
199 |
-
# Ensure the first letter is uppercase
|
200 |
-
result_text = result_text[0].upper() + result_text[1:] if result_text else result_text
|
201 |
-
|
202 |
-
# Extract and format sources (if available)
|
203 |
-
sources = []
|
204 |
-
for doc in relevant_context:
|
205 |
-
source_path = doc.metadata.get('source', '')
|
206 |
-
formatted_source = source_path[122:-4] if source_path else "Unknown source"
|
207 |
-
sources.append(formatted_source)
|
208 |
|
209 |
-
|
210 |
-
|
211 |
-
|
|
|
|
|
|
|
212 |
|
213 |
-
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
|
216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
|
218 |
# Display chat messages
|
219 |
for message in st.session_state.messages:
|
|
|
161 |
# return res['result']
|
162 |
|
163 |
def generate_response(prompt_input):
|
164 |
+
try:
|
165 |
+
# Retrieve vector database context using ONLY the current user input
|
166 |
+
retriever = st.session_state.chain.retriever
|
167 |
+
relevant_context = retriever.get_relevant_documents(prompt_input) # Retrieve context only for the current prompt
|
168 |
+
|
169 |
+
# Prepare full conversation history for the LLM
|
170 |
+
conversation_history = ""
|
171 |
+
for message in st.session_state.messages:
|
172 |
+
conversation_history += f"{message['role']}: {message['content']}\n"
|
173 |
+
|
174 |
+
# Append the current user prompt to the conversation history
|
175 |
+
conversation_history += f"user: {prompt_input}\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
+
# Format the input for the chain with the retrieved context
|
178 |
+
formatted_input = (
|
179 |
+
f"Context:\n"
|
180 |
+
f"{' '.join([doc.page_content for doc in relevant_context])}\n\n"
|
181 |
+
f"Conversation:\n{conversation_history}"
|
182 |
+
)
|
183 |
|
184 |
+
# Invoke the RetrievalQA chain directly with the formatted input
|
185 |
+
res = st.session_state.chain.invoke({"query": formatted_input})
|
186 |
+
|
187 |
+
# Process the response text
|
188 |
+
result_text = res['result']
|
189 |
+
|
190 |
+
# Clean up prefixing phrases and capitalize the first letter
|
191 |
+
if result_text.startswith('According to the provided context, '):
|
192 |
+
result_text = result_text[35:].strip()
|
193 |
+
elif result_text.startswith('Based on the provided context, '):
|
194 |
+
result_text = result_text[31:].strip()
|
195 |
+
elif result_text.startswith('According to the provided text, '):
|
196 |
+
result_text = result_text[34:].strip()
|
197 |
+
elif result_text.startswith('According to the context, '):
|
198 |
+
result_text = result_text[26:].strip()
|
199 |
+
|
200 |
+
# Ensure the first letter is uppercase
|
201 |
+
result_text = result_text[0].upper() + result_text[1:] if result_text else result_text
|
202 |
|
203 |
+
# Extract and format sources (if available)
|
204 |
+
sources = []
|
205 |
+
for doc in relevant_context:
|
206 |
+
source_path = doc.metadata.get('source', '')
|
207 |
+
formatted_source = source_path[122:-4] if source_path else "Unknown source"
|
208 |
+
sources.append(formatted_source)
|
209 |
+
|
210 |
+
# Remove duplicates and combine into a single string
|
211 |
+
unique_sources = list(set(sources))
|
212 |
+
source_list = ", ".join(unique_sources)
|
213 |
+
|
214 |
+
# # Combine response text with sources
|
215 |
+
# result_text += f"\n\n**Sources:** {source_list}" if source_list else "\n\n**Sources:** None"
|
216 |
+
|
217 |
+
return result_text
|
218 |
+
|
219 |
+
except Exception as e:
|
220 |
+
# Handle rate limit or other errors gracefully
|
221 |
+
if "rate_limit_exceeded" in str(e).lower():
|
222 |
+
return "⚠️ Rate limit exceeded. Please clear the chat history and try again."
|
223 |
+
else:
|
224 |
+
return f"❌ An error occurred: {str(e)}"
|
225 |
|
226 |
# Display chat messages
|
227 |
for message in st.session_state.messages:
|