CPS-Test-Mobile / app.py
Ali2206's picture
Update app.py
f18c2fd verified
raw
history blame
8.82 kB
import os, sys, json, re, gc, time, hashlib, logging, shutil, subprocess, multiprocessing as mp
from typing import List
from concurrent.futures import ThreadPoolExecutor, as_completed
import fitz # ⇐ PyMuPDF
import pyarrow as pa
import pyarrow.csv as pc
import pyarrow.dataset as ds
import pandas as pd
import torch, gradio as gr, psutil, numpy as np
from diskcache import Cache
# ────────────────────────────── CONSTANTS ────────────────────────────── #
PERSIST = "/data/hf_cache"
MODEL_CACHE = os.path.join(PERSIST, "txagent_models")
TOOL_CACHE = os.path.join(PERSIST, "tool_cache")
FILE_CACHE = os.path.join(PERSIST, "preprocessed")
REPORT_DIR = os.path.join(PERSIST, "reports")
VLLM_CACHEDIR = os.path.join(PERSIST, "vllm_cache")
for d in (MODEL_CACHE, TOOL_CACHE, FILE_CACHE, REPORT_DIR, VLLM_CACHEDIR):
os.makedirs(d, exist_ok=True)
os.environ.update(
HF_HOME = MODEL_CACHE,
TRANSFORMERS_CACHE = MODEL_CACHE,
VLLM_CACHE_DIR = VLLM_CACHEDIR,
TOKENIZERS_PARALLELISM= "false",
CUDA_LAUNCH_BLOCKING = "1",
)
# put local `src/` first
ROOT = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(ROOT, "src"))
from txagent.txagent import TxAgent # noqa: E402
# ─────────────────────────────── LOGGING ─────────────────────────────── #
logging.basicConfig(format="%(asctime)s β€” %(levelname)s β€” %(message)s",
level=logging.INFO)
log = logging.getLogger("app")
# ──────────────────────────────── CACHE ──────────────────────────────── #
cache = Cache(FILE_CACHE, size_limit=20 * 1024 ** 3) # 20Β GB
# ─────────────────────────────── HELPERS ─────────────────────────────── #
def md5(path: str) -> str:
h = hashlib.md5()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(1 << 20), b""):
h.update(chunk)
return h.hexdigest()
# β€”β€”β€” PDF β€”β€”β€” #
def _extract_pg(args):
path, pg_no = args
with fitz.open(path) as doc:
page = doc.load_page(pg_no)
text = page.get_text("text")
return pg_no, f"=== Page {pg_no+1} ===\n{text.strip()}"
def pdf_to_txt(path: str, progress=None) -> str:
doc = fitz.open(path)
total = doc.page_count
with mp.Pool() as pool:
for pg_no, txt in pool.imap_unordered(_extract_pg, [(path, i) for i in range(total)]):
if progress: progress(pg_no+1, total)
cache.set((path, "pg", pg_no, os.path.getmtime(path)), txt)
pages = [cache[(path, "pg", i, os.path.getmtime(path))] for i in range(total)]
return "\n\n".join(pages)
# β€”β€”β€” CSV/XLSX β€”β€”β€” #
def csv_to_arrow(path: str) -> pa.Table:
return pc.read_csv(path, read_options=pc.ReadOptions(block_size=1 << 24)) # 16Β MiB
def excel_to_arrow(path: str) -> pa.Table:
# openpyxl is C‑based; fallback to xlrd only for .xls
df = pd.read_excel(path, engine="openpyxl" if path.endswith("x") else "xlrd", dtype=str)
return pa.Table.from_pandas(df.fillna(""))
def table_to_rows(tbl: pa.Table) -> List[List[str]]:
cols = [col.to_pylist() for col in tbl.columns]
return [list(r) for r in zip(*cols)]
def load_tabular(path: str) -> List[List[str]]:
key = (path, os.path.getmtime(path))
if key in cache:
return cache[key]
tbl = csv_to_arrow(path) if path.endswith("csv") else excel_to_arrow(path)
rows = table_to_rows(tbl)
cache[key] = rows
return rows
# β€”β€”β€” CLEANERS β€”β€”β€” #
def strip_tool_noise(txt: str) -> str:
txt = re.sub(r"\[.*?TOOL.*?]", "", txt, flags=re.S)
txt = re.sub(r"\s+", " ", txt).strip()
return txt
def summarize(findings: List[str]) -> str:
uniq = list(dict.fromkeys(findings)) # preserve order, dedupe
if not uniq:
return "No missed diagnoses identified."
if len(uniq) == 1:
return f"Missed diagnosis: {uniq[0]}."
return ("Missed diagnoses include " +
", ".join(uniq[:-1]) +
f", and {uniq[-1]}. Please review urgently.")
# β€”β€”β€” MONITOR β€”β€”β€” #
def sys_usage(tag=""):
cpu = psutil.cpu_percent()
mem = psutil.virtual_memory()
log.info("[%s] CPU %.1f%% β€” RAM %.0f/%.0fΒ GB",
tag, cpu, mem.used/1e9, mem.total/1e9)
# ─────────────────────────────── AGENT ──────────────────────────────── #
def init_agent() -> TxAgent:
sys_usage("before‑load")
agent = TxAgent(
model_name ="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name ="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
step_rag_num =4,
force_finish =True,
enable_checker =False,
seed =42
)
agent.init_model()
sys_usage("after‑load")
return agent
AGENT = init_agent()
# ─────────────────────────────── GRADIO ─────────────────────────────── #
prompt_tpl = (
"Analyze the following excerpt (chunkΒ {idx}/{tot}) and list **only** missed diagnoses "
"with clinical finding + implication in one sentence each.\n\n{chunk}"
)
def analyze(user_msg, chat_hist, files, prog=gr.Progress()):
chat_hist.append({"role":"user", "content":user_msg})
yield chat_hist, None, ""
# β€”β€”β€” ingest files β€”β€”β€” #
extracted = ""
if files:
for f in files:
ext = f.name.lower().split(".")[-1]
if ext == "pdf":
txt = pdf_to_txt(f.name,
progress=lambda cur, tot: prog(cur/tot, desc=f"PDF {cur}/{tot}"))
extracted += txt + "\n"
elif ext in ("csv", "xls", "xlsx"):
rows = load_tabular(f.name)
extracted += "\n".join(",".join(r) for r in rows) + "\n"
chat_hist.append({"role":"assistant", "content":"βœ…Β Files parsed"})
yield chat_hist, None, ""
# β€”β€”β€” chunk & batch β€”β€”β€” #
max_tokens = 6000
chunks = [extracted[i:i+max_tokens] for i in range(0, len(extracted), max_tokens)]
findings = []
for i in range(0, len(chunks), 4): # batch of 4
batch = chunks[i:i+4]
prompts = [prompt_tpl.format(idx=i+j+1, tot=len(chunks), chunk=c[:4000])
for j,c in enumerate(batch)]
with torch.inference_mode():
outs = [list(AGENT.run_gradio_chat(p, [], 0.2, 512, 2048, False, []))[-1]
for p in prompts]
for out in outs:
if out and hasattr(out, "content"):
clean = strip_tool_noise(out.content)
if clean and "No missed" not in clean:
findings.append(clean)
prog((i+len(batch))/len(chunks), desc=f"LLM {i+len(batch)}/{len(chunks)}")
summary = summarize(findings)
chat_hist.append({"role":"assistant", "content":summary})
# save full
if files:
fn_hash = md5(files[0].name)
p = os.path.join(REPORT_DIR, f"{fn_hash}_report.txt")
with open(p, "w") as w:
w.write("\n".join(findings) + "\n\n" + summary)
yield chat_hist, p, summary
else:
yield chat_hist, None, summary
def ui():
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("<h1 style='text-align:center'>🩺 Clinical Oversight Assistant</h1>")
chat = gr.Chatbot(height=600, label="Detailed analysis", type="messages")
summ = gr.Markdown(label="Summary of missed diagnoses")
files = gr.File(file_types=[".pdf",".csv",".xls",".xlsx"], file_count="multiple")
txtbox = gr.Textbox(placeholder="Ask about potential oversights…", show_label=False)
run = gr.Button("Analyze", variant="primary")
dl = gr.File(label="Download full report")
run.click(analyze, [txtbox, gr.State([]), files], [chat, dl, summ])
txtbox.submit(analyze, [txtbox, gr.State([]), files], [chat, dl, summ])
return demo
if __name__ == "__main__":
ui().queue(api_open=False).launch(
server_name="0.0.0.0", server_port=7860,
allowed_paths=[REPORT_DIR], show_error=True, share=False)