File size: 5,651 Bytes
f75a23b
f394b25
 
f75a23b
f394b25
9a8092d
f394b25
f75a23b
 
1c5bd8e
f75a23b
e4d9325
a71a831
 
f75a23b
 
 
a71a831
 
f75a23b
1c5bd8e
499e72e
a71a831
f75a23b
 
 
 
 
 
 
 
 
a71a831
 
499e72e
828effe
1c5bd8e
afdc6ee
 
9a8092d
afdc6ee
1c5bd8e
 
 
 
 
befca65
 
 
 
 
 
 
 
 
 
1c5bd8e
 
befca65
1c5bd8e
befca65
1c5bd8e
e4d9325
1c5bd8e
 
12ddaba
1c5bd8e
 
e4d9325
1c5bd8e
 
e4d9325
1c5bd8e
 
 
befca65
f75a23b
 
 
 
 
9a8092d
 
 
 
 
 
 
 
 
 
 
 
 
f75a23b
 
befca65
 
 
 
 
 
 
 
9a8092d
afdc6ee
befca65
 
 
 
 
afdc6ee
befca65
afdc6ee
befca65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afdc6ee
befca65
 
afdc6ee
befca65
9a8092d
befca65
 
 
a71a831
55e3db0
f394b25
befca65
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import sys
import os
import pandas as pd
import json
import gradio as gr
from typing import List, Tuple
import hashlib
import shutil
import re
from datetime import datetime
import time

persistent_dir = "/data/hf_cache"
os.makedirs(persistent_dir, exist_ok=True)

model_cache_dir = os.path.join(persistent_dir, "txagent_models")
tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
file_cache_dir = os.path.join(persistent_dir, "cache")
report_dir = os.path.join(persistent_dir, "reports")

for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
    os.makedirs(directory, exist_ok=True)

os.environ["HF_HOME"] = model_cache_dir
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir

current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.abspath(os.path.join(current_dir, "src"))
sys.path.insert(0, src_path)

from txagent.txagent import TxAgent

def file_hash(path: str) -> str:
    with open(path, "rb") as f:
        return hashlib.md5(f.read()).hexdigest()

def clean_response(text: str) -> str:
    try:
        text = text.encode('utf-8', 'surrogatepass').decode('utf-8')
    except UnicodeError:
        text = text.encode('utf-8', 'replace').decode('utf-8')
    text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
    text = re.sub(r"\n{3,}", "\n\n", text)
    text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
    return text.strip()

def parse_excel_as_whole_prompt(file_path: str) -> str:
    xl = pd.ExcelFile(file_path)
    df = xl.parse(xl.sheet_names[0], header=0).fillna("")
    records = []
    for _, row in df.iterrows():
        record = f"- {row['Form Name']}: {row['Form Item']} = {row['Item Response']} ({row['Interview Date']} by {row['Interviewer']})\n{row['Description']}"
        records.append(clean_response(record))
    record_text = "\n".join(records)
    prompt = f"""
Patient Complete History:

Instructions:
Based on the complete patient record below, identify any potential missed diagnoses, medication conflicts, incomplete assessments, and urgent follow-up needs. Provide a clinical summary under the markdown headings.

Patient History:
{record_text}

### Missed Diagnoses
- ...

### Medication Conflicts
- ...

### Incomplete Assessments
- ...

### Urgent Follow-up
- ...
"""
    return prompt

def init_agent():
    default_tool_path = os.path.abspath("data/new_tool.json")
    target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
    if not os.path.exists(target_tool_path):
        shutil.copy(default_tool_path, target_tool_path)
    agent = TxAgent(
        model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
        rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
        tool_files_dict={"new_tool": target_tool_path},
        force_finish=True,
        enable_checker=True,
        step_rag_num=4,
        seed=100,
        additional_default_tools=[],
    )
    agent.init_model()
    return agent

def create_ui(agent):
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown("<h1 style='text-align: center;'>\ud83c\udfe5 Full Medical History Analyzer</h1>")
        chatbot = gr.Chatbot(label="Summary Output", height=600)
        file_upload = gr.File(label="Upload Excel File", file_types=[".xlsx"], file_count="single")
        msg_input = gr.Textbox(label="Optional Message", placeholder="Add context or instructions...", lines=2)
        send_btn = gr.Button("Analyze")
        download_output = gr.File(label="Download Report")

        def analyze(message: str, chat_history: List[Tuple[str, str]], file) -> Tuple[List[Tuple[str, str]], str]:
            if not file:
                raise gr.Error("Please upload an Excel file.")
            new_history = chat_history + [(message, None)]
            new_history.append((None, "⏳ Analyzing full patient history..."))
            yield new_history, None

            try:
                prompt = parse_excel_as_whole_prompt(file.name)
                full_output = ""
                for result in agent.run_gradio_chat(
                    message=prompt,
                    history=[],
                    temperature=0.2,
                    max_new_tokens=2048,
                    max_token=4096,
                    call_agent=False,
                    conversation=[],
                ):
                    if isinstance(result, list):
                        for r in result:
                            if hasattr(r, 'content') and r.content:
                                full_output += clean_response(r.content) + "\n"
                    elif isinstance(result, str):
                        full_output += clean_response(result) + "\n"

                new_history[-1] = (None, full_output.strip())
                report_path = os.path.join(report_dir, f"{file_hash(file.name)}_final_report.txt")
                with open(report_path, "w", encoding="utf-8") as f:
                    f.write(full_output.strip())
                yield new_history, report_path
            except Exception as e:
                new_history.append((None, f"❌ Error during analysis: {str(e)}"))
                yield new_history, None

        send_btn.click(analyze, inputs=[msg_input, chatbot, file_upload], outputs=[chatbot, download_output])
        msg_input.submit(analyze, inputs=[msg_input, chatbot, file_upload], outputs=[chatbot, download_output])
    return demo

if __name__ == "__main__":
    agent = init_agent()
    demo = create_ui(agent)
    demo.queue(api_open=False).launch(
        server_name="0.0.0.0",
        server_port=7860,
        show_error=True,
        allowed_paths=[report_dir],
        share=False
    )