Spaces:
Running
Running
Jatin Mehra
commited on
Commit
·
1dc0983
1
Parent(s):
ba76b7d
Refactor PDF processing and embedding creation; update chunking to include metadata
Browse files- app.py +16 -15
- preprocessing.py +49 -33
app.py
CHANGED
|
@@ -2,7 +2,7 @@ import os
|
|
| 2 |
import dotenv
|
| 3 |
import pickle
|
| 4 |
import uuid
|
| 5 |
-
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
| 6 |
from fastapi.responses import JSONResponse
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
from fastapi.staticfiles import StaticFiles
|
|
@@ -16,8 +16,7 @@ from preprocessing import (
|
|
| 16 |
build_faiss_index,
|
| 17 |
retrieve_similar_chunks,
|
| 18 |
agentic_rag,
|
| 19 |
-
tools
|
| 20 |
-
memory
|
| 21 |
)
|
| 22 |
from sentence_transformers import SentenceTransformer
|
| 23 |
import shutil
|
|
@@ -88,8 +87,8 @@ def load_session(session_id, model_name="meta-llama/llama-4-scout-17b-16e-instru
|
|
| 88 |
# Recreate non-pickled objects
|
| 89 |
if data.get("chunks") and data.get("file_path") and os.path.exists(data["file_path"]):
|
| 90 |
# Recreate model, embeddings and index
|
| 91 |
-
model = SentenceTransformer('
|
| 92 |
-
embeddings = create_embeddings(data["chunks"], model)
|
| 93 |
index = build_faiss_index(embeddings)
|
| 94 |
|
| 95 |
# Recreate LLM
|
|
@@ -165,13 +164,15 @@ async def upload_pdf(
|
|
| 165 |
raise ValueError("GROQ_API_KEY is not set in the environment variables")
|
| 166 |
|
| 167 |
# Process the PDF
|
| 168 |
-
|
| 169 |
-
chunks = chunk_text(
|
| 170 |
|
| 171 |
# Create embeddings
|
| 172 |
-
model = SentenceTransformer('
|
| 173 |
-
embeddings = create_embeddings(chunks, model)
|
| 174 |
-
|
|
|
|
|
|
|
| 175 |
|
| 176 |
# Initialize LLM
|
| 177 |
llm = model_selection(model_name)
|
|
@@ -180,7 +181,7 @@ async def upload_pdf(
|
|
| 180 |
session_data = {
|
| 181 |
"file_path": file_path,
|
| 182 |
"file_name": file.filename,
|
| 183 |
-
"chunks": chunks
|
| 184 |
"model": model,
|
| 185 |
"index": index,
|
| 186 |
"llm": llm,
|
|
@@ -224,16 +225,15 @@ async def chat(request: ChatRequest):
|
|
| 224 |
session["index"],
|
| 225 |
session["chunks"],
|
| 226 |
session["model"],
|
| 227 |
-
k=
|
| 228 |
)
|
| 229 |
-
context = "\n".join([chunk for chunk, _ in similar_chunks])
|
| 230 |
|
| 231 |
# Generate response using agentic_rag
|
| 232 |
response = agentic_rag(
|
| 233 |
session["llm"],
|
| 234 |
tools,
|
| 235 |
query=request.query,
|
| 236 |
-
|
| 237 |
Use_Tavily=request.use_search
|
| 238 |
)
|
| 239 |
|
|
@@ -244,12 +244,13 @@ async def chat(request: ChatRequest):
|
|
| 244 |
return {
|
| 245 |
"status": "success",
|
| 246 |
"answer": response["output"],
|
| 247 |
-
"context_used": [{"text": chunk, "score": float(score)} for chunk, score in similar_chunks]
|
| 248 |
}
|
| 249 |
|
| 250 |
except Exception as e:
|
| 251 |
raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
|
| 252 |
|
|
|
|
| 253 |
# Route to get chat history
|
| 254 |
@app.post("/chat-history")
|
| 255 |
async def get_chat_history(request: SessionRequest):
|
|
|
|
| 2 |
import dotenv
|
| 3 |
import pickle
|
| 4 |
import uuid
|
| 5 |
+
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
| 6 |
from fastapi.responses import JSONResponse
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
from fastapi.staticfiles import StaticFiles
|
|
|
|
| 16 |
build_faiss_index,
|
| 17 |
retrieve_similar_chunks,
|
| 18 |
agentic_rag,
|
| 19 |
+
tools
|
|
|
|
| 20 |
)
|
| 21 |
from sentence_transformers import SentenceTransformer
|
| 22 |
import shutil
|
|
|
|
| 87 |
# Recreate non-pickled objects
|
| 88 |
if data.get("chunks") and data.get("file_path") and os.path.exists(data["file_path"]):
|
| 89 |
# Recreate model, embeddings and index
|
| 90 |
+
model = SentenceTransformer('BAAI/bge-large-en-v1.5')
|
| 91 |
+
embeddings, _ = create_embeddings(data["chunks"], model) # Unpack tuple
|
| 92 |
index = build_faiss_index(embeddings)
|
| 93 |
|
| 94 |
# Recreate LLM
|
|
|
|
| 164 |
raise ValueError("GROQ_API_KEY is not set in the environment variables")
|
| 165 |
|
| 166 |
# Process the PDF
|
| 167 |
+
documents = process_pdf_file(file_path) # Returns list of Document objects
|
| 168 |
+
chunks = chunk_text(documents, max_length=1000) # Updated to handle documents
|
| 169 |
|
| 170 |
# Create embeddings
|
| 171 |
+
model = SentenceTransformer('BAAI/bge-large-en-v1.5') # Updated embedding model
|
| 172 |
+
embeddings, chunks_with_metadata = create_embeddings(chunks, model) # Unpack tuple
|
| 173 |
+
|
| 174 |
+
# Build FAISS index
|
| 175 |
+
index = build_faiss_index(embeddings) # Pass only embeddings array
|
| 176 |
|
| 177 |
# Initialize LLM
|
| 178 |
llm = model_selection(model_name)
|
|
|
|
| 181 |
session_data = {
|
| 182 |
"file_path": file_path,
|
| 183 |
"file_name": file.filename,
|
| 184 |
+
"chunks": chunks_with_metadata, # Store chunks with metadata
|
| 185 |
"model": model,
|
| 186 |
"index": index,
|
| 187 |
"llm": llm,
|
|
|
|
| 225 |
session["index"],
|
| 226 |
session["chunks"],
|
| 227 |
session["model"],
|
| 228 |
+
k=10
|
| 229 |
)
|
|
|
|
| 230 |
|
| 231 |
# Generate response using agentic_rag
|
| 232 |
response = agentic_rag(
|
| 233 |
session["llm"],
|
| 234 |
tools,
|
| 235 |
query=request.query,
|
| 236 |
+
context_chunks=similar_chunks, # Pass the list of tuples
|
| 237 |
Use_Tavily=request.use_search
|
| 238 |
)
|
| 239 |
|
|
|
|
| 244 |
return {
|
| 245 |
"status": "success",
|
| 246 |
"answer": response["output"],
|
| 247 |
+
"context_used": [{"text": chunk, "score": float(score)} for chunk, score, _ in similar_chunks]
|
| 248 |
}
|
| 249 |
|
| 250 |
except Exception as e:
|
| 251 |
raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
|
| 252 |
|
| 253 |
+
|
| 254 |
# Route to get chat history
|
| 255 |
@app.post("/chat-history")
|
| 256 |
async def get_chat_history(request: SessionRequest):
|
preprocessing.py
CHANGED
|
@@ -25,49 +25,70 @@ def estimate_tokens(text):
|
|
| 25 |
return len(text) // 4
|
| 26 |
|
| 27 |
def process_pdf_file(file_path):
|
| 28 |
-
"""Load a PDF file and extract its text."""
|
| 29 |
if not os.path.exists(file_path):
|
| 30 |
raise FileNotFoundError(f"The file {file_path} does not exist.")
|
| 31 |
loader = PyMuPDFLoader(file_path)
|
| 32 |
documents = loader.load()
|
| 33 |
-
|
| 34 |
-
return text
|
| 35 |
|
| 36 |
-
def chunk_text(
|
| 37 |
-
"""Split
|
| 38 |
-
paragraphs = text.split("\n\n")
|
| 39 |
chunks = []
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
return chunks
|
| 50 |
|
| 51 |
-
def create_embeddings(
|
| 52 |
-
"""Create embeddings for a list of texts
|
|
|
|
| 53 |
embeddings = model.encode(texts, show_progress_bar=True, convert_to_tensor=True)
|
| 54 |
-
return embeddings.cpu().numpy()
|
| 55 |
|
| 56 |
def build_faiss_index(embeddings):
|
| 57 |
-
"""Build a FAISS index from embeddings for similarity search."""
|
| 58 |
dim = embeddings.shape[1]
|
| 59 |
-
index = faiss.
|
|
|
|
|
|
|
| 60 |
index.add(embeddings)
|
| 61 |
return index
|
| 62 |
|
| 63 |
-
def retrieve_similar_chunks(query, index,
|
| 64 |
"""Retrieve top k similar chunks to the query from the FAISS index."""
|
| 65 |
query_embedding = model.encode([query], convert_to_tensor=True).cpu().numpy()
|
| 66 |
distances, indices = index.search(query_embedding, k)
|
| 67 |
-
return [(
|
| 68 |
|
| 69 |
-
def agentic_rag(llm, tools, query,
|
| 70 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
search_instructions = (
|
| 72 |
"Use the search tool if the context is insufficient to answer the question or you are unsure. Give source links if you use the search tool."
|
| 73 |
if Use_Tavily
|
|
@@ -80,35 +101,30 @@ def agentic_rag(llm, tools, query, context, Use_Tavily=False):
|
|
| 80 |
Instructions:
|
| 81 |
1. Use the provided context to answer the user's question.
|
| 82 |
2. Provide a clear answer, if you don't know the answer, say 'I don't know'.
|
|
|
|
| 83 |
"""),
|
| 84 |
("human", "Context: {context}\n\nQuestion: {input}"),
|
| 85 |
MessagesPlaceholder(variable_name="chat_history"),
|
| 86 |
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
| 87 |
])
|
| 88 |
-
|
| 89 |
-
# Only use tools when Tavily is enabled
|
| 90 |
-
agent_tools = tools if Use_Tavily else []
|
| 91 |
|
|
|
|
| 92 |
try:
|
| 93 |
-
# Create the agent and executor with appropriate tools
|
| 94 |
agent = create_tool_calling_agent(llm, agent_tools, prompt)
|
| 95 |
agent_executor = AgentExecutor(agent=agent, tools=agent_tools, memory=memory, verbose=True)
|
| 96 |
-
|
| 97 |
-
# Execute the agent
|
| 98 |
return agent_executor.invoke({
|
| 99 |
-
"input": query,
|
| 100 |
"context": context,
|
| 101 |
"search_instructions": search_instructions
|
| 102 |
})
|
| 103 |
except Exception as e:
|
| 104 |
print(f"Error during agent execution: {str(e)}")
|
| 105 |
-
# Fallback to direct LLM call without agent framework
|
| 106 |
fallback_prompt = ChatPromptTemplate.from_messages([
|
| 107 |
("system", "You are a helpful assistant. Use the provided context to answer the user's question."),
|
| 108 |
("human", "Context: {context}\n\nQuestion: {input}")
|
| 109 |
])
|
| 110 |
response = llm.invoke(fallback_prompt.format(context=context, input=query))
|
| 111 |
-
return {"output": response.content}
|
| 112 |
|
| 113 |
if __name__ == "__main__":
|
| 114 |
# Process PDF and prepare index
|
|
|
|
| 25 |
return len(text) // 4
|
| 26 |
|
| 27 |
def process_pdf_file(file_path):
|
| 28 |
+
"""Load a PDF file and extract its text with metadata."""
|
| 29 |
if not os.path.exists(file_path):
|
| 30 |
raise FileNotFoundError(f"The file {file_path} does not exist.")
|
| 31 |
loader = PyMuPDFLoader(file_path)
|
| 32 |
documents = loader.load()
|
| 33 |
+
return documents # Return list of Document objects with metadata
|
|
|
|
| 34 |
|
| 35 |
+
def chunk_text(documents, max_length=1000):
|
| 36 |
+
"""Split documents into chunks with metadata."""
|
|
|
|
| 37 |
chunks = []
|
| 38 |
+
for doc in documents:
|
| 39 |
+
text = doc.page_content
|
| 40 |
+
metadata = doc.metadata
|
| 41 |
+
paragraphs = text.split("\n\n")
|
| 42 |
+
current_chunk = ""
|
| 43 |
+
current_metadata = metadata.copy()
|
| 44 |
+
for paragraph in paragraphs:
|
| 45 |
+
if estimate_tokens(current_chunk + paragraph) <= max_length // 4:
|
| 46 |
+
current_chunk += paragraph + "\n\n"
|
| 47 |
+
else:
|
| 48 |
+
chunks.append({"text": current_chunk.strip(), "metadata": current_metadata})
|
| 49 |
+
current_chunk = paragraph + "\n\n"
|
| 50 |
+
if current_chunk:
|
| 51 |
+
chunks.append({"text": current_chunk.strip(), "metadata": current_metadata})
|
| 52 |
return chunks
|
| 53 |
|
| 54 |
+
def create_embeddings(chunks, model):
|
| 55 |
+
"""Create embeddings for a list of chunk texts."""
|
| 56 |
+
texts = [chunk["text"] for chunk in chunks]
|
| 57 |
embeddings = model.encode(texts, show_progress_bar=True, convert_to_tensor=True)
|
| 58 |
+
return embeddings.cpu().numpy(), chunks
|
| 59 |
|
| 60 |
def build_faiss_index(embeddings):
|
| 61 |
+
"""Build a FAISS HNSW index from embeddings for similarity search."""
|
| 62 |
dim = embeddings.shape[1]
|
| 63 |
+
index = faiss.IndexHNSWFlat(dim, 32) # 32 = number of neighbors in HNSW graph
|
| 64 |
+
index.hnsw.efConstruction = 200 # Higher = better quality, slower build
|
| 65 |
+
index.hnsw.efSearch = 50 # Higher = better accuracy, slower search
|
| 66 |
index.add(embeddings)
|
| 67 |
return index
|
| 68 |
|
| 69 |
+
def retrieve_similar_chunks(query, index, chunks, model, k=10, max_chunk_length=1000):
|
| 70 |
"""Retrieve top k similar chunks to the query from the FAISS index."""
|
| 71 |
query_embedding = model.encode([query], convert_to_tensor=True).cpu().numpy()
|
| 72 |
distances, indices = index.search(query_embedding, k)
|
| 73 |
+
return [(chunks[i]["text"][:max_chunk_length], distances[0][j], chunks[i]["metadata"]) for j, i in enumerate(indices[0])]
|
| 74 |
|
| 75 |
+
def agentic_rag(llm, tools, query, context_chunks, Use_Tavily=False):
|
| 76 |
+
# Sort chunks by relevance (lower distance = more relevant)
|
| 77 |
+
context_chunks = sorted(context_chunks, key=lambda x: x[1]) # Sort by distance
|
| 78 |
+
context = ""
|
| 79 |
+
total_tokens = 0
|
| 80 |
+
max_tokens = 7000 # Leave room for prompt and response
|
| 81 |
+
|
| 82 |
+
# Aggregate relevant chunks until token limit is reached
|
| 83 |
+
for chunk, _, _ in context_chunks: # Unpack three elements
|
| 84 |
+
chunk_tokens = estimate_tokens(chunk)
|
| 85 |
+
if total_tokens + chunk_tokens <= max_tokens:
|
| 86 |
+
context += chunk + "\n\n"
|
| 87 |
+
total_tokens += chunk_tokens
|
| 88 |
+
else:
|
| 89 |
+
break
|
| 90 |
+
|
| 91 |
+
# Define prompt template
|
| 92 |
search_instructions = (
|
| 93 |
"Use the search tool if the context is insufficient to answer the question or you are unsure. Give source links if you use the search tool."
|
| 94 |
if Use_Tavily
|
|
|
|
| 101 |
Instructions:
|
| 102 |
1. Use the provided context to answer the user's question.
|
| 103 |
2. Provide a clear answer, if you don't know the answer, say 'I don't know'.
|
| 104 |
+
3. Prioritize information from the most relevant context chunks.
|
| 105 |
"""),
|
| 106 |
("human", "Context: {context}\n\nQuestion: {input}"),
|
| 107 |
MessagesPlaceholder(variable_name="chat_history"),
|
| 108 |
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
| 109 |
])
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
+
agent_tools = tools if Use_Tavily else []
|
| 112 |
try:
|
|
|
|
| 113 |
agent = create_tool_calling_agent(llm, agent_tools, prompt)
|
| 114 |
agent_executor = AgentExecutor(agent=agent, tools=agent_tools, memory=memory, verbose=True)
|
|
|
|
|
|
|
| 115 |
return agent_executor.invoke({
|
| 116 |
+
"input": query,
|
| 117 |
"context": context,
|
| 118 |
"search_instructions": search_instructions
|
| 119 |
})
|
| 120 |
except Exception as e:
|
| 121 |
print(f"Error during agent execution: {str(e)}")
|
|
|
|
| 122 |
fallback_prompt = ChatPromptTemplate.from_messages([
|
| 123 |
("system", "You are a helpful assistant. Use the provided context to answer the user's question."),
|
| 124 |
("human", "Context: {context}\n\nQuestion: {input}")
|
| 125 |
])
|
| 126 |
response = llm.invoke(fallback_prompt.format(context=context, input=query))
|
| 127 |
+
return {"output": response.content}
|
| 128 |
|
| 129 |
if __name__ == "__main__":
|
| 130 |
# Process PDF and prepare index
|