Jatin Mehra commited on
Commit
a193f24
·
1 Parent(s): f50be30

Refactor app.py to implement FastAPI for PDF processing, session management, and chat functionality

Browse files
Files changed (1) hide show
  1. app.py +256 -186
app.py CHANGED
@@ -1,196 +1,266 @@
1
  import os
2
- import tempfile
3
- import json
4
- import streamlit as st
5
- from streamlit_chat import message
6
- from preprocessing import Model
7
- from io import BytesIO
8
  import pickle
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Home Page Setup
11
- st.set_page_config(
12
- page_title="PDF Insight Pro",
13
- page_icon="📄",
14
- layout="centered",
 
 
 
 
 
15
  )
16
 
17
- # Custom CSS for a more polished look
18
- st.markdown("""
19
- <style>
20
- .main {
21
- background-color: #f5f5f5;
22
- }
23
- .stButton button {
24
- background-color: #4CAF50;
25
- color: white;
26
- border-radius: 8px;
27
- }
28
- .stTextInput input {
29
- border-radius: 8px;
30
- padding: 10px;
31
- }
32
- .stFileUploader input {
33
- border-radius: 8px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  }
35
- .stMarkdown h1 {
36
- color: #4CAF50;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  }
38
- </style>
39
- """, unsafe_allow_html=True)
40
-
41
- # Custom title and header
42
- st.title("📄 PDF Insight Pro")
43
- st.subheader("Empower Your Documents with AI-Driven Insights")
44
-
45
- def display_messages():
46
- """
47
- Displays the chat messages in the Streamlit UI.
48
- """
49
- st.subheader("🗨️ Conversation")
50
- st.markdown("---")
51
- for i, (msg, is_user) in enumerate(st.session_state["messages"]):
52
- message(msg, is_user=is_user, key=str(i))
53
- st.session_state["process_input_spinner"] = st.empty()
54
-
55
- def process_user_input():
56
- """
57
- Processes the user input by generating a response from the assistant.
58
- """
59
- if st.session_state["user_input"] and len(st.session_state["user_input"].strip()) > 0:
60
- user_input = st.session_state["user_input"].strip()
61
- with st.session_state["process_input_spinner"], st.spinner("Analyzing..."):
62
- agent_response = st.session_state["assistant"].get_response(
63
- user_input,
64
- st.session_state["temperature"],
65
- st.session_state["max_tokens"],
66
- st.session_state["model"]
67
- )
68
-
69
- st.session_state["messages"].append((user_input, True))
70
- st.session_state["messages"].append((agent_response, False))
71
- st.session_state["user_input"] = ""
72
-
73
- # Save chat history temporarily on local storage
74
- with open("chat_history.pkl", "wb") as f:
75
- pickle.dump(st.session_state["messages"], f)
76
-
77
- def process_file():
78
- """
79
- Processes the uploaded PDF file and appends its content to the context.
80
- """
81
- for file in st.session_state["file_uploader"]:
82
- with tempfile.NamedTemporaryFile(delete=False) as tf:
83
- tf.write(file.getbuffer())
84
- file_path = tf.name
85
-
86
- with st.session_state["process_file_spinner"], st.spinner(f"Processing {file.name}..."):
87
- try:
88
- st.session_state["assistant"].add_to_context(file_path)
89
- except Exception as e:
90
- st.error(f"Failed to process file {file.name}: {str(e)}")
91
- os.remove(file_path)
92
-
93
- def download_chat_history():
94
- """
95
- Allows users to download chat history in HTML or JSON format.
96
- """
97
- # Convert messages to JSON format
98
- chat_data = [{"role": "user" if is_user else "assistant", "content": msg} for msg, is_user in st.session_state["messages"]]
99
-
100
- # Download as JSON
101
- json_data = json.dumps(chat_data, indent=4)
102
- st.download_button(
103
- label="💾 Download Chat History as JSON",
104
- data=json_data,
105
- file_name="chat_history.json",
106
- mime="application/json"
107
- )
108
-
109
- # Download as HTML
110
- html_data = "<html><body><h1>Chat History</h1><ul>"
111
- for entry in chat_data:
112
- role = "User" if entry["role"] == "user" else "Assistant"
113
- html_data += f"<li><strong>{role}:</strong> {entry['content']}</li>"
114
- html_data += "</ul></body></html>"
115
- st.download_button(
116
- label="💾 Download Chat History as HTML",
117
- data=html_data,
118
- file_name="chat_history.html",
119
- mime="text/html"
120
- )
121
-
122
- def main_page():
123
- """
124
- Main function to set up the Streamlit UI and handle user interactions.
125
- """
126
- # Initialize session state variables
127
- if "messages" not in st.session_state:
128
- st.session_state["messages"] = []
129
-
130
- if "assistant" not in st.session_state:
131
- st.session_state["assistant"] = Model()
132
-
133
- if "user_input" not in st.session_state:
134
- st.session_state["user_input"] = ""
135
-
136
- if "temperature" not in st.session_state:
137
- st.session_state["temperature"] = 0.5
138
-
139
- if "max_tokens" not in st.session_state:
140
- st.session_state["max_tokens"] = 550
141
-
142
- if "model" not in st.session_state:
143
- st.session_state["model"] = "llama-3.1-8b-instant"
144
-
145
- # File uploader
146
- st.subheader("📤 Upload Your PDF Documents")
147
- st.file_uploader(
148
- "Choose PDF files to analyze",
149
- type=["pdf"],
150
- key="file_uploader",
151
- on_change=process_file,
152
- accept_multiple_files=True,
153
- )
154
-
155
- st.session_state["process_file_spinner"] = st.empty()
156
-
157
- # Document management section
158
- if st.session_state["assistant"].contexts:
159
- st.subheader("🗂️ Manage Uploaded Documents")
160
- for i, context in enumerate(st.session_state["assistant"].contexts):
161
- st.text_area(f"Document {i+1} Context", context[:500] + "..." if len(context) > 500 else context, height=100)
162
- if st.button(f"Remove Document {i+1}"):
163
- st.session_state["assistant"].remove_from_context(i)
164
-
165
- # Model settings
166
- with st.expander("⚙️ Customize AI Settings", expanded=True):
167
- st.slider("Sampling Temperature", min_value=0.0, max_value=1.0, step=0.1, key="temperature", help="Higher values make output more random.")
168
- st.slider("Max Tokens", min_value=750, max_value=5000, step=50, key="max_tokens", help="Limits the length of the response.")
169
- st.selectbox("Choose AI Model", ["llama-3.1-8b-instant", "llama3-70b-8192", "gemma-7b-it"], key="model")
170
-
171
- # Display messages and input box
172
- display_messages()
173
- st.text_input("Type your query and hit Enter", key="user_input", on_change=process_user_input, placeholder="Ask something about your documents...")
174
 
