Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -505,35 +505,50 @@ def translate_en_to_ar(text):
|
|
| 505 |
|
| 506 |
|
| 507 |
|
| 508 |
-
# Medical context prompt
|
| 509 |
-
MEDICAL_PROMPT = """You are a medical doctor who provides accurate and reliable health information based on current medical knowledge.
|
| 510 |
-
Only answer medical questions and provide information from reliable healthcare sources.
|
| 511 |
-
If a question is not medical in nature, politely explain that you can only address health-related queries.
|
| 512 |
-
Question: {question}
|
| 513 |
-
Answer: """
|
| 514 |
-
|
| 515 |
-
def generate_response(question, max_length=350):
|
| 516 |
-
tok = models['gen_tokenizer']
|
| 517 |
-
mod = models['gen_model']
|
| 518 |
-
# Prepare prompt
|
| 519 |
-
full_prompt = MEDICAL_PROMPT.format(question=question)
|
| 520 |
-
|
| 521 |
-
# Generate response
|
| 522 |
-
inputs = tok(full_prompt, return_tensors="pt", truncation=True)
|
| 523 |
-
outputs = mod.generate(
|
| 524 |
-
inputs.input_ids,
|
| 525 |
-
max_length=max_length,
|
| 526 |
-
num_beams=4,
|
| 527 |
-
temperature=0.7,
|
| 528 |
-
early_stopping=True
|
| 529 |
-
)
|
| 530 |
-
response = tok.decode(outputs[0], skip_special_tokens=True)
|
| 531 |
-
# Extract only the answer part
|
| 532 |
-
answer = response.split("Answer: ")[-1].strip()
|
| 533 |
-
return answer
|
| 534 |
|
| 535 |
|
| 536 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 537 |
|
| 538 |
|
| 539 |
|
|
@@ -565,21 +580,41 @@ async def health_check():
|
|
| 565 |
|
| 566 |
|
| 567 |
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
async def chat_endpoint(chat_query: ChatQuery):
|
| 571 |
try:
|
| 572 |
-
|
| 573 |
-
language_code
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 583 |
except Exception as e:
|
| 584 |
raise HTTPException(status_code=500, detail=str(e))
|
| 585 |
|
|
@@ -589,6 +624,8 @@ async def chat_endpoint(chat_query: ChatQuery):
|
|
| 589 |
|
| 590 |
|
| 591 |
|
|
|
|
|
|
|
| 592 |
|
| 593 |
|
| 594 |
@app.post("/api/chat")
|
|
|
|
| 505 |
|
| 506 |
|
| 507 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
|
| 509 |
|
| 510 |
|
| 511 |
+
def get_completion(prompt: str, model: str = "gryphe/mythomax-l2-13b:free") -> str:
|
| 512 |
+
api_key = os.environ.get('OPENROUTER_API_KEY')
|
| 513 |
+
if not api_key:
|
| 514 |
+
raise HTTPException(status_code=500, detail="OPENROUTER_API_KEY not found in environment variables")
|
| 515 |
+
|
| 516 |
+
client = OpenAI(
|
| 517 |
+
base_url="https://openrouter.ai/api/v1",
|
| 518 |
+
api_key=api_key
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
if not prompt.strip():
|
| 522 |
+
raise HTTPException(status_code=400, detail="Please enter a question")
|
| 523 |
+
|
| 524 |
+
try:
|
| 525 |
+
completion = client.chat.completions.create(
|
| 526 |
+
extra_headers={
|
| 527 |
+
"HTTP-Referer": "https://huggingface.co/spaces/thechaiexperiment/phitrial",
|
| 528 |
+
"X-Title": "My Hugging Face Space"
|
| 529 |
+
},
|
| 530 |
+
model=model,
|
| 531 |
+
messages=[
|
| 532 |
+
{
|
| 533 |
+
"role": "user",
|
| 534 |
+
"content": prompt
|
| 535 |
+
}
|
| 536 |
+
]
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
if (completion and
|
| 540 |
+
hasattr(completion, 'choices') and
|
| 541 |
+
completion.choices and
|
| 542 |
+
hasattr(completion.choices[0], 'message') and
|
| 543 |
+
hasattr(completion.choices[0].message, 'content')):
|
| 544 |
+
return completion.choices[0].message.content
|
| 545 |
+
else:
|
| 546 |
+
raise HTTPException(status_code=500, detail="Received invalid response from API")
|
| 547 |
+
|
| 548 |
+
except Exception as e:
|
| 549 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 550 |
+
|
| 551 |
+
|
| 552 |
|
| 553 |
|
| 554 |
|
|
|
|
| 580 |
|
| 581 |
|
| 582 |
|
| 583 |
+
@app.post("/ask")
|
| 584 |
+
async def chat(query: ChatQuery):
|
|
|
|
| 585 |
try:
|
| 586 |
+
# Handle Arabic input
|
| 587 |
+
if query.language_code == 0:
|
| 588 |
+
# Translate question from Arabic to English
|
| 589 |
+
english_query = translate_ar_to_en(query.query)
|
| 590 |
+
if not english_query:
|
| 591 |
+
raise HTTPException(status_code=500, detail="Failed to translate question from Arabic to English")
|
| 592 |
+
|
| 593 |
+
# Get completion in English
|
| 594 |
+
english_response = get_completion(english_query)
|
| 595 |
+
|
| 596 |
+
# Translate response back to Arabic
|
| 597 |
+
arabic_response = translate_en_to_ar(english_response)
|
| 598 |
+
if not arabic_response:
|
| 599 |
+
raise HTTPException(status_code=500, detail="Failed to translate response to Arabic")
|
| 600 |
+
|
| 601 |
+
return {
|
| 602 |
+
"original_query": query.query,
|
| 603 |
+
"translated_query": english_query,
|
| 604 |
+
"response": arabic_response,
|
| 605 |
+
"response_in_english": english_response
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
# Handle English input
|
| 609 |
+
else:
|
| 610 |
+
response = get_completion(query.query)
|
| 611 |
+
return {
|
| 612 |
+
"query": query.query,
|
| 613 |
+
"response": response
|
| 614 |
+
}
|
| 615 |
+
|
| 616 |
+
except HTTPException as e:
|
| 617 |
+
raise e
|
| 618 |
except Exception as e:
|
| 619 |
raise HTTPException(status_code=500, detail=str(e))
|
| 620 |
|
|
|
|
| 624 |
|
| 625 |
|
| 626 |
|
| 627 |
+
|
| 628 |
+
|
| 629 |
|
| 630 |
|
| 631 |
@app.post("/api/chat")
|