Spaces:
Running
Refactor and reorganize codebase for improved maintainability and clarity
Browse files- Updated import paths in gen_dataset.py to reflect new module structure.
- Introduced models.py to define Pydantic models for the PDF Insight Beta application.
- Created preprocessing_refactored.py to modularize preprocessing functionality while maintaining backward compatibility.
- Initialized services module with easy imports for all service classes and functions.
- Developed llm_service.py for LLM model management and interaction.
- Implemented rag_service.py for RAG operations, including tool creation and agent management.
- Established session_service.py for high-level session management operations.
- Added test_refactored.py to verify functionality and backward compatibility of refactored code.
- Created utility modules for session management and data persistence.
- api/__init__.py +34 -0
- api/chat_routes.py +109 -0
- api/session_routes.py +84 -0
- api/upload_routes.py +79 -0
- api/utility_routes.py +31 -0
- configs/config.py +129 -0
- development_scripts/app.py +357 -0
- preprocessing.py → development_scripts/preprocessing.py +0 -0
- gen_dataset.py +1 -1
- models/models.py +114 -0
- preprocessing_refactored.py +78 -0
- services/__init__.py +39 -0
- services/llm_service.py +103 -0
- services/rag_service.py +425 -0
- services/session_service.py +253 -0
- test_refactored.py +176 -0
- utils/__init__.py +62 -0
- utils/session_utils.py +219 -0
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
API routes module initialization.
|
3 |
+
|
4 |
+
This module provides easy imports for all API route handlers.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from .upload_routes import upload_pdf_handler
|
8 |
+
from .chat_routes import chat_handler
|
9 |
+
from .session_routes import (
|
10 |
+
get_chat_history_handler,
|
11 |
+
clear_history_handler,
|
12 |
+
remove_pdf_handler
|
13 |
+
)
|
14 |
+
from .utility_routes import (
|
15 |
+
root_handler,
|
16 |
+
get_models_handler
|
17 |
+
)
|
18 |
+
|
19 |
+
__all__ = [
|
20 |
+
# Upload routes
|
21 |
+
"upload_pdf_handler",
|
22 |
+
|
23 |
+
# Chat routes
|
24 |
+
"chat_handler",
|
25 |
+
|
26 |
+
# Session routes
|
27 |
+
"get_chat_history_handler",
|
28 |
+
"clear_history_handler",
|
29 |
+
"remove_pdf_handler",
|
30 |
+
|
31 |
+
# Utility routes
|
32 |
+
"root_handler",
|
33 |
+
"get_models_handler"
|
34 |
+
]
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Chat API routes.
|
3 |
+
|
4 |
+
This module handles chat and conversation endpoints.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import traceback
|
8 |
+
from fastapi import HTTPException
|
9 |
+
from langchain.memory import ConversationBufferMemory
|
10 |
+
|
11 |
+
from configs.config import Config, ErrorMessages
|
12 |
+
from models.models import ChatRequest, ChatResponse
|
13 |
+
from services import session_manager, rag_service
|
14 |
+
from utils import retrieve_similar_chunks
|
15 |
+
|
16 |
+
|
17 |
+
async def chat_handler(request: ChatRequest) -> ChatResponse:
|
18 |
+
"""
|
19 |
+
Handle chat requests with document context.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
request: Chat request containing query and session info
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
Chat response with answer and context
|
26 |
+
|
27 |
+
Raises:
|
28 |
+
HTTPException: If processing fails
|
29 |
+
"""
|
30 |
+
# Validate query
|
31 |
+
if not request.query or not request.query.strip():
|
32 |
+
raise HTTPException(status_code=400, detail=ErrorMessages.EMPTY_QUERY)
|
33 |
+
|
34 |
+
if len(request.query.strip()) < 3:
|
35 |
+
raise HTTPException(status_code=400, detail=ErrorMessages.QUERY_TOO_SHORT)
|
36 |
+
|
37 |
+
# Get session data
|
38 |
+
session_data, found = session_manager.get_session(request.session_id, request.model_name)
|
39 |
+
if not found:
|
40 |
+
raise HTTPException(status_code=404, detail=ErrorMessages.SESSION_EXPIRED)
|
41 |
+
|
42 |
+
try:
|
43 |
+
# Validate session data integrity
|
44 |
+
is_valid, missing_keys = session_manager.validate_session(request.session_id)
|
45 |
+
if not is_valid:
|
46 |
+
raise HTTPException(status_code=500, detail=ErrorMessages.SESSION_INCOMPLETE)
|
47 |
+
|
48 |
+
# Prepare agent memory with chat history
|
49 |
+
agent_memory = ConversationBufferMemory(
|
50 |
+
memory_key="chat_history",
|
51 |
+
input_key="input",
|
52 |
+
return_messages=True
|
53 |
+
)
|
54 |
+
|
55 |
+
for entry in session_data.get("chat_history", []):
|
56 |
+
agent_memory.chat_memory.add_user_message(entry["user"])
|
57 |
+
agent_memory.chat_memory.add_ai_message(entry["assistant"])
|
58 |
+
|
59 |
+
# Retrieve initial similar chunks for context
|
60 |
+
initial_similar_chunks = retrieve_similar_chunks(
|
61 |
+
request.query,
|
62 |
+
session_data["index"],
|
63 |
+
session_data["chunks"],
|
64 |
+
session_data["model"],
|
65 |
+
k=Config.INITIAL_CONTEXT_CHUNKS
|
66 |
+
)
|
67 |
+
|
68 |
+
# Generate response using RAG service
|
69 |
+
response = rag_service.generate_response(
|
70 |
+
llm=session_data["llm"],
|
71 |
+
query=request.query,
|
72 |
+
context_chunks=initial_similar_chunks,
|
73 |
+
faiss_index=session_data["index"],
|
74 |
+
document_chunks=session_data["chunks"],
|
75 |
+
embedding_model=session_data["model"],
|
76 |
+
memory=agent_memory,
|
77 |
+
use_tavily=request.use_search
|
78 |
+
)
|
79 |
+
|
80 |
+
response_output = response.get("output", ErrorMessages.RESPONSE_GENERATION_ERROR)
|
81 |
+
|
82 |
+
# Save chat history
|
83 |
+
session_manager.add_chat_entry(
|
84 |
+
request.session_id,
|
85 |
+
request.query,
|
86 |
+
response_output
|
87 |
+
)
|
88 |
+
|
89 |
+
return ChatResponse(
|
90 |
+
status="success",
|
91 |
+
answer=response_output,
|
92 |
+
context_used=[
|
93 |
+
{
|
94 |
+
"text": chunk,
|
95 |
+
"score": float(score),
|
96 |
+
"metadata": meta
|
97 |
+
}
|
98 |
+
for chunk, score, meta in initial_similar_chunks
|
99 |
+
]
|
100 |
+
)
|
101 |
+
|
102 |
+
except HTTPException:
|
103 |
+
# Re-raise HTTP exceptions as-is
|
104 |
+
raise
|
105 |
+
except Exception as e:
|
106 |
+
raise HTTPException(
|
107 |
+
status_code=500,
|
108 |
+
detail=ErrorMessages.PROCESSING_ERROR.format(error=str(e))
|
109 |
+
)
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Session management API routes.
|
3 |
+
|
4 |
+
This module handles session-related endpoints like history and cleanup.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from fastapi import HTTPException
|
8 |
+
|
9 |
+
from configs.config import ErrorMessages, SuccessMessages
|
10 |
+
from models.models import SessionRequest, ChatHistoryResponse, StatusResponse
|
11 |
+
from services import session_manager
|
12 |
+
|
13 |
+
|
14 |
+
async def get_chat_history_handler(request: SessionRequest) -> ChatHistoryResponse:
|
15 |
+
"""
|
16 |
+
Get chat history for a session.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
request: Session request with session ID
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
Chat history response
|
23 |
+
|
24 |
+
Raises:
|
25 |
+
HTTPException: If session not found
|
26 |
+
"""
|
27 |
+
session_data, found = session_manager.get_session(request.session_id)
|
28 |
+
if not found:
|
29 |
+
raise HTTPException(status_code=404, detail=ErrorMessages.SESSION_NOT_FOUND)
|
30 |
+
|
31 |
+
return ChatHistoryResponse(
|
32 |
+
status="success",
|
33 |
+
history=session_data.get("chat_history", [])
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
async def clear_history_handler(request: SessionRequest) -> StatusResponse:
|
38 |
+
"""
|
39 |
+
Clear chat history for a session.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
request: Session request with session ID
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
Status response
|
46 |
+
|
47 |
+
Raises:
|
48 |
+
HTTPException: If session not found
|
49 |
+
"""
|
50 |
+
success = session_manager.clear_chat_history(request.session_id)
|
51 |
+
if not success:
|
52 |
+
raise HTTPException(status_code=404, detail=ErrorMessages.SESSION_NOT_FOUND)
|
53 |
+
|
54 |
+
return StatusResponse(
|
55 |
+
status="success",
|
56 |
+
message=SuccessMessages.CHAT_HISTORY_CLEARED
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
async def remove_pdf_handler(request: SessionRequest) -> StatusResponse:
|
61 |
+
"""
|
62 |
+
Remove PDF and session data.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
request: Session request with session ID
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
Status response
|
69 |
+
|
70 |
+
Raises:
|
71 |
+
HTTPException: If session not found or removal failed
|
72 |
+
"""
|
73 |
+
success = session_manager.remove_session(request.session_id)
|
74 |
+
|
75 |
+
if success:
|
76 |
+
return StatusResponse(
|
77 |
+
status="success",
|
78 |
+
message=SuccessMessages.PDF_REMOVED
|
79 |
+
)
|
80 |
+
else:
|
81 |
+
raise HTTPException(
|
82 |
+
status_code=404,
|
83 |
+
detail=ErrorMessages.SESSION_REMOVAL_FAILED
|
84 |
+
)
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
File upload API routes.
|
3 |
+
|
4 |
+
This module handles PDF file upload and processing endpoints.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import shutil
|
9 |
+
import traceback
|
10 |
+
import uuid
|
11 |
+
from fastapi import UploadFile, File, Form, HTTPException
|
12 |
+
from fastapi.responses import JSONResponse
|
13 |
+
|
14 |
+
from configs.config import Config, ErrorMessages, SuccessMessages
|
15 |
+
from models.models import UploadResponse
|
16 |
+
from services import session_manager, validate_api_keys
|
17 |
+
from utils import process_pdf_file, chunk_text, create_upload_file_path
|
18 |
+
|
19 |
+
|
20 |
+
async def upload_pdf_handler(
|
21 |
+
file: UploadFile = File(...),
|
22 |
+
model_name: str = Form(Config.DEFAULT_MODEL)
|
23 |
+
) -> UploadResponse:
|
24 |
+
"""
|
25 |
+
Handle PDF file upload and processing.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
file: Uploaded PDF file
|
29 |
+
model_name: LLM model name to use
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
Upload response with session ID
|
33 |
+
|
34 |
+
Raises:
|
35 |
+
HTTPException: If processing fails
|
36 |
+
"""
|
37 |
+
session_id = str(uuid.uuid4())
|
38 |
+
file_path = None
|
39 |
+
|
40 |
+
try:
|
41 |
+
# Validate API keys
|
42 |
+
validate_api_keys(model_name, use_search=False)
|
43 |
+
|
44 |
+
# Save uploaded file
|
45 |
+
file_path = create_upload_file_path(session_id, file.filename)
|
46 |
+
with open(file_path, "wb") as buffer:
|
47 |
+
shutil.copyfileobj(file.file, buffer)
|
48 |
+
|
49 |
+
# Process PDF file
|
50 |
+
documents = process_pdf_file(file_path)
|
51 |
+
chunks_with_metadata = chunk_text(documents, max_length=Config.DEFAULT_CHUNK_SIZE)
|
52 |
+
|
53 |
+
# Create session
|
54 |
+
session_id = session_manager.create_session(
|
55 |
+
file_path=file_path,
|
56 |
+
file_name=file.filename,
|
57 |
+
chunks_with_metadata=chunks_with_metadata,
|
58 |
+
model_name=model_name
|
59 |
+
)
|
60 |
+
|
61 |
+
return UploadResponse(
|
62 |
+
status="success",
|
63 |
+
session_id=session_id,
|
64 |
+
message=SuccessMessages.PDF_PROCESSED.format(filename=file.filename)
|
65 |
+
)
|
66 |
+
|
67 |
+
except Exception as e:
|
68 |
+
# Clean up file on error
|
69 |
+
if file_path and os.path.exists(file_path):
|
70 |
+
os.remove(file_path)
|
71 |
+
|
72 |
+
error_msg = str(e)
|
73 |
+
stack_trace = traceback.format_exc()
|
74 |
+
print(f"Error processing PDF: {error_msg}\nStack trace: {stack_trace}")
|
75 |
+
|
76 |
+
raise HTTPException(
|
77 |
+
status_code=500,
|
78 |
+
detail=ErrorMessages.PDF_PROCESSING_ERROR.format(error=error_msg)
|
79 |
+
)
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utility API routes.
|
3 |
+
|
4 |
+
This module handles utility endpoints like model listing and health checks.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from fastapi.responses import RedirectResponse
|
8 |
+
|
9 |
+
from models.models import ModelsResponse
|
10 |
+
from services import get_available_models
|
11 |
+
|
12 |
+
|
13 |
+
async def root_handler():
|
14 |
+
"""
|
15 |
+
Root endpoint that redirects to the main application.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
Redirect response to static index.html
|
19 |
+
"""
|
20 |
+
return RedirectResponse(url="/static/index.html")
|
21 |
+
|
22 |
+
|
23 |
+
async def get_models_handler() -> ModelsResponse:
|
24 |
+
"""
|
25 |
+
Get list of available models.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
Models response with available model configurations
|
29 |
+
"""
|
30 |
+
models = get_available_models()
|
31 |
+
return ModelsResponse(models=models)
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Configuration module for PDF Insight Beta application.
|
3 |
+
|
4 |
+
This module centralizes all configuration settings, constants, and environment variables.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
from typing import List, Dict, Any
|
9 |
+
import dotenv
|
10 |
+
|
11 |
+
# Load environment variables
|
12 |
+
dotenv.load_dotenv()
|
13 |
+
|
14 |
+
|
15 |
+
class Config:
|
16 |
+
"""Application configuration class."""
|
17 |
+
|
18 |
+
# API Configuration
|
19 |
+
GROQ_API_KEY: str = os.getenv("GROQ_API_KEY", "")
|
20 |
+
TAVILY_API_KEY: str = os.getenv("TAVILY_API_KEY", "")
|
21 |
+
|
22 |
+
# Application Settings
|
23 |
+
UPLOAD_DIR: str = "uploads"
|
24 |
+
MAX_FILE_SIZE: int = 50 * 1024 * 1024 # 50MB
|
25 |
+
|
26 |
+
# Model Configuration
|
27 |
+
DEFAULT_MODEL: str = "llama3-8b-8192"
|
28 |
+
EMBEDDING_MODEL: str = "BAAI/bge-large-en-v1.5"
|
29 |
+
|
30 |
+
# Text Processing Settings
|
31 |
+
DEFAULT_CHUNK_SIZE: int = 1000
|
32 |
+
MIN_CHUNK_LENGTH: int = 20
|
33 |
+
MIN_PARAGRAPH_LENGTH: int = 10
|
34 |
+
|
35 |
+
# RAG Configuration
|
36 |
+
DEFAULT_K_CHUNKS: int = 10
|
37 |
+
INITIAL_CONTEXT_CHUNKS: int = 5
|
38 |
+
MAX_CONTEXT_TOKENS: int = 7000
|
39 |
+
SIMILARITY_THRESHOLD: float = 1.5
|
40 |
+
|
41 |
+
# LLM Settings
|
42 |
+
LLM_TEMPERATURE: float = 0.1
|
43 |
+
MAX_TOKENS: int = 4500
|
44 |
+
|
45 |
+
# FAISS Index Configuration
|
46 |
+
FAISS_NEIGHBORS: int = 32
|
47 |
+
FAISS_EF_CONSTRUCTION: int = 200
|
48 |
+
FAISS_EF_SEARCH: int = 50
|
49 |
+
|
50 |
+
# Agent Configuration
|
51 |
+
AGENT_MAX_ITERATIONS: int = 2
|
52 |
+
AGENT_VERBOSE: bool = False
|
53 |
+
|
54 |
+
# Tavily Search Configuration
|
55 |
+
TAVILY_MAX_RESULTS: int = 5
|
56 |
+
TAVILY_SEARCH_DEPTH: str = "advanced"
|
57 |
+
TAVILY_INCLUDE_ANSWER: bool = True
|
58 |
+
TAVILY_INCLUDE_RAW_CONTENT: bool = False
|
59 |
+
|
60 |
+
# CORS Configuration
|
61 |
+
CORS_ORIGINS: List[str] = ["*"]
|
62 |
+
CORS_CREDENTIALS: bool = True
|
63 |
+
CORS_METHODS: List[str] = ["*"]
|
64 |
+
CORS_HEADERS: List[str] = ["*"]
|
65 |
+
|
66 |
+
|
67 |
+
class ModelConfig:
|
68 |
+
"""Model configuration and metadata."""
|
69 |
+
|
70 |
+
AVAILABLE_MODELS: List[Dict[str, str]] = [
|
71 |
+
{"id": "meta-llama/llama-4-scout-17b-16e-instruct", "name": "Llama 4 Scout 17B"},
|
72 |
+
{"id": "llama-3.1-8b-instant", "name": "Llama 3.1 8B Instant"},
|
73 |
+
{"id": "llama-3.3-70b-versatile", "name": "Llama 3.3 70b Versatile"},
|
74 |
+
]
|
75 |
+
|
76 |
+
@classmethod
|
77 |
+
def get_model_ids(cls) -> List[str]:
|
78 |
+
"""Get list of available model IDs."""
|
79 |
+
return [model["id"] for model in cls.AVAILABLE_MODELS]
|
80 |
+
|
81 |
+
@classmethod
|
82 |
+
def is_valid_model(cls, model_id: str) -> bool:
|
83 |
+
"""Check if a model ID is valid."""
|
84 |
+
return model_id in cls.get_model_ids()
|
85 |
+
|
86 |
+
|
87 |
+
class ErrorMessages:
|
88 |
+
"""Centralized error messages."""
|
89 |
+
|
90 |
+
# Validation Errors
|
91 |
+
EMPTY_QUERY = "Query cannot be empty"
|
92 |
+
QUERY_TOO_SHORT = "Query must be at least 3 characters long"
|
93 |
+
|
94 |
+
# Session Errors
|
95 |
+
SESSION_NOT_FOUND = "Session not found"
|
96 |
+
SESSION_EXPIRED = "Session not found or expired. Please upload a document first."
|
97 |
+
SESSION_INCOMPLETE = "Session data is incomplete. Please upload the document again."
|
98 |
+
SESSION_REMOVAL_FAILED = "Session not found or could not be removed"
|
99 |
+
|
100 |
+
# File Errors
|
101 |
+
FILE_NOT_FOUND = "The file {file_path} does not exist."
|
102 |
+
PDF_PROCESSING_ERROR = "Error processing PDF: {error}"
|
103 |
+
|
104 |
+
# API Key Errors
|
105 |
+
GROQ_API_KEY_MISSING = "GROQ_API_KEY is not set for Groq Llama models."
|
106 |
+
TAVILY_API_KEY_MISSING = "TAVILY_API_KEY is not set. Web search will not function."
|
107 |
+
|
108 |
+
# Processing Errors
|
109 |
+
PROCESSING_ERROR = "Error processing query: {error}"
|
110 |
+
RESPONSE_GENERATION_ERROR = "Sorry, I could not generate a response."
|
111 |
+
|
112 |
+
|
113 |
+
class SuccessMessages:
|
114 |
+
"""Centralized success messages."""
|
115 |
+
|
116 |
+
PDF_PROCESSED = "Processed {filename}"
|
117 |
+
PDF_REMOVED = "PDF file and session removed successfully"
|
118 |
+
CHAT_HISTORY_CLEARED = "Chat history cleared"
|
119 |
+
|
120 |
+
|
121 |
+
# Initialize directories
|
122 |
+
def initialize_directories():
|
123 |
+
"""Create necessary directories if they don't exist."""
|
124 |
+
if not os.path.exists(Config.UPLOAD_DIR):
|
125 |
+
os.makedirs(Config.UPLOAD_DIR)
|
126 |
+
|
127 |
+
|
128 |
+
# Initialize on import
|
129 |
+
initialize_directories()
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import dotenv
|
3 |
+
import pickle
|
4 |
+
import uuid
|
5 |
+
import shutil
|
6 |
+
import traceback
|
7 |
+
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
8 |
+
from fastapi.responses import JSONResponse
|
9 |
+
from fastapi.middleware.cors import CORSMiddleware
|
10 |
+
from fastapi.staticfiles import StaticFiles
|
11 |
+
from pydantic import BaseModel
|
12 |
+
import uvicorn
|
13 |
+
from development_scripts.preprocessing import (
|
14 |
+
model_selection,
|
15 |
+
process_pdf_file,
|
16 |
+
chunk_text,
|
17 |
+
create_embeddings,
|
18 |
+
build_faiss_index,
|
19 |
+
retrieve_similar_chunks,
|
20 |
+
agentic_rag,
|
21 |
+
tools as global_base_tools,
|
22 |
+
create_vector_search_tool
|
23 |
+
)
|
24 |
+
from sentence_transformers import SentenceTransformer
|
25 |
+
from langchain.memory import ConversationBufferMemory
|
26 |
+
|
27 |
+
# Load environment variables
|
28 |
+
dotenv.load_dotenv()
|
29 |
+
|
30 |
+
# Initialize FastAPI app
|
31 |
+
app = FastAPI(title="PDF Insight Beta", description="Agentic RAG for PDF documents")
|
32 |
+
|
33 |
+
# Add CORS middleware
|
34 |
+
app.add_middleware(
|
35 |
+
CORSMiddleware,
|
36 |
+
allow_origins=["*"],
|
37 |
+
allow_credentials=True,
|
38 |
+
allow_methods=["*"],
|
39 |
+
allow_headers=["*"],
|
40 |
+
)
|
41 |
+
|
42 |
+
# Create upload directory if it doesn't exist
|
43 |
+
UPLOAD_DIR = "uploads"
|
44 |
+
if not os.path.exists(UPLOAD_DIR):
|
45 |
+
os.makedirs(UPLOAD_DIR)
|
46 |
+
|
47 |
+
# Store active sessions
|
48 |
+
sessions = {}
|
49 |
+
|
50 |
+
# Define model for chat request
|
51 |
+
class ChatRequest(BaseModel):
|
52 |
+
session_id: str
|
53 |
+
query: str
|
54 |
+
use_search: bool = False
|
55 |
+
model_name: str = "meta-llama/llama-4-scout-17b-16e-instruct"
|
56 |
+
|
57 |
+
class SessionRequest(BaseModel):
|
58 |
+
session_id: str
|
59 |
+
|
60 |
+
# Function to save session data
|
61 |
+
def save_session(session_id, data):
|
62 |
+
sessions[session_id] = data # Keep non-picklable in memory for active session
|
63 |
+
|
64 |
+
pickle_safe_data = {
|
65 |
+
"file_path": data.get("file_path"),
|
66 |
+
"file_name": data.get("file_name"),
|
67 |
+
"chunks": data.get("chunks"), # Chunks with metadata (list of dicts)
|
68 |
+
"chat_history": data.get("chat_history", [])
|
69 |
+
# FAISS index, embedding model, and LLM model are not pickled, will be reloaded/recreated
|
70 |
+
}
|
71 |
+
|
72 |
+
with open(f"{UPLOAD_DIR}/{session_id}_session.pkl", "wb") as f:
|
73 |
+
pickle.dump(pickle_safe_data, f)
|
74 |
+
|
75 |
+
|
76 |
+
# Function to load session data
|
77 |
+
def load_session(session_id, model_name="llama3-8b-8192"): # Ensure model_name matches default
|
78 |
+
try:
|
79 |
+
if session_id in sessions:
|
80 |
+
cached_session = sessions[session_id]
|
81 |
+
# Ensure LLM and potentially other non-pickled parts are up-to-date or loaded
|
82 |
+
if cached_session.get("llm") is None or (hasattr(cached_session["llm"], "model_name") and cached_session["llm"].model_name != model_name):
|
83 |
+
cached_session["llm"] = model_selection(model_name)
|
84 |
+
if cached_session.get("model") is None: # Embedding model
|
85 |
+
cached_session["model"] = SentenceTransformer('BAAI/bge-large-en-v1.5')
|
86 |
+
if cached_session.get("index") is None and cached_session.get("chunks"): # FAISS index
|
87 |
+
embeddings, _ = create_embeddings(cached_session["chunks"], cached_session["model"])
|
88 |
+
cached_session["index"] = build_faiss_index(embeddings)
|
89 |
+
return cached_session, True
|
90 |
+
|
91 |
+
file_path_pkl = f"{UPLOAD_DIR}/{session_id}_session.pkl"
|
92 |
+
if os.path.exists(file_path_pkl):
|
93 |
+
with open(file_path_pkl, "rb") as f:
|
94 |
+
data = pickle.load(f)
|
95 |
+
|
96 |
+
original_pdf_path = data.get("file_path")
|
97 |
+
if data.get("chunks") and original_pdf_path and os.path.exists(original_pdf_path):
|
98 |
+
embedding_model_instance = SentenceTransformer('BAAI/bge-large-en-v1.5')
|
99 |
+
# Chunks are already {text: ..., metadata: ...}
|
100 |
+
recreated_embeddings, _ = create_embeddings(data["chunks"], embedding_model_instance)
|
101 |
+
recreated_index = build_faiss_index(recreated_embeddings)
|
102 |
+
recreated_llm = model_selection(model_name)
|
103 |
+
|
104 |
+
full_session_data = {
|
105 |
+
"file_path": original_pdf_path,
|
106 |
+
"file_name": data.get("file_name"),
|
107 |
+
"chunks": data.get("chunks"), # chunks_with_metadata
|
108 |
+
"chat_history": data.get("chat_history", []),
|
109 |
+
"model": embedding_model_instance, # SentenceTransformer model
|
110 |
+
"index": recreated_index, # FAISS index
|
111 |
+
"llm": recreated_llm # LLM
|
112 |
+
}
|
113 |
+
sessions[session_id] = full_session_data
|
114 |
+
return full_session_data, True
|
115 |
+
else:
|
116 |
+
print(f"Warning: Session data for {session_id} is incomplete or PDF missing. Cannot reconstruct.")
|
117 |
+
if os.path.exists(file_path_pkl): os.remove(file_path_pkl) # Clean up stale pkl
|
118 |
+
return None, False
|
119 |
+
|
120 |
+
return None, False
|
121 |
+
except Exception as e:
|
122 |
+
print(f"Error loading session {session_id}: {str(e)}")
|
123 |
+
print(traceback.format_exc())
|
124 |
+
return None, False
|
125 |
+
|
126 |
+
# Function to remove PDF file
|
127 |
+
def remove_pdf_file(session_id):
|
128 |
+
try:
|
129 |
+
# Check if the session exists
|
130 |
+
session_path = f"{UPLOAD_DIR}/{session_id}_session.pkl"
|
131 |
+
if os.path.exists(session_path):
|
132 |
+
# Load session data
|
133 |
+
with open(session_path, "rb") as f:
|
134 |
+
data = pickle.load(f)
|
135 |
+
|
136 |
+
# Delete PDF file if it exists
|
137 |
+
if data.get("file_path") and os.path.exists(data["file_path"]):
|
138 |
+
os.remove(data["file_path"])
|
139 |
+
|
140 |
+
# Remove session file
|
141 |
+
os.remove(session_path)
|
142 |
+
|
143 |
+
# Remove from memory if exists
|
144 |
+
if session_id in sessions:
|
145 |
+
del sessions[session_id]
|
146 |
+
|
147 |
+
return True
|
148 |
+
except Exception as e:
|
149 |
+
print(f"Error removing PDF file: {str(e)}")
|
150 |
+
return False
|
151 |
+
|
152 |
+
# Mount static files (we'll create these later)
|
153 |
+
app.mount("/static", StaticFiles(directory="static"), name="static")
|
154 |
+
|
155 |
+
# Route for the home page
|
156 |
+
@app.get("/")
|
157 |
+
async def read_root():
|
158 |
+
from fastapi.responses import RedirectResponse
|
159 |
+
return RedirectResponse(url="/static/index.html")
|
160 |
+
|
161 |
+
# Route to upload a PDF file
|
162 |
+
@app.post("/upload-pdf")
|
163 |
+
async def upload_pdf(
|
164 |
+
file: UploadFile = File(...),
|
165 |
+
model_name: str = Form("llama3-8b-8192") # Default model
|
166 |
+
):
|
167 |
+
session_id = str(uuid.uuid4())
|
168 |
+
file_path = None
|
169 |
+
|
170 |
+
try:
|
171 |
+
file_path = f"{UPLOAD_DIR}/{session_id}_{file.filename}"
|
172 |
+
with open(file_path, "wb") as buffer:
|
173 |
+
shutil.copyfileobj(file.file, buffer)
|
174 |
+
|
175 |
+
if not os.getenv("GROQ_API_KEY") and "llama" in model_name: # Llama specific check for Groq
|
176 |
+
raise ValueError("GROQ_API_KEY is not set for Groq Llama models.")
|
177 |
+
if not os.getenv("TAVILY_API_KEY"): # Needed for TavilySearchResults
|
178 |
+
print("Warning: TAVILY_API_KEY is not set. Web search will not function.")
|
179 |
+
|
180 |
+
documents = process_pdf_file(file_path)
|
181 |
+
chunks_with_metadata = chunk_text(documents, max_length=1000) # Increased from 256 to 1000 tokens for better context
|
182 |
+
|
183 |
+
embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5')
|
184 |
+
embeddings, _ = create_embeddings(chunks_with_metadata, embedding_model) # Chunks are already with metadata
|
185 |
+
|
186 |
+
index = build_faiss_index(embeddings)
|
187 |
+
llm = model_selection(model_name)
|
188 |
+
|
189 |
+
session_data = {
|
190 |
+
"file_path": file_path,
|
191 |
+
"file_name": file.filename,
|
192 |
+
"chunks": chunks_with_metadata, # Store chunks with metadata
|
193 |
+
"model": embedding_model, # SentenceTransformer instance
|
194 |
+
"index": index, # FAISS index instance
|
195 |
+
"llm": llm, # LLM instance
|
196 |
+
"chat_history": []
|
197 |
+
}
|
198 |
+
save_session(session_id, session_data)
|
199 |
+
|
200 |
+
return {"status": "success", "session_id": session_id, "message": f"Processed {file.filename}"}
|
201 |
+
|
202 |
+
except Exception as e:
|
203 |
+
if file_path and os.path.exists(file_path):
|
204 |
+
os.remove(file_path)
|
205 |
+
error_msg = str(e)
|
206 |
+
stack_trace = traceback.format_exc()
|
207 |
+
print(f"Error processing PDF: {error_msg}\nStack trace: {stack_trace}")
|
208 |
+
return JSONResponse(
|
209 |
+
status_code=500, # Internal server error for processing issues
|
210 |
+
content={"status": "error", "detail": error_msg, "type": type(e).__name__}
|
211 |
+
)
|
212 |
+
|
213 |
+
# Route to chat with the document
|
214 |
+
@app.post("/chat")
|
215 |
+
async def chat(request: ChatRequest):
|
216 |
+
# Validate query
|
217 |
+
if not request.query or not request.query.strip():
|
218 |
+
raise HTTPException(status_code=400, detail="Query cannot be empty")
|
219 |
+
|
220 |
+
if len(request.query.strip()) < 3:
|
221 |
+
raise HTTPException(status_code=400, detail="Query must be at least 3 characters long")
|
222 |
+
|
223 |
+
session, found = load_session(request.session_id, model_name=request.model_name)
|
224 |
+
if not found:
|
225 |
+
raise HTTPException(status_code=404, detail="Session not found or expired. Please upload a document first.")
|
226 |
+
|
227 |
+
try:
|
228 |
+
# Validate session data integrity
|
229 |
+
required_keys = ["index", "chunks", "model", "llm"]
|
230 |
+
missing_keys = [key for key in required_keys if key not in session]
|
231 |
+
if missing_keys:
|
232 |
+
print(f"Warning: Session {request.session_id} missing required data: {missing_keys}")
|
233 |
+
raise HTTPException(status_code=500, detail="Session data is incomplete. Please upload the document again.")
|
234 |
+
|
235 |
+
# Per-request memory to ensure chat history is correctly loaded for the agent
|
236 |
+
agent_memory = ConversationBufferMemory(memory_key="chat_history", input_key="input", return_messages=True)
|
237 |
+
for entry in session.get("chat_history", []):
|
238 |
+
agent_memory.chat_memory.add_user_message(entry["user"])
|
239 |
+
agent_memory.chat_memory.add_ai_message(entry["assistant"])
|
240 |
+
|
241 |
+
# Prepare tools for the agent for THIS request
|
242 |
+
current_request_tools = []
|
243 |
+
|
244 |
+
# 1. Add the document-specific vector search tool
|
245 |
+
vector_search_tool_instance = create_vector_search_tool(
|
246 |
+
faiss_index=session["index"],
|
247 |
+
document_chunks_with_metadata=session["chunks"], # Pass the correct variable
|
248 |
+
embedding_model=session["model"], # This is the SentenceTransformer model
|
249 |
+
max_chunk_length=1000,
|
250 |
+
k=10
|
251 |
+
)
|
252 |
+
current_request_tools.append(vector_search_tool_instance)
|
253 |
+
|
254 |
+
# 2. Conditionally add Tavily (web search) tool
|
255 |
+
if request.use_search:
|
256 |
+
if os.getenv("TAVILY_API_KEY"):
|
257 |
+
tavily_tool = next((tool for tool in global_base_tools if tool.name == "tavily_search_results_json"), None)
|
258 |
+
if tavily_tool:
|
259 |
+
current_request_tools.append(tavily_tool)
|
260 |
+
else: # Should not happen if global_base_tools is defined correctly
|
261 |
+
print("Warning: Tavily search requested, but tool misconfigured.")
|
262 |
+
else:
|
263 |
+
print("Warning: Tavily search requested, but TAVILY_API_KEY is not set.")
|
264 |
+
|
265 |
+
# Retrieve initial similar chunks for RAG context (can be empty if no good match)
|
266 |
+
# This context is given to the agent *before* it decides to use tools.
|
267 |
+
# k=5 means we retrieve up to 5 chunks for initial context.
|
268 |
+
# The agent can then use `vector_database_search` to search more if needed.
|
269 |
+
initial_similar_chunks = retrieve_similar_chunks(
|
270 |
+
request.query,
|
271 |
+
session["index"],
|
272 |
+
session["chunks"], # list of dicts {text:..., metadata:...}
|
273 |
+
session["model"], # SentenceTransformer model
|
274 |
+
k=5 # Number of chunks for initial context
|
275 |
+
)
|
276 |
+
|
277 |
+
print(f"Query: '{request.query}' - Found {len(initial_similar_chunks)} initial chunks")
|
278 |
+
if initial_similar_chunks:
|
279 |
+
print(f"Best chunk score: {initial_similar_chunks[0][1]:.4f}")
|
280 |
+
|
281 |
+
response = agentic_rag(
|
282 |
+
session["llm"],
|
283 |
+
current_request_tools, # Pass the dynamically assembled list of tools
|
284 |
+
query=request.query,
|
285 |
+
context_chunks=initial_similar_chunks,
|
286 |
+
Use_Tavily=request.use_search, # Still passed to agentic_rag for potential fine-grained logic, though prompt adapts to tools
|
287 |
+
memory=agent_memory
|
288 |
+
)
|
289 |
+
|
290 |
+
response_output = response.get("output", "Sorry, I could not generate a response.")
|
291 |
+
print(f"Generated response length: {len(response_output)} characters")
|
292 |
+
|
293 |
+
session["chat_history"].append({"user": request.query, "assistant": response_output})
|
294 |
+
save_session(request.session_id, session) # Save updated history and potentially other modified session state
|
295 |
+
|
296 |
+
return {
|
297 |
+
"status": "success",
|
298 |
+
"answer": response_output,
|
299 |
+
# Return context that was PRE-FETCHED for the agent, not necessarily all context it might have used via tools
|
300 |
+
"context_used": [{"text": chunk, "score": float(score), "metadata": meta} for chunk, score, meta in initial_similar_chunks]
|
301 |
+
}
|
302 |
+
|
303 |
+
except Exception as e:
|
304 |
+
print(f"Error processing chat query: {str(e)}\nTraceback: {traceback.format_exc()}")
|
305 |
+
raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
|
306 |
+
|
307 |
+
|
308 |
+
# Route to get chat history
|
309 |
+
@app.post("/chat-history")
|
310 |
+
async def get_chat_history(request: SessionRequest):
|
311 |
+
# Try to load session if not in memory
|
312 |
+
session, found = load_session(request.session_id)
|
313 |
+
if not found:
|
314 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
315 |
+
|
316 |
+
return {
|
317 |
+
"status": "success",
|
318 |
+
"history": session.get("chat_history", [])
|
319 |
+
}
|
320 |
+
|
321 |
+
# Route to clear chat history
|
322 |
+
@app.post("/clear-history")
|
323 |
+
async def clear_history(request: SessionRequest):
|
324 |
+
# Try to load session if not in memory
|
325 |
+
session, found = load_session(request.session_id)
|
326 |
+
if not found:
|
327 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
328 |
+
|
329 |
+
session["chat_history"] = []
|
330 |
+
save_session(request.session_id, session)
|
331 |
+
|
332 |
+
return {"status": "success", "message": "Chat history cleared"}
|
333 |
+
|
334 |
+
# Route to remove PDF from session
|
335 |
+
@app.post("/remove-pdf")
|
336 |
+
async def remove_pdf(request: SessionRequest):
|
337 |
+
success = remove_pdf_file(request.session_id)
|
338 |
+
|
339 |
+
if success:
|
340 |
+
return {"status": "success", "message": "PDF file and session removed successfully"}
|
341 |
+
else:
|
342 |
+
raise HTTPException(status_code=404, detail="Session not found or could not be removed")
|
343 |
+
|
344 |
+
# Route to list available models
|
345 |
+
@app.get("/models")
|
346 |
+
async def get_models():
|
347 |
+
# You can expand this list as needed
|
348 |
+
models = [
|
349 |
+
{"id": "meta-llama/llama-4-scout-17b-16e-instruct", "name": "Llama 4 Scout 17B"},
|
350 |
+
{"id": "llama-3.1-8b-instant", "name": "Llama 3.1 8B Instant"},
|
351 |
+
{"id": "llama-3.3-70b-versatile", "name": "Llama 3.3 70B Versatile"},
|
352 |
+
]
|
353 |
+
return {"models": models}
|
354 |
+
|
355 |
+
# Run the application if this file is executed directly
|
356 |
+
if __name__ == "__main__":
|
357 |
+
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
|
File without changes
|
@@ -4,7 +4,7 @@ ds = load_dataset("neural-bridge/rag-dataset-12000")
|
|
4 |
|
5 |
# Test the RAG system with DS dataset
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
-
from preprocessing import model_selection, create_embeddings, build_faiss_index, retrieve_similar_chunks, agentic_rag
|
8 |
import dotenv
|
9 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
10 |
import json
|
|
|
4 |
|
5 |
# Test the RAG system with DS dataset
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
+
from development_scripts.preprocessing import model_selection, create_embeddings, build_faiss_index, retrieve_similar_chunks, agentic_rag
|
8 |
import dotenv
|
9 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
10 |
import json
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Pydantic models and data structures for PDF Insight Beta application.
|
3 |
+
|
4 |
+
This module defines all the data models used throughout the application.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from typing import List, Dict, Any, Optional
|
8 |
+
from pydantic import BaseModel, Field
|
9 |
+
|
10 |
+
|
11 |
+
class ChatRequest(BaseModel):
|
12 |
+
"""Request model for chat endpoint."""
|
13 |
+
session_id: str = Field(..., description="Session identifier")
|
14 |
+
query: str = Field(..., description="User query")
|
15 |
+
use_search: bool = Field(default=False, description="Whether to use web search")
|
16 |
+
model_name: str = Field(
|
17 |
+
default="meta-llama/llama-4-scout-17b-16e-instruct",
|
18 |
+
description="LLM model to use"
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
class SessionRequest(BaseModel):
|
23 |
+
"""Request model for session-related endpoints."""
|
24 |
+
session_id: str = Field(..., description="Session identifier")
|
25 |
+
|
26 |
+
|
27 |
+
class UploadResponse(BaseModel):
|
28 |
+
"""Response model for PDF upload."""
|
29 |
+
status: str
|
30 |
+
session_id: str
|
31 |
+
message: str
|
32 |
+
|
33 |
+
|
34 |
+
class ChatResponse(BaseModel):
|
35 |
+
"""Response model for chat endpoint."""
|
36 |
+
status: str
|
37 |
+
answer: str
|
38 |
+
context_used: List[Dict[str, Any]]
|
39 |
+
|
40 |
+
|
41 |
+
class ChatHistoryResponse(BaseModel):
|
42 |
+
"""Response model for chat history endpoint."""
|
43 |
+
status: str
|
44 |
+
history: List[Dict[str, str]]
|
45 |
+
|
46 |
+
|
47 |
+
class StatusResponse(BaseModel):
|
48 |
+
"""Generic status response model."""
|
49 |
+
status: str
|
50 |
+
message: str
|
51 |
+
|
52 |
+
|
53 |
+
class ErrorResponse(BaseModel):
|
54 |
+
"""Error response model."""
|
55 |
+
status: str
|
56 |
+
detail: str
|
57 |
+
type: Optional[str] = None
|
58 |
+
|
59 |
+
|
60 |
+
class ModelInfo(BaseModel):
|
61 |
+
"""Model information."""
|
62 |
+
id: str
|
63 |
+
name: str
|
64 |
+
|
65 |
+
|
66 |
+
class ModelsResponse(BaseModel):
|
67 |
+
"""Response model for models endpoint."""
|
68 |
+
models: List[ModelInfo]
|
69 |
+
|
70 |
+
|
71 |
+
class ChunkMetadata(BaseModel):
|
72 |
+
"""Metadata for document chunks."""
|
73 |
+
source: Optional[str] = None
|
74 |
+
page: Optional[int] = None
|
75 |
+
|
76 |
+
class Config:
|
77 |
+
extra = "allow" # Allow additional metadata fields
|
78 |
+
|
79 |
+
|
80 |
+
class DocumentChunk(BaseModel):
|
81 |
+
"""Document chunk with text and metadata."""
|
82 |
+
text: str
|
83 |
+
metadata: ChunkMetadata
|
84 |
+
|
85 |
+
def to_dict(self) -> Dict[str, Any]:
|
86 |
+
"""Convert to dictionary format used in processing."""
|
87 |
+
return {
|
88 |
+
"text": self.text,
|
89 |
+
"metadata": self.metadata.dict()
|
90 |
+
}
|
91 |
+
|
92 |
+
|
93 |
+
class SessionData(BaseModel):
|
94 |
+
"""Session data structure."""
|
95 |
+
file_path: str
|
96 |
+
file_name: str
|
97 |
+
chunks: List[Dict[str, Any]] # List of chunk dictionaries
|
98 |
+
chat_history: List[Dict[str, str]] = Field(default_factory=list)
|
99 |
+
|
100 |
+
class Config:
|
101 |
+
arbitrary_types_allowed = True # Allow non-Pydantic types like FAISS index
|
102 |
+
|
103 |
+
|
104 |
+
class ChatHistoryEntry(BaseModel):
|
105 |
+
"""Single chat history entry."""
|
106 |
+
user: str
|
107 |
+
assistant: str
|
108 |
+
|
109 |
+
|
110 |
+
class ContextChunk(BaseModel):
|
111 |
+
"""Context chunk with similarity score."""
|
112 |
+
text: str
|
113 |
+
score: float
|
114 |
+
metadata: Dict[str, Any]
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Refactored preprocessing module for PDF Insight Beta.
|
3 |
+
|
4 |
+
This module provides the core preprocessing functionality with improved organization.
|
5 |
+
The original logic has been preserved while breaking it into more maintainable components.
|
6 |
+
|
7 |
+
This module maintains backward compatibility with the original preprocessing.py interface.
|
8 |
+
"""
|
9 |
+
|
10 |
+
# Re-export everything from the new modular structure for backward compatibility
|
11 |
+
from configs.config import Config
|
12 |
+
from services import (
|
13 |
+
create_llm_model as model_selection,
|
14 |
+
create_tavily_search_tool,
|
15 |
+
rag_service
|
16 |
+
)
|
17 |
+
from utils import (
|
18 |
+
process_pdf_file,
|
19 |
+
chunk_text,
|
20 |
+
create_embeddings,
|
21 |
+
build_faiss_index,
|
22 |
+
retrieve_similar_chunks,
|
23 |
+
estimate_tokens
|
24 |
+
)
|
25 |
+
|
26 |
+
# Create global tools for backward compatibility
|
27 |
+
def create_global_tools():
|
28 |
+
"""Create global tools list for backward compatibility."""
|
29 |
+
tavily_tool = create_tavily_search_tool()
|
30 |
+
return [tavily_tool] if tavily_tool else []
|
31 |
+
|
32 |
+
# Global tools instance (for backward compatibility)
|
33 |
+
tools = create_global_tools()
|
34 |
+
|
35 |
+
# Alias for the main RAG function to maintain original interface
|
36 |
+
def agentic_rag(llm, agent_specific_tools, query, context_chunks, memory, Use_Tavily=False):
|
37 |
+
"""
|
38 |
+
Main RAG function with original interface for backward compatibility.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
llm: Language model instance
|
42 |
+
agent_specific_tools: List of tools for the agent
|
43 |
+
query: User query
|
44 |
+
context_chunks: Context chunks from retrieval
|
45 |
+
memory: Conversation memory
|
46 |
+
Use_Tavily: Whether to use web search
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
Dictionary with 'output' key containing the response
|
50 |
+
"""
|
51 |
+
# Convert parameters to work with new RAG service
|
52 |
+
return rag_service.generate_response(
|
53 |
+
llm=llm,
|
54 |
+
query=query,
|
55 |
+
context_chunks=context_chunks,
|
56 |
+
faiss_index=None, # Will be handled internally by tools
|
57 |
+
document_chunks=[], # Will be handled internally by tools
|
58 |
+
embedding_model=None, # Will be handled internally by tools
|
59 |
+
memory=memory,
|
60 |
+
use_tavily=Use_Tavily
|
61 |
+
)
|
62 |
+
|
63 |
+
# Re-export the vector search tool creator for backward compatibility
|
64 |
+
from services.rag_service import create_vector_search_tool
|
65 |
+
|
66 |
+
# Maintain all original exports
|
67 |
+
__all__ = [
|
68 |
+
'model_selection',
|
69 |
+
'process_pdf_file',
|
70 |
+
'chunk_text',
|
71 |
+
'create_embeddings',
|
72 |
+
'build_faiss_index',
|
73 |
+
'retrieve_similar_chunks',
|
74 |
+
'agentic_rag',
|
75 |
+
'tools',
|
76 |
+
'create_vector_search_tool',
|
77 |
+
'estimate_tokens'
|
78 |
+
]
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Services module initialization.
|
3 |
+
|
4 |
+
This module provides easy imports for all service classes and functions.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from .llm_service import (
|
8 |
+
create_llm_model,
|
9 |
+
create_tavily_search_tool,
|
10 |
+
validate_api_keys,
|
11 |
+
get_available_models,
|
12 |
+
is_model_supported
|
13 |
+
)
|
14 |
+
|
15 |
+
from .session_service import SessionManager, session_manager
|
16 |
+
|
17 |
+
from .rag_service import (
|
18 |
+
create_vector_search_tool,
|
19 |
+
RAGService,
|
20 |
+
rag_service
|
21 |
+
)
|
22 |
+
|
23 |
+
__all__ = [
|
24 |
+
# LLM service
|
25 |
+
"create_llm_model",
|
26 |
+
"create_tavily_search_tool",
|
27 |
+
"validate_api_keys",
|
28 |
+
"get_available_models",
|
29 |
+
"is_model_supported",
|
30 |
+
|
31 |
+
# Session service
|
32 |
+
"SessionManager",
|
33 |
+
"session_manager",
|
34 |
+
|
35 |
+
# RAG service
|
36 |
+
"create_vector_search_tool",
|
37 |
+
"RAGService",
|
38 |
+
"rag_service"
|
39 |
+
]
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
LLM service for model management and interaction.
|
3 |
+
|
4 |
+
This module provides services for LLM model creation and management.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
from typing import Optional
|
9 |
+
from langchain_groq import ChatGroq
|
10 |
+
from langchain_community.tools.tavily_search import TavilySearchResults
|
11 |
+
|
12 |
+
from configs.config import Config, ErrorMessages
|
13 |
+
|
14 |
+
|
15 |
+
def create_llm_model(model_name: str) -> ChatGroq:
|
16 |
+
"""
|
17 |
+
Create and configure an LLM model.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
model_name: Name of the model to create
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
Configured ChatGroq instance
|
24 |
+
|
25 |
+
Raises:
|
26 |
+
ValueError: If API key is missing for the model
|
27 |
+
"""
|
28 |
+
if not os.getenv("GROQ_API_KEY") and "llama" in model_name:
|
29 |
+
raise ValueError(ErrorMessages.GROQ_API_KEY_MISSING)
|
30 |
+
|
31 |
+
llm = ChatGroq(
|
32 |
+
model=model_name,
|
33 |
+
api_key=os.getenv("GROQ_API_KEY"),
|
34 |
+
temperature=Config.LLM_TEMPERATURE,
|
35 |
+
max_tokens=Config.MAX_TOKENS
|
36 |
+
)
|
37 |
+
return llm
|
38 |
+
|
39 |
+
|
40 |
+
def create_tavily_search_tool() -> Optional[TavilySearchResults]:
|
41 |
+
"""
|
42 |
+
Create Tavily search tool with error handling.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
TavilySearchResults instance or None if creation fails
|
46 |
+
"""
|
47 |
+
try:
|
48 |
+
if not os.getenv("TAVILY_API_KEY"):
|
49 |
+
print(f"Warning: {ErrorMessages.TAVILY_API_KEY_MISSING}")
|
50 |
+
return None
|
51 |
+
|
52 |
+
return TavilySearchResults(
|
53 |
+
max_results=Config.TAVILY_MAX_RESULTS,
|
54 |
+
search_depth=Config.TAVILY_SEARCH_DEPTH,
|
55 |
+
include_answer=Config.TAVILY_INCLUDE_ANSWER,
|
56 |
+
include_raw_content=Config.TAVILY_INCLUDE_RAW_CONTENT
|
57 |
+
)
|
58 |
+
except Exception as e:
|
59 |
+
print(f"Warning: Could not create Tavily tool: {e}")
|
60 |
+
return None
|
61 |
+
|
62 |
+
|
63 |
+
def validate_api_keys(model_name: str, use_search: bool = False) -> None:
|
64 |
+
"""
|
65 |
+
Validate that required API keys are available.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
model_name: LLM model name
|
69 |
+
use_search: Whether web search is requested
|
70 |
+
|
71 |
+
Raises:
|
72 |
+
ValueError: If required API keys are missing
|
73 |
+
"""
|
74 |
+
if not os.getenv("GROQ_API_KEY") and "llama" in model_name:
|
75 |
+
raise ValueError(ErrorMessages.GROQ_API_KEY_MISSING)
|
76 |
+
|
77 |
+
if use_search and not os.getenv("TAVILY_API_KEY"):
|
78 |
+
print(f"Warning: {ErrorMessages.TAVILY_API_KEY_MISSING}")
|
79 |
+
|
80 |
+
|
81 |
+
def get_available_models() -> list:
|
82 |
+
"""
|
83 |
+
Get list of available models.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
List of available model configurations
|
87 |
+
"""
|
88 |
+
from configs.config import ModelConfig
|
89 |
+
return ModelConfig.AVAILABLE_MODELS
|
90 |
+
|
91 |
+
|
92 |
+
def is_model_supported(model_name: str) -> bool:
|
93 |
+
"""
|
94 |
+
Check if a model is supported.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
model_name: Model name to check
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
True if model is supported, False otherwise
|
101 |
+
"""
|
102 |
+
from configs.config import ModelConfig
|
103 |
+
return ModelConfig.is_valid_model(model_name)
|
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
RAG (Retrieval Augmented Generation) service.
|
3 |
+
|
4 |
+
This module provides the RAG implementation with tool creation and agent management.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import traceback
|
8 |
+
from typing import List, Dict, Any, Optional, Tuple
|
9 |
+
from langchain.tools import tool
|
10 |
+
from langchain.agents import AgentExecutor, create_tool_calling_agent
|
11 |
+
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
12 |
+
from langchain.memory import ConversationBufferMemory
|
13 |
+
from sentence_transformers import SentenceTransformer
|
14 |
+
import faiss
|
15 |
+
|
16 |
+
from configs.config import Config
|
17 |
+
from utils import (
|
18 |
+
retrieve_similar_chunks,
|
19 |
+
filter_relevant_chunks,
|
20 |
+
prepare_context_from_chunks
|
21 |
+
)
|
22 |
+
from services.llm_service import create_tavily_search_tool
|
23 |
+
|
24 |
+
|
25 |
+
def create_vector_search_tool(
|
26 |
+
faiss_index: faiss.IndexHNSWFlat,
|
27 |
+
document_chunks_with_metadata: List[Dict[str, Any]],
|
28 |
+
embedding_model: SentenceTransformer,
|
29 |
+
k: int = None,
|
30 |
+
max_chunk_length: int = None
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
Create a vector search tool for document retrieval.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
faiss_index: FAISS index for similarity search
|
37 |
+
document_chunks_with_metadata: List of document chunks
|
38 |
+
embedding_model: SentenceTransformer model
|
39 |
+
k: Number of chunks to retrieve
|
40 |
+
max_chunk_length: Maximum chunk length
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
LangChain tool for vector search
|
44 |
+
"""
|
45 |
+
if k is None:
|
46 |
+
k = Config.DEFAULT_K_CHUNKS // 3 # Use fewer chunks for tool
|
47 |
+
if max_chunk_length is None:
|
48 |
+
max_chunk_length = Config.DEFAULT_CHUNK_SIZE
|
49 |
+
|
50 |
+
@tool
|
51 |
+
def vector_database_search(query: str) -> str:
|
52 |
+
"""Search the uploaded PDF document for information related to the query.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
query: The search query string to find relevant information in the document.
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
A string containing relevant information found in the document.
|
59 |
+
"""
|
60 |
+
# Handle very short or empty queries
|
61 |
+
if not query or len(query.strip()) < 3:
|
62 |
+
return "Please provide a more specific search query with at least 3 characters."
|
63 |
+
|
64 |
+
try:
|
65 |
+
# Retrieve similar chunks using the provided session-specific components
|
66 |
+
similar_chunks_data = retrieve_similar_chunks(
|
67 |
+
query,
|
68 |
+
faiss_index,
|
69 |
+
document_chunks_with_metadata,
|
70 |
+
embedding_model,
|
71 |
+
k=k,
|
72 |
+
max_chunk_length=max_chunk_length
|
73 |
+
)
|
74 |
+
|
75 |
+
# Format the response
|
76 |
+
if not similar_chunks_data:
|
77 |
+
return "No relevant information found in the document for that query. Please try rephrasing your question or using different keywords."
|
78 |
+
|
79 |
+
# Filter out chunks with very high distance (low similarity)
|
80 |
+
filtered_chunks = filter_relevant_chunks(similar_chunks_data)
|
81 |
+
|
82 |
+
if not filtered_chunks:
|
83 |
+
return "No sufficiently relevant information found in the document for that query. Please try rephrasing your question or using different keywords."
|
84 |
+
|
85 |
+
context = "\n\n---\n\n".join([chunk_text for chunk_text, _, _ in filtered_chunks])
|
86 |
+
return f"The following information was found in the document regarding '{query}':\n{context}"
|
87 |
+
|
88 |
+
except Exception as e:
|
89 |
+
print(f"Error in vector search tool: {e}")
|
90 |
+
return f"Error searching the document: {str(e)}"
|
91 |
+
|
92 |
+
return vector_database_search
|
93 |
+
|
94 |
+
|
95 |
+
class RAGService:
|
96 |
+
"""Service for RAG operations."""
|
97 |
+
|
98 |
+
def __init__(self):
|
99 |
+
"""Initialize RAG service."""
|
100 |
+
self.tavily_tool = create_tavily_search_tool()
|
101 |
+
|
102 |
+
def create_agent_tools(
|
103 |
+
self,
|
104 |
+
faiss_index: faiss.IndexHNSWFlat,
|
105 |
+
document_chunks: List[Dict[str, Any]],
|
106 |
+
embedding_model: SentenceTransformer,
|
107 |
+
use_web_search: bool = False
|
108 |
+
) -> List:
|
109 |
+
"""
|
110 |
+
Create tools for the RAG agent.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
faiss_index: FAISS index
|
114 |
+
document_chunks: Document chunks
|
115 |
+
embedding_model: Embedding model
|
116 |
+
use_web_search: Whether to include web search tool
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
List of tools for the agent
|
120 |
+
"""
|
121 |
+
tools = []
|
122 |
+
|
123 |
+
# Add vector search tool
|
124 |
+
vector_tool = create_vector_search_tool(
|
125 |
+
faiss_index=faiss_index,
|
126 |
+
document_chunks_with_metadata=document_chunks,
|
127 |
+
embedding_model=embedding_model,
|
128 |
+
max_chunk_length=Config.DEFAULT_CHUNK_SIZE,
|
129 |
+
k=10
|
130 |
+
)
|
131 |
+
tools.append(vector_tool)
|
132 |
+
|
133 |
+
# Add web search tool if requested and available
|
134 |
+
if use_web_search and self.tavily_tool:
|
135 |
+
tools.append(self.tavily_tool)
|
136 |
+
|
137 |
+
return tools
|
138 |
+
|
139 |
+
def create_agent_prompt(self, has_document_search: bool, has_web_search: bool) -> ChatPromptTemplate:
|
140 |
+
"""
|
141 |
+
Create prompt template for the agent.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
has_document_search: Whether document search is available
|
145 |
+
has_web_search: Whether web search is available
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
ChatPromptTemplate for the agent
|
149 |
+
"""
|
150 |
+
# Build tool instructions dynamically
|
151 |
+
tool_instructions = ""
|
152 |
+
if has_document_search:
|
153 |
+
tool_instructions += "Use vector_database_search to find information in the uploaded document. "
|
154 |
+
if has_web_search:
|
155 |
+
tool_instructions += "Use tavily_search_results_json for web searches when document search is insufficient. "
|
156 |
+
|
157 |
+
if not tool_instructions:
|
158 |
+
tool_instructions = "Answer based on the provided context only. "
|
159 |
+
|
160 |
+
return ChatPromptTemplate.from_messages([
|
161 |
+
("system", f"""You are a helpful AI assistant that answers questions about documents.
|
162 |
+
|
163 |
+
Context: {{context}}
|
164 |
+
|
165 |
+
Tools available: {tool_instructions}
|
166 |
+
|
167 |
+
Instructions:
|
168 |
+
- Use the provided context first
|
169 |
+
- If context is insufficient, use available tools to search for more information
|
170 |
+
- Provide clear, helpful answers
|
171 |
+
- If you cannot find an answer, say so clearly"""),
|
172 |
+
("human", "{input}"),
|
173 |
+
MessagesPlaceholder(variable_name="chat_history"),
|
174 |
+
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
175 |
+
])
|
176 |
+
|
177 |
+
def execute_agent(
|
178 |
+
self,
|
179 |
+
llm,
|
180 |
+
tools: List,
|
181 |
+
query: str,
|
182 |
+
context: str,
|
183 |
+
memory: ConversationBufferMemory
|
184 |
+
) -> Dict[str, Any]:
|
185 |
+
"""
|
186 |
+
Execute the RAG agent with given tools and context.
|
187 |
+
|
188 |
+
Args:
|
189 |
+
llm: Language model
|
190 |
+
tools: List of tools
|
191 |
+
query: User query
|
192 |
+
context: Context string
|
193 |
+
memory: Conversation memory
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
Agent response
|
197 |
+
"""
|
198 |
+
try:
|
199 |
+
# Validate tools
|
200 |
+
for tool in tools:
|
201 |
+
if not hasattr(tool, 'name') or not hasattr(tool, 'description'):
|
202 |
+
raise ValueError(f"Tool {tool} is missing required attributes")
|
203 |
+
|
204 |
+
# Create prompt
|
205 |
+
has_document_search = any(t.name == "vector_database_search" for t in tools)
|
206 |
+
has_web_search = any(t.name == "tavily_search_results_json" for t in tools)
|
207 |
+
prompt = self.create_agent_prompt(has_document_search, has_web_search)
|
208 |
+
|
209 |
+
# Create agent
|
210 |
+
agent = create_tool_calling_agent(llm, tools, prompt)
|
211 |
+
agent_executor = AgentExecutor(
|
212 |
+
agent=agent,
|
213 |
+
tools=tools,
|
214 |
+
memory=memory,
|
215 |
+
verbose=Config.AGENT_VERBOSE,
|
216 |
+
handle_parsing_errors=True,
|
217 |
+
max_iterations=Config.AGENT_MAX_ITERATIONS,
|
218 |
+
return_intermediate_steps=False,
|
219 |
+
early_stopping_method="generate"
|
220 |
+
)
|
221 |
+
|
222 |
+
# Execute agent
|
223 |
+
agent_input = {
|
224 |
+
"input": query,
|
225 |
+
"context": context,
|
226 |
+
}
|
227 |
+
|
228 |
+
response_payload = agent_executor.invoke(agent_input)
|
229 |
+
|
230 |
+
# Validate response
|
231 |
+
agent_output = response_payload.get("output", "") if response_payload else ""
|
232 |
+
|
233 |
+
if not agent_output or len(agent_output.strip()) < 10:
|
234 |
+
raise ValueError("Insufficient response from agent")
|
235 |
+
|
236 |
+
# Check for incomplete responses
|
237 |
+
problematic_prefixes = [
|
238 |
+
"Based on the Document,",
|
239 |
+
"According to a web search,",
|
240 |
+
"Based on the available information,",
|
241 |
+
"I need to",
|
242 |
+
"Let me"
|
243 |
+
]
|
244 |
+
|
245 |
+
stripped_output = agent_output.strip()
|
246 |
+
if any(stripped_output == prefix.strip() or
|
247 |
+
stripped_output == prefix.strip() + "."
|
248 |
+
for prefix in problematic_prefixes):
|
249 |
+
raise ValueError("Agent returned incomplete response")
|
250 |
+
|
251 |
+
return response_payload
|
252 |
+
|
253 |
+
except Exception as e:
|
254 |
+
raise
|
255 |
+
|
256 |
+
def fallback_response(
|
257 |
+
self,
|
258 |
+
llm,
|
259 |
+
tools: List,
|
260 |
+
query: str,
|
261 |
+
context: str,
|
262 |
+
use_tavily: bool = False
|
263 |
+
) -> Dict[str, Any]:
|
264 |
+
"""
|
265 |
+
Generate fallback response using direct tool usage or LLM.
|
266 |
+
|
267 |
+
Args:
|
268 |
+
llm: Language model
|
269 |
+
tools: List of available tools
|
270 |
+
query: User query
|
271 |
+
context: Context string
|
272 |
+
use_tavily: Whether to use web search
|
273 |
+
|
274 |
+
Returns:
|
275 |
+
Fallback response
|
276 |
+
"""
|
277 |
+
try:
|
278 |
+
tool_results = []
|
279 |
+
|
280 |
+
# Try vector search first if available
|
281 |
+
vector_tool = next((t for t in tools if t.name == "vector_database_search"), None)
|
282 |
+
if vector_tool:
|
283 |
+
try:
|
284 |
+
search_result = vector_tool.run(query)
|
285 |
+
if search_result and "No relevant information" not in search_result:
|
286 |
+
tool_results.append(f"Document Search: {search_result}")
|
287 |
+
except Exception as tool_error:
|
288 |
+
pass
|
289 |
+
|
290 |
+
# Try web search if needed and available
|
291 |
+
if use_tavily:
|
292 |
+
web_tool = next((t for t in tools if t.name == "tavily_search_results_json"), None)
|
293 |
+
if web_tool:
|
294 |
+
try:
|
295 |
+
web_result = web_tool.run(query)
|
296 |
+
if web_result:
|
297 |
+
tool_results.append(f"Web Search: {web_result}")
|
298 |
+
except Exception as tool_error:
|
299 |
+
pass
|
300 |
+
|
301 |
+
# Combine tool results with context
|
302 |
+
enhanced_context = context
|
303 |
+
if tool_results:
|
304 |
+
enhanced_context += "\n\nAdditional Information:\n" + "\n\n".join(tool_results)
|
305 |
+
|
306 |
+
# Use direct LLM call with enhanced context
|
307 |
+
direct_prompt = ChatPromptTemplate.from_messages([
|
308 |
+
("system", "You are a helpful assistant. Use the provided context and information to answer the user's question clearly and completely."),
|
309 |
+
("human", "Context and Information: {context}\n\nQuestion: {input}")
|
310 |
+
])
|
311 |
+
|
312 |
+
formatted_prompt = direct_prompt.format_prompt(
|
313 |
+
context=enhanced_context,
|
314 |
+
input=query
|
315 |
+
).to_messages()
|
316 |
+
|
317 |
+
response = llm.invoke(formatted_prompt)
|
318 |
+
direct_output = response.content if hasattr(response, 'content') else str(response)
|
319 |
+
|
320 |
+
return {"output": direct_output}
|
321 |
+
|
322 |
+
except Exception as manual_error:
|
323 |
+
|
324 |
+
# Final fallback - simple LLM response
|
325 |
+
fallback_prompt = ChatPromptTemplate.from_messages([
|
326 |
+
("system", """You are a helpful assistant that answers questions about documents.
|
327 |
+
Use the provided context to answer the user's question.
|
328 |
+
If the context contains relevant information, start your answer with "Based on the document, ..."
|
329 |
+
If the context is insufficient, clearly state what you don't know."""),
|
330 |
+
("human", "Context: {context}\n\nQuestion: {input}")
|
331 |
+
])
|
332 |
+
|
333 |
+
formatted_fallback = fallback_prompt.format_prompt(
|
334 |
+
context=context,
|
335 |
+
input=query
|
336 |
+
).to_messages()
|
337 |
+
|
338 |
+
response = llm.invoke(formatted_fallback)
|
339 |
+
fallback_output = response.content if hasattr(response, 'content') else str(response)
|
340 |
+
|
341 |
+
return {"output": fallback_output}
|
342 |
+
|
343 |
+
def generate_response(
|
344 |
+
self,
|
345 |
+
llm,
|
346 |
+
query: str,
|
347 |
+
context_chunks: List[Tuple],
|
348 |
+
faiss_index: faiss.IndexHNSWFlat,
|
349 |
+
document_chunks: List[Dict[str, Any]],
|
350 |
+
embedding_model: SentenceTransformer,
|
351 |
+
memory: ConversationBufferMemory,
|
352 |
+
use_tavily: bool = False
|
353 |
+
) -> Dict[str, Any]:
|
354 |
+
"""
|
355 |
+
Generate RAG response using agent or fallback methods.
|
356 |
+
|
357 |
+
Args:
|
358 |
+
llm: Language model
|
359 |
+
query: User query
|
360 |
+
context_chunks: Initial context chunks
|
361 |
+
faiss_index: FAISS index
|
362 |
+
document_chunks: Document chunks
|
363 |
+
embedding_model: Embedding model
|
364 |
+
memory: Conversation memory
|
365 |
+
use_tavily: Whether to use web search
|
366 |
+
|
367 |
+
Returns:
|
368 |
+
Generated response
|
369 |
+
"""
|
370 |
+
# Validate inputs
|
371 |
+
if not query or not query.strip():
|
372 |
+
return {"output": "Please provide a valid question."}
|
373 |
+
|
374 |
+
# Create tools
|
375 |
+
tools = self.create_agent_tools(
|
376 |
+
faiss_index, document_chunks, embedding_model, use_tavily
|
377 |
+
)
|
378 |
+
|
379 |
+
if not tools:
|
380 |
+
fallback_prompt = ChatPromptTemplate.from_messages([
|
381 |
+
("system", "You are a helpful assistant that answers questions about documents. Use the provided context to answer the user's question."),
|
382 |
+
("human", "Context: {context}\n\nQuestion: {input}")
|
383 |
+
])
|
384 |
+
try:
|
385 |
+
formatted_prompt = fallback_prompt.format_prompt(
|
386 |
+
context="No context available",
|
387 |
+
input=query
|
388 |
+
).to_messages()
|
389 |
+
response = llm.invoke(formatted_prompt)
|
390 |
+
return {"output": response.content if hasattr(response, 'content') else str(response)}
|
391 |
+
except Exception as e:
|
392 |
+
return {"output": "I'm sorry, I encountered an error processing your request."}
|
393 |
+
|
394 |
+
# Prepare context
|
395 |
+
context = prepare_context_from_chunks(context_chunks)
|
396 |
+
|
397 |
+
# Try agent execution
|
398 |
+
if not tools:
|
399 |
+
# Handle case where no tools are available
|
400 |
+
fallback_prompt = ChatPromptTemplate.from_messages([
|
401 |
+
("system", "You are a helpful assistant that answers questions about documents. Use the provided context to answer the user's question."),
|
402 |
+
("human", "Context: {context}\n\nQuestion: {input}")
|
403 |
+
])
|
404 |
+
formatted_prompt = fallback_prompt.format_prompt(
|
405 |
+
context=context,
|
406 |
+
input=query
|
407 |
+
).to_messages()
|
408 |
+
response = llm.invoke(formatted_prompt)
|
409 |
+
return {"output": response.content if hasattr(response, 'content') else str(response)}
|
410 |
+
|
411 |
+
try:
|
412 |
+
return self.execute_agent(llm, tools, query, context, memory)
|
413 |
+
|
414 |
+
except Exception as e:
|
415 |
+
error_msg = str(e)
|
416 |
+
|
417 |
+
# Try fallback approach
|
418 |
+
try:
|
419 |
+
return self.fallback_response(llm, tools, query, context, use_tavily)
|
420 |
+
except Exception as fallback_error:
|
421 |
+
return {"output": "I'm sorry, I encountered an error processing your request. Please try again."}
|
422 |
+
|
423 |
+
|
424 |
+
# Global RAG service instance
|
425 |
+
rag_service = RAGService()
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Session management service.
|
3 |
+
|
4 |
+
This module provides high-level session management operations.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import uuid
|
8 |
+
from typing import Dict, Any, Tuple, Optional
|
9 |
+
from sentence_transformers import SentenceTransformer
|
10 |
+
|
11 |
+
from configs.config import Config, ErrorMessages
|
12 |
+
from services.llm_service import create_llm_model
|
13 |
+
from utils import (
|
14 |
+
save_session_to_file,
|
15 |
+
load_session_from_file,
|
16 |
+
reconstruct_session_objects,
|
17 |
+
cleanup_session_files,
|
18 |
+
validate_session_data,
|
19 |
+
session_exists,
|
20 |
+
create_embeddings,
|
21 |
+
build_faiss_index
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
class SessionManager:
|
26 |
+
"""Manager for session operations."""
|
27 |
+
|
28 |
+
def __init__(self):
|
29 |
+
"""Initialize session manager."""
|
30 |
+
self.active_sessions: Dict[str, Dict[str, Any]] = {}
|
31 |
+
|
32 |
+
def create_session(
|
33 |
+
self,
|
34 |
+
file_path: str,
|
35 |
+
file_name: str,
|
36 |
+
chunks_with_metadata: list,
|
37 |
+
model_name: str
|
38 |
+
) -> str:
|
39 |
+
"""
|
40 |
+
Create a new session with processed document data.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
file_path: Path to the uploaded file
|
44 |
+
file_name: Original filename
|
45 |
+
chunks_with_metadata: Processed document chunks
|
46 |
+
model_name: LLM model name
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
Session ID
|
50 |
+
"""
|
51 |
+
session_id = str(uuid.uuid4())
|
52 |
+
|
53 |
+
# Create embedding model and process chunks
|
54 |
+
embedding_model = SentenceTransformer(Config.EMBEDDING_MODEL)
|
55 |
+
embeddings, _ = create_embeddings(chunks_with_metadata, embedding_model)
|
56 |
+
|
57 |
+
# Build FAISS index
|
58 |
+
index = build_faiss_index(embeddings)
|
59 |
+
|
60 |
+
# Create LLM
|
61 |
+
llm = create_llm_model(model_name)
|
62 |
+
|
63 |
+
# Create session data
|
64 |
+
session_data = {
|
65 |
+
"file_path": file_path,
|
66 |
+
"file_name": file_name,
|
67 |
+
"chunks": chunks_with_metadata,
|
68 |
+
"model": embedding_model,
|
69 |
+
"index": index,
|
70 |
+
"llm": llm,
|
71 |
+
"chat_history": []
|
72 |
+
}
|
73 |
+
|
74 |
+
# Save to memory and file
|
75 |
+
self.active_sessions[session_id] = session_data
|
76 |
+
save_session_to_file(session_id, session_data)
|
77 |
+
|
78 |
+
return session_id
|
79 |
+
|
80 |
+
def get_session(
|
81 |
+
self,
|
82 |
+
session_id: str,
|
83 |
+
model_name: str = None
|
84 |
+
) -> Tuple[Optional[Dict[str, Any]], bool]:
|
85 |
+
"""
|
86 |
+
Retrieve session data, loading from file if necessary.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
session_id: Session identifier
|
90 |
+
model_name: LLM model name (for reconstruction)
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
Tuple of (session_data, found)
|
94 |
+
"""
|
95 |
+
if model_name is None:
|
96 |
+
model_name = Config.DEFAULT_MODEL
|
97 |
+
|
98 |
+
try:
|
99 |
+
# Check if session is in memory
|
100 |
+
if session_id in self.active_sessions:
|
101 |
+
cached_session = self.active_sessions[session_id]
|
102 |
+
|
103 |
+
# Ensure LLM is up-to-date
|
104 |
+
if (cached_session.get("llm") is None or
|
105 |
+
(hasattr(cached_session["llm"], "model_name") and
|
106 |
+
cached_session["llm"].model_name != model_name)):
|
107 |
+
cached_session["llm"] = create_llm_model(model_name)
|
108 |
+
|
109 |
+
# Ensure embedding model exists
|
110 |
+
if cached_session.get("model") is None:
|
111 |
+
cached_session["model"] = SentenceTransformer(Config.EMBEDDING_MODEL)
|
112 |
+
|
113 |
+
# Ensure FAISS index exists
|
114 |
+
if cached_session.get("index") is None and cached_session.get("chunks"):
|
115 |
+
embeddings, _ = create_embeddings(
|
116 |
+
cached_session["chunks"],
|
117 |
+
cached_session["model"]
|
118 |
+
)
|
119 |
+
cached_session["index"] = build_faiss_index(embeddings)
|
120 |
+
|
121 |
+
return cached_session, True
|
122 |
+
|
123 |
+
# Try to load from file
|
124 |
+
data, success = load_session_from_file(session_id)
|
125 |
+
if not success:
|
126 |
+
return None, False
|
127 |
+
|
128 |
+
# Check if original PDF exists
|
129 |
+
original_pdf_path = data.get("file_path")
|
130 |
+
if not (data.get("chunks") and original_pdf_path and
|
131 |
+
session_exists(session_id)):
|
132 |
+
print(f"Warning: Session data for {session_id} is incomplete or PDF missing.")
|
133 |
+
cleanup_session_files(session_id)
|
134 |
+
return None, False
|
135 |
+
|
136 |
+
# Reconstruct session objects
|
137 |
+
embedding_model = SentenceTransformer(Config.EMBEDDING_MODEL)
|
138 |
+
full_session_data = reconstruct_session_objects(
|
139 |
+
data, model_name, embedding_model
|
140 |
+
)
|
141 |
+
|
142 |
+
# Cache in memory
|
143 |
+
self.active_sessions[session_id] = full_session_data
|
144 |
+
|
145 |
+
return full_session_data, True
|
146 |
+
|
147 |
+
except Exception as e:
|
148 |
+
print(f"Error loading session {session_id}: {str(e)}")
|
149 |
+
return None, False
|
150 |
+
|
151 |
+
def save_session(self, session_id: str, session_data: Dict[str, Any]) -> bool:
|
152 |
+
"""
|
153 |
+
Save session data to memory and file.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
session_id: Session identifier
|
157 |
+
session_data: Session data to save
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
True if successful, False otherwise
|
161 |
+
"""
|
162 |
+
# Update memory cache
|
163 |
+
self.active_sessions[session_id] = session_data
|
164 |
+
|
165 |
+
# Save to file
|
166 |
+
return save_session_to_file(session_id, session_data)
|
167 |
+
|
168 |
+
def remove_session(self, session_id: str) -> bool:
|
169 |
+
"""
|
170 |
+
Remove session and associated files.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
session_id: Session identifier
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
True if successful, False otherwise
|
177 |
+
"""
|
178 |
+
try:
|
179 |
+
# Remove from memory
|
180 |
+
if session_id in self.active_sessions:
|
181 |
+
del self.active_sessions[session_id]
|
182 |
+
|
183 |
+
# Clean up files
|
184 |
+
return cleanup_session_files(session_id)
|
185 |
+
|
186 |
+
except Exception as e:
|
187 |
+
print(f"Error removing session {session_id}: {str(e)}")
|
188 |
+
return False
|
189 |
+
|
190 |
+
def clear_chat_history(self, session_id: str) -> bool:
|
191 |
+
"""
|
192 |
+
Clear chat history for a session.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
session_id: Session identifier
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
True if successful, False otherwise
|
199 |
+
"""
|
200 |
+
session_data, found = self.get_session(session_id)
|
201 |
+
if not found:
|
202 |
+
return False
|
203 |
+
|
204 |
+
session_data["chat_history"] = []
|
205 |
+
return self.save_session(session_id, session_data)
|
206 |
+
|
207 |
+
def add_chat_entry(
|
208 |
+
self,
|
209 |
+
session_id: str,
|
210 |
+
user_message: str,
|
211 |
+
assistant_message: str
|
212 |
+
) -> bool:
|
213 |
+
"""
|
214 |
+
Add a chat entry to session history.
|
215 |
+
|
216 |
+
Args:
|
217 |
+
session_id: Session identifier
|
218 |
+
user_message: User's message
|
219 |
+
assistant_message: Assistant's response
|
220 |
+
|
221 |
+
Returns:
|
222 |
+
True if successful, False otherwise
|
223 |
+
"""
|
224 |
+
session_data, found = self.get_session(session_id)
|
225 |
+
if not found:
|
226 |
+
return False
|
227 |
+
|
228 |
+
session_data["chat_history"].append({
|
229 |
+
"user": user_message,
|
230 |
+
"assistant": assistant_message
|
231 |
+
})
|
232 |
+
|
233 |
+
return self.save_session(session_id, session_data)
|
234 |
+
|
235 |
+
def validate_session(self, session_id: str) -> Tuple[bool, list]:
|
236 |
+
"""
|
237 |
+
Validate session data integrity.
|
238 |
+
|
239 |
+
Args:
|
240 |
+
session_id: Session identifier
|
241 |
+
|
242 |
+
Returns:
|
243 |
+
Tuple of (is_valid, missing_keys)
|
244 |
+
"""
|
245 |
+
session_data, found = self.get_session(session_id)
|
246 |
+
if not found:
|
247 |
+
return False, ["session_not_found"]
|
248 |
+
|
249 |
+
return validate_session_data(session_data)
|
250 |
+
|
251 |
+
|
252 |
+
# Global session manager instance
|
253 |
+
session_manager = SessionManager()
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Test script to verify the refactored code works correctly.
|
3 |
+
|
4 |
+
This script tests the main functionality to ensure backward compatibility
|
5 |
+
and proper operation of the refactored modules.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import sys
|
9 |
+
import os
|
10 |
+
import tempfile
|
11 |
+
import traceback
|
12 |
+
|
13 |
+
# Add the current directory to Python path
|
14 |
+
sys.path.insert(0, '/workspaces/PDF-Insight-Beta')
|
15 |
+
|
16 |
+
def test_imports():
|
17 |
+
"""Test that all modules can be imported successfully."""
|
18 |
+
print("Testing imports...")
|
19 |
+
|
20 |
+
try:
|
21 |
+
# Test config import
|
22 |
+
from configs.config import Config, ModelConfig, ErrorMessages
|
23 |
+
print("✓ Config module imported successfully")
|
24 |
+
|
25 |
+
# Test models import
|
26 |
+
from models.models import ChatRequest, UploadResponse
|
27 |
+
print("✓ Models module imported successfully")
|
28 |
+
|
29 |
+
# Test utils import
|
30 |
+
from utils import estimate_tokens, process_pdf_file
|
31 |
+
print("✓ Utils module imported successfully")
|
32 |
+
|
33 |
+
# Test services import
|
34 |
+
from services import create_llm_model, session_manager, rag_service
|
35 |
+
print("✓ Services module imported successfully")
|
36 |
+
|
37 |
+
# Test API import
|
38 |
+
from api import upload_pdf_handler, chat_handler
|
39 |
+
print("✓ API module imported successfully")
|
40 |
+
|
41 |
+
# Test backward compatibility
|
42 |
+
from preprocessing_refactored import model_selection, chunk_text, agentic_rag
|
43 |
+
print("✓ Backward compatibility import successful")
|
44 |
+
|
45 |
+
return True
|
46 |
+
|
47 |
+
except Exception as e:
|
48 |
+
print(f"✗ Import failed: {e}")
|
49 |
+
traceback.print_exc()
|
50 |
+
return False
|
51 |
+
|
52 |
+
def test_basic_functionality():
|
53 |
+
"""Test basic functionality of key components."""
|
54 |
+
print("\nTesting basic functionality...")
|
55 |
+
|
56 |
+
try:
|
57 |
+
from configs.config import Config
|
58 |
+
from utils.text_processing import estimate_tokens
|
59 |
+
from services.llm_service import get_available_models
|
60 |
+
|
61 |
+
# Test token estimation
|
62 |
+
tokens = estimate_tokens("This is a test string")
|
63 |
+
assert tokens > 0
|
64 |
+
print(f"✓ Token estimation works: {tokens} tokens")
|
65 |
+
|
66 |
+
# Test model listing
|
67 |
+
models = get_available_models()
|
68 |
+
assert len(models) > 0
|
69 |
+
print(f"✓ Model listing works: {len(models)} models available")
|
70 |
+
|
71 |
+
# Test config access
|
72 |
+
assert Config.DEFAULT_CHUNK_SIZE > 0
|
73 |
+
print(f"✓ Config access works: chunk size = {Config.DEFAULT_CHUNK_SIZE}")
|
74 |
+
|
75 |
+
return True
|
76 |
+
|
77 |
+
except Exception as e:
|
78 |
+
print(f"✗ Basic functionality test failed: {e}")
|
79 |
+
traceback.print_exc()
|
80 |
+
return False
|
81 |
+
|
82 |
+
def test_backward_compatibility():
|
83 |
+
"""Test that original interfaces still work."""
|
84 |
+
print("\nTesting backward compatibility...")
|
85 |
+
|
86 |
+
try:
|
87 |
+
# Test original preprocessing interface
|
88 |
+
from preprocessing_refactored import model_selection, tools, estimate_tokens
|
89 |
+
|
90 |
+
# These should work without errors
|
91 |
+
assert callable(model_selection)
|
92 |
+
assert isinstance(tools, list)
|
93 |
+
assert callable(estimate_tokens)
|
94 |
+
|
95 |
+
print("✓ Original preprocessing interface preserved")
|
96 |
+
|
97 |
+
# Test that we can still access the original functions
|
98 |
+
from preprocessing_refactored import (
|
99 |
+
process_pdf_file, chunk_text, create_embeddings,
|
100 |
+
build_faiss_index, retrieve_similar_chunks, agentic_rag
|
101 |
+
)
|
102 |
+
|
103 |
+
print("✓ All original functions accessible")
|
104 |
+
|
105 |
+
return True
|
106 |
+
|
107 |
+
except Exception as e:
|
108 |
+
print(f"✗ Backward compatibility test failed: {e}")
|
109 |
+
traceback.print_exc()
|
110 |
+
return False
|
111 |
+
|
112 |
+
def test_app_creation():
|
113 |
+
"""Test that the FastAPI app can be created."""
|
114 |
+
print("\nTesting app creation...")
|
115 |
+
|
116 |
+
try:
|
117 |
+
from app_refactored import create_app
|
118 |
+
|
119 |
+
app = create_app()
|
120 |
+
assert app is not None
|
121 |
+
print("✓ FastAPI app created successfully")
|
122 |
+
|
123 |
+
# Check that routes are properly defined
|
124 |
+
routes = [route.path for route in app.routes]
|
125 |
+
expected_routes = ["/", "/upload-pdf", "/chat", "/models"]
|
126 |
+
|
127 |
+
for route in expected_routes:
|
128 |
+
if route in routes:
|
129 |
+
print(f"✓ Route {route} found")
|
130 |
+
else:
|
131 |
+
print(f"✗ Route {route} missing")
|
132 |
+
return False
|
133 |
+
|
134 |
+
return True
|
135 |
+
|
136 |
+
except Exception as e:
|
137 |
+
print(f"✗ App creation test failed: {e}")
|
138 |
+
traceback.print_exc()
|
139 |
+
return False
|
140 |
+
|
141 |
+
def main():
|
142 |
+
"""Run all tests."""
|
143 |
+
print("=" * 50)
|
144 |
+
print("Testing Refactored PDF Insight Beta")
|
145 |
+
print("=" * 50)
|
146 |
+
|
147 |
+
tests = [
|
148 |
+
test_imports,
|
149 |
+
test_basic_functionality,
|
150 |
+
test_backward_compatibility,
|
151 |
+
test_app_creation
|
152 |
+
]
|
153 |
+
|
154 |
+
results = []
|
155 |
+
for test in tests:
|
156 |
+
results.append(test())
|
157 |
+
|
158 |
+
print("\n" + "=" * 50)
|
159 |
+
print("Test Results:")
|
160 |
+
print("=" * 50)
|
161 |
+
|
162 |
+
passed = sum(results)
|
163 |
+
total = len(results)
|
164 |
+
|
165 |
+
print(f"Tests passed: {passed}/{total}")
|
166 |
+
|
167 |
+
if passed == total:
|
168 |
+
print("✓ All tests passed! Refactoring successful.")
|
169 |
+
return 0
|
170 |
+
else:
|
171 |
+
print("✗ Some tests failed. Please check the issues above.")
|
172 |
+
return 1
|
173 |
+
|
174 |
+
if __name__ == "__main__":
|
175 |
+
exit_code = main()
|
176 |
+
sys.exit(exit_code)
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utility modules initialization.
|
3 |
+
|
4 |
+
This module provides easy imports for all utility functions.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from .text_processing import (
|
8 |
+
estimate_tokens,
|
9 |
+
process_pdf_file,
|
10 |
+
chunk_text,
|
11 |
+
create_embeddings,
|
12 |
+
filter_relevant_chunks,
|
13 |
+
prepare_context_from_chunks,
|
14 |
+
validate_chunk_data
|
15 |
+
)
|
16 |
+
|
17 |
+
from .faiss_utils import (
|
18 |
+
build_faiss_index,
|
19 |
+
retrieve_similar_chunks,
|
20 |
+
search_index_with_validation,
|
21 |
+
get_index_stats
|
22 |
+
)
|
23 |
+
|
24 |
+
from .session_utils import (
|
25 |
+
create_session_file_path,
|
26 |
+
create_upload_file_path,
|
27 |
+
prepare_pickle_safe_data,
|
28 |
+
save_session_to_file,
|
29 |
+
load_session_from_file,
|
30 |
+
reconstruct_session_objects,
|
31 |
+
cleanup_session_files,
|
32 |
+
validate_session_data,
|
33 |
+
session_exists
|
34 |
+
)
|
35 |
+
|
36 |
+
__all__ = [
|
37 |
+
# Text processing
|
38 |
+
"estimate_tokens",
|
39 |
+
"process_pdf_file",
|
40 |
+
"chunk_text",
|
41 |
+
"create_embeddings",
|
42 |
+
"filter_relevant_chunks",
|
43 |
+
"prepare_context_from_chunks",
|
44 |
+
"validate_chunk_data",
|
45 |
+
|
46 |
+
# FAISS utilities
|
47 |
+
"build_faiss_index",
|
48 |
+
"retrieve_similar_chunks",
|
49 |
+
"search_index_with_validation",
|
50 |
+
"get_index_stats",
|
51 |
+
|
52 |
+
# Session utilities
|
53 |
+
"create_session_file_path",
|
54 |
+
"create_upload_file_path",
|
55 |
+
"prepare_pickle_safe_data",
|
56 |
+
"save_session_to_file",
|
57 |
+
"load_session_from_file",
|
58 |
+
"reconstruct_session_objects",
|
59 |
+
"cleanup_session_files",
|
60 |
+
"validate_session_data",
|
61 |
+
"session_exists"
|
62 |
+
]
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Session management utilities.
|
3 |
+
|
4 |
+
This module provides utilities for session data persistence and management.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import pickle
|
9 |
+
import traceback
|
10 |
+
from typing import Dict, Any, Tuple, Optional, List
|
11 |
+
|
12 |
+
from configs.config import Config, ErrorMessages
|
13 |
+
|
14 |
+
|
15 |
+
def create_session_file_path(session_id: str) -> str:
|
16 |
+
"""
|
17 |
+
Create the file path for a session pickle file.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
session_id: Session identifier
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
File path for the session data
|
24 |
+
"""
|
25 |
+
return f"{Config.UPLOAD_DIR}/{session_id}_session.pkl"
|
26 |
+
|
27 |
+
|
28 |
+
def create_upload_file_path(session_id: str, filename: str) -> str:
|
29 |
+
"""
|
30 |
+
Create the file path for an uploaded file.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
session_id: Session identifier
|
34 |
+
filename: Original filename
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
File path for the uploaded file
|
38 |
+
"""
|
39 |
+
return f"{Config.UPLOAD_DIR}/{session_id}_{filename}"
|
40 |
+
|
41 |
+
|
42 |
+
def prepare_pickle_safe_data(session_data: Dict[str, Any]) -> Dict[str, Any]:
|
43 |
+
"""
|
44 |
+
Prepare session data for pickling by removing non-serializable objects.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
session_data: Full session data
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
Pickle-safe session data
|
51 |
+
"""
|
52 |
+
return {
|
53 |
+
"file_path": session_data.get("file_path"),
|
54 |
+
"file_name": session_data.get("file_name"),
|
55 |
+
"chunks": session_data.get("chunks"), # Chunks with metadata (list of dicts)
|
56 |
+
"chat_history": session_data.get("chat_history", [])
|
57 |
+
# FAISS index, embedding model, and LLM model are not pickled
|
58 |
+
}
|
59 |
+
|
60 |
+
|
61 |
+
def save_session_to_file(session_id: str, session_data: Dict[str, Any]) -> bool:
|
62 |
+
"""
|
63 |
+
Save session data to pickle file.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
session_id: Session identifier
|
67 |
+
session_data: Session data to save
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
True if successful, False otherwise
|
71 |
+
"""
|
72 |
+
try:
|
73 |
+
pickle_safe_data = prepare_pickle_safe_data(session_data)
|
74 |
+
file_path = create_session_file_path(session_id)
|
75 |
+
|
76 |
+
with open(file_path, "wb") as f:
|
77 |
+
pickle.dump(pickle_safe_data, f)
|
78 |
+
|
79 |
+
return True
|
80 |
+
except Exception as e:
|
81 |
+
print(f"Error saving session {session_id}: {str(e)}")
|
82 |
+
return False
|
83 |
+
|
84 |
+
|
85 |
+
def load_session_from_file(session_id: str) -> Tuple[Optional[Dict[str, Any]], bool]:
|
86 |
+
"""
|
87 |
+
Load session data from pickle file.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
session_id: Session identifier
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
Tuple of (session_data, success)
|
94 |
+
"""
|
95 |
+
try:
|
96 |
+
file_path = create_session_file_path(session_id)
|
97 |
+
|
98 |
+
if not os.path.exists(file_path):
|
99 |
+
return None, False
|
100 |
+
|
101 |
+
with open(file_path, "rb") as f:
|
102 |
+
data = pickle.load(f)
|
103 |
+
|
104 |
+
return data, True
|
105 |
+
except Exception as e:
|
106 |
+
print(f"Error loading session {session_id}: {str(e)}")
|
107 |
+
return None, False
|
108 |
+
|
109 |
+
|
110 |
+
def reconstruct_session_objects(
|
111 |
+
session_data: Dict[str, Any],
|
112 |
+
model_name: str,
|
113 |
+
embedding_model
|
114 |
+
) -> Dict[str, Any]:
|
115 |
+
"""
|
116 |
+
Reconstruct non-serializable objects in session data.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
session_data: Basic session data from pickle
|
120 |
+
model_name: LLM model name
|
121 |
+
embedding_model: SentenceTransformer instance
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
Complete session data with reconstructed objects
|
125 |
+
"""
|
126 |
+
# Import here to avoid circular imports
|
127 |
+
from sentence_transformers import SentenceTransformer
|
128 |
+
from langchain_groq import ChatGroq
|
129 |
+
|
130 |
+
# Create LLM model
|
131 |
+
llm = ChatGroq(
|
132 |
+
model=model_name,
|
133 |
+
api_key=os.getenv("GROQ_API_KEY"),
|
134 |
+
temperature=Config.LLM_TEMPERATURE,
|
135 |
+
max_tokens=Config.MAX_TOKENS
|
136 |
+
)
|
137 |
+
|
138 |
+
# Reconstruct embeddings and FAISS index
|
139 |
+
if session_data.get("chunks"):
|
140 |
+
# Import here to avoid circular imports
|
141 |
+
from utils.text_processing import create_embeddings
|
142 |
+
from utils.faiss_utils import build_faiss_index
|
143 |
+
|
144 |
+
embeddings, _ = create_embeddings(session_data["chunks"], embedding_model)
|
145 |
+
faiss_index = build_faiss_index(embeddings)
|
146 |
+
else:
|
147 |
+
embeddings, faiss_index = None, None
|
148 |
+
|
149 |
+
return {
|
150 |
+
**session_data,
|
151 |
+
"model": embedding_model,
|
152 |
+
"index": faiss_index,
|
153 |
+
"llm": llm
|
154 |
+
}
|
155 |
+
|
156 |
+
|
157 |
+
def cleanup_session_files(session_id: str) -> bool:
|
158 |
+
"""
|
159 |
+
Clean up all files associated with a session.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
session_id: Session identifier
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
True if successful, False otherwise
|
166 |
+
"""
|
167 |
+
try:
|
168 |
+
session_file = create_session_file_path(session_id)
|
169 |
+
|
170 |
+
# Load session data to get file path
|
171 |
+
if os.path.exists(session_file):
|
172 |
+
try:
|
173 |
+
with open(session_file, "rb") as f:
|
174 |
+
data = pickle.load(f)
|
175 |
+
|
176 |
+
# Delete PDF file if it exists
|
177 |
+
pdf_path = data.get("file_path")
|
178 |
+
if pdf_path and os.path.exists(pdf_path):
|
179 |
+
os.remove(pdf_path)
|
180 |
+
except Exception as e:
|
181 |
+
print(f"Error reading session file for cleanup: {e}")
|
182 |
+
|
183 |
+
# Remove session file
|
184 |
+
os.remove(session_file)
|
185 |
+
|
186 |
+
return True
|
187 |
+
except Exception as e:
|
188 |
+
print(f"Error cleaning up session {session_id}: {str(e)}")
|
189 |
+
return False
|
190 |
+
|
191 |
+
|
192 |
+
def validate_session_data(session_data: Dict[str, Any]) -> Tuple[bool, List[str]]:
|
193 |
+
"""
|
194 |
+
Validate session data integrity.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
session_data: Session data to validate
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
Tuple of (is_valid, missing_keys)
|
201 |
+
"""
|
202 |
+
required_keys = ["index", "chunks", "model", "llm"]
|
203 |
+
missing_keys = [key for key in required_keys if key not in session_data]
|
204 |
+
|
205 |
+
return len(missing_keys) == 0, missing_keys
|
206 |
+
|
207 |
+
|
208 |
+
def session_exists(session_id: str) -> bool:
|
209 |
+
"""
|
210 |
+
Check if a session exists.
|
211 |
+
|
212 |
+
Args:
|
213 |
+
session_id: Session identifier
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
True if session exists, False otherwise
|
217 |
+
"""
|
218 |
+
session_file = create_session_file_path(session_id)
|
219 |
+
return os.path.exists(session_file)
|