175
- # Download chat history section
176
- st.subheader("💾 Download Chat History")
177
- download_chat_history()
178
-
179
- # Developer info and bug report
180
- st.subheader("🐞 Bug Report")
181
- st.markdown("""
182
- If you encounter any bugs or issues while using the app, please send a bug report to the developer. You can include a screenshot (optional) to help identify the problem.\n
183
- """)
184
- st.subheader("💡 Suggestions")
185
- st.markdown("""
186
- Suggestions to improve the app's UI and user interface are also welcome. Feel free to reach out to the developer with your suggestions.\n
187
- """)
188
- st.subheader("👨‍💻 Developer Info")
189
- st.markdown("""
190
- **Developer**: Jatin Mehra\n
191
- **Email**: [email protected]\n
192
- **Mobile**: 9910364780\n
193
- """)
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  if __name__ == "__main__":
196
- main_page()
 
 
1
  import os
2
+ import dotenv
 
 
 
 
 
3
  import pickle
4
+ import uuid
5
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks, Request
6
+ from fastapi.responses import JSONResponse
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi.staticfiles import StaticFiles
9
+ from pydantic import BaseModel
10
+ import uvicorn
11
+ from preprocessing import (
12
+ model_selection,
13
+ process_pdf_file,
14
+ chunk_text,
15
+ create_embeddings,
16
+ build_faiss_index,
17
+ retrieve_similar_chunks,
18
+ agentic_rag,
19
+ tools,
20
+ memory
21
+ )
22
+ from sentence_transformers import SentenceTransformer
23
+ import shutil
24
+ import traceback
25
+
26
+ # Load environment variables
27
+ dotenv.load_dotenv()
28
 
