|
import sys |
|
import os |
|
import pandas as pd |
|
import json |
|
import gradio as gr |
|
from typing import List |
|
import hashlib |
|
import shutil |
|
import re |
|
from datetime import datetime |
|
import time |
|
|
|
persistent_dir = "/data/hf_cache" |
|
os.makedirs(persistent_dir, exist_ok=True) |
|
|
|
model_cache_dir = os.path.join(persistent_dir, "txagent_models") |
|
tool_cache_dir = os.path.join(persistent_dir, "tool_cache") |
|
file_cache_dir = os.path.join(persistent_dir, "cache") |
|
report_dir = os.path.join(persistent_dir, "reports") |
|
|
|
for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]: |
|
os.makedirs(directory, exist_ok=True) |
|
|
|
os.environ["HF_HOME"] = model_cache_dir |
|
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir |
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
src_path = os.path.abspath(os.path.join(current_dir, "src")) |
|
sys.path.insert(0, src_path) |
|
|
|
from txagent.txagent import TxAgent |
|
|
|
def file_hash(path: str) -> str: |
|
with open(path, "rb") as f: |
|
return hashlib.md5(f.read()).hexdigest() |
|
|
|
def clean_response(text: str) -> str: |
|
text = text.encode("utf-8", "ignore").decode("utf-8") |
|
text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL) |
|
text = re.sub(r"\n{3,}", "\n\n", text) |
|
text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text) |
|
return text.strip() |
|
|
|
def parse_excel_to_prompts(file_path: str) -> List[str]: |
|
xl = pd.ExcelFile(file_path) |
|
df = xl.parse(xl.sheet_names[0], header=0).fillna("") |
|
groups = df.groupby("Booking Number") |
|
prompts = [] |
|
for booking, group in groups: |
|
records = [] |
|
for _, row in group.iterrows(): |
|
records.append(f"- {row['Form Name']}: {row['Form Item']} = {row['Item Response']} ({row['Interview Date']} by {row['Interviewer']})\n{row['Description']}") |
|
record_text = "\n".join(records) |
|
prompt = f""" |
|
Patient Booking Number: {booking} |
|
|
|
Instructions: |
|
Analyze the following patient case for missed diagnoses, medication conflicts, incomplete assessments, and any urgent follow-up needed. Summarize under the markdown headings. |
|
|
|
Data: |
|
{record_text} |
|
|
|
### Missed Diagnoses |
|
- ... |
|
|
|
### Medication Conflicts |
|
- ... |
|
|
|
### Incomplete Assessments |
|
- ... |
|
|
|
### Urgent Follow-up |
|
- ... |
|
""" |
|
prompts.append(prompt) |
|
return prompts |
|
|
|
def init_agent(): |
|
default_tool_path = os.path.abspath("data/new_tool.json") |
|
target_tool_path = os.path.join(tool_cache_dir, "new_tool.json") |
|
if not os.path.exists(target_tool_path): |
|
shutil.copy(default_tool_path, target_tool_path) |
|
agent = TxAgent( |
|
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", |
|
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", |
|
tool_files_dict={"new_tool": target_tool_path}, |
|
force_finish=True, |
|
enable_checker=True, |
|
step_rag_num=4, |
|
seed=100, |
|
additional_default_tools=[], |
|
) |
|
agent.init_model() |
|
return agent |
|
|
|
def create_ui(agent): |
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("<h1 style='text-align: center;'>\ud83e\uddfa Clinical Oversight Assistant (Excel Optimized)</h1>") |
|
chatbot = gr.Chatbot(label="Analysis", height=600, type="messages") |
|
file_upload = gr.File(file_types=[".xlsx"], file_count="single") |
|
msg_input = gr.Textbox(placeholder="Ask about patient history...", show_label=False) |
|
send_btn = gr.Button("Analyze", variant="primary") |
|
download_output = gr.File(label="Download Full Report") |
|
|
|
def analyze(message: str, history: List[dict], file) -> tuple: |
|
history.append({"role": "user", "content": message}) |
|
history.append({"role": "assistant", "content": "⏳ Processing Excel data..."}) |
|
yield history, None |
|
|
|
prompts = parse_excel_to_prompts(file.name) |
|
full_output = "" |
|
|
|
for idx, prompt in enumerate(prompts, 1): |
|
chunk_output = "" |
|
for result in agent.run_gradio_chat( |
|
message=prompt, |
|
history=[], |
|
temperature=0.2, |
|
max_new_tokens=1024, |
|
max_token=4096, |
|
call_agent=False, |
|
conversation=[], |
|
): |
|
if isinstance(result, list): |
|
for r in result: |
|
if hasattr(r, 'content') and r.content: |
|
chunk_output += clean_response(r.content) + "\n" |
|
elif isinstance(result, str): |
|
chunk_output += clean_response(result) + "\n" |
|
if chunk_output: |
|
output = f"--- Booking {idx} ---\n{chunk_output.strip()}\n" |
|
history.append({"role": "assistant", "content": output}) |
|
full_output += output + "\n" |
|
yield history, None |
|
|
|
file_hash_value = file_hash(file.name) |
|
report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") |
|
with open(report_path, "w", encoding="utf-8") as f: |
|
f.write(full_output) |
|
yield history, report_path if os.path.exists(report_path) else None |
|
|
|
send_btn.click(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output]) |
|
msg_input.submit(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output]) |
|
return demo |
|
|
|
if __name__ == "__main__": |
|
agent = init_agent() |
|
demo = create_ui(agent) |
|
demo.queue(api_open=False).launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
show_error=True, |
|
allowed_paths=[report_dir], |
|
share=False |
|
) |