File size: 8,819 Bytes
f18c2fd
 
2ce0a4e
f18c2fd
 
be8f191
 
f18c2fd
 
 
 
a6968c2
f18c2fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41eb6bd
f18c2fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463c8b4
f18c2fd
 
 
 
 
 
463c8b4
a8cd932
f18c2fd
463c8b4
 
f18c2fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67dd49b
f18c2fd
 
 
 
 
 
 
 
 
 
a6968c2
fe67870
e24be23
f18c2fd
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import os, sys, json, re, gc, time, hashlib, logging, shutil, subprocess, multiprocessing as mp
from typing import List
from concurrent.futures import ThreadPoolExecutor, as_completed

import fitz                                     # ⇐ PyMuPDF
import pyarrow as pa
import pyarrow.csv as pc
import pyarrow.dataset as ds
import pandas as pd
import torch, gradio as gr, psutil, numpy as np
from diskcache import Cache

# ──────────────────────────────  CONSTANTS  ────────────────────────────── #
PERSIST = "/data/hf_cache"
MODEL_CACHE   = os.path.join(PERSIST, "txagent_models")
TOOL_CACHE    = os.path.join(PERSIST, "tool_cache")
FILE_CACHE    = os.path.join(PERSIST, "preprocessed")
REPORT_DIR    = os.path.join(PERSIST, "reports")
VLLM_CACHEDIR = os.path.join(PERSIST, "vllm_cache")

for d in (MODEL_CACHE, TOOL_CACHE, FILE_CACHE, REPORT_DIR, VLLM_CACHEDIR):
    os.makedirs(d, exist_ok=True)

os.environ.update(
    HF_HOME               = MODEL_CACHE,
    TRANSFORMERS_CACHE    = MODEL_CACHE,
    VLLM_CACHE_DIR        = VLLM_CACHEDIR,
    TOKENIZERS_PARALLELISM= "false",
    CUDA_LAUNCH_BLOCKING  = "1",
)

# put local `src/` first
ROOT = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(ROOT, "src"))
from txagent.txagent import TxAgent                        # noqa: E402

# ───────────────────────────────  LOGGING  ─────────────────────────────── #
logging.basicConfig(format="%(asctime)s β€” %(levelname)s β€” %(message)s",
                    level=logging.INFO)
log = logging.getLogger("app")

# ────────────────────────────────  CACHE  ──────────────────────────────── #
cache = Cache(FILE_CACHE, size_limit=20 * 1024 ** 3)       # 20Β GB

# ───────────────────────────────  HELPERS  ─────────────────────────────── #
def md5(path: str) -> str:
    h = hashlib.md5()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(1 << 20), b""):
            h.update(chunk)
    return h.hexdigest()

# β€”β€”β€” PDF β€”β€”β€” #
def _extract_pg(args):
    path, pg_no = args
    with fitz.open(path) as doc:
        page = doc.load_page(pg_no)
        text = page.get_text("text")
    return pg_no, f"=== Page {pg_no+1} ===\n{text.strip()}"

def pdf_to_txt(path: str, progress=None) -> str:
    doc = fitz.open(path)
    total = doc.page_count
    with mp.Pool() as pool:
        for pg_no, txt in pool.imap_unordered(_extract_pg, [(path, i) for i in range(total)]):
            if progress: progress(pg_no+1, total)
            cache.set((path, "pg", pg_no, os.path.getmtime(path)), txt)
    pages = [cache[(path, "pg", i, os.path.getmtime(path))] for i in range(total)]
    return "\n\n".join(pages)

# β€”β€”β€” CSV/XLSX β€”β€”β€” #
def csv_to_arrow(path: str) -> pa.Table:
    return pc.read_csv(path, read_options=pc.ReadOptions(block_size=1 << 24))  # 16Β MiB

def excel_to_arrow(path: str) -> pa.Table:
    # openpyxl is C‑based; fallback to xlrd only for .xls
    df = pd.read_excel(path, engine="openpyxl" if path.endswith("x") else "xlrd", dtype=str)
    return pa.Table.from_pandas(df.fillna(""))

def table_to_rows(tbl: pa.Table) -> List[List[str]]:
    cols = [col.to_pylist() for col in tbl.columns]
    return [list(r) for r in zip(*cols)]

def load_tabular(path: str) -> List[List[str]]:
    key = (path, os.path.getmtime(path))
    if key in cache:
        return cache[key]
    tbl = csv_to_arrow(path) if path.endswith("csv") else excel_to_arrow(path)
    rows = table_to_rows(tbl)
    cache[key] = rows
    return rows

# β€”β€”β€” CLEANERS β€”β€”β€” #
def strip_tool_noise(txt: str) -> str:
    txt = re.sub(r"\[.*?TOOL.*?]", "", txt, flags=re.S)
    txt = re.sub(r"\s+", " ", txt).strip()
    return txt