29
+ # Initialize FastAPI app
30
+ app = FastAPI(title="PDF Insight Beta", description="Agentic RAG for PDF documents")
31
+
32
+ # Add CORS middleware
33
+ app.add_middleware(
34
+ CORSMiddleware,
35
+ allow_origins=["*"],
36
+ allow_credentials=True,
37
+ allow_methods=["*"],
38
+ allow_headers=["*"],
39
  )
40
 
41
+ # Create upload directory if it doesn't exist
42
+ UPLOAD_DIR = "uploads"
43
+ if not os.path.exists(UPLOAD_DIR):
44
+ os.makedirs(UPLOAD_DIR)
45
+
46
+ # Store active sessions
47
+ sessions = {}
48
+
49
+ # Define model for chat request
50
+ class ChatRequest(BaseModel):
51
+ session_id: str
52
+ query: str
53
+ use_search: bool = False
54
+ model_name: str = "meta-llama/llama-4-scout-17b-16e-instruct"
55
+
56
+ class SessionRequest(BaseModel):
57
+ session_id: str
58
+
59
+ # Function to save session data
60
+ def save_session(session_id, data):
61
+ sessions[session_id] = data
62
+
63
+ # Create a copy of data that is safe to pickle
64
+ pickle_safe_data = {
65
+ "file_path": data.get("file_path"),
66
+ "file_name": data.get("file_name"),
67
+ "chunks": data.get("chunks"),
68
+ "chat_history": data.get("chat_history", [])
69
+ }
70
+
71
+ # Persist to disk
72
+ with open(f"{UPLOAD_DIR}/{session_id}_session.pkl", "wb") as f:
73
+ pickle.dump(pickle_safe_data, f)
74
+
75
+ # Function to load session data
76
+ def load_session(session_id, model_name="meta-llama/llama-4-scout-17b-16e-instruct"):
77
+ try:
78
+ # Check if session is already in memory
79
+ if session_id in sessions:
80
+ return sessions[session_id], True
81
+
82
+ # Try to load from disk
83
+ file_path = f"{UPLOAD_DIR}/{session_id}_session.pkl"
84
+ if os.path.exists(file_path):
85
+ with open(file_path, "rb") as f:
86
+ data = pickle.load(f)
87
+
88
+ # Recreate non-pickled objects
89
+ if data.get("chunks") and data.get("file_path") and os.path.exists(data["file_path"]):
90
+ # Recreate model, embeddings and index
91
+ model = SentenceTransformer('all-MiniLM-L6-v2')
92
+ embeddings = create_embeddings(data["chunks"], model)
93
+ index = build_faiss_index(embeddings)
94
+
95
+ # Recreate LLM
96
+ llm = model_selection(model_name)
97
+
98
+ # Reconstruct full session data
99
+ data["model"] = model
100
+ data["index"] = index
101
+ data["llm"] = llm
102
+
103
+ # Store in memory
104
+ sessions[session_id] = data
105
+ return data, True
106
+
107
+ return None, False
108
+ except Exception as e:
109
+ print(f"Error loading session: {str(e)}")
110
+ return None, False
111
+
112
+ # Mount static files (we'll create these later)
113
+ app.mount("/static", StaticFiles(directory="static"), name="static")
114
+
115
+ # Route for the home page
116
+ @app.get("/")
117
+ async def read_root():
118
+ return {"status": "ok", "message": "PDF Insight Beta API is running"}
119
+
120
+ # Route to upload a PDF file
121
+ @app.post("/upload-pdf")
122
+ async def upload_pdf(
123
+ file: UploadFile = File(...),
124
+ model_name: str = Form("meta-llama/llama-4-scout-17b-16e-instruct")
125
+ ):
126
+ # Generate a unique session ID
127
+ session_id = str(uuid.uuid4())
128
+ file_path = None
129
+
130
+ try:
131
+ # Save the uploaded file
132
+ file_path = f"{UPLOAD_DIR}/{session_id}_{file.filename}"
133
+ with open(file_path, "wb") as buffer:
134
+ shutil.copyfileobj(file.file, buffer)
135
+
136
+ # Check if API keys are set
137
+ if not os.getenv("GROQ_API_KEY"):
138
+ raise ValueError("GROQ_API_KEY is not set in the environment variables")
139
+
140
+ # Process the PDF
141
+ text = process_pdf_file(file_path)
142
+ chunks = chunk_text(text, max_length=1500)
143
+
144
+ # Create embeddings
145
+ model = SentenceTransformer('all-MiniLM-L6-v2')
146
+ embeddings = create_embeddings(chunks, model)
147
+ index = build_faiss_index(embeddings)
148
+
149
+ # Initialize LLM
150
+ llm = model_selection(model_name)
151
+
152
+ # Save session data
153
+ session_data = {
154
+ "file_path": file_path,
155
+ "file_name": file.filename,
156
+ "chunks": chunks,
157
+ "model": model,
158
+ "index": index,
159
+ "llm": llm,
160
+ "chat_history": []
161
  }
