import os, sys, json, re, gc, time, hashlib, logging, shutil, subprocess, multiprocessing as mp from typing import List from concurrent.futures import ThreadPoolExecutor, as_completed import fitz # ⇐ PyMuPDF import pyarrow as pa import pyarrow.csv as pc import pyarrow.dataset as ds import pandas as pd import torch, gradio as gr, psutil, numpy as np from diskcache import Cache # ────────────────────────────── CONSTANTS ────────────────────────────── # PERSIST = "/data/hf_cache" MODEL_CACHE = os.path.join(PERSIST, "txagent_models") TOOL_CACHE = os.path.join(PERSIST, "tool_cache") FILE_CACHE = os.path.join(PERSIST, "preprocessed") REPORT_DIR = os.path.join(PERSIST, "reports") VLLM_CACHEDIR = os.path.join(PERSIST, "vllm_cache") for d in (MODEL_CACHE, TOOL_CACHE, FILE_CACHE, REPORT_DIR, VLLM_CACHEDIR): os.makedirs(d, exist_ok=True) os.environ.update( HF_HOME = MODEL_CACHE, TRANSFORMERS_CACHE = MODEL_CACHE, VLLM_CACHE_DIR = VLLM_CACHEDIR, TOKENIZERS_PARALLELISM= "false", CUDA_LAUNCH_BLOCKING = "1", ) # put local `src/` first ROOT = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, os.path.join(ROOT, "src")) from txagent.txagent import TxAgent # noqa: E402 # ─────────────────────────────── LOGGING ─────────────────────────────── # logging.basicConfig(format="%(asctime)s — %(levelname)s — %(message)s", level=logging.INFO) log = logging.getLogger("app") # ──────────────────────────────── CACHE ──────────────────────────────── # cache = Cache(FILE_CACHE, size_limit=20 * 1024 ** 3) # 20 GB # ─────────────────────────────── HELPERS ─────────────────────────────── # def md5(path: str) -> str: h = hashlib.md5() with open(path, "rb") as f: for chunk in iter(lambda: f.read(1 << 20), b""): h.update(chunk) return h.hexdigest() # ——— PDF ——— # def _extract_pg(args): path, pg_no = args with fitz.open(path) as doc: page = doc.load_page(pg_no) text = page.get_text("text") return pg_no, f"=== Page {pg_no+1} ===\n{text.strip()}" def pdf_to_txt(path: str, progress=None) -> str: doc = fitz.open(path) total = doc.page_count with mp.Pool() as pool: for pg_no, txt in pool.imap_unordered(_extract_pg, [(path, i) for i in range(total)]): if progress: progress(pg_no+1, total) cache.set((path, "pg", pg_no, os.path.getmtime(path)), txt) pages = [cache[(path, "pg", i, os.path.getmtime(path))] for i in range(total)] return "\n\n".join(pages) # ——— CSV/XLSX ——— # def csv_to_arrow(path: str) -> pa.Table: return pc.read_csv(path, read_options=pc.ReadOptions(block_size=1 << 24)) # 16 MiB def excel_to_arrow(path: str) -> pa.Table: # openpyxl is C‑based; fallback to xlrd only for .xls df = pd.read_excel(path, engine="openpyxl" if path.endswith("x") else "xlrd", dtype=str) return pa.Table.from_pandas(df.fillna("")) def table_to_rows(tbl: pa.Table) -> List[List[str]]: cols = [col.to_pylist() for col in tbl.columns] return [list(r) for r in zip(*cols)] def load_tabular(path: str) -> List[List[str]]: key = (path, os.path.getmtime(path)) if key in cache: return cache[key] tbl = csv_to_arrow(path) if path.endswith("csv") else excel_to_arrow(path) rows = table_to_rows(tbl) cache[key] = rows return rows # ——— CLEANERS ——— # def strip_tool_noise(txt: str) -> str: txt = re.sub(r"\[.*?TOOL.*?]", "", txt, flags=re.S) txt = re.sub(r"\s+", " ", txt).strip() return txt def summarize(findings: List[str]) -> str: uniq = list(dict.fromkeys(findings)) # preserve order, dedupe if not uniq: return "No missed diagnoses identified." if len(uniq) == 1: return f"Missed diagnosis: {uniq[0]}." return ("Missed diagnoses include " + ", ".join(uniq[:-1]) + f", and {uniq[-1]}. Please review urgently.") # ——— MONITOR ——— # def sys_usage(tag=""): cpu = psutil.cpu_percent() mem = psutil.virtual_memory() log.info("[%s] CPU %.1f%% — RAM %.0f/%.0f GB", tag, cpu, mem.used/1e9, mem.total/1e9) # ─────────────────────────────── AGENT ──────────────────────────────── # def init_agent() -> TxAgent: sys_usage("before‑load") agent = TxAgent( model_name ="mims-harvard/TxAgent-T1-Llama-3.1-8B", rag_model_name ="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", step_rag_num =4, force_finish =True, enable_checker =False, seed =42 ) agent.init_model() sys_usage("after‑load") return agent AGENT = init_agent() # ─────────────────────────────── GRADIO ─────────────────────────────── # prompt_tpl = ( "Analyze the following excerpt (chunk {idx}/{tot}) and list **only** missed diagnoses " "with clinical finding + implication in one sentence each.\n\n{chunk}" ) def analyze(user_msg, chat_hist, files, prog=gr.Progress()): chat_hist.append({"role":"user", "content":user_msg}) yield chat_hist, None, "" # ——— ingest files ——— # extracted = "" if files: for f in files: ext = f.name.lower().split(".")[-1] if ext == "pdf": txt = pdf_to_txt(f.name, progress=lambda cur, tot: prog(cur/tot, desc=f"PDF {cur}/{tot}")) extracted += txt + "\n" elif ext in ("csv", "xls", "xlsx"): rows = load_tabular(f.name) extracted += "\n".join(",".join(r) for r in rows) + "\n" chat_hist.append({"role":"assistant", "content":"✅ Files parsed"}) yield chat_hist, None, "" # ——— chunk & batch ——— # max_tokens = 6000 chunks = [extracted[i:i+max_tokens] for i in range(0, len(extracted), max_tokens)] findings = [] for i in range(0, len(chunks), 4): # batch of 4 batch = chunks[i:i+4] prompts = [prompt_tpl.format(idx=i+j+1, tot=len(chunks), chunk=c[:4000]) for j,c in enumerate(batch)] with torch.inference_mode(): outs = [list(AGENT.run_gradio_chat(p, [], 0.2, 512, 2048, False, []))[-1] for p in prompts] for out in outs: if out and hasattr(out, "content"): clean = strip_tool_noise(out.content) if clean and "No missed" not in clean: findings.append(clean) prog((i+len(batch))/len(chunks), desc=f"LLM {i+len(batch)}/{len(chunks)}") summary = summarize(findings) chat_hist.append({"role":"assistant", "content":summary}) # save full if files: fn_hash = md5(files[0].name) p = os.path.join(REPORT_DIR, f"{fn_hash}_report.txt") with open(p, "w") as w: w.write("\n".join(findings) + "\n\n" + summary) yield chat_hist, p, summary else: yield chat_hist, None, summary def ui(): with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("