def summarize(findings: List[str]) -> str:
    uniq = list(dict.fromkeys(findings))      # preserve order, dedupe
    if not uniq:
        return "No missed diagnoses identified."
    if len(uniq) == 1:
        return f"Missed diagnosis: {uniq[0]}."
    return ("Missed diagnoses include " +
            ", ".join(uniq[:-1]) +
            f", and {uniq[-1]}. Please review urgently.")

# β€”β€”β€” MONITOR β€”β€”β€” #
def sys_usage(tag=""):
    cpu = psutil.cpu_percent()
    mem = psutil.virtual_memory()
    log.info("[%s] CPU %.1f%% β€” RAM %.0f/%.0fΒ GB",
             tag, cpu, mem.used/1e9, mem.total/1e9)

# ───────────────────────────────  AGENT  ──────────────────────────────── #
def init_agent() -> TxAgent:
    sys_usage("before‑load")
    agent = TxAgent(
        model_name       ="mims-harvard/TxAgent-T1-Llama-3.1-8B",
        rag_model_name   ="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
        step_rag_num     =4,
        force_finish     =True,
        enable_checker   =False,
        seed             =42
    )
    agent.init_model()
    sys_usage("after‑load")
    return agent

AGENT = init_agent()

# ───────────────────────────────  GRADIO  ─────────────────────────────── #
prompt_tpl = (
"Analyze the following excerpt (chunkΒ {idx}/{tot}) and list **only** missed diagnoses "
"with clinical finding + implication in one sentence each.\n\n{chunk}"
)

def analyze(user_msg, chat_hist, files, prog=gr.Progress()):
    chat_hist.append({"role":"user", "content":user_msg})
    yield chat_hist, None, ""

    # β€”β€”β€” ingest files β€”β€”β€” #
    extracted = ""
    if files:
        for f in files:
            ext = f.name.lower().split(".")[-1]
            if ext == "pdf":
                txt = pdf_to_txt(f.name,
                        progress=lambda cur, tot: prog(cur/tot, desc=f"PDF {cur}/{tot}"))
                extracted += txt + "\n"
            elif ext in ("csv", "xls", "xlsx"):
                rows = load_tabular(f.name)
                extracted += "\n".join(",".join(r) for r in rows) + "\n"
        chat_hist.append({"role":"assistant", "content":"βœ…Β Files parsed"})
        yield chat_hist, None, ""

    # β€”β€”β€” chunk & batch β€”β€”β€” #
    max_tokens = 6000
    chunks = [extracted[i:i+max_tokens] for i in range(0, len(extracted), max_tokens)]
    findings = []
    for i in range(0, len(chunks), 4):                     # batch of 4
        batch = chunks[i:i+4]
        prompts = [prompt_tpl.format(idx=i+j+1, tot=len(chunks), chunk=c[:4000])
                   for j,c in enumerate(batch)]

        with torch.inference_mode():
            outs = [list(AGENT.run_gradio_chat(p, [], 0.2, 512, 2048, False, []))[-1]
                    for p in prompts]

        for out in outs:
            if out and hasattr(out, "content"):
                clean = strip_tool_noise(out.content)
                if clean and "No missed" not in clean:
                    findings.append(clean)

        prog((i+len(batch))/len(chunks), desc=f"LLM {i+len(batch)}/{len(chunks)}")

    summary = summarize(findings)
    chat_hist.append({"role":"assistant", "content":summary})

    # save full
    if files:
        fn_hash = md5(files[0].name)
        p = os.path.join(REPORT_DIR, f"{fn_hash}_report.txt")
        with open(p, "w") as w:
            w.write("\n".join(findings) + "\n\n" + summary)
        yield chat_hist, p, summary
    else:
        yield chat_hist, None, summary

def ui():
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown("<h1 style='text-align:center'>🩺 Clinical Oversight Assistant</h1>")
        chat   = gr.Chatbot(height=600, label="Detailed analysis", type="messages")
        summ   = gr.Markdown(label="Summary of missed diagnoses")
        files  = gr.File(file_types=[".pdf",".csv",".xls",".xlsx"], file_count="multiple")
        txtbox = gr.Textbox(placeholder="Ask about potential oversights…", show_label=False)
        run    = gr.Button("Analyze", variant="primary")
        dl     = gr.File(label="Download full report")

        run.click(analyze, [txtbox, gr.State([]), files], [chat, dl, summ])
        txtbox.submit(analyze, [txtbox, gr.State([]), files], [chat, dl, summ])
    return demo

if __name__ == "__main__":
    ui().queue(api_open=False).launch(
        server_name="0.0.0.0", server_port=7860,
        allowed_paths=[REPORT_DIR], show_error=True, share=False)