162
+ save_session(session_id, session_data)
163
+
164
+ return {"status": "success", "session_id": session_id, "message": f"Processed {file.filename}"}
165
+
166
+ except Exception as e:
167
+ # Clean up on error
168
+ if file_path and os.path.exists(file_path):
169
+ os.remove(file_path)
170
+
171
+ error_msg = str(e)
172
+ stack_trace = traceback.format_exc()
173
+ print(f"Error processing PDF: {error_msg}")
174
+ print(f"Stack trace: {stack_trace}")
175
+
176
+ return JSONResponse(
177
+ status_code=400,
178
+ content={
179
+ "status": "error",
180
+ "detail": error_msg,
181
+ "type": type(e).__name__
182
+ }
183
+ )
184
+
185
+ # Route to chat with the document
186
+ @app.post("/chat")
187
+ async def chat(request: ChatRequest):
188
+ # Try to load session if not in memory
189
+ session, found = load_session(request.session_id, model_name=request.model_name)
190
+ if not found:
191
+ raise HTTPException(status_code=404, detail="Session not found. Please upload a document first.")
192
+
193
+ try:
194
+ # Retrieve similar chunks
195
+ similar_chunks = retrieve_similar_chunks(
196
+ request.query,
197
+ session["index"],
198
+ session["chunks"],
199
+ session["model"],
200
+ k=3
201
+ )
202
+ context = "\n".join([chunk for chunk, _ in similar_chunks])
203
+
204
+ # Generate response using agentic_rag
205
+ response = agentic_rag(
206
+ session["llm"],
207
+ tools,
208
+ query=request.query,
209
+ context=context,
210
+ Use_Tavily=request.use_search
211
+ )
212
+
213
+ # Update chat history
214
+ session["chat_history"].append({"user": request.query, "assistant": response["output"]})
215
+ save_session(request.session_id, session)
216
+
217
+ return {
218
+ "status": "success",
219
+ "answer": response["output"],
220
+ "context_used": [{"text": chunk, "score": float(score)} for chunk, score in similar_chunks]
221
  }
222
+
223
+ except Exception as e:
224
+ raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
225
+
226
+ # Route to get chat history
227
+ @app.post("/chat-history")
228
+ async def get_chat_history(request: SessionRequest):
229
+ # Try to load session if not in memory
230
+ session, found = load_session(request.session_id)
231
+ if not found:
232
+ raise HTTPException(status_code=404, detail="Session not found")
233
+
234
+ return {
235
+ "status": "success",
236
+ "history": session.get("chat_history", [])
237
+ }
238
+
239
+ # Route to clear chat history
240
+ @app.post("/clear-history")
241
+ async def clear_history(request: SessionRequest):
242
+ # Try to load session if not in memory
243
+ session, found = load_session(request.session_id)
244
+ if not found:
245
+ raise HTTPException(status_code=404, detail="Session not found")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
+ session["chat_history"] = []
248
+ save_session(request.session_id, session)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
+ return {"status": "success", "message": "Chat history cleared"}
251
+
252
+ # Route to list available models
253
+ @app.get("/models")
254
+ async def get_models():
255
+ # You can expand this list as needed
256
+ models = [
257
+ {"id": "meta-llama/llama-4-scout-17b-16e-instruct", "name": "Llama 4 Scout 17B"},
258
+ {"id": "llama-3.1-8b-instant", "name": "Llama 3.1 8B Instant"},
259
+ {"id": "llama-3.3-70b-versatile", "name": "Llama 3.3 70B Versatile"},
260
+ ]
261
+ return {"models": models}
262
+
263
+ # Run the application if this file is executed directly
264
  if __name__ == "__main__":
265
+ uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
266
+