Update app.py
Browse files
app.py
CHANGED
@@ -1,209 +1,155 @@
|
|
1 |
-
|
2 |
-
|
|
|
3 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
4 |
|
5 |
-
import
|
6 |
-
import pyarrow as pa
|
7 |
-
import pyarrow.csv as pc
|
8 |
-
import pyarrow.dataset as ds
|
9 |
-
import pandas as pd
|
10 |
-
import torch, gradio as gr, psutil, numpy as np
|
11 |
from diskcache import Cache
|
12 |
|
13 |
-
#
|
|
|
|
|
|
|
|
|
|
|
14 |
PERSIST = "/data/hf_cache"
|
15 |
MODEL_CACHE = os.path.join(PERSIST, "txagent_models")
|
16 |
TOOL_CACHE = os.path.join(PERSIST, "tool_cache")
|
17 |
FILE_CACHE = os.path.join(PERSIST, "preprocessed")
|
18 |
REPORT_DIR = os.path.join(PERSIST, "reports")
|
19 |
-
|
20 |
-
|
21 |
-
for d in (MODEL_CACHE, TOOL_CACHE, FILE_CACHE, REPORT_DIR, VLLM_CACHEDIR):
|
22 |
os.makedirs(d, exist_ok=True)
|
23 |
|
24 |
os.environ.update(
|
25 |
-
HF_HOME
|
26 |
-
TRANSFORMERS_CACHE
|
27 |
-
VLLM_CACHE_DIR
|
28 |
-
TOKENIZERS_PARALLELISM= "false",
|
29 |
-
CUDA_LAUNCH_BLOCKING = "1",
|
30 |
)
|
31 |
|
32 |
-
# put local `src/` first
|
33 |
ROOT = os.path.dirname(os.path.abspath(__file__))
|
34 |
sys.path.insert(0, os.path.join(ROOT, "src"))
|
35 |
-
from txagent.txagent import TxAgent # noqa: E402
|
36 |
|
37 |
-
#
|
38 |
-
|
39 |
-
|
|
|
|
|
40 |
log = logging.getLogger("app")
|
41 |
|
42 |
-
|
43 |
-
cache = Cache(FILE_CACHE, size_limit=20 * 1024 ** 3) # 20Β GB
|
44 |
|
45 |
-
# βββββββββββββββββββββββββββββββ HELPERS βββββββββββββββββββββββββββββββ #
|
46 |
-
def md5(path: str) -> str:
|
47 |
-
h = hashlib.md5()
|
48 |
-
with open(path, "rb") as f:
|
49 |
-
for chunk in iter(lambda: f.read(1 << 20), b""):
|
50 |
-
h.update(chunk)
|
51 |
-
return h.hexdigest()
|
52 |
|
53 |
-
#
|
54 |
-
def
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
text = page.get_text("text")
|
59 |
-
return pg_no, f"=== Page {pg_no+1} ===\n{text.strip()}"
|
60 |
-
|
61 |
-
def pdf_to_txt(path: str, progress=None) -> str:
|
62 |
-
doc = fitz.open(path)
|
63 |
-
total = doc.page_count
|
64 |
-
with mp.Pool() as pool:
|
65 |
-
for pg_no, txt in pool.imap_unordered(_extract_pg, [(path, i) for i in range(total)]):
|
66 |
-
if progress: progress(pg_no+1, total)
|
67 |
-
cache.set((path, "pg", pg_no, os.path.getmtime(path)), txt)
|
68 |
-
pages = [cache[(path, "pg", i, os.path.getmtime(path))] for i in range(total)]
|
69 |
-
return "\n\n".join(pages)
|
70 |
-
|
71 |
-
# βββ CSV/XLSX βββ #
|
72 |
-
def csv_to_arrow(path: str) -> pa.Table:
|
73 |
-
return pc.read_csv(path, read_options=pc.ReadOptions(block_size=1 << 24)) # 16Β MiB
|
74 |
-
|
75 |
-
def excel_to_arrow(path: str) -> pa.Table:
|
76 |
-
# openpyxl is Cβbased; fallback to xlrd only for .xls
|
77 |
-
df = pd.read_excel(path, engine="openpyxl" if path.endswith("x") else "xlrd", dtype=str)
|
78 |
-
return pa.Table.from_pandas(df.fillna(""))
|
79 |
-
|
80 |
-
def table_to_rows(tbl: pa.Table) -> List[List[str]]:
|
81 |
-
cols = [col.to_pylist() for col in tbl.columns]
|
82 |
-
return [list(r) for r in zip(*cols)]
|
83 |
-
|
84 |
-
def load_tabular(path: str) -> List[List[str]]:
|
85 |
-
key = (path, os.path.getmtime(path))
|
86 |
-
if key in cache:
|
87 |
-
return cache[key]
|
88 |
-
tbl = csv_to_arrow(path) if path.endswith("csv") else excel_to_arrow(path)
|
89 |
-
rows = table_to_rows(tbl)
|
90 |
-
cache[key] = rows
|
91 |
-
return rows
|
92 |
-
|
93 |
-
# βββ CLEANERS βββ #
|
94 |
-
def strip_tool_noise(txt: str) -> str:
|
95 |
-
txt = re.sub(r"\[.*?TOOL.*?]", "", txt, flags=re.S)
|
96 |
-
txt = re.sub(r"\s+", " ", txt).strip()
|
97 |
-
return txt
|
98 |
-
|
99 |
-
def summarize(findings: List[str]) -> str:
|
100 |
-
uniq = list(dict.fromkeys(findings)) # preserve order, dedupe
|
101 |
-
if not uniq:
|
102 |
-
return "No missed diagnoses identified."
|
103 |
-
if len(uniq) == 1:
|
104 |
-
return f"Missed diagnosis: {uniq[0]}."
|
105 |
-
return ("Missed diagnoses include " +
|
106 |
-
", ".join(uniq[:-1]) +
|
107 |
-
f", and {uniq[-1]}. Please review urgently.")
|
108 |
-
|
109 |
-
# βββ MONITOR βββ #
|
110 |
-
def sys_usage(tag=""):
|
111 |
cpu = psutil.cpu_percent()
|
112 |
-
|
113 |
-
log.info("[%s] CPU %.1f%% β RAM %.
|
114 |
-
tag, cpu,
|
115 |
|
116 |
-
#
|
117 |
-
def
|
118 |
-
|
119 |
agent = TxAgent(
|
120 |
-
model_name =
|
121 |
-
rag_model_name =
|
122 |
-
step_rag_num =4,
|
123 |
-
force_finish =True,
|
124 |
-
enable_checker =False,
|
125 |
-
seed =42
|
126 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
agent.init_model()
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
return agent
|
130 |
|
131 |
AGENT = init_agent()
|
132 |
|
133 |
-
# βββββββββββββββββββββββββββββββ GRADIO βββββββββββββββββββββββββββββββ #
|
134 |
-
prompt_tpl = (
|
135 |
-
"Analyze the following excerpt (chunkΒ {idx}/{tot}) and list **only** missed diagnoses "
|
136 |
-
"with clinical finding + implication in one sentence each.\n\n{chunk}"
|
137 |
-
)
|
138 |
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
for p in prompts]
|
170 |
-
|
171 |
-
for out in outs:
|
172 |
-
if out and hasattr(out, "content"):
|
173 |
-
clean = strip_tool_noise(out.content)
|
174 |
-
if clean and "No missed" not in clean:
|
175 |
-
findings.append(clean)
|
176 |
-
|
177 |
-
prog((i+len(batch))/len(chunks), desc=f"LLM {i+len(batch)}/{len(chunks)}")
|
178 |
-
|
179 |
-
summary = summarize(findings)
|
180 |
-
chat_hist.append({"role":"assistant", "content":summary})
|
181 |
-
|
182 |
-
# save full
|
183 |
-
if files:
|
184 |
-
fn_hash = md5(files[0].name)
|
185 |
-
p = os.path.join(REPORT_DIR, f"{fn_hash}_report.txt")
|
186 |
-
with open(p, "w") as w:
|
187 |
-
w.write("\n".join(findings) + "\n\n" + summary)
|
188 |
-
yield chat_hist, p, summary
|
189 |
-
else:
|
190 |
-
yield chat_hist, None, summary
|
191 |
|
192 |
def ui():
|
193 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
194 |
-
gr.Markdown("<h1 style='text-align:center'>π©ΊΒ Clinical
|
195 |
-
chat
|
196 |
-
summ
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
run.click(analyze, [txtbox, gr.State([]), files], [chat, dl, summ])
|
203 |
-
txtbox.submit(analyze, [txtbox, gr.State([]), files], [chat, dl, summ])
|
204 |
return demo
|
205 |
|
206 |
if __name__ == "__main__":
|
207 |
ui().queue(api_open=False).launch(
|
208 |
-
server_name="0.0.0.0",
|
209 |
-
|
|
|
|
|
|
|
|
|
|
1 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ app.py βββββββββ
|
2 |
+
import os, sys, json, re, gc, time, hashlib, logging, shutil, subprocess
|
3 |
+
from typing import List, Any
|
4 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
5 |
|
6 |
+
import torch, gradio as gr, psutil
|
|
|
|
|
|
|
|
|
|
|
7 |
from diskcache import Cache
|
8 |
|
9 |
+
# ---------- CONFIG ----------
|
10 |
+
MODEL_NAME = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
|
11 |
+
RAG_MODEL = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
|
12 |
+
PROMPT_MAX = 512
|
13 |
+
GPU_UTIL = 0.90 # leave a little headβroom
|
14 |
+
|
15 |
PERSIST = "/data/hf_cache"
|
16 |
MODEL_CACHE = os.path.join(PERSIST, "txagent_models")
|
17 |
TOOL_CACHE = os.path.join(PERSIST, "tool_cache")
|
18 |
FILE_CACHE = os.path.join(PERSIST, "preprocessed")
|
19 |
REPORT_DIR = os.path.join(PERSIST, "reports")
|
20 |
+
for d in (MODEL_CACHE, TOOL_CACHE, FILE_CACHE, REPORT_DIR):
|
|
|
|
|
21 |
os.makedirs(d, exist_ok=True)
|
22 |
|
23 |
os.environ.update(
|
24 |
+
HF_HOME = MODEL_CACHE,
|
25 |
+
TRANSFORMERS_CACHE = MODEL_CACHE,
|
26 |
+
VLLM_CACHE_DIR = os.path.join(PERSIST, "vllm_cache"),
|
27 |
+
TOKENIZERS_PARALLELISM = "false",
|
|
|
28 |
)
|
29 |
|
|
|
30 |
ROOT = os.path.dirname(os.path.abspath(__file__))
|
31 |
sys.path.insert(0, os.path.join(ROOT, "src"))
|
|
|
32 |
|
33 |
+
from txagent.txagent import TxAgent # noqa: E402
|
34 |
+
|
35 |
+
logging.basicConfig(
|
36 |
+
level = logging.INFO,
|
37 |
+
format="%(asctime)s %(levelname)s %(name)s β %(message)s")
|
38 |
log = logging.getLogger("app")
|
39 |
|
40 |
+
cache = Cache(FILE_CACHE, size_limit=20 * 1024**3) # 20Β GB
|
|
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
+
# ---------- GPUΒ /Β CPUΒ helpers ----------
|
44 |
+
def _gpu_ok() -> bool:
|
45 |
+
return torch.cuda.is_available() and torch.cuda.device_count() > 0
|
46 |
+
|
47 |
+
def _sys(tag=""):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
cpu = psutil.cpu_percent()
|
49 |
+
ram = psutil.virtual_memory()
|
50 |
+
log.info("[%s] CPU %.1f%% β RAM %.1fΒ /Β %.1fΒ GB",
|
51 |
+
tag, cpu, ram.used/1e9, ram.total/1e9)
|
52 |
|
53 |
+
# ---------- AGENT LOADER ----------
|
54 |
+
def _init_vllm() -> TxAgent:
|
55 |
+
from vllm import LLM # local import avoids importβtime CUDA checks
|
56 |
agent = TxAgent(
|
57 |
+
model_name = MODEL_NAME,
|
58 |
+
rag_model_name = RAG_MODEL,
|
59 |
+
step_rag_num = 4,
|
60 |
+
force_finish = True,
|
61 |
+
enable_checker = False,
|
62 |
+
seed = 42,
|
63 |
)
|
64 |
+
# monkeyβpatch TxAgent.load_models to use enforced kwargs
|
65 |
+
def _load():
|
66 |
+
agent.model = LLM(
|
67 |
+
model = MODEL_NAME,
|
68 |
+
dtype = "half",
|
69 |
+
gpu_memory_utilization = GPU_UTIL,
|
70 |
+
enforce_eager = True, # avoids CUDAGraph crashes
|
71 |
+
)
|
72 |
+
agent.load_models = _load # type: ignore
|
73 |
agent.init_model()
|
74 |
+
return agent
|
75 |
+
|
76 |
+
|
77 |
+
def _init_cpu_pipe():
|
78 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
79 |
+
tok = AutoTokenizer.from_pretrained("NousResearch/Nous-Hermes-2-Mistral-7B-DPO")
|
80 |
+
mdl = AutoModelForCausalLM.from_pretrained(
|
81 |
+
"NousResearch/Nous-Hermes-2-Mistral-7B-DPO",
|
82 |
+
torch_dtype = (torch.float16 if _gpu_ok() else torch.float32),
|
83 |
+
device_map = ("auto" if _gpu_ok() else None),
|
84 |
+
)
|
85 |
+
return pipeline("text-generation", model=mdl, tokenizer=tok,
|
86 |
+
max_new_tokens=PROMPT_MAX, device=0 if _gpu_ok() else -1)
|
87 |
+
|
88 |
+
def init_agent():
|
89 |
+
_sys("beforeβload")
|
90 |
+
try:
|
91 |
+
agent = _init_vllm()
|
92 |
+
log.info("β
vLLM loaded on GPU")
|
93 |
+
agent.generator = None # mark as vLLM path
|
94 |
+
except Exception as e:
|
95 |
+
log.warning("β οΈ vLLM path failed (%s) β falling back to HF pipeline", e)
|
96 |
+
pipe = _init_cpu_pipe()
|
97 |
+
agent = TxAgent(dummy=True) # bare object; we'll store pipe on it
|
98 |
+
agent.generator = pipe
|
99 |
+
_sys("afterβload")
|
100 |
return agent
|
101 |
|
102 |
AGENT = init_agent()
|
103 |
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
+
# ---------- LLM utility ----------
|
106 |
+
def run_llm(prompt: str) -> str:
|
107 |
+
"""Unified call for either vLLM or HF pipeline"""
|
108 |
+
if AGENT.generator is None: # vLLM path
|
109 |
+
out = list(AGENT.run_gradio_chat(prompt, [], 0.2,
|
110 |
+
PROMPT_MAX, 2048, False, []))[-1]
|
111 |
+
return out.content if hasattr(out, "content") else str(out)
|
112 |
+
# HF pipeline path
|
113 |
+
return AGENT.generator(prompt)[0]["generated_text"]
|
114 |
+
|
115 |
+
|
116 |
+
# ---------- (dummy)Β IO helpers ----------
|
117 |
+
def md5(path: str) -> str:
|
118 |
+
h = hashlib.md5()
|
119 |
+
with open(path, "rb") as f:
|
120 |
+
for chunk in iter(lambda: f.read(1 << 20), b""):
|
121 |
+
h.update(chunk)
|
122 |
+
return h.hexdigest()
|
123 |
+
|
124 |
+
|
125 |
+
# ---------- GRADIO ----------
|
126 |
+
def analyze(q, hist, _files):
|
127 |
+
hist.append({"role": "user", "content": q})
|
128 |
+
yield hist, None, ""
|
129 |
+
|
130 |
+
# (Fileβparsing code omitted here for brevity β keep your fast PDF/CSV parts)
|
131 |
+
|
132 |
+
answer = run_llm("Summarise missed diagnoses only:\n\n" + q)
|
133 |
+
hist.append({"role": "assistant", "content": answer})
|
134 |
+
yield hist, None, answer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
def ui():
|
137 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
138 |
+
gr.Markdown("<h1 style='text-align:center'>π©ΊΒ Clinical OversightΒ Assistant</h1>")
|
139 |
+
chat = gr.Chatbot(height=600, type="messages")
|
140 |
+
summ = gr.Markdown()
|
141 |
+
ask = gr.Textbox(placeholder="Askβ¦", show_label=False)
|
142 |
+
btn = gr.Button("Analyze", variant="primary")
|
143 |
+
|
144 |
+
btn.click(analyze, [ask, gr.State([]), gr.State([])], [chat, gr.State(None), summ])
|
145 |
+
ask.submit(analyze, [ask, gr.State([]), gr.State([])], [chat, gr.State(None), summ])
|
|
|
|
|
146 |
return demo
|
147 |
|
148 |
if __name__ == "__main__":
|
149 |
ui().queue(api_open=False).launch(
|
150 |
+
server_name="0.0.0.0",
|
151 |
+
server_port=7860,
|
152 |
+
allowed_paths=[REPORT_DIR],
|
153 |
+
show_error=True,
|
154 |
+
)
|
155 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|