|
import os, sys, json, re, gc, time, hashlib, logging, shutil, subprocess, multiprocessing as mp |
|
from typing import List |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
|
import fitz |
|
import pyarrow as pa |
|
import pyarrow.csv as pc |
|
import pyarrow.dataset as ds |
|
import pandas as pd |
|
import torch, gradio as gr, psutil, numpy as np |
|
from diskcache import Cache |
|
|
|
|
|
PERSIST = "/data/hf_cache" |
|
MODEL_CACHE = os.path.join(PERSIST, "txagent_models") |
|
TOOL_CACHE = os.path.join(PERSIST, "tool_cache") |
|
FILE_CACHE = os.path.join(PERSIST, "preprocessed") |
|
REPORT_DIR = os.path.join(PERSIST, "reports") |
|
VLLM_CACHEDIR = os.path.join(PERSIST, "vllm_cache") |
|
|
|
for d in (MODEL_CACHE, TOOL_CACHE, FILE_CACHE, REPORT_DIR, VLLM_CACHEDIR): |
|
os.makedirs(d, exist_ok=True) |
|
|
|
os.environ.update( |
|
HF_HOME = MODEL_CACHE, |
|
TRANSFORMERS_CACHE = MODEL_CACHE, |
|
VLLM_CACHE_DIR = VLLM_CACHEDIR, |
|
TOKENIZERS_PARALLELISM= "false", |
|
CUDA_LAUNCH_BLOCKING = "1", |
|
) |
|
|
|
|
|
ROOT = os.path.dirname(os.path.abspath(__file__)) |
|
sys.path.insert(0, os.path.join(ROOT, "src")) |
|
from txagent.txagent import TxAgent |
|
|
|
|
|
logging.basicConfig(format="%(asctime)s β %(levelname)s β %(message)s", |
|
level=logging.INFO) |
|
log = logging.getLogger("app") |
|
|
|
|
|
cache = Cache(FILE_CACHE, size_limit=20 * 1024 ** 3) |
|
|
|
|
|
def md5(path: str) -> str: |
|
h = hashlib.md5() |
|
with open(path, "rb") as f: |
|
for chunk in iter(lambda: f.read(1 << 20), b""): |
|
h.update(chunk) |
|
return h.hexdigest() |
|
|
|
|
|
def _extract_pg(args): |
|
path, pg_no = args |
|
with fitz.open(path) as doc: |
|
page = doc.load_page(pg_no) |
|
text = page.get_text("text") |
|
return pg_no, f"=== Page {pg_no+1} ===\n{text.strip()}" |
|
|
|
def pdf_to_txt(path: str, progress=None) -> str: |
|
doc = fitz.open(path) |
|
total = doc.page_count |
|
with mp.Pool() as pool: |
|
for pg_no, txt in pool.imap_unordered(_extract_pg, [(path, i) for i in range(total)]): |
|
if progress: progress(pg_no+1, total) |
|
cache.set((path, "pg", pg_no, os.path.getmtime(path)), txt) |
|
pages = [cache[(path, "pg", i, os.path.getmtime(path))] for i in range(total)] |
|
return "\n\n".join(pages) |
|
|
|
|
|
def csv_to_arrow(path: str) -> pa.Table: |
|
return pc.read_csv(path, read_options=pc.ReadOptions(block_size=1 << 24)) |
|
|
|
def excel_to_arrow(path: str) -> pa.Table: |
|
|
|
df = pd.read_excel(path, engine="openpyxl" if path.endswith("x") else "xlrd", dtype=str) |
|
return pa.Table.from_pandas(df.fillna("")) |
|
|
|
def table_to_rows(tbl: pa.Table) -> List[List[str]]: |
|
cols = [col.to_pylist() for col in tbl.columns] |
|
return [list(r) for r in zip(*cols)] |
|
|
|
def load_tabular(path: str) -> List[List[str]]: |
|
key = (path, os.path.getmtime(path)) |
|
if key in cache: |
|
return cache[key] |
|
tbl = csv_to_arrow(path) if path.endswith("csv") else excel_to_arrow(path) |
|
rows = table_to_rows(tbl) |
|
cache[key] = rows |
|
return rows |
|
|
|
|
|
def strip_tool_noise(txt: str) -> str: |
|
txt = re.sub(r"\[.*?TOOL.*?]", "", txt, flags=re.S) |
|
txt = re.sub(r"\s+", " ", txt).strip() |
|
return txt |
|
|
|
def summarize(findings: List[str]) -> str: |
|
uniq = list(dict.fromkeys(findings)) |
|
if not uniq: |
|
return "No missed diagnoses identified." |
|
if len(uniq) == 1: |
|
return f"Missed diagnosis: {uniq[0]}." |
|
return ("Missed diagnoses include " + |
|
", ".join(uniq[:-1]) + |
|
f", and {uniq[-1]}. Please review urgently.") |
|
|
|
|
|
def sys_usage(tag=""): |
|
cpu = psutil.cpu_percent() |
|
mem = psutil.virtual_memory() |
|
log.info("[%s] CPU %.1f%% β RAM %.0f/%.0fΒ GB", |
|
tag, cpu, mem.used/1e9, mem.total/1e9) |
|
|
|
|
|
def init_agent() -> TxAgent: |
|
sys_usage("beforeβload") |
|
agent = TxAgent( |
|
model_name ="mims-harvard/TxAgent-T1-Llama-3.1-8B", |
|
rag_model_name ="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", |
|
step_rag_num =4, |
|
force_finish =True, |
|
enable_checker =False, |
|
seed =42 |
|
) |
|
agent.init_model() |
|
sys_usage("afterβload") |
|
return agent |
|
|
|
AGENT = init_agent() |
|
|
|
|
|
prompt_tpl = ( |
|
"Analyze the following excerpt (chunkΒ {idx}/{tot}) and list **only** missed diagnoses " |
|
"with clinical finding + implication in one sentence each.\n\n{chunk}" |
|
) |
|
|
|
def analyze(user_msg, chat_hist, files, prog=gr.Progress()): |
|
chat_hist.append({"role":"user", "content":user_msg}) |
|
yield chat_hist, None, "" |
|
|
|
|
|
extracted = "" |
|
if files: |
|
for f in files: |
|
ext = f.name.lower().split(".")[-1] |
|
if ext == "pdf": |
|
txt = pdf_to_txt(f.name, |
|
progress=lambda cur, tot: prog(cur/tot, desc=f"PDF {cur}/{tot}")) |
|
extracted += txt + "\n" |
|
elif ext in ("csv", "xls", "xlsx"): |
|
rows = load_tabular(f.name) |
|
extracted += "\n".join(",".join(r) for r in rows) + "\n" |
|
chat_hist.append({"role":"assistant", "content":"β
Β Files parsed"}) |
|
yield chat_hist, None, "" |
|
|
|
|
|
max_tokens = 6000 |
|
chunks = [extracted[i:i+max_tokens] for i in range(0, len(extracted), max_tokens)] |
|
findings = [] |
|
for i in range(0, len(chunks), 4): |
|
batch = chunks[i:i+4] |
|
prompts = [prompt_tpl.format(idx=i+j+1, tot=len(chunks), chunk=c[:4000]) |
|
for j,c in enumerate(batch)] |
|
|
|
with torch.inference_mode(): |
|
outs = [list(AGENT.run_gradio_chat(p, [], 0.2, 512, 2048, False, []))[-1] |
|
for p in prompts] |
|
|
|
for out in outs: |
|
if out and hasattr(out, "content"): |
|
clean = strip_tool_noise(out.content) |
|
if clean and "No missed" not in clean: |
|
findings.append(clean) |
|
|
|
prog((i+len(batch))/len(chunks), desc=f"LLM {i+len(batch)}/{len(chunks)}") |
|
|
|
summary = summarize(findings) |
|
chat_hist.append({"role":"assistant", "content":summary}) |
|
|
|
|
|
if files: |
|
fn_hash = md5(files[0].name) |
|
p = os.path.join(REPORT_DIR, f"{fn_hash}_report.txt") |
|
with open(p, "w") as w: |
|
w.write("\n".join(findings) + "\n\n" + summary) |
|
yield chat_hist, p, summary |
|
else: |
|
yield chat_hist, None, summary |
|
|
|
def ui(): |
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("<h1 style='text-align:center'>π©ΊΒ ClinicalΒ OversightΒ Assistant</h1>") |
|
chat = gr.Chatbot(height=600, label="Detailed analysis", type="messages") |
|
summ = gr.Markdown(label="Summary of missed diagnoses") |
|
files = gr.File(file_types=[".pdf",".csv",".xls",".xlsx"], file_count="multiple") |
|
txtbox = gr.Textbox(placeholder="Ask about potential oversightsβ¦", show_label=False) |
|
run = gr.Button("Analyze", variant="primary") |
|
dl = gr.File(label="Download full report") |
|
|
|
run.click(analyze, [txtbox, gr.State([]), files], [chat, dl, summ]) |
|
txtbox.submit(analyze, [txtbox, gr.State([]), files], [chat, dl, summ]) |
|
return demo |
|
|
|
if __name__ == "__main__": |
|
ui().queue(api_open=False).launch( |
|
server_name="0.0.0.0", server_port=7860, |
|
allowed_paths=[REPORT_DIR], show_error=True, share=False) |
|
|