Jatin Mehra commited on
Commit
4dbeb79
·
1 Parent(s): 4a31622

Enhance model selection and tool creation with improved error handling, add content validation in chunking, and refine agent response logic for better user interaction and reliability

Browse files
Files changed (1) hide show
  1. preprocessing.py +263 -87
preprocessing.py CHANGED
@@ -14,13 +14,32 @@ dotenv.load_dotenv()
14
  # Initialize LLM and tools globally
15
 
16
  def model_selection(model_name):
17
- llm = ChatGroq(model=model_name, api_key=os.getenv("GROQ_API_KEY"))
 
 
 
 
 
18
  return llm
19
 
20
- tools = [TavilySearchResults(max_results=5)]
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # Initialize memory for conversation history
23
- memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
 
 
 
24
 
25
  def estimate_tokens(text):
26
  """Estimate the number of tokens in a text (rough approximation)."""
@@ -44,12 +63,19 @@ def chunk_text(documents, max_length=1000):
44
  current_chunk = ""
45
  current_metadata = metadata.copy()
46
  for paragraph in paragraphs:
 
 
 
 
47
  if estimate_tokens(current_chunk + paragraph) <= max_length // 4:
48
  current_chunk += paragraph + "\n\n"
49
  else:
50
- chunks.append({"text": current_chunk.strip(), "metadata": current_metadata})
 
 
51
  current_chunk = paragraph + "\n\n"
52
- if current_chunk:
 
53
  chunks.append({"text": current_chunk.strip(), "metadata": current_metadata})
54
  return chunks
55
 
