eikarna commited on
Commit
86ee5b2
·
1 Parent(s): a468722

Fix: File Upload Session

Browse files
Files changed (1) hide show
  1. app.py +22 -41
app.py CHANGED
@@ -27,32 +27,15 @@ logger = logging.getLogger(__name__)
27
  def load_embedding_model():
28
  return SentenceTransformer('all-MiniLM-L6-v2')
29
 
30
- # Vector store class
31
  class SimpleVectorStore:
32
- def __init__(self, file_path: str = "vector_store.pkl"):
33
- self.file_path = file_path
34
  self.documents = []
35
  self.embeddings = []
36
- self.load()
37
-
38
- def load(self):
39
- if os.path.exists(self.file_path):
40
- with open(self.file_path, 'rb') as f:
41
- data = pickle.load(f)
42
- self.documents = data['documents']
43
- self.embeddings = data['embeddings']
44
-
45
- def save(self):
46
- with open(self.file_path, 'wb') as f:
47
- pickle.dump({
48
- 'documents': self.documents,
49
- 'embeddings': self.embeddings
50
- }, f)
51
 
52
  def add_document(self, text: str, embedding: np.ndarray):
53
  self.documents.append(text)
54
  self.embeddings.append(embedding)
55
- self.save()
56
 
57
  def search(self, query_embedding: np.ndarray, top_k: int = 3) -> List[str]:
58
  if not self.embeddings:
@@ -144,20 +127,15 @@ def query(payload: Dict[str, Any], api_url: str) -> Optional[Dict[str, Any]]:
144
  logger.error(f"API request failed: {str(e)}")
145
  raise
146
 
 
147
  def process_response(response: Dict[str, Any]) -> str:
148
- """Process and clean up the model's response."""
149
  if not isinstance(response, list) or not response:
150
  raise ValueError("Invalid response format")
151
 
152
- text = response[0]['generated_text'].strip()
153
- cleanup_patterns = [
154
- "Assistant:", "AI:", "</think>", "<think>",
155
- "\n\nHuman:", "\n\nUser:"
156
- ]
157
- for pattern in cleanup_patterns:
158
- text = text.replace(pattern, "").strip()
159
 
160
- return text
161
 
162
  # Page configuration
163
  st.set_page_config(
@@ -235,28 +213,31 @@ if prompt := st.chat_input("Type your message..."):
235
 
236
  try:
237
  with st.spinner("Generating response..."):
238
- # Get relevant context from vector store
239
  embedding_model = load_embedding_model()
240
  query_embedding = embedding_model.encode(prompt)
241
  relevant_contexts = st.session_state.vector_store.search(query_embedding)
242
 
243
- # Prepare context-enhanced prompt
244
- context_text = "\n".join(relevant_contexts)
245
- full_prompt = f"""Context information:
246
- {context_text}
247
-
248
- System: {system_message}
249
-
250
- User: {prompt}
251
- Assistant: Let me help you based on the provided context."""
 
 
 
 
 
252
 
253
  payload = {
254
  "inputs": full_prompt,
255
  "parameters": {
256
  "max_new_tokens": max_tokens,
257
  "temperature": temperature,
258
- "top_p": top_p,
259
- "return_full_text": False
260
  }
261
  }
262
 
@@ -279,4 +260,4 @@ Assistant: Let me help you based on the provided context."""
279
 
280
  except Exception as e:
281
  logger.error(f"Error: {str(e)}", exc_info=True)
282
- st.error(f"Error: {str(e)}")
 
27
  def load_embedding_model():
28
  return SentenceTransformer('all-MiniLM-L6-v2')
29
 
30
+ # Modified Vector Store Class
31
  class SimpleVectorStore:
32
+ def __init__(self):
 
33
  self.documents = []
34
  self.embeddings = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  def add_document(self, text: str, embedding: np.ndarray):
37
  self.documents.append(text)
38
  self.embeddings.append(embedding)
 
39
 
40
  def search(self, query_embedding: np.ndarray, top_k: int = 3) -> List[str]:
41
  if not self.embeddings:
 
127
  logger.error(f"API request failed: {str(e)}")
128
  raise
129
 
130
+ # Enhanced response validation
131
  def process_response(response: Dict[str, Any]) -> str:
 
132
  if not isinstance(response, list) or not response:
133
  raise ValueError("Invalid response format")
134
 
135
+ if 'generated_text' not in response[0]:
136
+ raise ValueError("Unexpected response structure")
 
 
 
 
 
137
 
138
+ text = response[0]['generated_text'].strip()
139
 
140
  # Page configuration
141
  st.set_page_config(
 
213
 
214
  try:
215
  with st.spinner("Generating response..."):
 
216
  embedding_model = load_embedding_model()
217
  query_embedding = embedding_model.encode(prompt)
218
  relevant_contexts = st.session_state.vector_store.search(query_embedding)
219
 
220
+ # Dynamic context handling
221
+ context_text = "\n".join(relevant_contexts) if relevant_contexts else ""
222
+ system_msg = (
223
+ f"{system_message} Use the provided context to answer accurately."
224
+ if context_text
225
+ else system_message
226
+ )
227
+
228
+ # Format for DeepSeek model
229
+ full_prompt = f"""<|beginofutterance|>System: {system_msg}
230
+ {context_text if context_text else ''}
231
+ <|endofutterance|>
232
+ <|beginofutterance|>User: {prompt}<|endofutterance|>
233
+ <|beginofutterance|>Assistant:"""
234
 
235
  payload = {
236
  "inputs": full_prompt,
237
  "parameters": {
238
  "max_new_tokens": max_tokens,
239
  "temperature": temperature,
240
+ "top_p": top_p
 
241
  }
242
  }
243
 
 
260
 
261
  except Exception as e:
262
  logger.error(f"Error: {str(e)}", exc_info=True)
263
+ st.error(f"Error: {str(e)}")