Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -91,6 +91,8 @@ def load_models():
|
|
91 |
model_name = "M4-ai/Orca-2.0-Tau-1.8B"
|
92 |
models['llm_tokenizer'] = AutoTokenizer.from_pretrained(model_name)
|
93 |
models['llm_model'] = AutoModelForCausalLM.from_pretrained(model_name)
|
|
|
|
|
94 |
print("Models loaded successfully")
|
95 |
return True
|
96 |
except Exception as e:
|
@@ -501,6 +503,49 @@ def translate_en_to_ar(text):
|
|
501 |
print(f"Error during English to Arabic translation: {e}")
|
502 |
return None
|
503 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
504 |
@app.get("/")
|
505 |
async def root():
|
506 |
return {"message": "Welcome to the FastAPI application! Use the /health endpoint to check health, and /api/query for processing queries."}
|
@@ -516,6 +561,38 @@ async def health_check():
|
|
516 |
}
|
517 |
return status
|
518 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
519 |
@app.post("/api/chat")
|
520 |
async def chat_endpoint(chat_query: ChatQuery):
|
521 |
try:
|
|
|
91 |
model_name = "M4-ai/Orca-2.0-Tau-1.8B"
|
92 |
models['llm_tokenizer'] = AutoTokenizer.from_pretrained(model_name)
|
93 |
models['llm_model'] = AutoModelForCausalLM.from_pretrained(model_name)
|
94 |
+
models['gen_tokenizer'] = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-1.7B-Instruct")
|
95 |
+
models['gen_model'] = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-1.7B-Instruct")
|
96 |
print("Models loaded successfully")
|
97 |
return True
|
98 |
except Exception as e:
|
|
|
503 |
print(f"Error during English to Arabic translation: {e}")
|
504 |
return None
|
505 |
|
506 |
+
|
507 |
+
|
508 |
+
# Medical context prompt
|
509 |
+
MEDICAL_PROMPT = """You are a medical doctor who provides accurate and reliable health information based on current medical knowledge.
|
510 |
+
Only answer medical questions and provide information from reliable healthcare sources.
|
511 |
+
If a question is not medical in nature, politely explain that you can only address health-related queries.
|
512 |
+
Question: {question}
|
513 |
+
Answer: """
|
514 |
+
|
515 |
+
def generate_response(question, max_length=350):
|
516 |
+
tok = models['gen_tokenizer']
|
517 |
+
mod = models['gen_model']
|
518 |
+
# Prepare prompt
|
519 |
+
full_prompt = MEDICAL_PROMPT.format(question=question)
|
520 |
+
|
521 |
+
# Generate response
|
522 |
+
inputs = tok(full_prompt, return_tensors="pt", truncation=True)
|
523 |
+
outputs = mod.generate(
|
524 |
+
inputs.input_ids,
|
525 |
+
max_length=max_length,
|
526 |
+
num_beams=4,
|
527 |
+
temperature=0.7,
|
528 |
+
early_stopping=True
|
529 |
+
)
|
530 |
+
response = tok.decode(outputs[0], skip_special_tokens=True)
|
531 |
+
# Extract only the answer part
|
532 |
+
answer = response.split("Answer: ")[-1].strip()
|
533 |
+
return answer
|
534 |
+
|
535 |
+
|
536 |
+
|
537 |
+
|
538 |
+
|
539 |
+
|
540 |
+
|
541 |
+
|
542 |
+
|
543 |
+
|
544 |
+
|
545 |
+
|
546 |
+
|
547 |
+
|
548 |
+
|
549 |
@app.get("/")
|
550 |
async def root():
|
551 |
return {"message": "Welcome to the FastAPI application! Use the /health endpoint to check health, and /api/query for processing queries."}
|
|
|
561 |
}
|
562 |
return status
|
563 |
|
564 |
+
|
565 |
+
|
566 |
+
|
567 |
+
|
568 |
+
|
569 |
+
@app.post("/api/ask")
|
570 |
+
async def chat_endpoint(chat_query: ChatQuery):
|
571 |
+
try:
|
572 |
+
query_text = chat_query.query
|
573 |
+
language_code = chat_query.language_code
|
574 |
+
if language_code == 0:
|
575 |
+
query_text = translate_ar_to_en(query_text)
|
576 |
+
# Generate response
|
577 |
+
answer = generate_response(query_text)
|
578 |
+
# Translate back to Arabic if needed
|
579 |
+
if query.language_code == 0:
|
580 |
+
answer = translate_en_to_ar(answer)
|
581 |
+
if not answer:
|
582 |
+
return Response(answer="", error="Translation failed")
|
583 |
+
|
584 |
+
return Response(answer=answer)
|
585 |
+
except Exception as e:
|
586 |
+
raise HTTPException(status_code=500, detail=str(e))
|
587 |
+
|
588 |
+
|
589 |
+
|
590 |
+
|
591 |
+
|
592 |
+
|
593 |
+
|
594 |
+
|
595 |
+
|
596 |
@app.post("/api/chat")
|
597 |
async def chat_endpoint(chat_query: ChatQuery):
|
598 |
try:
|