Jatin Mehra commited on
Commit
24b32e6
·
1 Parent(s): 63ed7c1

Refactor session management and enhance PDF upload handling for improved error handling and model consistency

Browse files
Files changed (1) hide show
  1. app.py +101 -88
app.py CHANGED
@@ -18,9 +18,11 @@ from preprocessing import (
18
  build_faiss_index,
19
  retrieve_similar_chunks,
20
  agentic_rag,
21
- tools
 
22
  )
23
  from sentence_transformers import SentenceTransformer
 
24
 
25
  # Load environment variables
26
  dotenv.load_dotenv()
@@ -57,50 +59,44 @@ class SessionRequest(BaseModel):
57
 
58
  # Function to save session data
59
  def save_session(session_id, data):
60
- sessions[session_id] = data
61
 
62
- # Create a copy of data that is safe to pickle
63
  pickle_safe_data = {
64
  "file_path": data.get("file_path"),
65
  "file_name": data.get("file_name"),
66
- "chunks": data.get("chunks"),
67
  "chat_history": data.get("chat_history", [])
 
68
  }
69
 
70
- # Persist to disk
71
  with open(f"{UPLOAD_DIR}/{session_id}_session.pkl", "wb") as f:
72
  pickle.dump(pickle_safe_data, f)
73
 
 
74
  # Function to load session data
75
- def load_session(session_id, model_name="meta-llama/llama-4-scout-17b-16e-instruct"):
76
  try:
77
- # Check if session is already in memory
78
  if session_id in sessions:
79
- # Ensure the LLM in the cached session matches the requested model_name
80
- # If not, update it. This handles cases where model_name might change for an existing session.
81
- if sessions[session_id].get("llm") is None or sessions[session_id]["llm"].model_name != model_name:
82
- try:
83
- sessions[session_id]["llm"] = model_selection(model_name)
84
- except Exception as e:
85
- print(f"Error updating LLM for in-memory session {session_id} to {model_name}: {str(e)}")
86
- # Decide if this is a critical error; for now, we'll proceed with the old LLM or handle as error
87
- # For simplicity, if LLM update fails, we might want to indicate session load failure or use existing.
88
- # Here, we'll let it proceed, but this could be a point of further refinement.
89
- return sessions[session_id], True
90
 
91
- # Try to load from disk
92
  file_path_pkl = f"{UPLOAD_DIR}/{session_id}_session.pkl"
93
  if os.path.exists(file_path_pkl):
94
  with open(file_path_pkl, "rb") as f:
95
- data = pickle.load(f) # This is pickle_safe_data
96
 
97
- # Recreate non-pickled objects
98
- # Ensure 'chunks' and 'file_path' (for the original PDF) are present in the loaded data
99
- # and the original PDF file still exists.
100
  original_pdf_path = data.get("file_path")
101
  if data.get("chunks") and original_pdf_path and os.path.exists(original_pdf_path):
102
  embedding_model_instance = SentenceTransformer('BAAI/bge-large-en-v1.5')
103
- # data["chunks"] is already the list of dicts: {text: ..., metadata: ...}
104
  recreated_embeddings, _ = create_embeddings(data["chunks"], embedding_model_instance)
105
  recreated_index = build_faiss_index(recreated_embeddings)
106
  recreated_llm = model_selection(model_name)
