Spaces:
Sleeping
Sleeping
Commit
·
504482b
1
Parent(s):
fbd9256
Update app.py
Browse files
app.py
CHANGED
|
@@ -18,7 +18,7 @@ import torch
|
|
| 18 |
import pandas as pd
|
| 19 |
from huggingface_hub import hf_hub_download
|
| 20 |
from safetensors.torch import load_file # Import Safetensors loader
|
| 21 |
-
from typing import Dict, Optional
|
| 22 |
|
| 23 |
# Initialize FastAPI app
|
| 24 |
app = FastAPI()
|
|
@@ -40,6 +40,22 @@ class QueryRequest(BaseModel):
|
|
| 40 |
query: str
|
| 41 |
language_code: int = 0
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
def init_nltk():
|
| 44 |
"""Initialize NLTK resources"""
|
| 45 |
try:
|
|
@@ -258,69 +274,89 @@ async def health_check():
|
|
| 258 |
}
|
| 259 |
return status
|
| 260 |
|
| 261 |
-
@app.post("/api/
|
| 262 |
-
async def
|
| 263 |
-
"""Main query processing endpoint"""
|
| 264 |
try:
|
| 265 |
-
query_text =
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
}
|
| 286 |
-
|
| 287 |
-
doc_texts = [retrieve_document_text(doc_id) for doc_id, _ in relevant_docs]
|
| 288 |
-
doc_texts = [text for text in doc_texts if text.strip()]
|
| 289 |
-
|
| 290 |
-
if not doc_texts:
|
| 291 |
-
return {
|
| 292 |
-
'answer': 'Unable to retrieve relevant documents. Please try again.',
|
| 293 |
-
'success': True
|
| 294 |
-
}
|
| 295 |
-
|
| 296 |
-
rerank_scores = rerank_documents(query_text, doc_texts)
|
| 297 |
-
ranked_texts = [text for _, text in sorted(zip(rerank_scores, doc_texts), reverse=True)]
|
| 298 |
-
|
| 299 |
-
context = " ".join(ranked_texts[:3])
|
| 300 |
-
answer = generate_answer(query_text, context)
|
| 301 |
-
|
| 302 |
-
if language_code == 0:
|
| 303 |
-
answer = translate_text(answer, 'en_to_ar')
|
| 304 |
-
|
| 305 |
-
return {
|
| 306 |
-
'answer': answer,
|
| 307 |
-
'reranked_documents': ranked_texts,
|
| 308 |
-
'success': True
|
| 309 |
-
}
|
| 310 |
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
except Exception as e:
|
| 319 |
-
|
| 320 |
-
raise HTTPException(
|
| 321 |
-
status_code=500,
|
| 322 |
-
detail=str(e)
|
| 323 |
-
)
|
| 324 |
|
| 325 |
# Initialize application
|
| 326 |
print("Initializing application...")
|
|
|
|
| 18 |
import pandas as pd
|
| 19 |
from huggingface_hub import hf_hub_download
|
| 20 |
from safetensors.torch import load_file # Import Safetensors loader
|
| 21 |
+
from typing import List, Dict, Optional
|
| 22 |
|
| 23 |
# Initialize FastAPI app
|
| 24 |
app = FastAPI()
|
|
|
|
| 40 |
query: str
|
| 41 |
language_code: int = 0
|
| 42 |
|
| 43 |
+
class MedicalProfile(BaseModel):
|
| 44 |
+
chronic_conditions: List[str]
|
| 45 |
+
symptoms: List[str]
|
| 46 |
+
food_restrictions: List[str]
|
| 47 |
+
mental_conditions: List[str]
|
| 48 |
+
daily_symptoms: List[str]
|
| 49 |
+
|
| 50 |
+
class ChatQuery(BaseModel):
|
| 51 |
+
query: str
|
| 52 |
+
conversation_id: str
|
| 53 |
+
|
| 54 |
+
class ChatMessage(BaseModel):
|
| 55 |
+
role: str
|
| 56 |
+
content: str
|
| 57 |
+
timestamp: str
|
| 58 |
+
|
| 59 |
def init_nltk():
|
| 60 |
"""Initialize NLTK resources"""
|
| 61 |
try:
|
|
|
|
| 274 |
}
|
| 275 |
return status
|
| 276 |
|
| 277 |
+
@app.post("/api/chat")
|
| 278 |
+
async def chat_endpoint(chat_query: ChatQuery):
|
|
|
|
| 279 |
try:
|
| 280 |
+
query_text = chat_query.query
|
| 281 |
+
query_embedding = models['embedding'].encode([query_text])
|
| 282 |
+
relevant_docs = query_embeddings(query_embedding)
|
| 283 |
+
|
| 284 |
+
doc_texts = [retrieve_document_text(doc_id) for doc_id, _ in relevant_docs]
|
| 285 |
+
doc_texts = [text for text in doc_texts if text.strip()]
|
| 286 |
+
|
| 287 |
+
rerank_scores = rerank_documents(query_text, doc_texts)
|
| 288 |
+
ranked_texts = [text for _, text in sorted(zip(rerank_scores, doc_texts), reverse=True)]
|
| 289 |
+
|
| 290 |
+
context = " ".join(ranked_texts[:3])
|
| 291 |
+
answer = generate_answer(query_text, context)
|
| 292 |
+
|
| 293 |
+
return {
|
| 294 |
+
"response": answer,
|
| 295 |
+
"conversation_id": chat_query.conversation_id,
|
| 296 |
+
"success": True
|
| 297 |
+
}
|
| 298 |
+
except Exception as e:
|
| 299 |
+
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
|
| 301 |
+
@app.post("/api/resources")
|
| 302 |
+
async def resources_endpoint(profile: MedicalProfile):
|
| 303 |
+
try:
|
| 304 |
+
context = f"""
|
| 305 |
+
Medical conditions: {', '.join(profile.chronic_conditions)}
|
| 306 |
+
Current symptoms: {', '.join(profile.daily_symptoms)}
|
| 307 |
+
Restrictions: {', '.join(profile.food_restrictions)}
|
| 308 |
+
Mental health: {', '.join(profile.mental_conditions)}
|
| 309 |
+
"""
|
| 310 |
+
|
| 311 |
+
query_embedding = models['embedding'].encode([context])
|
| 312 |
+
relevant_docs = query_embeddings(query_embedding)
|
| 313 |
+
doc_texts = [retrieve_document_text(doc_id) for doc_id, _ in relevant_docs]
|
| 314 |
+
doc_texts = [text for text in doc_texts if text.strip()]
|
| 315 |
+
|
| 316 |
+
rerank_scores = rerank_documents(context, doc_texts)
|
| 317 |
+
ranked_docs = sorted(zip(relevant_docs, rerank_scores, doc_texts), key=lambda x: x[1], reverse=True)
|
| 318 |
+
|
| 319 |
+
resources = []
|
| 320 |
+
for (doc_id, _), score, text in ranked_docs[:10]:
|
| 321 |
+
doc_info = data['df'][data['df']['id'] == doc_id].iloc[0]
|
| 322 |
+
resources.append({
|
| 323 |
+
"id": doc_id,
|
| 324 |
+
"title": doc_info['title'],
|
| 325 |
+
"content": text[:200],
|
| 326 |
+
"score": float(score)
|
| 327 |
+
})
|
| 328 |
+
|
| 329 |
+
return {"resources": resources, "success": True}
|
| 330 |
+
except Exception as e:
|
| 331 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 332 |
|
| 333 |
+
@app.post("/api/recipes")
|
| 334 |
+
async def recipes_endpoint(profile: MedicalProfile):
|
| 335 |
+
try:
|
| 336 |
+
recipe_query = f"Recipes and meals suitable for someone with: {', '.join(profile.chronic_conditions + profile.food_restrictions)}"
|
| 337 |
+
|
| 338 |
+
query_embedding = models['embedding'].encode([recipe_query])
|
| 339 |
+
relevant_docs = query_embeddings(query_embedding)
|
| 340 |
+
doc_texts = [retrieve_document_text(doc_id) for doc_id, _ in relevant_docs]
|
| 341 |
+
doc_texts = [text for text in doc_texts if text.strip()]
|
| 342 |
+
|
| 343 |
+
rerank_scores = rerank_documents(recipe_query, doc_texts)
|
| 344 |
+
ranked_docs = sorted(zip(relevant_docs, rerank_scores, doc_texts), key=lambda x: x[1], reverse=True)
|
| 345 |
+
|
| 346 |
+
recipes = []
|
| 347 |
+
for (doc_id, _), score, text in ranked_docs[:10]:
|
| 348 |
+
doc_info = data['df'][data['df']['id'] == doc_id].iloc[0]
|
| 349 |
+
if 'recipe' in text.lower() or 'meal' in text.lower():
|
| 350 |
+
recipes.append({
|
| 351 |
+
"id": doc_id,
|
| 352 |
+
"title": doc_info['title'],
|
| 353 |
+
"content": text[:200],
|
| 354 |
+
"score": float(score)
|
| 355 |
+
})
|
| 356 |
+
|
| 357 |
+
return {"recipes": recipes[:5], "success": True}
|
| 358 |
except Exception as e:
|
| 359 |
+
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
|
| 361 |
# Initialize application
|
| 362 |
print("Initializing application...")
|