eikarna commited on
Commit
33009e3
·
1 Parent(s): d9760ae

Revert to non-RAG

Browse files
Files changed (1) hide show
  1. app.py.bak +264 -0
app.py.bak ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+ import logging
4
+ import time
5
+ from typing import Dict, Any, Optional, List
6
+ import os
7
+ from PIL import Image
8
+ import pytesseract
9
+ import fitz # PyMuPDF
10
+ from io import BytesIO
11
+ import hashlib
12
+ from sentence_transformers import SentenceTransformer
13
+ import numpy as np
14
+ from pathlib import Path
15
+ import pickle
16
+ import tempfile
17
+
18
+ # Configure logging
19
+ logging.basicConfig(
20
+ level=logging.INFO,
21
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
22
+ )
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # Initialize SBERT model for embeddings
26
+ @st.cache_resource
27
+ def load_embedding_model():
28
+ return SentenceTransformer('all-MiniLM-L6-v2')
29
+
30
+ # Modified Vector Store Class
31
+ class SimpleVectorStore:
32
+ def __init__(self):
33
+ self.documents = []
34
+ self.embeddings = []
35
+
36
+ def add_document(self, text: str, embedding: np.ndarray):
37
+ self.documents.append(text)
38
+ self.embeddings.append(embedding)
39
+
40
+ def search(self, query_embedding: np.ndarray, top_k: int = 3) -> List[str]:
41
+ if not self.embeddings:
42
+ return []
43
+
44
+ similarities = np.dot(self.embeddings, query_embedding)
45
+ top_indices = np.argsort(similarities)[-top_k:][::-1]
46
+ return [self.documents[i] for i in top_indices]
47
+
48
+ # Document processing functions
49
+ def process_text(text: str) -> List[str]:
50
+ """Split text into chunks."""
51
+ # Simple splitting by sentences (can be improved with better chunking)
52
+ chunks = text.split('. ')
53
+ return [chunk + '.' for chunk in chunks if chunk]
54
+
55
+ def process_image(image) -> str:
56
+ """Extract text from image using OCR."""
57
+ try:
58
+ text = pytesseract.image_to_string(image)
59
+ return text
60
+ except Exception as e:
61
+ logger.error(f"Error processing image: {str(e)}")
62
+ return ""
63
+
64
+ def process_pdf(pdf_file) -> str:
65
+ """Extract text from PDF."""
66
+ try:
67
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
68
+ tmp_file.write(pdf_file.read())
69
+ tmp_file.flush()
70
+
71
+ doc = fitz.open(tmp_file.name)
72
+ text = ""
73
+ for page in doc:
74
+ text += page.get_text()
75
+ doc.close()
76
+ os.unlink(tmp_file.name)
77
+ return text
78
+ except Exception as e:
79
+ logger.error(f"Error processing PDF: {str(e)}")
80
+ return ""
81
+
82
+ # Initialize session state
83
+ if "messages" not in st.session_state:
84
+ st.session_state.messages = []
85
+ if "request_timestamps" not in st.session_state:
86
+ st.session_state.request_timestamps = []
87
+ if "vector_store" not in st.session_state:
88
+ st.session_state.vector_store = SimpleVectorStore()
89
+
90
+ # Rate limiting configuration
91
+ RATE_LIMIT_PERIOD = 60
92
+ MAX_REQUESTS_PER_PERIOD = 30
93
+
94
+ def check_rate_limit() -> bool:
95
+ """Check if we're within rate limits."""
96
+ current_time = time.time()
97
+ st.session_state.request_timestamps = [
98
+ ts for ts in st.session_state.request_timestamps
99
+ if current_time - ts < RATE_LIMIT_PERIOD
100
+ ]
101
+
102
+ if len(st.session_state.request_timestamps) >= MAX_REQUESTS_PER_PERIOD:
103
+ return False
104
+
105
+ st.session_state.request_timestamps.append(current_time)
106
+ return True
107
+
108
+ def query(payload: Dict[str, Any], api_url: str) -> Optional[Dict[str, Any]]:
109
+ """Query the Hugging Face API with error handling and rate limiting."""
110
+ if not check_rate_limit():
111
+ raise Exception(f"Rate limit exceeded. Please wait {RATE_LIMIT_PERIOD} seconds.")
112
+
113
+ try:
114
+ headers = {"Authorization": f"Bearer {st.secrets['HF_TOKEN']}"}
115
+ response = requests.post(api_url, headers=headers, json=payload, timeout=30)
116
+
117
+ if response.status_code == 429:
118
+ raise Exception("Too many requests. Please try again later.")
119
+
120
+ response.raise_for_status()
121
+ print(response.request.url)
122
+ print(response.request.headers)
123
+ print(response.request.body)
124
+ print(response)
125
+ return response.json()
126
+ except requests.exceptions.JSONDecodeError as e:
127
+ logger.error(f"API request failed: {str(e)}")
128
+ raise
129
+
130
+ # Enhanced response validation
131
+ def process_response(response: Dict[str, Any]) -> str:
132
+ if not isinstance(response, list) or not response:
133
+ raise ValueError("Invalid response format")
134
+
135
+ if 'generated_text' not in response[0]:
136
+ raise ValueError("Unexpected response structure")
137
+
138
+ text = response[0]['generated_text'].strip()
139
+
140
+ # Page configuration
141
+ st.set_page_config(
142
+ page_title="RAG-Enabled DeepSeek Chatbot",
143
+ page_icon="🤖",
144
+ layout="wide"
145
+ )
146
+
147
+ # Sidebar configuration
148
+ with st.sidebar:
149
+ st.header("Model Configuration")
150
+ st.markdown("[Get HuggingFace Token](https://huggingface.co/settings/tokens)")
151
+
152
+ model_options = [
153
+ "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
154
+ ]
155
+ selected_model = st.selectbox("Select Model", model_options, index=0)
156
+
157
+ system_message = st.text_area(
158
+ "System Message",
159
+ value="You are a friendly chatbot with RAG capabilities. Use the provided context to answer questions accurately. If the context doesn't contain relevant information, say so.",
160
+ height=100
161
+ )
162
+
163
+ max_tokens = st.slider("Max Tokens", 10, 4000, 100)
164
+ temperature = st.slider("Temperature", 0.1, 4.0, 0.3)
165
+ top_p = st.slider("Top-p", 0.1, 1.0, 0.6)
166
+
167
+ # File upload section
168
+ st.header("Upload Knowledge Base")
169
+ uploaded_files = st.file_uploader(
170
+ "Upload files (PDF, Images, Text)",
171
+ type=['pdf', 'png', 'jpg', 'jpeg', 'txt'],
172
+ accept_multiple_files=True
173
+ )
174
+
175
+ # Process uploaded files
176
+ if uploaded_files:
177
+ embedding_model = load_embedding_model()
178
+
179
+ for file in uploaded_files:
180
+ try:
181
+ if file.type == "application/pdf":
182
+ text = process_pdf(file)
183
+ elif file.type.startswith("image/"):
184
+ image = Image.open(file)
185
+ text = process_image(image)
186
+ else: # text files
187
+ text = file.getvalue().decode()
188
+
189
+ chunks = process_text(text)
190
+ for chunk in chunks:
191
+ embedding = embedding_model.encode(chunk)
192
+ st.session_state.vector_store.add_document(chunk, embedding)
193
+
194
+ st.sidebar.success(f"Successfully processed {file.name}")
195
+ except Exception as e:
196
+ st.sidebar.error(f"Error processing {file.name}: {str(e)}")
197
+
198
+ # Main chat interface
199
+ st.title("🤖 RAG-Enabled DeepSeek Chatbot")
200
+ st.caption("Upload documents in the sidebar to enhance the chatbot's knowledge")
201
+
202
+ # Display chat history
203
+ for message in st.session_state.messages:
204
+ with st.chat_message(message["role"]):
205
+ st.markdown(message["content"])
206
+
207
+ # Handle user input
208
+ if prompt := st.chat_input("Type your message..."):
209
+ # Display user message
210
+ st.session_state.messages.append({"role": "user", "content": prompt})
211
+ with st.chat_message("user"):
212
+ st.markdown(prompt)
213
+
214
+ try:
215
+ with st.spinner("Generating response..."):
216
+ embedding_model = load_embedding_model()
217
+ query_embedding = embedding_model.encode(prompt)
218
+ relevant_contexts = st.session_state.vector_store.search(query_embedding)
219
+
220
+ # Dynamic context handling
221
+ context_text = "\n".join(relevant_contexts) if relevant_contexts else ""
222
+ system_msg = (
223
+ f"{system_message} Use the provided context to answer accurately."
224
+ if context_text
225
+ else system_message
226
+ )
227
+
228
+ # Format for DeepSeek model
229
+ full_prompt = f"""<|beginofutterance|>System: {system_msg}
230
+ {context_text if context_text else ''}
231
+ <|endofutterance|>
232
+ <|beginofutterance|>User: {prompt}<|endofutterance|>
233
+ <|beginofutterance|>Assistant:"""
234
+
235
+ payload = {
236
+ "inputs": full_prompt,
237
+ "parameters": {
238
+ "max_new_tokens": max_tokens,
239
+ "temperature": temperature,
240
+ "top_p": top_p,
241
+ "return_full_text": False
242
+ }
243
+ }
244
+
245
+ api_url = f"https://api-inference.huggingface.co/models/{selected_model}"
246
+
247
+ # Get and process response
248
+ output = query(payload, api_url)
249
+ if output:
250
+ response_text = process_response(output)
251
+
252
+ # Display assistant response
253
+ with st.chat_message("assistant"):
254
+ st.markdown(response_text)
255
+
256
+ # Update chat history
257
+ st.session_state.messages.append({
258
+ "role": "assistant",
259
+ "content": response_text
260
+ })
261
+
262
+ except Exception as e:
263
+ logger.error(f"Error: {str(e)}", exc_info=True)
264
+ st.error(f"Error: {str(e)}")