Spaces:
Running
Running
Jatin Mehra
commited on
Commit
·
24b32e6
1
Parent(s):
63ed7c1
Refactor session management and enhance PDF upload handling for improved error handling and model consistency
Browse files
app.py
CHANGED
@@ -18,9 +18,11 @@ from preprocessing import (
|
|
18 |
build_faiss_index,
|
19 |
retrieve_similar_chunks,
|
20 |
agentic_rag,
|
21 |
-
tools
|
|
|
22 |
)
|
23 |
from sentence_transformers import SentenceTransformer
|
|
|
24 |
|
25 |
# Load environment variables
|
26 |
dotenv.load_dotenv()
|
@@ -57,50 +59,44 @@ class SessionRequest(BaseModel):
|
|
57 |
|
58 |
# Function to save session data
|
59 |
def save_session(session_id, data):
|
60 |
-
sessions[session_id] = data
|
61 |
|
62 |
-
# Create a copy of data that is safe to pickle
|
63 |
pickle_safe_data = {
|
64 |
"file_path": data.get("file_path"),
|
65 |
"file_name": data.get("file_name"),
|
66 |
-
"chunks": data.get("chunks"),
|
67 |
"chat_history": data.get("chat_history", [])
|
|
|
68 |
}
|
69 |
|
70 |
-
# Persist to disk
|
71 |
with open(f"{UPLOAD_DIR}/{session_id}_session.pkl", "wb") as f:
|
72 |
pickle.dump(pickle_safe_data, f)
|
73 |
|
|
|
74 |
# Function to load session data
|
75 |
-
def load_session(session_id, model_name="
|
76 |
try:
|
77 |
-
# Check if session is already in memory
|
78 |
if session_id in sessions:
|
79 |
-
|
80 |
-
#
|
81 |
-
if
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
return sessions[session_id], True
|
90 |
|
91 |
-
# Try to load from disk
|
92 |
file_path_pkl = f"{UPLOAD_DIR}/{session_id}_session.pkl"
|
93 |
if os.path.exists(file_path_pkl):
|
94 |
with open(file_path_pkl, "rb") as f:
|
95 |
-
data = pickle.load(f)
|
96 |
|
97 |
-
# Recreate non-pickled objects
|
98 |
-
# Ensure 'chunks' and 'file_path' (for the original PDF) are present in the loaded data
|
99 |
-
# and the original PDF file still exists.
|
100 |
original_pdf_path = data.get("file_path")
|
101 |
if data.get("chunks") and original_pdf_path and os.path.exists(original_pdf_path):
|
102 |
embedding_model_instance = SentenceTransformer('BAAI/bge-large-en-v1.5')
|
103 |
-
#
|
104 |
recreated_embeddings, _ = create_embeddings(data["chunks"], embedding_model_instance)
|
105 |
recreated_index = build_faiss_index(recreated_embeddings)
|
106 |
recreated_llm = model_selection(model_name)
|
@@ -108,25 +104,23 @@ def load_session(session_id, model_name="meta-llama/llama-4-scout-17b-16e-instru
|
|
108 |
full_session_data = {
|
109 |
"file_path": original_pdf_path,
|
110 |
"file_name": data.get("file_name"),
|
111 |
-
"chunks": data.get("chunks"),
|
112 |
"chat_history": data.get("chat_history", []),
|
113 |
"model": embedding_model_instance, # SentenceTransformer model
|
114 |
"index": recreated_index, # FAISS index
|
115 |
"llm": recreated_llm # LLM
|
116 |
}
|
117 |
-
sessions[session_id] = full_session_data
|
118 |
return full_session_data, True
|
119 |
else:
|
120 |
-
|
121 |
-
|
122 |
-
# Optionally, remove the stale .pkl file
|
123 |
-
# os.remove(file_path_pkl)
|
124 |
return None, False
|
125 |
|
126 |
-
return None, False
|
127 |
except Exception as e:
|
128 |
print(f"Error loading session {session_id}: {str(e)}")
|
129 |
-
print(traceback.format_exc())
|
130 |
return None, False
|
131 |
|
132 |
# Function to remove PDF file
|
@@ -168,44 +162,46 @@ async def read_root():
|
|
168 |
@app.post("/upload-pdf")
|
169 |
async def upload_pdf(
|
170 |
file: UploadFile = File(...),
|
171 |
-
model_name: str = Form("
|
172 |
):
|
173 |
-
# Generate a unique session ID
|
174 |
session_id = str(uuid.uuid4())
|
175 |
file_path = None
|
176 |
|
177 |
try:
|
178 |
-
# Save the uploaded file
|
179 |
file_path = f"{UPLOAD_DIR}/{session_id}_{file.filename}"
|
180 |
with open(file_path, "wb") as buffer:
|
181 |
shutil.copyfileobj(file.file, buffer)
|
182 |
|
183 |
-
#
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
documents = process_pdf_file(file_path)
|
189 |
-
|
190 |
-
|
191 |
-
#
|
192 |
-
|
193 |
-
|
|
|
|
|
|
|
|
|
|
|
194 |
|
195 |
-
|
196 |
-
|
197 |
|
198 |
-
|
199 |
llm = model_selection(model_name)
|
200 |
|
201 |
-
# Save session data
|
202 |
session_data = {
|
203 |
"file_path": file_path,
|
204 |
"file_name": file.filename,
|
205 |
-
"chunks": chunks_with_metadata,
|
206 |
-
"model":
|
207 |
-
"index": index,
|
208 |
-
"llm": llm,
|
209 |
"chat_history": []
|
210 |
}
|
211 |
save_session(session_id, session_data)
|
@@ -213,71 +209,89 @@ async def upload_pdf(
|
|
213 |
return {"status": "success", "session_id": session_id, "message": f"Processed {file.filename}"}
|
214 |
|
215 |
except Exception as e:
|
216 |
-
# Clean up on error
|
217 |
if file_path and os.path.exists(file_path):
|
218 |
os.remove(file_path)
|
219 |
-
|
220 |
error_msg = str(e)
|
221 |
stack_trace = traceback.format_exc()
|
222 |
-
print(f"Error processing PDF: {error_msg}")
|
223 |
-
print(f"Stack trace: {stack_trace}")
|
224 |
-
|
225 |
return JSONResponse(
|
226 |
-
status_code=
|
227 |
-
content={
|
228 |
-
"status": "error",
|
229 |
-
"detail": error_msg,
|
230 |
-
"type": type(e).__name__
|
231 |
-
}
|
232 |
)
|
233 |
|
234 |
# Route to chat with the document
|
235 |
@app.post("/chat")
|
236 |
async def chat(request: ChatRequest):
|
237 |
-
# Try to load session if not in memory
|
238 |
session, found = load_session(request.session_id, model_name=request.model_name)
|
239 |
if not found:
|
240 |
-
raise HTTPException(status_code=404, detail="Session not found. Please upload a document first.")
|
241 |
|
242 |
try:
|
243 |
-
|
244 |
-
agent_memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
245 |
-
|
246 |
for entry in session.get("chat_history", []):
|
247 |
agent_memory.chat_memory.add_user_message(entry["user"])
|
248 |
agent_memory.chat_memory.add_ai_message(entry["assistant"])
|
249 |
-
|
250 |
|
251 |
-
#
|
252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
request.query,
|
254 |
session["index"],
|
255 |
-
session["chunks"],
|
256 |
-
session["model"],
|
257 |
-
k=
|
258 |
)
|
259 |
|
260 |
-
# Generate response using agentic_rag
|
261 |
response = agentic_rag(
|
262 |
session["llm"],
|
263 |
-
|
264 |
query=request.query,
|
265 |
-
context_chunks=
|
266 |
-
Use_Tavily=request.use_search,
|
267 |
memory=agent_memory
|
268 |
)
|
269 |
|
270 |
-
|
271 |
-
session["chat_history"].append({"user": request.query, "assistant":
|
272 |
-
save_session(request.session_id, session)
|
273 |
|
274 |
return {
|
275 |
"status": "success",
|
276 |
-
"answer":
|
277 |
-
|
|
|
278 |
}
|
279 |
-
|
280 |
except Exception as e:
|
|
|
281 |
raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
|
282 |
|
283 |
|
@@ -330,5 +344,4 @@ async def get_models():
|
|
330 |
|
331 |
# Run the application if this file is executed directly
|
332 |
if __name__ == "__main__":
|
333 |
-
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
|
334 |
-
|
|
|
18 |
build_faiss_index,
|
19 |
retrieve_similar_chunks,
|
20 |
agentic_rag,
|
21 |
+
tools as global_base_tools,
|
22 |
+
create_vector_search_tool
|
23 |
)
|
24 |
from sentence_transformers import SentenceTransformer
|
25 |
+
from langchain.memory import ConversationBufferMemory
|
26 |
|
27 |
# Load environment variables
|
28 |
dotenv.load_dotenv()
|
|
|
59 |
|
60 |
# Function to save session data
|
61 |
def save_session(session_id, data):
|
62 |
+
sessions[session_id] = data # Keep non-picklable in memory for active session
|
63 |
|
|
|
64 |
pickle_safe_data = {
|
65 |
"file_path": data.get("file_path"),
|
66 |
"file_name": data.get("file_name"),
|
67 |
+
"chunks": data.get("chunks"), # Chunks with metadata (list of dicts)
|
68 |
"chat_history": data.get("chat_history", [])
|
69 |
+
# FAISS index, embedding model, and LLM model are not pickled, will be reloaded/recreated
|
70 |
}
|
71 |
|
|
|
72 |
with open(f"{UPLOAD_DIR}/{session_id}_session.pkl", "wb") as f:
|
73 |
pickle.dump(pickle_safe_data, f)
|
74 |
|
75 |
+
|
76 |
# Function to load session data
|
77 |
+
def load_session(session_id, model_name="llama3-8b-8192"): # Ensure model_name matches default
|
78 |
try:
|
|
|
79 |
if session_id in sessions:
|
80 |
+
cached_session = sessions[session_id]
|
81 |
+
# Ensure LLM and potentially other non-pickled parts are up-to-date or loaded
|
82 |
+
if cached_session.get("llm") is None or (hasattr(cached_session["llm"], "model_name") and cached_session["llm"].model_name != model_name):
|
83 |
+
cached_session["llm"] = model_selection(model_name)
|
84 |
+
if cached_session.get("model") is None: # Embedding model
|
85 |
+
cached_session["model"] = SentenceTransformer('BAAI/bge-large-en-v1.5')
|
86 |
+
if cached_session.get("index") is None and cached_session.get("chunks"): # FAISS index
|
87 |
+
embeddings, _ = create_embeddings(cached_session["chunks"], cached_session["model"])
|
88 |
+
cached_session["index"] = build_faiss_index(embeddings)
|
89 |
+
return cached_session, True
|
|
|
90 |
|
|
|
91 |
file_path_pkl = f"{UPLOAD_DIR}/{session_id}_session.pkl"
|
92 |
if os.path.exists(file_path_pkl):
|
93 |
with open(file_path_pkl, "rb") as f:
|
94 |
+
data = pickle.load(f)
|
95 |
|
|
|
|
|
|
|
96 |
original_pdf_path = data.get("file_path")
|
97 |
if data.get("chunks") and original_pdf_path and os.path.exists(original_pdf_path):
|
98 |
embedding_model_instance = SentenceTransformer('BAAI/bge-large-en-v1.5')
|
99 |
+
# Chunks are already {text: ..., metadata: ...}
|
100 |
recreated_embeddings, _ = create_embeddings(data["chunks"], embedding_model_instance)
|
101 |
recreated_index = build_faiss_index(recreated_embeddings)
|
102 |
recreated_llm = model_selection(model_name)
|
|
|
104 |
full_session_data = {
|
105 |
"file_path": original_pdf_path,
|
106 |
"file_name": data.get("file_name"),
|
107 |
+
"chunks": data.get("chunks"), # chunks_with_metadata
|
108 |
"chat_history": data.get("chat_history", []),
|
109 |
"model": embedding_model_instance, # SentenceTransformer model
|
110 |
"index": recreated_index, # FAISS index
|
111 |
"llm": recreated_llm # LLM
|
112 |
}
|
113 |
+
sessions[session_id] = full_session_data
|
114 |
return full_session_data, True
|
115 |
else:
|
116 |
+
print(f"Warning: Session data for {session_id} is incomplete or PDF missing. Cannot reconstruct.")
|
117 |
+
if os.path.exists(file_path_pkl): os.remove(file_path_pkl) # Clean up stale pkl
|
|
|
|
|
118 |
return None, False
|
119 |
|
120 |
+
return None, False
|
121 |
except Exception as e:
|
122 |
print(f"Error loading session {session_id}: {str(e)}")
|
123 |
+
print(traceback.format_exc())
|
124 |
return None, False
|
125 |
|
126 |
# Function to remove PDF file
|
|
|
162 |
@app.post("/upload-pdf")
|
163 |
async def upload_pdf(
|
164 |
file: UploadFile = File(...),
|
165 |
+
model_name: str = Form("llama3-8b-8192") # Default model
|
166 |
):
|
|
|
167 |
session_id = str(uuid.uuid4())
|
168 |
file_path = None
|
169 |
|
170 |
try:
|
|
|
171 |
file_path = f"{UPLOAD_DIR}/{session_id}_{file.filename}"
|
172 |
with open(file_path, "wb") as buffer:
|
173 |
shutil.copyfileobj(file.file, buffer)
|
174 |
|
175 |
+
if not os.getenv("GROQ_API_KEY") and "llama" in model_name: # Llama specific check for Groq
|
176 |
+
raise ValueError("GROQ_API_KEY is not set for Groq Llama models.")
|
177 |
+
if not os.getenv("TAVILY_API_KEY"): # Needed for TavilySearchResults
|
178 |
+
print("Warning: TAVILY_API_KEY is not set. Web search will not function.")
|
179 |
+
|
180 |
+
documents = process_pdf_file(file_path)
|
181 |
+
# Ensure max_length for chunk_text is appropriate.
|
182 |
+
# The value 1500 might be too large if estimate_tokens is text_len // 4, as it means ~6000 characters.
|
183 |
+
# Let's use a smaller max_length for chunks for better granularity in RAG retrieval.
|
184 |
+
# For `bge-large-en-v1.5` (max sequence length 512 tokens), chunks around 250-400 tokens are often good.
|
185 |
+
# If estimate_tokens is len(text)//4, then max_length of 250 tokens is roughly 1000 characters.
|
186 |
+
# Let's use max_length=256 (tokens) for chunker config, so about 1024 characters.
|
187 |
+
# The chunk_text function uses max_length as character count / 4. So if we want 256 tokens, max_length = 256*4 = 1024
|
188 |
+
# However, the current chunk_text logic is `estimate_tokens(current_chunk + paragraph) <= max_length // 4`.
|
189 |
+
# This means `max_length` is already considered a token limit. So `max_length=256` (tokens) is the target.
|
190 |
+
chunks_with_metadata = chunk_text(documents, max_length=256) # max_length in tokens
|
191 |
|
192 |
+
embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5')
|
193 |
+
embeddings, _ = create_embeddings(chunks_with_metadata, embedding_model) # Chunks are already with metadata
|
194 |
|
195 |
+
index = build_faiss_index(embeddings)
|
196 |
llm = model_selection(model_name)
|
197 |
|
|
|
198 |
session_data = {
|
199 |
"file_path": file_path,
|
200 |
"file_name": file.filename,
|
201 |
+
"chunks": chunks_with_metadata, # Store chunks with metadata
|
202 |
+
"model": embedding_model, # SentenceTransformer instance
|
203 |
+
"index": index, # FAISS index instance
|
204 |
+
"llm": llm, # LLM instance
|
205 |
"chat_history": []
|
206 |
}
|
207 |
save_session(session_id, session_data)
|
|
|
209 |
return {"status": "success", "session_id": session_id, "message": f"Processed {file.filename}"}
|
210 |
|
211 |
except Exception as e:
|
|
|
212 |
if file_path and os.path.exists(file_path):
|
213 |
os.remove(file_path)
|
|
|
214 |
error_msg = str(e)
|
215 |
stack_trace = traceback.format_exc()
|
216 |
+
print(f"Error processing PDF: {error_msg}\nStack trace: {stack_trace}")
|
|
|
|
|
217 |
return JSONResponse(
|
218 |
+
status_code=500, # Internal server error for processing issues
|
219 |
+
content={"status": "error", "detail": error_msg, "type": type(e).__name__}
|
|
|
|
|
|
|
|
|
220 |
)
|
221 |
|
222 |
# Route to chat with the document
|
223 |
@app.post("/chat")
|
224 |
async def chat(request: ChatRequest):
|
|
|
225 |
session, found = load_session(request.session_id, model_name=request.model_name)
|
226 |
if not found:
|
227 |
+
raise HTTPException(status_code=404, detail="Session not found or expired. Please upload a document first.")
|
228 |
|
229 |
try:
|
230 |
+
# Per-request memory to ensure chat history is correctly loaded for the agent
|
231 |
+
agent_memory = ConversationBufferMemory(memory_key="chat_history", input_key="input", return_messages=True)
|
|
|
232 |
for entry in session.get("chat_history", []):
|
233 |
agent_memory.chat_memory.add_user_message(entry["user"])
|
234 |
agent_memory.chat_memory.add_ai_message(entry["assistant"])
|
|
|
235 |
|
236 |
+
# Prepare tools for the agent for THIS request
|
237 |
+
current_request_tools = []
|
238 |
+
|
239 |
+
# 1. Add the document-specific vector search tool
|
240 |
+
if "index" in session and "chunks" in session and "model" in session:
|
241 |
+
vector_search_tool_instance = create_vector_search_tool(
|
242 |
+
faiss_index=session["index"],
|
243 |
+
document_chunks_with_metadata=session["chunks"], # Pass the correct variable
|
244 |
+
embedding_model=session["model"] # This is the SentenceTransformer model
|
245 |
+
)
|
246 |
+
current_request_tools.append(vector_search_tool_instance)
|
247 |
+
else:
|
248 |
+
print(f"Warning: Session {request.session_id} missing data for vector_database_search tool.")
|
249 |
+
|
250 |
+
# 2. Conditionally add Tavily (web search) tool
|
251 |
+
if request.use_search:
|
252 |
+
if os.getenv("TAVILY_API_KEY"):
|
253 |
+
tavily_tool = next((tool for tool in global_base_tools if tool.name == "tavily_search_results_json"), None)
|
254 |
+
if tavily_tool:
|
255 |
+
current_request_tools.append(tavily_tool)
|
256 |
+
else: # Should not happen if global_base_tools is defined correctly
|
257 |
+
print("Warning: Tavily search requested, but tool misconfigured.")
|
258 |
+
else:
|
259 |
+
print("Warning: Tavily search requested, but TAVILY_API_KEY is not set.")
|
260 |
+
|
261 |
+
# Retrieve initial similar chunks for RAG context (can be empty if no good match)
|
262 |
+
# This context is given to the agent *before* it decides to use tools.
|
263 |
+
# k=5 means we retrieve up to 5 chunks for initial context.
|
264 |
+
# The agent can then use `vector_database_search` to search more if needed.
|
265 |
+
initial_similar_chunks = retrieve_similar_chunks(
|
266 |
request.query,
|
267 |
session["index"],
|
268 |
+
session["chunks"], # list of dicts {text:..., metadata:...}
|
269 |
+
session["model"], # SentenceTransformer model
|
270 |
+
k=5 # Number of chunks for initial context
|
271 |
)
|
272 |
|
|
|
273 |
response = agentic_rag(
|
274 |
session["llm"],
|
275 |
+
current_request_tools, # Pass the dynamically assembled list of tools
|
276 |
query=request.query,
|
277 |
+
context_chunks=initial_similar_chunks,
|
278 |
+
Use_Tavily=request.use_search, # Still passed to agentic_rag for potential fine-grained logic, though prompt adapts to tools
|
279 |
memory=agent_memory
|
280 |
)
|
281 |
|
282 |
+
response_output = response.get("output", "Sorry, I could not generate a response.")
|
283 |
+
session["chat_history"].append({"user": request.query, "assistant": response_output})
|
284 |
+
save_session(request.session_id, session) # Save updated history and potentially other modified session state
|
285 |
|
286 |
return {
|
287 |
"status": "success",
|
288 |
+
"answer": response_output,
|
289 |
+
# Return context that was PRE-FETCHED for the agent, not necessarily all context it might have used via tools
|
290 |
+
"context_used": [{"text": chunk, "score": float(score), "metadata": meta} for chunk, score, meta in initial_similar_chunks]
|
291 |
}
|
292 |
+
|
293 |
except Exception as e:
|
294 |
+
print(f"Error processing chat query: {str(e)}\nTraceback: {traceback.format_exc()}")
|
295 |
raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
|
296 |
|
297 |
|
|
|
344 |
|
345 |
# Run the application if this file is executed directly
|
346 |
if __name__ == "__main__":
|
347 |
+
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
|
|