File size: 7,093 Bytes
cf5094d
 
782e103
dae38a2
1da2cfd
dae38a2
cf5094d
e24be23
 
 
 
65a2e99
f05e804
dae38a2
65a2e99
 
e24be23
cf5094d
1da2cfd
 
1fa1ea5
f05e804
1da2cfd
 
 
dae38a2
cf5094d
 
 
dae38a2
cf5094d
 
dae38a2
1fa1ea5
cf5094d
 
dae38a2
1fa1ea5
1da2cfd
 
1fa1ea5
1da2cfd
1fa1ea5
1da2cfd
1fa1ea5
cf5094d
1fa1ea5
 
1da2cfd
 
e24be23
1fa1ea5
dae38a2
 
e24be23
1fa1ea5
dae38a2
1da2cfd
 
1ebbef1
1da2cfd
 
cf5094d
 
dae38a2
cf5094d
 
dae38a2
 
 
1fa1ea5
dae38a2
 
cf5094d
1da2cfd
cf5094d
1da2cfd
cf5094d
1fa1ea5
1da2cfd
cf5094d
 
 
e24be23
 
 
 
cf5094d
e24be23
 
cf5094d
 
e24be23
 
 
 
1fa1ea5
e24be23
 
 
 
1fa1ea5
 
 
 
 
 
cf5094d
d14e134
cf5094d
 
 
1fa1ea5
d14e134
1fa1ea5
d14e134
1fa1ea5
d14e134
cf5094d
1fa1ea5
cf5094d
1fa1ea5
cf5094d
 
1da2cfd
cf5094d
 
 
 
 
b90a0eb
cf5094d
 
1fa1ea5
cf5094d
 
 
1fa1ea5
782e103
cf5094d
d14e134
782e103
1bb8be7
cf5094d
 
dae38a2
e24be23
 
cf5094d
1fa1ea5
e24be23
 
 
e778114
 
782e103
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# ✅ Fully optimized app.py for Hugging Face Space with persistent 150GB storage

import sys, os, json, gradio as gr, pandas as pd, pdfplumber, hashlib, shutil, re, time
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Thread

# Use /data for persistent HF storage
base_dir = "/data"
model_cache_dir = os.path.join(base_dir, "txagent_models")
tool_cache_dir = os.path.join(base_dir, "tool_cache")
file_cache_dir = os.path.join(base_dir, "cache")
report_dir = os.path.join(base_dir, "reports")
vllm_cache_dir = os.path.join(base_dir, "vllm_cache")

for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
    os.makedirs(d, exist_ok=True)

# Set persistent HF + VLLM cache
os.environ.update({
    "HF_HOME": model_cache_dir,
    "TRANSFORMERS_CACHE": model_cache_dir,
    "VLLM_CACHE_DIR": vllm_cache_dir,
    "TOKENIZERS_PARALLELISM": "false",
    "CUDA_LAUNCH_BLOCKING": "1"
})

# Force local loading only
LOCAL_TXAGENT_PATH = os.path.join(model_cache_dir, "mims-harvard", "TxAgent-T1-Llama-3.1-8B")
LOCAL_RAG_PATH = os.path.join(model_cache_dir, "mims-harvard", "ToolRAG-T1-GTE-Qwen2-1.5B")

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src")))
from txagent.txagent import TxAgent

def file_hash(path): return hashlib.md5(open(path, "rb").read()).hexdigest()
def sanitize_utf8(text): return text.encode("utf-8", "ignore").decode("utf-8")
MEDICAL_KEYWORDS = {"diagnosis", "assessment", "plan", "results", "medications", "summary", "findings"}

def extract_priority_pages(file_path, max_pages=20):
    try:
        with pdfplumber.open(file_path) as pdf:
            pages = []
            for i, page in enumerate(pdf.pages[:3]):
                pages.append(f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}")
            for i, page in enumerate(pdf.pages[3:max_pages], start=4):
                text = page.extract_text() or ""
                if any(re.search(rf'\\b{kw}\\b', text.lower()) for kw in MEDICAL_KEYWORDS):
                    pages.append(f"=== Page {i} ===\n{text.strip()}")
            return "\n\n".join(pages)
    except Exception as e:
        return f"PDF processing error: {str(e)}"

