CPS-Test-Mobile / app.py
Ali2206's picture
Update app.py
e44a01b verified
raw
history blame
5.63 kB
import sys
import os
import pandas as pd
import json
import gradio as gr
from typing import List, Tuple
import hashlib
import shutil
import re
from datetime import datetime
persistent_dir = "/data/hf_cache"
os.makedirs(persistent_dir, exist_ok=True)
model_cache_dir = os.path.join(persistent_dir, "txagent_models")
tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
file_cache_dir = os.path.join(persistent_dir, "cache")
report_dir = os.path.join(persistent_dir, "reports")
for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
os.makedirs(directory, exist_ok=True)
os.environ["HF_HOME"] = model_cache_dir
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.abspath(os.path.join(current_dir, "src"))
sys.path.insert(0, src_path)
from txagent.txagent import TxAgent
def file_hash(path: str) -> str:
with open(path, "rb") as f:
return hashlib.md5(f.read()).hexdigest()
def clean_response(text: str) -> str:
try:
text = text.encode('utf-8', 'surrogatepass').decode('utf-8')
except UnicodeError:
text = text.encode('utf-8', 'replace').decode('utf-8')
text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
text = re.sub(r"\n{3,}", "\n\n", text)
text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
return text.strip()
def parse_excel_as_whole_prompt(file_path: str) -> str:
xl = pd.ExcelFile(file_path)
df = xl.parse(xl.sheet_names[0], header=0).fillna("")
records = []
for _, row in df.iterrows():
record = f"- {row['Form Name']}: {row['Form Item']} = {row['Item Response']} ({row['Interview Date']} by {row['Interviewer']})\n{row['Description']}"
records.append(clean_response(record))
record_text = "\n".join(records)
prompt = f"""
Patient Complete History:
Instructions:
Based on the complete patient record below, identify any potential missed diagnoses, medication conflicts, incomplete assessments, and urgent follow-up needs. Provide a clinical summary under the markdown headings.
Patient History:
{record_text}
### Missed Diagnoses
- ...
### Medication Conflicts
- ...
### Incomplete Assessments
- ...
### Urgent Follow-up
- ...
"""
return prompt
def init_agent():
default_tool_path = os.path.abspath("data/new_tool.json")
target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
if not os.path.exists(target_tool_path):
shutil.copy(default_tool_path, target_tool_path)
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict={"new_tool": target_tool_path},
force_finish=True,
enable_checker=True,
step_rag_num=4,
seed=100,
additional_default_tools=[],
)
agent.init_model()
return agent
def create_ui(agent):
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("<h1 style='text-align: center;'>🏥 Full Medical History Analyzer</h1>")
chatbot = gr.Chatbot(label="Summary Output", height=600)
file_upload = gr.File(label="Upload Excel File", file_types=[".xlsx"], file_count="single")
msg_input = gr.Textbox(label="Optional Message", placeholder="Add context or instructions...", lines=2)
send_btn = gr.Button("Analyze")
download_output = gr.File(label="Download Report")
def analyze(message: str, chat_history: List[Tuple[str, str]], file) -> Tuple[List[Tuple[str, str]], str]:
if not file:
raise gr.Error("Please upload an Excel file.")
new_history = chat_history + [(message, None)]
new_history.append((None, "⏳ Analyzing full patient history..."))
yield new_history, None
try:
prompt = parse_excel_as_whole_prompt(file.name)
full_output = ""
for result in agent.run_gradio_chat(
message=prompt,
history=[],
temperature=0.2,
max_new_tokens=2048,
max_token=4096,
call_agent=False,
conversation=[],
):
if isinstance(result, list):
for r in result:
if hasattr(r, 'content') and r.content:
full_output += clean_response(r.content) + "\n"
elif isinstance(result, str):
full_output += clean_response(result) + "\n"
new_history[-1] = (None, full_output.strip())
report_path = os.path.join(report_dir, f"{file_hash(file.name)}_final_report.txt")
with open(report_path, "w", encoding="utf-8") as f:
f.write(full_output.strip())
yield new_history, report_path
except Exception as e:
new_history.append((None, f"❌ Error during analysis: {str(e)}"))
yield new_history, None
send_btn.click(analyze, inputs=[msg_input, chatbot, file_upload], outputs=[chatbot, download_output])
msg_input.submit(analyze, inputs=[msg_input, chatbot, file_upload], outputs=[chatbot, download_output])
return demo
if __name__ == "__main__":
agent = init_agent()
demo = create_ui(agent)
demo.queue(api_open=False).launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
allowed_paths=[report_dir],
share=False
)