eikarna
commited on
Commit
·
86ee5b2
1
Parent(s):
a468722
Fix: File Upload Session
Browse files
app.py
CHANGED
@@ -27,32 +27,15 @@ logger = logging.getLogger(__name__)
|
|
27 |
def load_embedding_model():
|
28 |
return SentenceTransformer('all-MiniLM-L6-v2')
|
29 |
|
30 |
-
# Vector
|
31 |
class SimpleVectorStore:
|
32 |
-
def __init__(self
|
33 |
-
self.file_path = file_path
|
34 |
self.documents = []
|
35 |
self.embeddings = []
|
36 |
-
self.load()
|
37 |
-
|
38 |
-
def load(self):
|
39 |
-
if os.path.exists(self.file_path):
|
40 |
-
with open(self.file_path, 'rb') as f:
|
41 |
-
data = pickle.load(f)
|
42 |
-
self.documents = data['documents']
|
43 |
-
self.embeddings = data['embeddings']
|
44 |
-
|
45 |
-
def save(self):
|
46 |
-
with open(self.file_path, 'wb') as f:
|
47 |
-
pickle.dump({
|
48 |
-
'documents': self.documents,
|
49 |
-
'embeddings': self.embeddings
|
50 |
-
}, f)
|
51 |
|
52 |
def add_document(self, text: str, embedding: np.ndarray):
|
53 |
self.documents.append(text)
|
54 |
self.embeddings.append(embedding)
|
55 |
-
self.save()
|
56 |
|
57 |
def search(self, query_embedding: np.ndarray, top_k: int = 3) -> List[str]:
|
58 |
if not self.embeddings:
|
@@ -144,20 +127,15 @@ def query(payload: Dict[str, Any], api_url: str) -> Optional[Dict[str, Any]]:
|
|
144 |
logger.error(f"API request failed: {str(e)}")
|
145 |
raise
|
146 |
|
|
|
147 |
def process_response(response: Dict[str, Any]) -> str:
|
148 |
-
"""Process and clean up the model's response."""
|
149 |
if not isinstance(response, list) or not response:
|
150 |
raise ValueError("Invalid response format")
|
151 |
|
152 |
-
|
153 |
-
|
154 |
-
"Assistant:", "AI:", "</think>", "<think>",
|
155 |
-
"\n\nHuman:", "\n\nUser:"
|
156 |
-
]
|
157 |
-
for pattern in cleanup_patterns:
|
158 |
-
text = text.replace(pattern, "").strip()
|
159 |
|
160 |
-
|
161 |
|
162 |
# Page configuration
|
163 |
st.set_page_config(
|
@@ -235,28 +213,31 @@ if prompt := st.chat_input("Type your message..."):
|
|
235 |
|
236 |
try:
|
237 |
with st.spinner("Generating response..."):
|
238 |
-
# Get relevant context from vector store
|
239 |
embedding_model = load_embedding_model()
|
240 |
query_embedding = embedding_model.encode(prompt)
|
241 |
relevant_contexts = st.session_state.vector_store.search(query_embedding)
|
242 |
|
243 |
-
#
|
244 |
-
context_text = "\n".join(relevant_contexts)
|
245 |
-
|
246 |
-
{
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
|
|
|
|
|
|
|
|
|
|
252 |
|
253 |
payload = {
|
254 |
"inputs": full_prompt,
|
255 |
"parameters": {
|
256 |
"max_new_tokens": max_tokens,
|
257 |
"temperature": temperature,
|
258 |
-
"top_p": top_p
|
259 |
-
"return_full_text": False
|
260 |
}
|
261 |
}
|
262 |
|
@@ -279,4 +260,4 @@ Assistant: Let me help you based on the provided context."""
|
|
279 |
|
280 |
except Exception as e:
|
281 |
logger.error(f"Error: {str(e)}", exc_info=True)
|
282 |
-
st.error(f"Error: {str(e)}")
|
|
|
27 |
def load_embedding_model():
|
28 |
return SentenceTransformer('all-MiniLM-L6-v2')
|
29 |
|
30 |
+
# Modified Vector Store Class
|
31 |
class SimpleVectorStore:
|
32 |
+
def __init__(self):
|
|
|
33 |
self.documents = []
|
34 |
self.embeddings = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
def add_document(self, text: str, embedding: np.ndarray):
|
37 |
self.documents.append(text)
|
38 |
self.embeddings.append(embedding)
|
|
|
39 |
|
40 |
def search(self, query_embedding: np.ndarray, top_k: int = 3) -> List[str]:
|
41 |
if not self.embeddings:
|
|
|
127 |
logger.error(f"API request failed: {str(e)}")
|
128 |
raise
|
129 |
|
130 |
+
# Enhanced response validation
|
131 |
def process_response(response: Dict[str, Any]) -> str:
|
|
|
132 |
if not isinstance(response, list) or not response:
|
133 |
raise ValueError("Invalid response format")
|
134 |
|
135 |
+
if 'generated_text' not in response[0]:
|
136 |
+
raise ValueError("Unexpected response structure")
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
+
text = response[0]['generated_text'].strip()
|
139 |
|
140 |
# Page configuration
|
141 |
st.set_page_config(
|
|
|
213 |
|
214 |
try:
|
215 |
with st.spinner("Generating response..."):
|
|
|
216 |
embedding_model = load_embedding_model()
|
217 |
query_embedding = embedding_model.encode(prompt)
|
218 |
relevant_contexts = st.session_state.vector_store.search(query_embedding)
|
219 |
|
220 |
+
# Dynamic context handling
|
221 |
+
context_text = "\n".join(relevant_contexts) if relevant_contexts else ""
|
222 |
+
system_msg = (
|
223 |
+
f"{system_message} Use the provided context to answer accurately."
|
224 |
+
if context_text
|
225 |
+
else system_message
|
226 |
+
)
|
227 |
+
|
228 |
+
# Format for DeepSeek model
|
229 |
+
full_prompt = f"""<|beginofutterance|>System: {system_msg}
|
230 |
+
{context_text if context_text else ''}
|
231 |
+
<|endofutterance|>
|
232 |
+
<|beginofutterance|>User: {prompt}<|endofutterance|>
|
233 |
+
<|beginofutterance|>Assistant:"""
|
234 |
|
235 |
payload = {
|
236 |
"inputs": full_prompt,
|
237 |
"parameters": {
|
238 |
"max_new_tokens": max_tokens,
|
239 |
"temperature": temperature,
|
240 |
+
"top_p": top_p
|
|
|
241 |
}
|
242 |
}
|
243 |
|
|
|
260 |
|
261 |
except Exception as e:
|
262 |
logger.error(f"Error: {str(e)}", exc_info=True)
|
263 |
+
st.error(f"Error: {str(e)}")
|