def convert_file_to_json(file_path, file_type):
    try:
        h = file_hash(file_path)
        cache_path = os.path.join(file_cache_dir, f"{h}.json")
        if os.path.exists(cache_path): return open(cache_path, "r", encoding="utf-8").read()

        if file_type == "pdf":
            text = extract_priority_pages(file_path)
            result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
            Thread(target=full_pdf_processing, args=(file_path, h)).start()
        elif file_type == "csv":
            df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str)
            result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna('').astype(str).values.tolist()})
        elif file_type in ["xls", "xlsx"]:
            df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
            result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna('').astype(str).values.tolist()})
        else:
            return json.dumps({"error": f"Unsupported file type: {file_type}"})

        with open(cache_path, "w", encoding="utf-8") as f: f.write(result)
        return result
    except Exception as e:
        return json.dumps({"error": str(e)})

def full_pdf_processing(file_path, h):
    try:
        cache_path = os.path.join(file_cache_dir, f"{h}_full.json")
        if os.path.exists(cache_path): return
        with pdfplumber.open(file_path) as pdf:
            full_text = "\n".join([f"=== Page {i+1} ===\n{(p.extract_text() or '').strip()}" for i, p in enumerate(pdf.pages)])
        with open(cache_path, "w", encoding="utf-8") as f: f.write(json.dumps({"content": full_text}))
    except: pass

def init_agent():
    target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
    if not os.path.exists(target_tool_path):
        shutil.copy(os.path.abspath("data/new_tool.json"), target_tool_path)

    agent = TxAgent(
        model_name=LOCAL_TXAGENT_PATH,
        rag_model_name=LOCAL_RAG_PATH,
        tool_files_dict={"new_tool": target_tool_path},
        force_finish=True,
        enable_checker=True,
        step_rag_num=8,
        seed=100
    )
    agent.init_model()
    return agent

agent_container = {"agent": None}
def get_agent():
    if agent_container["agent"] is None:
        agent_container["agent"] = init_agent()
    return agent_container["agent"]

def create_ui():
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown("""<h1 style='text-align:center;'>🩺 Clinical Oversight Assistant</h1>""")
        chatbot = gr.Chatbot(label="Analysis", height=600)
        msg_input = gr.Textbox(placeholder="Ask a question about the patient...")
        file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
        send_btn = gr.Button("Analyze", variant="primary")
        state = gr.State([])

        def analyze(message, history, conversation, files):
            try:
                extracted, hval = "", ""
                if files:
                    with ThreadPoolExecutor(max_workers=3) as pool:
                        futures = [pool.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower()) for f in files]
                        extracted = "\n".join([sanitize_utf8(f.result()) for f in as_completed(futures)])
                        hval = file_hash(files[0].name)

                prompt = f"""Review these medical records and identify exactly what might have been missed:
1. Missed diagnoses
2. Medication conflicts
3. Incomplete assessments
4. Abnormal results needing follow-up

Medical Records:\n{extracted[:15000]}
"""
                final_response = ""
                for chunk in get_agent().run_gradio_chat(prompt, history=[], temperature=0.2, max_new_tokens=1024, max_token=4096, call_agent=False, conversation=conversation):
                    if isinstance(chunk, str): final_response += chunk
                    elif isinstance(chunk, list): final_response += "".join([c.content for c in chunk if hasattr(c, 'content')])
                cleaned = final_response.replace("[TOOL_CALLS]", "").strip()
                updated_history = history + [[message, cleaned]]
                return updated_history, None
            except Exception as e:
                return history + [[message, f"❌ Error: {str(e)}"]], None

        send_btn.click(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, gr.File()])
        msg_input.submit(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, gr.File()])
    return demo

if __name__ == "__main__":
    ui = create_ui()
    ui.queue(api_open=False).launch(
        server_name="0.0.0.0",
        server_port=7860,
        show_error=True,
        allowed_paths=["/data/reports"],
        share=False
    )