@@ -73,57 +99,109 @@ def retrieve_similar_chunks(query, index, chunks_with_metadata, embedding_model,
73
  query_embedding = embedding_model.encode([query], convert_to_tensor=True).cpu().numpy()
74
  distances, indices = index.search(query_embedding, k)
75
 
76
- # Ensure indices are within bounds of chunks_with_metadata
77
- valid_indices = [i for i in indices[0] if 0 <= i < len(chunks_with_metadata)]
 
 
 
 
 
 
 
 
 
 
78
 
79
- return [
80
- (chunks_with_metadata[i]["text"][:max_chunk_length], distances[0][j], chunks_with_metadata[i]["metadata"])
81
- for j, i in enumerate(valid_indices) # Use valid_indices
82
- ]
83
 
84
 
85
  def create_vector_search_tool(faiss_index, document_chunks_with_metadata, embedding_model, k=3, max_chunk_length=1000):
86
  @tool
87
  def vector_database_search(query: str) -> str:
 
 
 
 
 
 
 
88
  """
89
- Searches the currently uploaded PDF document for information semantically similar to the query.
90
- Use this tool when the user's question is likely answerable from the content of the specific document they provided.
91
- Input should be the search query.
92
- """
93
- # Retrieve similar chunks using the provided session-specific components
94
- similar_chunks_data = retrieve_similar_chunks(
95
- query,
96
- faiss_index,
97
- document_chunks_with_metadata, # This is the list of dicts {text: ..., metadata: ...}
98
- embedding_model,
99
- k=k,
100
- max_chunk_length=max_chunk_length
101
- )
102
- # Format the response
103
- if not similar_chunks_data:
104
- return "No relevant information found in the document for that query."
105
 
106
- context = "\n\n---\n\n".join([chunk_text for chunk_text, _, _ in similar_chunks_data])
107
- return f"The following information was found in the document regarding '{query}':\n{context}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  return vector_database_search
110
 
111
- def agentic_rag(llm, agent_specific_tools, query, context_chunks, memory, Use_Tavily=False): # Renamed 'tools' to 'agent_specific_tools'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  # Sort chunks by relevance (lower distance = more relevant)
113
  context_chunks = sorted(context_chunks, key=lambda x: x[1]) if context_chunks else []
114
  context = ""
115
  total_tokens = 0
116
  max_tokens = 7000 # Leave room for prompt and response
117
 
118
- for chunk, _, _ in context_chunks:
119
- chunk_tokens = estimate_tokens(chunk)
120
- if total_tokens + chunk_tokens <= max_tokens:
121
- context += chunk + "\n\n"
122
- total_tokens += chunk_tokens
123
- else:
124
- break
 
 
 
 
125
 
126
  context = context.strip() if context else "No initial context provided from preliminary search."
 
127
 
128
 
129
  # Dynamically build the tool guidance for the prompt
@@ -131,70 +209,168 @@ def agentic_rag(llm, agent_specific_tools, query, context_chunks, memory, Use_Ta
131
  has_document_search = any(t.name == "vector_database_search" for t in agent_specific_tools)
132
  has_web_search = any(t.name == "tavily_search_results_json" for t in agent_specific_tools)
133
 
134
- guidance_parts = []
 
135
  if has_document_search:
136
- guidance_parts.append(
137
- "If the direct context (if any from preliminary search) is insufficient and the question seems answerable from the uploaded document, "
138
- "use the 'vector_database_search' tool to find relevant information within the document."
139
- )
140
- if has_web_search: # Tavily tool would only be in agent_specific_tools if Use_Tavily was true
141
- guidance_parts.append(
142
- "If the information is not found in the document (after using 'vector_database_search' if appropriate) "
143
- "or the question is of a general nature not specific to the document, "
144
- "use the 'tavily_search_results_json' tool for web searches."
145
- )
146
-
147
- if not guidance_parts:
148
- search_behavior_instructions = "If the context is insufficient, you *must* state that you don't know."
149
- else:
150
- search_behavior_instructions = " ".join(guidance_parts)
151
- search_behavior_instructions += ("\n * If, after all steps and tool use (if any), you cannot find an answer, "
152
- "respond with: \"Based on the available information, I don't know the answer.\"")
153
 
154
  prompt = ChatPromptTemplate.from_messages([
155
- ("system", f"""
156
- You are an expert Q&A system. Your primary function is to answer questions using a given set of documents (Context) and available tools.
157
-
158
- **Your Process:**
159
-
160
- 1. **Analyze the Question:** Understand exactly what the user is asking.
161
- 2. **Scan the Context:** Thoroughly review the 'Context' provided (if any) to find relevant information. This context is derived from a preliminary similarity search in the document.
162
- 3. **Formulate the Answer:**
163
- * If the initially provided context contains a clear answer, synthesize it into a concise response. Start your answer with "Based on the Document, ...".
164
- * {search_behavior_instructions}
165
- * When using the 'vector_database_search' tool, the information comes from the document. Prepend your answer with "Based on the Document, ...".
166
- * When using the 'tavily_search_results_json' tool, the information comes from the web. Prepend your answer with "According to a web search, ...". If no useful information is found, state that.
167
- 4. **Clarity:** Ensure your final answer is clear, direct, and avoids jargon if possible.
168
-
169
- **Important Rules:**
170
-
171
- * **Stick to Sources:** Do *not* use any information outside of the provided 'Context', document search results ('vector_database_search'), or web search results ('tavily_search_results_json').
172
- * **No Speculation:** Do not make assumptions or infer information not explicitly present.
173
- * **Cite Sources (If Web Searching):** If you use the 'tavily_search_results_json' tool and it provides source links, you MUST include them in your response.
174
- """),
175
- ("human", "Context: {{context}}\n\nQuestion: {{input}}"), # Double braces for f-string in f-string
176
  MessagesPlaceholder(variable_name="chat_history"),
177
  MessagesPlaceholder(variable_name="agent_scratchpad"),
178
  ])
179
 
180
  try:
 
 
 
 
 
 
 
 
 
181
  agent = create_tool_calling_agent(llm, agent_specific_tools, prompt)
182
- agent_executor = AgentExecutor(agent=agent, tools=agent_specific_tools, memory=memory, verbose=True)
183
- response_payload = agent_executor.invoke({
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  "input": query,
185
  "context": context,
186
- })
187
- return response_payload # Expecting dict like {'output': '...'}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  except Exception as e:
189
- print(f"Error during agent execution: {str(e)} \nTraceback: {traceback.format_exc()}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  fallback_prompt_template = ChatPromptTemplate.from_messages([
191
- ("system", "You are a helpful assistant. Use the provided context to answer the user's question. If the context is insufficient, say you don't know."),
 
 
 
192
  ("human", "Context: {context}\n\nQuestion: {input}")
193
  ])
194
- # Format the prompt with the actual context and query
195
- formatted_fallback_prompt = fallback_prompt_template.format_prompt(context=context, input=query).to_messages()
196
- response = llm.invoke(formatted_fallback_prompt)
197
- return {"output": response.content if hasattr(response, 'content') else str(response)}
 
 
 
 
 
 
 
198
 
199
  """if __name__ == "__main__":
200
  # Process PDF and prepare index
 
14
  # Initialize LLM and tools globally
15
 
16
  def model_selection(model_name):
17
+ llm = ChatGroq(
18
+ model=model_name,
19
+ api_key=os.getenv("GROQ_API_KEY"),
20
+ temperature=0.1, # Lower temperature for more consistent tool calling
21
+ max_tokens=2048 # Reasonable limit for responses
22
+ )
23
  return llm
24
 
25
+ # Create tools with better error handling
26
+ def create_tavily_tool():
27
+ try:
28
+ return TavilySearchResults(
29
+ max_results=5,
30
+ search_depth="advanced",
31
+ include_answer=True,
32
+ include_raw_content=False
33
+ )
34
+ except Exception as e:
35
+ print(f"Warning: Could not create Tavily tool: {e}")
36
+ return None
37
 
38
+ # Initialize tools globally but with error handling
39
+ _tavily_tool = create_tavily_tool()
40
+ tools = [_tavily_tool] if _tavily_tool else []
41
+
42
+ # Note: Memory should be created per session, not globally
43
 
44
  def estimate_tokens(text):
45
  """Estimate the number of tokens in a text (rough approximation)."""
 
63
  current_chunk = ""
64
  current_metadata = metadata.copy()
65
  for paragraph in paragraphs:
66
+ # Skip very short paragraphs (less than 10 characters)
67
+ if len(paragraph.strip()) < 10:
68
+ continue
69
+
70
  if estimate_tokens(current_chunk + paragraph) <= max_length // 4:
71
  current_chunk += paragraph + "\n\n"
72
  else:
73
+ # Only add chunks with meaningful content
74
+ if current_chunk.strip() and len(current_chunk.strip()) > 20:
75
+ chunks.append({"text": current_chunk.strip(), "metadata": current_metadata})
76
  current_chunk = paragraph + "\n\n"
77
+ # Add the last chunk if it has meaningful content
78
+ if current_chunk.strip() and len(current_chunk.strip()) > 20:
79
  chunks.append({"text": current_chunk.strip(), "metadata": current_metadata})
80
  return chunks
81
 
 
99
  query_embedding = embedding_model.encode([query], convert_to_tensor=True).cpu().numpy()
100
  distances, indices = index.search(query_embedding, k)
101
 
102
+ # Ensure indices are within bounds and create mapping for correct distances
103
+ valid_results = []
104
+ for idx_pos, chunk_idx in enumerate(indices[0]):
105
+ if 0 <= chunk_idx < len(chunks_with_metadata):
106
+ chunk_text = chunks_with_metadata[chunk_idx]["text"][:max_chunk_length]
107
+ # Only include chunks with meaningful content
108
+ if chunk_text.strip(): # Skip empty chunks
109
+ valid_results.append((
110
+ chunk_text,
111
+ distances[0][idx_pos], # Use original position for correct distance
112
+ chunks_with_metadata[chunk_idx]["metadata"]
113
+ ))
114
 
115
+ return valid_results
 
 
 
116
 
117
 
118
  def create_vector_search_tool(faiss_index, document_chunks_with_metadata, embedding_model, k=3, max_chunk_length=1000):
119
  @tool
120
  def vector_database_search(query: str) -> str:
121
+ """Search the uploaded PDF document for information related to the query.
122
+
123
+ Args:
124
+ query: The search query string to find relevant information in the document.
125
+
126
+ Returns:
127
+ A string containing relevant information found in the document.
128
  """
129
+ # Handle very short or empty queries
130
+ if not query or len(query.strip()) < 3:
131
+ return "Please provide a more specific search query with at least 3 characters."
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ try:
134
+ # Retrieve similar chunks using the provided session-specific components
135
+ similar_chunks_data = retrieve_similar_chunks(
136
+ query,
137
+ faiss_index,
138
+ document_chunks_with_metadata, # This is the list of dicts {text: ..., metadata: ...}
139
+ embedding_model,
140
+ k=k,
141
+ max_chunk_length=max_chunk_length
142
+ )
143
+
144
+ # Format the response
145
+ if not similar_chunks_data:
146
+ return "No relevant information found in the document for that query. Please try rephrasing your question or using different keywords."
147
+
148
+ # Filter out chunks with very high distance (low similarity)
149
+ filtered_chunks = [chunk for chunk in similar_chunks_data if chunk[1] < 1.5] # Adjust threshold as needed
150
+
151
+ if not filtered_chunks:
152
+ return "No sufficiently relevant information found in the document for that query. Please try rephrasing your question or using different keywords."
153
+
154
+ context = "\n\n---\n\n".join([chunk_text for chunk_text, _, _ in filtered_chunks])
155
+ return f"The following information was found in the document regarding '{query}':\n{context}"
156
+
157
+ except Exception as e:
158
+ print(f"Error in vector search tool: {e}")
159
+ return f"Error searching the document: {str(e)}"
160
 
161
  return vector_database_search
162
 
163
+ def agentic_rag(llm, agent_specific_tools, query, context_chunks, memory, Use_Tavily=False):
164
+ # Validate inputs
165
+ if not query or not query.strip():
166
+ return {"output": "Please provide a valid question."}
167
+
168
+ if not agent_specific_tools:
169
+ print("Warning: No tools provided, using direct LLM response")
170
+ # Use direct LLM call without agent if no tools
171
+ fallback_prompt = ChatPromptTemplate.from_messages([
172
+ ("system", "You are a helpful assistant that answers questions about documents. Use the provided context to answer the user's question."),
173
+ ("human", "Context: {context}\n\nQuestion: {input}")
174
+ ])
175
+ try:
176
+ formatted_prompt = fallback_prompt.format_prompt(context="No context available", input=query).to_messages()
177
+ response = llm.invoke(formatted_prompt)
178
+ return {"output": response.content if hasattr(response, 'content') else str(response)}
179
+ except Exception as e:
180
+ print(f"Direct LLM call failed: {e}")
181
+ return {"output": "I'm sorry, I encountered an error processing your request."}
182
+
183
+ print(f"Available tools: {[tool.name for tool in agent_specific_tools]}")
184
+
185
  # Sort chunks by relevance (lower distance = more relevant)
186
  context_chunks = sorted(context_chunks, key=lambda x: x[1]) if context_chunks else []
187
  context = ""
188
  total_tokens = 0
189
  max_tokens = 7000 # Leave room for prompt and response
190
 
191
+ # Filter out chunks with very high distance scores (low similarity)
192
+ relevant_chunks = [chunk for chunk in context_chunks if len(chunk) >= 3 and chunk[1] < 1.5]
193
+
194
+ for chunk, _, _ in relevant_chunks:
195
+ if chunk and chunk.strip(): # Ensure chunk has content
196
+ chunk_tokens = estimate_tokens(chunk)
197
+ if total_tokens + chunk_tokens <= max_tokens:
198
+ context += chunk + "\n\n"
199
+ total_tokens += chunk_tokens
200
+ else:
201
+ break
202
 
203
  context = context.strip() if context else "No initial context provided from preliminary search."
204
+ print(f"Using context length: {len(context)} characters")
205
 
206
 
207
  # Dynamically build the tool guidance for the prompt
 
209
  has_document_search = any(t.name == "vector_database_search" for t in agent_specific_tools)
210
  has_web_search = any(t.name == "tavily_search_results_json" for t in agent_specific_tools)
211
 
212
+ # Simplified tool guidance
213
+ tool_instructions = ""
214
  if has_document_search:
215
+ tool_instructions += "Use vector_database_search to find information in the uploaded document. "
216
+ if has_web_search:
217
+ tool_instructions += "Use tavily_search_results_json for web searches when document search is insufficient. "
218
+
219
+ if not tool_instructions:
220
+ tool_instructions = "Answer based on the provided context only. "
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  prompt = ChatPromptTemplate.from_messages([
223
+ ("system", f"""You are a helpful AI assistant that answers questions about documents.
224
+
225
+ Context: {{context}}
226
+
227
+ Tools available: {tool_instructions}
228
+
229
+ Instructions:
230
+ - Use the provided context first
231
+ - If context is insufficient, use available tools to search for more information
232
+ - Provide clear, helpful answers
233
+ - If you cannot find an answer, say so clearly"""),
234
+ ("human", "{input}"),
 
 
 
 
 
 
 
 
 
235
  MessagesPlaceholder(variable_name="chat_history"),
236
  MessagesPlaceholder(variable_name="agent_scratchpad"),
237
  ])
238
 
239
  try:
240
+ print(f"Creating agent with {len(agent_specific_tools)} tools")
241
+
242
+ # Validate that tools are properly formatted
243
+ for tool in agent_specific_tools:
244
+ print(f"Tool: {tool.name} - {type(tool)}")
245
+ # Ensure tool has required attributes
246
+ if not hasattr(tool, 'name') or not hasattr(tool, 'description'):
247
+ raise ValueError(f"Tool {tool} is missing required attributes")
248
+
249
  agent = create_tool_calling_agent(llm, agent_specific_tools, prompt)
250
+ agent_executor = AgentExecutor(
251
+ agent=agent,
252
+ tools=agent_specific_tools,
253
+ memory=memory,
254
+ verbose=True,
255
+ handle_parsing_errors=True,
256
+ max_iterations=2, # Reduced further to prevent issues
257
+ return_intermediate_steps=False,
258
+ early_stopping_method="generate"
259
+ )
260
+
261
+ print(f"Invoking agent with query: '{query}' and context length: {len(context)} chars")
262
+
263
+ # Create input with simpler structure
264
+ agent_input = {
265
  "input": query,
266
  "context": context,
267
+ }
268
+
269
+ response_payload = agent_executor.invoke(agent_input)
270
+
271
+ print(f"Agent response keys: {response_payload.keys() if response_payload else 'None'}")
272
+
273
+ # Extract and validate the output
274
+ agent_output = response_payload.get("output", "") if response_payload else ""
275
+ print(f"Agent output length: {len(agent_output)} chars")
276
+ print(f"Agent output preview: {agent_output[:100]}..." if len(agent_output) > 100 else f"Agent output: {agent_output}")
277
+
278
+ # Validate response quality
279
+ if not agent_output or len(agent_output.strip()) < 10:
280
+ print(f"Warning: Agent returned insufficient response (length: {len(agent_output)}), using fallback")
281
+ raise ValueError("Insufficient response from agent")
282
+
283
+ # Check if response is just a prefix without content
284
+ problematic_prefixes = [
285
+ "Based on the Document,",
286
+ "According to a web search,",
287
+ "Based on the available information,",
288
+ "I need to",
289
+ "Let me"
290
+ ]
291
+
292
+ stripped_output = agent_output.strip()
293
+ if any(stripped_output == prefix.strip() or stripped_output == prefix.strip() + "." for prefix in problematic_prefixes):
294
+ print(f"Warning: Agent returned only prefix without content: '{stripped_output}', using fallback")
295
+ raise ValueError("Agent returned incomplete response")
296
+
297
+ return response_payload
298
  except Exception as e:
299
+ error_msg = str(e)
300
+ print(f"Error during agent execution: {error_msg} \nTraceback: {traceback.format_exc()}")
301
+
302
+ # Check if it's a specific Groq function calling error
303
+ if "Failed to call a function" in error_msg or "function" in error_msg.lower():
304
+ print("Detected Groq function calling error, trying simpler approach...")
305
+
306
+ # Try with a simpler agent setup or direct LLM call
307
+ try:
308
+ # First, try to use tools individually without agent framework
309
+ if agent_specific_tools:
310
+ print("Attempting manual tool usage...")
311
+ tool_results = []
312
+
313
+ # Try vector search first if available
314
+ vector_tool = next((t for t in agent_specific_tools if t.name == "vector_database_search"), None)
315
+ if vector_tool:
316
+ try:
317
+ search_result = vector_tool.run(query)
318
+ if search_result and "No relevant information" not in search_result:
319
+ tool_results.append(f"Document Search: {search_result}")
320
+ except Exception as tool_error:
321
+ print(f"Vector tool error: {tool_error}")
322
+
323
+ # Try web search if needed and available
324
+ if Use_Tavily:
325
+ web_tool = next((t for t in agent_specific_tools if t.name == "tavily_search_results_json"), None)
326
+ if web_tool:
327
+ try:
328
+ web_result = web_tool.run(query)
329
+ if web_result:
330
+ tool_results.append(f"Web Search: {web_result}")
331
+ except Exception as tool_error:
332
+ print(f"Web tool error: {tool_error}")
333
+
334
+ # Combine tool results with context
335
+ enhanced_context = context
336
+ if tool_results:
337
+ enhanced_context += "\n\nAdditional Information:\n" + "\n\n".join(tool_results)
338
+
339
+ # Use direct LLM call with enhanced context
340
+ direct_prompt = ChatPromptTemplate.from_messages([
341
+ ("system", "You are a helpful assistant. Use the provided context and information to answer the user's question clearly and completely."),
342
+ ("human", "Context and Information: {context}\n\nQuestion: {input}")
343
+ ])
344
+
345
+ formatted_prompt = direct_prompt.format_prompt(context=enhanced_context, input=query).to_messages()
346
+ response = llm.invoke(formatted_prompt)
347
+ direct_output = response.content if hasattr(response, 'content') else str(response)
348
+ print(f"Direct tool usage response length: {len(direct_output)} chars")
349
+ return {"output": direct_output}
350
+
351
+ except Exception as manual_error:
352
+ print(f"Manual tool usage also failed: {manual_error}")
353
+
354
+ print("Using fallback direct LLM response...")
355
+
356
  fallback_prompt_template = ChatPromptTemplate.from_messages([
357
+ ("system", """You are a helpful assistant that answers questions about documents.
358
+ Use the provided context to answer the user's question.
359
+ If the context contains relevant information, start your answer with "Based on the Document, ..."
360
+ If the context is insufficient, clearly state what you don't know."""),
361
  ("human", "Context: {context}\n\nQuestion: {input}")
362
  ])
363
+
364
+ try:
365
+ # Format the prompt with the actual context and query
366
+ formatted_fallback_prompt = fallback_prompt_template.format_prompt(context=context, input=query).to_messages()
367
+ response = llm.invoke(formatted_fallback_prompt)
368
+ fallback_output = response.content if hasattr(response, 'content') else str(response)
369
+ print(f"Fallback response length: {len(fallback_output)} chars")
370
+ return {"output": fallback_output}
371
+ except Exception as fallback_error:
372
+ print(f"Fallback also failed: {str(fallback_error)}")
373
+ return {"output": "I'm sorry, I encountered an error processing your request. Please try again."}
374
 
375
  """if __name__ == "__main__":
376
  # Process PDF and prepare index