File size: 8,819 Bytes
f18c2fd 2ce0a4e f18c2fd be8f191 f18c2fd a6968c2 f18c2fd 41eb6bd f18c2fd 463c8b4 f18c2fd 463c8b4 a8cd932 f18c2fd 463c8b4 f18c2fd 67dd49b f18c2fd a6968c2 fe67870 e24be23 f18c2fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import os, sys, json, re, gc, time, hashlib, logging, shutil, subprocess, multiprocessing as mp
from typing import List
from concurrent.futures import ThreadPoolExecutor, as_completed
import fitz # β PyMuPDF
import pyarrow as pa
import pyarrow.csv as pc
import pyarrow.dataset as ds
import pandas as pd
import torch, gradio as gr, psutil, numpy as np
from diskcache import Cache
# ββββββββββββββββββββββββββββββ CONSTANTS ββββββββββββββββββββββββββββββ #
PERSIST = "/data/hf_cache"
MODEL_CACHE = os.path.join(PERSIST, "txagent_models")
TOOL_CACHE = os.path.join(PERSIST, "tool_cache")
FILE_CACHE = os.path.join(PERSIST, "preprocessed")
REPORT_DIR = os.path.join(PERSIST, "reports")
VLLM_CACHEDIR = os.path.join(PERSIST, "vllm_cache")
for d in (MODEL_CACHE, TOOL_CACHE, FILE_CACHE, REPORT_DIR, VLLM_CACHEDIR):
os.makedirs(d, exist_ok=True)
os.environ.update(
HF_HOME = MODEL_CACHE,
TRANSFORMERS_CACHE = MODEL_CACHE,
VLLM_CACHE_DIR = VLLM_CACHEDIR,
TOKENIZERS_PARALLELISM= "false",
CUDA_LAUNCH_BLOCKING = "1",
)
# put local `src/` first
ROOT = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(ROOT, "src"))
from txagent.txagent import TxAgent # noqa: E402
# βββββββββββββββββββββββββββββββ LOGGING βββββββββββββββββββββββββββββββ #
logging.basicConfig(format="%(asctime)s β %(levelname)s β %(message)s",
level=logging.INFO)
log = logging.getLogger("app")
# ββββββββββββββββββββββββββββββββ CACHE ββββββββββββββββββββββββββββββββ #
cache = Cache(FILE_CACHE, size_limit=20 * 1024 ** 3) # 20Β GB
# βββββββββββββββββββββββββββββββ HELPERS βββββββββββββββββββββββββββββββ #
def md5(path: str) -> str:
h = hashlib.md5()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(1 << 20), b""):
h.update(chunk)
return h.hexdigest()
# βββ PDF βββ #
def _extract_pg(args):
path, pg_no = args
with fitz.open(path) as doc:
page = doc.load_page(pg_no)
text = page.get_text("text")
return pg_no, f"=== Page {pg_no+1} ===\n{text.strip()}"
def pdf_to_txt(path: str, progress=None) -> str:
doc = fitz.open(path)
total = doc.page_count
with mp.Pool() as pool:
for pg_no, txt in pool.imap_unordered(_extract_pg, [(path, i) for i in range(total)]):
if progress: progress(pg_no+1, total)
cache.set((path, "pg", pg_no, os.path.getmtime(path)), txt)
pages = [cache[(path, "pg", i, os.path.getmtime(path))] for i in range(total)]
return "\n\n".join(pages)
# βββ CSV/XLSX βββ #
def csv_to_arrow(path: str) -> pa.Table:
return pc.read_csv(path, read_options=pc.ReadOptions(block_size=1 << 24)) # 16Β MiB
def excel_to_arrow(path: str) -> pa.Table:
# openpyxl is Cβbased; fallback to xlrd only for .xls
df = pd.read_excel(path, engine="openpyxl" if path.endswith("x") else "xlrd", dtype=str)
return pa.Table.from_pandas(df.fillna(""))
def table_to_rows(tbl: pa.Table) -> List[List[str]]:
cols = [col.to_pylist() for col in tbl.columns]
return [list(r) for r in zip(*cols)]
def load_tabular(path: str) -> List[List[str]]:
key = (path, os.path.getmtime(path))
if key in cache:
return cache[key]
tbl = csv_to_arrow(path) if path.endswith("csv") else excel_to_arrow(path)
rows = table_to_rows(tbl)
cache[key] = rows
return rows
# βββ CLEANERS βββ #
def strip_tool_noise(txt: str) -> str:
txt = re.sub(r"\[.*?TOOL.*?]", "", txt, flags=re.S)
txt = re.sub(r"\s+", " ", txt).strip()
return txt
def summarize(findings: List[str]) -> str:
uniq = list(dict.fromkeys(findings)) # preserve order, dedupe
if not uniq:
return "No missed diagnoses identified."
if len(uniq) == 1:
return f"Missed diagnosis: {uniq[0]}."
return ("Missed diagnoses include " +
", ".join(uniq[:-1]) +
f", and {uniq[-1]}. Please review urgently.")
# βββ MONITOR βββ #
def sys_usage(tag=""):
cpu = psutil.cpu_percent()
mem = psutil.virtual_memory()
log.info("[%s] CPU %.1f%% β RAM %.0f/%.0fΒ GB",
tag, cpu, mem.used/1e9, mem.total/1e9)
# βββββββββββββββββββββββββββββββ AGENT ββββββββββββββββββββββββββββββββ #
def init_agent() -> TxAgent:
sys_usage("beforeβload")
agent = TxAgent(
model_name ="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name ="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
step_rag_num =4,
force_finish =True,
enable_checker =False,
seed =42
)
agent.init_model()
sys_usage("afterβload")
return agent
AGENT = init_agent()
# βββββββββββββββββββββββββββββββ GRADIO βββββββββββββββββββββββββββββββ #
prompt_tpl = (
"Analyze the following excerpt (chunkΒ {idx}/{tot}) and list **only** missed diagnoses "
"with clinical finding + implication in one sentence each.\n\n{chunk}"
)
def analyze(user_msg, chat_hist, files, prog=gr.Progress()):
chat_hist.append({"role":"user", "content":user_msg})
yield chat_hist, None, ""
# βββ ingest files βββ #
extracted = ""
if files:
for f in files:
ext = f.name.lower().split(".")[-1]
if ext == "pdf":
txt = pdf_to_txt(f.name,
progress=lambda cur, tot: prog(cur/tot, desc=f"PDF {cur}/{tot}"))
extracted += txt + "\n"
elif ext in ("csv", "xls", "xlsx"):
rows = load_tabular(f.name)
extracted += "\n".join(",".join(r) for r in rows) + "\n"
chat_hist.append({"role":"assistant", "content":"β
Β Files parsed"})
yield chat_hist, None, ""
# βββ chunk & batch βββ #
max_tokens = 6000
chunks = [extracted[i:i+max_tokens] for i in range(0, len(extracted), max_tokens)]
findings = []
for i in range(0, len(chunks), 4): # batch of 4
batch = chunks[i:i+4]
prompts = [prompt_tpl.format(idx=i+j+1, tot=len(chunks), chunk=c[:4000])
for j,c in enumerate(batch)]
with torch.inference_mode():
outs = [list(AGENT.run_gradio_chat(p, [], 0.2, 512, 2048, False, []))[-1]
for p in prompts]
for out in outs:
if out and hasattr(out, "content"):
clean = strip_tool_noise(out.content)
if clean and "No missed" not in clean:
findings.append(clean)
prog((i+len(batch))/len(chunks), desc=f"LLM {i+len(batch)}/{len(chunks)}")
summary = summarize(findings)
chat_hist.append({"role":"assistant", "content":summary})
# save full
if files:
fn_hash = md5(files[0].name)
p = os.path.join(REPORT_DIR, f"{fn_hash}_report.txt")
with open(p, "w") as w:
w.write("\n".join(findings) + "\n\n" + summary)
yield chat_hist, p, summary
else:
yield chat_hist, None, summary
def ui():
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("<h1 style='text-align:center'>π©ΊΒ ClinicalΒ OversightΒ Assistant</h1>")
chat = gr.Chatbot(height=600, label="Detailed analysis", type="messages")
summ = gr.Markdown(label="Summary of missed diagnoses")
files = gr.File(file_types=[".pdf",".csv",".xls",".xlsx"], file_count="multiple")
txtbox = gr.Textbox(placeholder="Ask about potential oversightsβ¦", show_label=False)
run = gr.Button("Analyze", variant="primary")
dl = gr.File(label="Download full report")
run.click(analyze, [txtbox, gr.State([]), files], [chat, dl, summ])
txtbox.submit(analyze, [txtbox, gr.State([]), files], [chat, dl, summ])
return demo
if __name__ == "__main__":
ui().queue(api_open=False).launch(
server_name="0.0.0.0", server_port=7860,
allowed_paths=[REPORT_DIR], show_error=True, share=False)
|