@@ -108,25 +104,23 @@ def load_session(session_id, model_name="meta-llama/llama-4-scout-17b-16e-instru
108
  full_session_data = {
109
  "file_path": original_pdf_path,
110
  "file_name": data.get("file_name"),
111
- "chunks": data.get("chunks"), # These are chunks_with_metadata
112
  "chat_history": data.get("chat_history", []),
113
  "model": embedding_model_instance, # SentenceTransformer model
114
  "index": recreated_index, # FAISS index
115
  "llm": recreated_llm # LLM
116
  }
117
- sessions[session_id] = full_session_data # Store in memory cache
118
  return full_session_data, True
119
  else:
120
- # If essential data for reconstruction is missing from pickle or the original PDF is gone
121
- print(f"Warning: Session data for {session_id} is incomplete or its PDF file '{original_pdf_path}' is missing. Cannot reconstruct session.")
122
- # Optionally, remove the stale .pkl file
123
- # os.remove(file_path_pkl)
124
  return None, False
125
 
126
- return None, False # Session not in memory and not found on disk, or reconstruction failed
127
  except Exception as e:
128
  print(f"Error loading session {session_id}: {str(e)}")
129
- print(traceback.format_exc()) # Print full traceback for debugging
130
  return None, False
131
 
132
  # Function to remove PDF file
@@ -168,44 +162,46 @@ async def read_root():
168
  @app.post("/upload-pdf")
169
  async def upload_pdf(
170
  file: UploadFile = File(...),
171
- model_name: str = Form("meta-llama/llama-4-scout-17b-16e-instruct")
172
  ):
173
- # Generate a unique session ID
174
  session_id = str(uuid.uuid4())
175
  file_path = None
176
 
177
  try:
178
- # Save the uploaded file
179
  file_path = f"{UPLOAD_DIR}/{session_id}_{file.filename}"
180
  with open(file_path, "wb") as buffer:
181
  shutil.copyfileobj(file.file, buffer)
182
 
183
- # Check if API keys are set
184
- if not os.getenv("GROQ_API_KEY"):
185
- raise ValueError("GROQ_API_KEY is not set in the environment variables")
186
-
187
- # Process the PDF
188
- documents = process_pdf_file(file_path) # Returns list of Document objects
189
- chunks = chunk_text(documents, max_length=1500) # Updated to handle documents
190
-
191
- # Create embeddings
192
- model = SentenceTransformer('BAAI/bge-large-en-v1.5') # Updated embedding model
193
- embeddings, chunks_with_metadata = create_embeddings(chunks, model) # Unpack tuple
 
 
 
 
 
194
 
195
- # Build FAISS index
196
- index = build_faiss_index(embeddings) # Pass only embeddings array
197
 
198
- # Initialize LLM
199
  llm = model_selection(model_name)
200
 
201
- # Save session data
202
  session_data = {
203
  "file_path": file_path,
204
  "file_name": file.filename,
205
- "chunks": chunks_with_metadata, # Store chunks with metadata
206
- "model": model,
207
- "index": index,
208
- "llm": llm,
209
  "chat_history": []
210
  }
211
  save_session(session_id, session_data)
@@ -213,71 +209,89 @@ async def upload_pdf(
213
  return {"status": "success", "session_id": session_id, "message": f"Processed {file.filename}"}
214
 
215
  except Exception as e:
216
- # Clean up on error
217
  if file_path and os.path.exists(file_path):
218
  os.remove(file_path)
219
-
220
  error_msg = str(e)
221
  stack_trace = traceback.format_exc()
222
- print(f"Error processing PDF: {error_msg}")
223
- print(f"Stack trace: {stack_trace}")
224
-
225
  return JSONResponse(
226
- status_code=400,
227
- content={
228
- "status": "error",
229
- "detail": error_msg,
230
- "type": type(e).__name__
231
- }
232
  )
233
 
234
  # Route to chat with the document
235
  @app.post("/chat")
236
  async def chat(request: ChatRequest):
237
- # Try to load session if not in memory
238
  session, found = load_session(request.session_id, model_name=request.model_name)
239
  if not found:
240
- raise HTTPException(status_code=404, detail="Session not found. Please upload a document first.")
241
 
242
  try:
243
- from langchain.memory import ConversationBufferMemory
244
- agent_memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
245
-
246
  for entry in session.get("chat_history", []):
247
  agent_memory.chat_memory.add_user_message(entry["user"])
248
  agent_memory.chat_memory.add_ai_message(entry["assistant"])
249
-
250
 
251
- # Retrieve similar chunks
252
- similar_chunks = retrieve_similar_chunks(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  request.query,
254
  session["index"],
255
- session["chunks"],
256
- session["model"],
257
- k=10
258
  )
259
 
260
- # Generate response using agentic_rag
261
  response = agentic_rag(
262
  session["llm"],
263
- tools,
264
  query=request.query,
265
- context_chunks=similar_chunks, # Pass the list of tuples
266
- Use_Tavily=request.use_search,
267
  memory=agent_memory
268
  )
269
 
270
- # Update chat history
271
- session["chat_history"].append({"user": request.query, "assistant": response["output"]})
272
- save_session(request.session_id, session)
273
 
274
  return {
275
  "status": "success",
276
- "answer": response["output"],
277
- "context_used": [{"text": chunk, "score": float(score)} for chunk, score, _ in similar_chunks]
 
278
  }
279
-
280
  except Exception as e:
 
281
  raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
282
 
283
 
@@ -330,5 +344,4 @@ async def get_models():
330
 
331
  # Run the application if this file is executed directly
332
  if __name__ == "__main__":
333
- uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
334
-
 
18
  build_faiss_index,
19
  retrieve_similar_chunks,
20
  agentic_rag,
21
+ tools as global_base_tools,
22
+ create_vector_search_tool
23
  )
24
  from sentence_transformers import SentenceTransformer
25
+ from langchain.memory import ConversationBufferMemory
26
 
27
  # Load environment variables
28
  dotenv.load_dotenv()
 
59
 
60
  # Function to save session data
61
  def save_session(session_id, data):
62
+ sessions[session_id] = data # Keep non-picklable in memory for active session
63
 
 
64
  pickle_safe_data = {
65
  "file_path": data.get("file_path"),
66
  "file_name": data.get("file_name"),
67
+ "chunks": data.get("chunks"), # Chunks with metadata (list of dicts)
68
  "chat_history": data.get("chat_history", [])
69
+ # FAISS index, embedding model, and LLM model are not pickled, will be reloaded/recreated
70
  }
71
 
 
72
  with open(f"{UPLOAD_DIR}/{session_id}_session.pkl", "wb") as f:
73
  pickle.dump(pickle_safe_data, f)
74
 
75
+
76
  # Function to load session data
77
+ def load_session(session_id, model_name="llama3-8b-8192"): # Ensure model_name matches default
78
  try:
 
79
  if session_id in sessions:
80
+ cached_session = sessions[session_id]
81
+ # Ensure LLM and potentially other non-pickled parts are up-to-date or loaded
82
+ if cached_session.get("llm") is None or (hasattr(cached_session["llm"], "model_name") and cached_session["llm"].model_name != model_name):
83
+ cached_session["llm"] = model_selection(model_name)
84
+ if cached_session.get("model") is None: # Embedding model
85
+ cached_session["model"] = SentenceTransformer('BAAI/bge-large-en-v1.5')
86
+ if cached_session.get("index") is None and cached_session.get("chunks"): # FAISS index
87
+ embeddings, _ = create_embeddings(cached_session["chunks"], cached_session["model"])
88
+ cached_session["index"] = build_faiss_index(embeddings)
89
+ return cached_session, True
 
90
 
 
91
  file_path_pkl = f"{UPLOAD_DIR}/{session_id}_session.pkl"
92
  if os.path.exists(file_path_pkl):
93
  with open(file_path_pkl, "rb") as f:
94
+ data = pickle.load(f)
95
 
 
 
 
96
  original_pdf_path = data.get("file_path")
97
  if data.get("chunks") and original_pdf_path and os.path.exists(original_pdf_path):
98
  embedding_model_instance = SentenceTransformer('BAAI/bge-large-en-v1.5')
99
+ # Chunks are already {text: ..., metadata: ...}
100
  recreated_embeddings, _ = create_embeddings(data["chunks"], embedding_model_instance)
101
  recreated_index = build_faiss_index(recreated_embeddings)
102
  recreated_llm = model_selection(model_name)
 
104
  full_session_data = {
105
  "file_path": original_pdf_path,
106
  "file_name": data.get("file_name"),
107
+ "chunks": data.get("chunks"), # chunks_with_metadata
108
  "chat_history": data.get("chat_history", []),
109
  "model": embedding_model_instance, # SentenceTransformer model
110
  "index": recreated_index, # FAISS index
111
  "llm": recreated_llm # LLM
112
  }
113
+ sessions[session_id] = full_session_data
114
  return full_session_data, True
115
  else:
116
+ print(f"Warning: Session data for {session_id} is incomplete or PDF missing. Cannot reconstruct.")
117
+ if os.path.exists(file_path_pkl): os.remove(file_path_pkl) # Clean up stale pkl
 
 
118
  return None, False
119
 
120
+ return None, False
121
  except Exception as e:
122
  print(f"Error loading session {session_id}: {str(e)}")
123
+ print(traceback.format_exc())
124
  return None, False
125
 
126
  # Function to remove PDF file
 
162
  @app.post("/upload-pdf")
163
  async def upload_pdf(
164
  file: UploadFile = File(...),
165
+ model_name: str = Form("llama3-8b-8192") # Default model
166
  ):
 
167
  session_id = str(uuid.uuid4())
168
  file_path = None
169
 
170
  try:
 
171
  file_path = f"{UPLOAD_DIR}/{session_id}_{file.filename}"
172
  with open(file_path, "wb") as buffer:
173
  shutil.copyfileobj(file.file, buffer)
174
 
175
+ if not os.getenv("GROQ_API_KEY") and "llama" in model_name: # Llama specific check for Groq
176
+ raise ValueError("GROQ_API_KEY is not set for Groq Llama models.")
177
+ if not os.getenv("TAVILY_API_KEY"): # Needed for TavilySearchResults
178
+ print("Warning: TAVILY_API_KEY is not set. Web search will not function.")
179
+
180
+ documents = process_pdf_file(file_path)
181
+ # Ensure max_length for chunk_text is appropriate.
182
+ # The value 1500 might be too large if estimate_tokens is text_len // 4, as it means ~6000 characters.
183
+ # Let's use a smaller max_length for chunks for better granularity in RAG retrieval.
184
+ # For `bge-large-en-v1.5` (max sequence length 512 tokens), chunks around 250-400 tokens are often good.
185
+ # If estimate_tokens is len(text)//4, then max_length of 250 tokens is roughly 1000 characters.
186
+ # Let's use max_length=256 (tokens) for chunker config, so about 1024 characters.
187
+ # The chunk_text function uses max_length as character count / 4. So if we want 256 tokens, max_length = 256*4 = 1024
188
+ # However, the current chunk_text logic is `estimate_tokens(current_chunk + paragraph) <= max_length // 4`.
189
+ # This means `max_length` is already considered a token limit. So `max_length=256` (tokens) is the target.
190
+ chunks_with_metadata = chunk_text(documents, max_length=256) # max_length in tokens
191
 
192
+ embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5')
193
+ embeddings, _ = create_embeddings(chunks_with_metadata, embedding_model) # Chunks are already with metadata
194
 
195
+ index = build_faiss_index(embeddings)
196
  llm = model_selection(model_name)
197
 
 
198
  session_data = {
199
  "file_path": file_path,
200
  "file_name": file.filename,
201
+ "chunks": chunks_with_metadata, # Store chunks with metadata
202
+ "model": embedding_model, # SentenceTransformer instance
203
+ "index": index, # FAISS index instance
204
+ "llm": llm, # LLM instance
205
  "chat_history": []
206
  }
207
  save_session(session_id, session_data)
 
209
  return {"status": "success", "session_id": session_id, "message": f"Processed {file.filename}"}
210
 
211
  except Exception as e:
 
212
  if file_path and os.path.exists(file_path):
213
  os.remove(file_path)
 
214
  error_msg = str(e)
215
  stack_trace = traceback.format_exc()
216
+ print(f"Error processing PDF: {error_msg}\nStack trace: {stack_trace}")
 
 
217
  return JSONResponse(
218
+ status_code=500, # Internal server error for processing issues
219
+ content={"status": "error", "detail": error_msg, "type": type(e).__name__}
 
 
 
 
220
  )
221
 
222
  # Route to chat with the document
223
  @app.post("/chat")
224
  async def chat(request: ChatRequest):
 
225
  session, found = load_session(request.session_id, model_name=request.model_name)
226
  if not found:
227
+ raise HTTPException(status_code=404, detail="Session not found or expired. Please upload a document first.")
228
 
229
  try:
230
+ # Per-request memory to ensure chat history is correctly loaded for the agent
231
+ agent_memory = ConversationBufferMemory(memory_key="chat_history", input_key="input", return_messages=True)
 
232
  for entry in session.get("chat_history", []):
233
  agent_memory.chat_memory.add_user_message(entry["user"])
234
  agent_memory.chat_memory.add_ai_message(entry["assistant"])
 
235
 
236
+ # Prepare tools for the agent for THIS request
237
+ current_request_tools = []
238
+
239
+ # 1. Add the document-specific vector search tool
240
+ if "index" in session and "chunks" in session and "model" in session:
241
+ vector_search_tool_instance = create_vector_search_tool(
242
+ faiss_index=session["index"],
243
+ document_chunks_with_metadata=session["chunks"], # Pass the correct variable
244
+ embedding_model=session["model"] # This is the SentenceTransformer model
245
+ )
246
+ current_request_tools.append(vector_search_tool_instance)
247
+ else:
248
+ print(f"Warning: Session {request.session_id} missing data for vector_database_search tool.")
249
+
250
+ # 2. Conditionally add Tavily (web search) tool
251
+ if request.use_search:
252
+ if os.getenv("TAVILY_API_KEY"):
253
+ tavily_tool = next((tool for tool in global_base_tools if tool.name == "tavily_search_results_json"), None)
254
+ if tavily_tool:
255
+ current_request_tools.append(tavily_tool)
256
+ else: # Should not happen if global_base_tools is defined correctly
257
+ print("Warning: Tavily search requested, but tool misconfigured.")
258
+ else:
259
+ print("Warning: Tavily search requested, but TAVILY_API_KEY is not set.")
260
+
261
+ # Retrieve initial similar chunks for RAG context (can be empty if no good match)
262
+ # This context is given to the agent *before* it decides to use tools.
263
+ # k=5 means we retrieve up to 5 chunks for initial context.
264
+ # The agent can then use `vector_database_search` to search more if needed.
265
+ initial_similar_chunks = retrieve_similar_chunks(
266
  request.query,
267
  session["index"],
268
+ session["chunks"], # list of dicts {text:..., metadata:...}
269
+ session["model"], # SentenceTransformer model
270
+ k=5 # Number of chunks for initial context
271
  )
272
 
 
273
  response = agentic_rag(
274
  session["llm"],
275
+ current_request_tools, # Pass the dynamically assembled list of tools
276
  query=request.query,
277
+ context_chunks=initial_similar_chunks,
278
+ Use_Tavily=request.use_search, # Still passed to agentic_rag for potential fine-grained logic, though prompt adapts to tools
279
  memory=agent_memory
280
  )
281
 
282
+ response_output = response.get("output", "Sorry, I could not generate a response.")
283
+ session["chat_history"].append({"user": request.query, "assistant": response_output})
284
+ save_session(request.session_id, session) # Save updated history and potentially other modified session state
285
 
286
  return {
287
  "status": "success",
288
+ "answer": response_output,
289
+ # Return context that was PRE-FETCHED for the agent, not necessarily all context it might have used via tools
290
+ "context_used": [{"text": chunk, "score": float(score), "metadata": meta} for chunk, score, meta in initial_similar_chunks]
291
  }
292
+
293
  except Exception as e:
294
+ print(f"Error processing chat query: {str(e)}\nTraceback: {traceback.format_exc()}")
295
  raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
296
 
297
 
 
344
 
345
  # Run the application if this file is executed directly
346
  if __name__ == "__main__":
347
+ uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)