Nioooor commited on
Commit
4c4dd9e
·
verified ·
1 Parent(s): 3ca76f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -50
app.py CHANGED
@@ -161,59 +161,67 @@ if "messages" not in st.session_state:
161
  # return res['result']
162
 
163
  def generate_response(prompt_input):
164
- # Retrieve vector database context using ONLY the current user input
165
- retriever = st.session_state.chain.retriever
166
- relevant_context = retriever.get_relevant_documents(prompt_input) # Retrieve context only for the current prompt
167
-
168
- # Prepare full conversation history for the LLM
169
- conversation_history = ""
170
- for message in st.session_state.messages:
171
- conversation_history += f"{message['role']}: {message['content']}\n"
172
-
173
- # Append the current user prompt to the conversation history
174
- conversation_history += f"user: {prompt_input}\n"
175
-
176
- # Format the input for the chain with the retrieved context
177
- formatted_input = (
178
- f"Context:\n"
179
- f"{' '.join([doc.page_content for doc in relevant_context])}\n\n"
180
- f"Conversation:\n{conversation_history}"
181
- )
182
-
183
- # Invoke the RetrievalQA chain directly with the formatted input
184
- res = st.session_state.chain.invoke({"query": formatted_input})
185
-
186
- # Process the response text
187
- result_text = res['result']
188
-
189
- # Clean up prefixing phrases and capitalize the first letter
190
- if result_text.startswith('According to the provided context, '):
191
- result_text = result_text[35:].strip()
192
- elif result_text.startswith('Based on the provided context, '):
193
- result_text = result_text[31:].strip()
194
- elif result_text.startswith('According to the provided text, '):
195
- result_text = result_text[34:].strip()
196
- elif result_text.startswith('According to the context, '):
197
- result_text = result_text[26:].strip()
198
-
199
- # Ensure the first letter is uppercase
200
- result_text = result_text[0].upper() + result_text[1:] if result_text else result_text
201
-
202
- # Extract and format sources (if available)
203
- sources = []
204
- for doc in relevant_context:
205
- source_path = doc.metadata.get('source', '')
206
- formatted_source = source_path[122:-4] if source_path else "Unknown source"
207
- sources.append(formatted_source)
208
 
209
- # Remove duplicates and combine into a single string
210
- unique_sources = list(set(sources))
211
- source_list = ", ".join(unique_sources)
 
 
 
212
 
213
- # # Combine response text with sources
214
- # result_text += f"\n\n**Sources:** {source_list}" if source_list else "\n\n**Sources:** None"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
- return result_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
  # Display chat messages
219
  for message in st.session_state.messages:
 
161
  # return res['result']
162
 
163
  def generate_response(prompt_input):
164
+ try:
165
+ # Retrieve vector database context using ONLY the current user input
166
+ retriever = st.session_state.chain.retriever
167
+ relevant_context = retriever.get_relevant_documents(prompt_input) # Retrieve context only for the current prompt
168
+
169
+ # Prepare full conversation history for the LLM
170
+ conversation_history = ""
171
+ for message in st.session_state.messages:
172
+ conversation_history += f"{message['role']}: {message['content']}\n"
173
+
174
+ # Append the current user prompt to the conversation history
175
+ conversation_history += f"user: {prompt_input}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
+ # Format the input for the chain with the retrieved context
178
+ formatted_input = (
179
+ f"Context:\n"
180
+ f"{' '.join([doc.page_content for doc in relevant_context])}\n\n"
181
+ f"Conversation:\n{conversation_history}"
182
+ )
183
 
184
+ # Invoke the RetrievalQA chain directly with the formatted input
185
+ res = st.session_state.chain.invoke({"query": formatted_input})
186
+
187
+ # Process the response text
188
+ result_text = res['result']
189
+
190
+ # Clean up prefixing phrases and capitalize the first letter
191
+ if result_text.startswith('According to the provided context, '):
192
+ result_text = result_text[35:].strip()
193
+ elif result_text.startswith('Based on the provided context, '):
194
+ result_text = result_text[31:].strip()
195
+ elif result_text.startswith('According to the provided text, '):
196
+ result_text = result_text[34:].strip()
197
+ elif result_text.startswith('According to the context, '):
198
+ result_text = result_text[26:].strip()
199
+
200
+ # Ensure the first letter is uppercase
201
+ result_text = result_text[0].upper() + result_text[1:] if result_text else result_text
202
 
203
+ # Extract and format sources (if available)
204
+ sources = []
205
+ for doc in relevant_context:
206
+ source_path = doc.metadata.get('source', '')
207
+ formatted_source = source_path[122:-4] if source_path else "Unknown source"
208
+ sources.append(formatted_source)
209
+
210
+ # Remove duplicates and combine into a single string
211
+ unique_sources = list(set(sources))
212
+ source_list = ", ".join(unique_sources)
213
+
214
+ # # Combine response text with sources
215
+ # result_text += f"\n\n**Sources:** {source_list}" if source_list else "\n\n**Sources:** None"
216
+
217
+ return result_text
218
+
219
+ except Exception as e:
220
+ # Handle rate limit or other errors gracefully
221
+ if "rate_limit_exceeded" in str(e).lower():
222
+ return "⚠️ Rate limit exceeded. Please clear the chat history and try again."
223
+ else:
224
+ return f"❌ An error occurred: {str(e)}"
225
 
226
  # Display chat messages
227
  for message in st.session_state.messages: