eikarna
commited on
Commit
·
86ee5b2
1
Parent(s):
a468722
Fix: File Upload Session
Browse files
app.py
CHANGED
|
@@ -27,32 +27,15 @@ logger = logging.getLogger(__name__)
|
|
| 27 |
def load_embedding_model():
|
| 28 |
return SentenceTransformer('all-MiniLM-L6-v2')
|
| 29 |
|
| 30 |
-
# Vector
|
| 31 |
class SimpleVectorStore:
|
| 32 |
-
def __init__(self
|
| 33 |
-
self.file_path = file_path
|
| 34 |
self.documents = []
|
| 35 |
self.embeddings = []
|
| 36 |
-
self.load()
|
| 37 |
-
|
| 38 |
-
def load(self):
|
| 39 |
-
if os.path.exists(self.file_path):
|
| 40 |
-
with open(self.file_path, 'rb') as f:
|
| 41 |
-
data = pickle.load(f)
|
| 42 |
-
self.documents = data['documents']
|
| 43 |
-
self.embeddings = data['embeddings']
|
| 44 |
-
|
| 45 |
-
def save(self):
|
| 46 |
-
with open(self.file_path, 'wb') as f:
|
| 47 |
-
pickle.dump({
|
| 48 |
-
'documents': self.documents,
|
| 49 |
-
'embeddings': self.embeddings
|
| 50 |
-
}, f)
|
| 51 |
|
| 52 |
def add_document(self, text: str, embedding: np.ndarray):
|
| 53 |
self.documents.append(text)
|
| 54 |
self.embeddings.append(embedding)
|
| 55 |
-
self.save()
|
| 56 |
|
| 57 |
def search(self, query_embedding: np.ndarray, top_k: int = 3) -> List[str]:
|
| 58 |
if not self.embeddings:
|
|
@@ -144,20 +127,15 @@ def query(payload: Dict[str, Any], api_url: str) -> Optional[Dict[str, Any]]:
|
|
| 144 |
logger.error(f"API request failed: {str(e)}")
|
| 145 |
raise
|
| 146 |
|
|
|
|
| 147 |
def process_response(response: Dict[str, Any]) -> str:
|
| 148 |
-
"""Process and clean up the model's response."""
|
| 149 |
if not isinstance(response, list) or not response:
|
| 150 |
raise ValueError("Invalid response format")
|
| 151 |
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
"Assistant:", "AI:", "</think>", "<think>",
|
| 155 |
-
"\n\nHuman:", "\n\nUser:"
|
| 156 |
-
]
|
| 157 |
-
for pattern in cleanup_patterns:
|
| 158 |
-
text = text.replace(pattern, "").strip()
|
| 159 |
|
| 160 |
-
|
| 161 |
|
| 162 |
# Page configuration
|
| 163 |
st.set_page_config(
|
|
@@ -235,28 +213,31 @@ if prompt := st.chat_input("Type your message..."):
|
|
| 235 |
|
| 236 |
try:
|
| 237 |
with st.spinner("Generating response..."):
|
| 238 |
-
# Get relevant context from vector store
|
| 239 |
embedding_model = load_embedding_model()
|
| 240 |
query_embedding = embedding_model.encode(prompt)
|
| 241 |
relevant_contexts = st.session_state.vector_store.search(query_embedding)
|
| 242 |
|
| 243 |
-
#
|
| 244 |
-
context_text = "\n".join(relevant_contexts)
|
| 245 |
-
|
| 246 |
-
{
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
|
| 253 |
payload = {
|
| 254 |
"inputs": full_prompt,
|
| 255 |
"parameters": {
|
| 256 |
"max_new_tokens": max_tokens,
|
| 257 |
"temperature": temperature,
|
| 258 |
-
"top_p": top_p
|
| 259 |
-
"return_full_text": False
|
| 260 |
}
|
| 261 |
}
|
| 262 |
|
|
@@ -279,4 +260,4 @@ Assistant: Let me help you based on the provided context."""
|
|
| 279 |
|
| 280 |
except Exception as e:
|
| 281 |
logger.error(f"Error: {str(e)}", exc_info=True)
|
| 282 |
-
st.error(f"Error: {str(e)}")
|
|
|
|
| 27 |
def load_embedding_model():
|
| 28 |
return SentenceTransformer('all-MiniLM-L6-v2')
|
| 29 |
|
| 30 |
+
# Modified Vector Store Class
|
| 31 |
class SimpleVectorStore:
|
| 32 |
+
def __init__(self):
|
|
|
|
| 33 |
self.documents = []
|
| 34 |
self.embeddings = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
def add_document(self, text: str, embedding: np.ndarray):
|
| 37 |
self.documents.append(text)
|
| 38 |
self.embeddings.append(embedding)
|
|
|
|
| 39 |
|
| 40 |
def search(self, query_embedding: np.ndarray, top_k: int = 3) -> List[str]:
|
| 41 |
if not self.embeddings:
|
|
|
|
| 127 |
logger.error(f"API request failed: {str(e)}")
|
| 128 |
raise
|
| 129 |
|
| 130 |
+
# Enhanced response validation
|
| 131 |
def process_response(response: Dict[str, Any]) -> str:
|
|
|
|
| 132 |
if not isinstance(response, list) or not response:
|
| 133 |
raise ValueError("Invalid response format")
|
| 134 |
|
| 135 |
+
if 'generated_text' not in response[0]:
|
| 136 |
+
raise ValueError("Unexpected response structure")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
+
text = response[0]['generated_text'].strip()
|
| 139 |
|
| 140 |
# Page configuration
|
| 141 |
st.set_page_config(
|
|
|
|
| 213 |
|
| 214 |
try:
|
| 215 |
with st.spinner("Generating response..."):
|
|
|
|
| 216 |
embedding_model = load_embedding_model()
|
| 217 |
query_embedding = embedding_model.encode(prompt)
|
| 218 |
relevant_contexts = st.session_state.vector_store.search(query_embedding)
|
| 219 |
|
| 220 |
+
# Dynamic context handling
|
| 221 |
+
context_text = "\n".join(relevant_contexts) if relevant_contexts else ""
|
| 222 |
+
system_msg = (
|
| 223 |
+
f"{system_message} Use the provided context to answer accurately."
|
| 224 |
+
if context_text
|
| 225 |
+
else system_message
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# Format for DeepSeek model
|
| 229 |
+
full_prompt = f"""<|beginofutterance|>System: {system_msg}
|
| 230 |
+
{context_text if context_text else ''}
|
| 231 |
+
<|endofutterance|>
|
| 232 |
+
<|beginofutterance|>User: {prompt}<|endofutterance|>
|
| 233 |
+
<|beginofutterance|>Assistant:"""
|
| 234 |
|
| 235 |
payload = {
|
| 236 |
"inputs": full_prompt,
|
| 237 |
"parameters": {
|
| 238 |
"max_new_tokens": max_tokens,
|
| 239 |
"temperature": temperature,
|
| 240 |
+
"top_p": top_p
|
|
|
|
| 241 |
}
|
| 242 |
}
|
| 243 |
|
|
|
|
| 260 |
|
| 261 |
except Exception as e:
|
| 262 |
logger.error(f"Error: {str(e)}", exc_info=True)
|
| 263 |
+
st.error(f"Error: {str(e)}")
|