Ali2206 commited on
Commit
f18c2fd
Β·
verified Β·
1 Parent(s): be8f191

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -390
app.py CHANGED
@@ -1,404 +1,209 @@
1
- import sys
2
- import os
3
- import pandas as pd
4
- import pdfplumber
5
- import json
6
- import gradio as gr
7
- from typing import List, Tuple, Optional
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
- import hashlib
10
- import shutil
11
- import re
12
- import psutil
13
- import subprocess
14
- import logging
15
- import torch
16
- import gc
17
- from diskcache import Cache
18
- import time
19
  import pyarrow as pa
20
- import pyarrow.parquet as pq
21
  import pyarrow.csv as pc
22
- import numpy as np
23
-
24
- # Configure logging
25
- logging.basicConfig(level=logging.INFO)
26
- logger = logging.getLogger(__name__)
27
-
28
- # Persistent directory
29
- persistent_dir = "/data/hf_cache"
30
- os.makedirs(persistent_dir, exist_ok=True)
31
-
32
- model_cache_dir = os.path.join(persistent_dir, "txagent_models")
33
- tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
34
- file_cache_dir = os.path.join(persistent_dir, "cache")
35
- report_dir = os.path.join(persistent_dir, "reports")
36
- vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache")
37
-
38
- for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
39
- os.makedirs(directory, exist_ok=True)
40
-
41
- os.environ["HF_HOME"] = model_cache_dir
42
- os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
43
- os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
44
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
45
- os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
46
-
47
- current_dir = os.path.dirname(os.path.abspath(__file__))
48
- src_path = os.path.abspath(os.path.join(current_dir, "src"))
49
- sys.path.insert(0, src_path)
50
-
51
- from txagent.txagent import TxAgent
52
-
53
- # Initialize cache with 10GB limit
54
- cache = Cache(file_cache_dir, size_limit=10 * 1024**3)
55
-
56
- def sanitize_utf8(text: str) -> str:
57
- return text.encode("utf-8", "ignore").decode("utf-8")
58
 
59
- def file_hash(path: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  with open(path, "rb") as f:
61
- return hashlib.md5(f.read()).hexdigest()
62
-
63
- def extract_all_pages(file_path: str, progress_callback=None) -> str:
64
- try:
65
- with pdfplumber.open(file_path) as pdf:
66
- total_pages = len(pdf.pages)
67
- if total_pages == 0:
68
- return ""
69
-
70
- batch_size = 10
71
- batches = [(i, min(i + batch_size, total_pages)) for i in range(0, total_pages, batch_size)]
72
- text_chunks = [""] * total_pages
73
- processed_pages = 0
74
-
75
- def extract_batch(start: int, end: int) -> List[tuple]:
76
- results = []
77
- with pdfplumber.open(file_path) as pdf:
78
- for page in pdf.pages[start:end]:
79
- page_num = start + pdf.pages.index(page)
80
- page_text = page.extract_text() or ""
81
- results.append((page_num, f"=== Page {page_num + 1} ===\n{page_text.strip()}"))
82
- return results
83
-
84
- with ThreadPoolExecutor(max_workers=6) as executor:
85
- futures = [executor.submit(extract_batch, start, end) for start, end in batches]
86
- for future in as_completed(futures):
87
- for page_num, text in future.result():
88
- text_chunks[page_num] = text
89
- processed_pages += batch_size
90
- if progress_callback:
91
- progress_callback(min(processed_pages, total_pages), total_pages)
92
-
93
- return "\n\n".join(filter(None, text_chunks))
94
- except Exception as e:
95
- logger.error("PDF processing error: %s", e)
96
- return f"PDF processing error: {str(e)}"
97
-
98
- def excel_to_arrow(file_path: str) -> pa.Table:
99
- """Convert Excel file to Arrow table for faster processing"""
100
- try:
101
- # First try with openpyxl (faster for xlsx)
102
- try:
103
- df = pd.read_excel(file_path, engine='openpyxl', header=None, dtype=str)
104
- except Exception:
105
- # Fall back to xlrd if needed
106
- df = pd.read_excel(file_path, engine='xlrd', header=None, dtype=str)
107
-
108
- # Convert to Arrow table
109
- table = pa.Table.from_pandas(df.fillna(""))
110
- return table
111
- except Exception as e:
112
- logger.error(f"Error converting Excel to Arrow: {e}")
113
- raise
114
-
115
- def csv_to_arrow(file_path: str) -> pa.Table:
116
- """Convert CSV file to Arrow table for faster processing"""
117
- try:
118
- read_options = pc.ReadOptions(
119
- encoding='utf-8',
120
- invalid_row_handler=lambda x: None,
121
- column_names=[str(i) for i in range(1000)] # Generous column count
122
- )
123
- convert_options = pc.ConvertOptions(
124
- strings_can_be_null=True,
125
- quoted_strings_can_be_null=True,
126
- include_columns=None
127
- )
128
- table = pc.read_csv(
129
- file_path,
130
- read_options=read_options,
131
- convert_options=convert_options
132
- )
133
- return table
134
- except Exception as e:
135
- logger.error(f"Error converting CSV to Arrow: {e}")
136
- raise
137
-
138
- def convert_file_to_json(file_path: str, file_type: str, progress_callback=None) -> str:
139
- try:
140
- file_h = file_hash(file_path)
141
- cache_key = f"{file_h}_{file_type}"
142
- if cache_key in cache:
143
- return cache[cache_key]
144
-
145
- if file_type == "pdf":
146
- text = extract_all_pages(file_path, progress_callback)
147
- result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
148
- elif file_type in ["csv", "xls", "xlsx"]:
149
- # Use Arrow for tabular data processing
150
- start_time = time.time()
151
-
152
- if file_type == "csv":
153
- table = csv_to_arrow(file_path)
154
- else: # Excel files
155
- table = excel_to_arrow(file_path)
156
-
157
- # Convert to list of lists efficiently
158
- content = []
159
- for col in table.columns:
160
- content.append([str(x) if x is not None else "" for x in col.to_pylist()])
161
-
162
- # Transpose to get rows
163
- rows = list(map(list, zip(*content)))
164
-
165
- logger.info(f"Processed {len(rows)} rows in {time.time()-start_time:.2f}s")
166
- result = json.dumps({
167
- "filename": os.path.basename(file_path),
168
- "rows": rows,
169
- "arrow_processed": True # Flag for optimized processing
170
- })
171
- else:
172
- result = json.dumps({"error": f"Unsupported file type: {file_type}"})
173
-
174
- cache[cache_key] = result
175
- return result
176
- except Exception as e:
177
- logger.error("Error processing %s: %s", os.path.basename(file_path), e)
178
- return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
179
-
180
- def log_system_usage(tag=""):
181
- try:
182
- cpu = psutil.cpu_percent(interval=1)
183
- mem = psutil.virtual_memory()
184
- logger.info("[%s] CPU: %.1f%% | RAM: %dMB / %dMB", tag, cpu, mem.used // (1024**2), mem.total // (1024**2))
185
- result = subprocess.run(
186
- ["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"],
187
- capture_output=True, text=True
188
- )
189
- if result.returncode == 0:
190
- used, total, util = result.stdout.strip().split(", ")
191
- logger.info("[%s] GPU: %sMB / %sMB | Utilization: %s%%", tag, used, total, util)
192
- except Exception as e:
193
- logger.error("[%s] GPU/CPU monitor failed: %s", tag, e)
194
-
195
- def clean_response(text: str) -> str:
196
- text = sanitize_utf8(text)
197
- # Remove unwanted patterns and tool call artifacts
198
- text = re.sub(r"\[.*?\]|\bNone\b|To analyze the patient record excerpt.*?medications\.|Since the previous attempts.*?\.|I need to.*?medications\.|Retrieving tools.*?\.", "", text, flags=re.DOTALL)
199
- # Extract only missed diagnoses, ignoring other categories
200
- diagnoses = []
201
- lines = text.splitlines()
202
- in_diagnoses_section = False
203
- for line in lines:
204
- line = line.strip()
205
- if not line:
206
- continue
207
- if re.match(r"###\s*Missed Diagnoses", line):
208
- in_diagnoses_section = True
209
- continue
210
- if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line):
211
- in_diagnoses_section = False
212
- continue
213
- if in_diagnoses_section and re.match(r"-\s*.+", line):
214
- diagnosis = re.sub(r"^\-\s*", "", line).strip()
215
- if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE):
216
- diagnoses.append(diagnosis)
217
- # Join diagnoses into a plain text paragraph
218
- text = " ".join(diagnoses)
219
- # Clean up extra whitespace and punctuation
220
- text = re.sub(r"\s+", " ", text).strip()
221
- text = re.sub(r"[^\w\s\.\,\(\)\-]", "", text)
222
- return text if text else ""
223
-
224
- def summarize_findings(combined_response: str) -> str:
225
- # Split response by chunk analyses
226
- chunks = combined_response.split("--- Analysis for Chunk")
227
- diagnoses = []
228
- for chunk in chunks:
229
- chunk = chunk.strip()
230
- if not chunk or "No oversights identified" in chunk:
231
- continue
232
- # Extract missed diagnoses from chunk
233
- lines = chunk.splitlines()
234
- in_diagnoses_section = False
235
- for line in lines:
236
- line = line.strip()
237
- if not line:
238
- continue
239
- if re.match(r"###\s*Missed Diagnoses", line):
240
- in_diagnoses_section = True
241
- continue
242
- if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line):
243
- in_diagnoses_section = False
244
- continue
245
- if in_diagnoses_section and re.match(r"-\s*.+", line):
246
- diagnosis = re.sub(r"^\-\s*", "", line).strip()
247
- if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE):
248
- diagnoses.append(diagnosis)
249
-
250
- # Remove duplicates while preserving order
251
- seen = set()
252
- unique_diagnoses = [d for d in diagnoses if not (d in seen or seen.add(d))]
253
-
254
- if not unique_diagnoses:
255
- return "No missed diagnoses were identified in the provided records."
256
-
257
- # Combine into a single paragraph
258
- summary = "Missed diagnoses include " + ", ".join(unique_diagnoses[:-1])
259
- if len(unique_diagnoses) > 1:
260
- summary += f", and {unique_diagnoses[-1]}"
261
- elif len(unique_diagnoses) == 1:
262
- summary = "Missed diagnoses include " + unique_diagnoses[0]
263
- summary += ", all of which require urgent clinical review to prevent potential adverse outcomes."
264
-
265
- return summary.strip()
266
-
267
- def init_agent():
268
- logger.info("Initializing model...")
269
- log_system_usage("Before Load")
270
- default_tool_path = os.path.abspath("data/new_tool.json")
271
- target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
272
- if not os.path.exists(target_tool_path):
273
- shutil.copy(default_tool_path, target_tool_path)
274
-
275
  agent = TxAgent(
276
- model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
277
- rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
278
- tool_files_dict={"new_tool": target_tool_path},
279
- force_finish=True,
280
- enable_checker=False,
281
- step_rag_num=4,
282
- seed=100,
283
- additional_default_tools=[],
284
  )
285
  agent.init_model()
286
- log_system_usage("After Load")
287
- logger.info("Agent Ready")
288
  return agent
289
 
290
- def create_ui(agent):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
292
- gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
293
- chatbot = gr.Chatbot(label="Detailed Analysis", height=600, type="messages")
294
- final_summary = gr.Markdown(label="Summary of Missed Diagnoses")
295
- file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
296
- msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
297
- send_btn = gr.Button("Analyze", variant="primary")
298
- download_output = gr.File(label="Download Full Report")
299
- progress_bar = gr.Progress()
300
-
301
- prompt_template = """
302
- Analyze the patient record excerpt for missed diagnoses only. Provide a concise, evidence-based summary as a single paragraph without headings or bullet points. Include specific clinical findings (e.g., 'elevated blood pressure (160/95) on page 10'), their potential implications (e.g., 'may indicate untreated hypertension'), and a recommendation for urgent review. Do not include other oversight categories like medication conflicts. If no missed diagnoses are found, state 'No missed diagnoses identified' in a single sentence.
303
- Patient Record Excerpt (Chunk {0} of {1}):
304
- {chunk}
305
- """
306
-
307
- def analyze(message: str, history: List[dict], files: List, progress=gr.Progress()):
308
- history.append({"role": "user", "content": message})
309
- yield history, None, ""
310
-
311
- extracted = ""
312
- file_hash_value = ""
313
- if files:
314
- def update_extraction_progress(current, total):
315
- progress(current / total, desc=f"Extracting text... Page {current}/{total}")
316
- return history, None, ""
317
-
318
- with ThreadPoolExecutor(max_workers=6) as executor:
319
- futures = [executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower(), update_extraction_progress) for f in files]
320
- results = [sanitize_utf8(f.result()) for f in as_completed(futures)]
321
- extracted = "\n".join(results)
322
- file_hash_value = file_hash(files[0].name) if files else ""
323
-
324
- history.append({"role": "assistant", "content": "βœ… Text extraction complete."})
325
- yield history, None, ""
326
-
327
- chunk_size = 6000
328
- chunks = [extracted[i:i + chunk_size] for i in range(0, len(extracted), chunk_size)]
329
- combined_response = ""
330
- batch_size = 2
331
-
332
- try:
333
- for batch_idx in range(0, len(chunks), batch_size):
334
- batch_chunks = chunks[batch_idx:batch_idx + batch_size]
335
- batch_prompts = [prompt_template.format(i + 1, len(chunks), chunk=chunk[:4000]) for i, chunk in enumerate(batch_chunks)]
336
- batch_responses = []
337
-
338
- progress((batch_idx + 1) / len(chunks), desc=f"Analyzing chunks {batch_idx + 1}-{min(batch_idx + batch_size, len(chunks))}/{len(chunks)}")
339
-
340
- with ThreadPoolExecutor(max_workers=len(batch_chunks)) as executor:
341
- futures = [executor.submit(agent.run_gradio_chat, prompt, [], 0.2, 512, 2048, False, []) for prompt in batch_prompts]
342
- for future in as_completed(futures):
343
- chunk_response = ""
344
- for chunk_output in future.result():
345
- if chunk_output is None:
346
- continue
347
- if isinstance(chunk_output, list):
348
- for m in chunk_output:
349
- if hasattr(m, 'content') and m.content:
350
- cleaned = clean_response(m.content)
351
- if cleaned:
352
- chunk_response += cleaned + " "
353
- elif isinstance(chunk_output, str) and chunk_output.strip():
354
- cleaned = clean_response(chunk_output)
355
- if cleaned:
356
- chunk_response += cleaned + " "
357
- batch_responses.append(chunk_response.strip())
358
- torch.cuda.empty_cache()
359
- gc.collect()
360
-
361
- for chunk_idx, chunk_response in enumerate(batch_responses, batch_idx + 1):
362
- if chunk_response:
363
- combined_response += f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response}\n"
364
- else:
365
- combined_response += f"--- Analysis for Chunk {chunk_idx} ---\nNo missed diagnoses identified.\n"
366
- history[-1] = {"role": "assistant", "content": combined_response.strip()}
367
- yield history, None, ""
368
-
369
- if combined_response.strip() and not all("No missed diagnoses identified" in chunk for chunk in combined_response.split("--- Analysis for Chunk")):
370
- history[-1]["content"] = combined_response.strip()
371
- else:
372
- history.append({"role": "assistant", "content": "No missed diagnoses identified in the provided records."})
373
-
374
- summary = summarize_findings(combined_response)
375
- report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None
376
- if report_path:
377
- with open(report_path, "w", encoding="utf-8") as f:
378
- f.write(combined_response + "\n\n" + summary)
379
- yield history, report_path if report_path and os.path.exists(report_path) else None, summary
380
-
381
- except Exception as e:
382
- logger.error("Analysis error: %s", e)
383
- history.append({"role": "assistant", "content": f"❌ Error occurred: {str(e)}"})
384
- yield history, None, f"Error occurred during analysis: {str(e)}"
385
-
386
- send_btn.click(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary])
387
- msg_input.submit(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary])
388
  return demo
389
 
390
  if __name__ == "__main__":
391
- try:
392
- logger.info("Launching app...")
393
- agent = init_agent()
394
- demo = create_ui(agent)
395
- demo.queue(api_open=False).launch(
396
- server_name="0.0.0.0",
397
- server_port=7860,
398
- show_error=True,
399
- allowed_paths=[report_dir],
400
- share=False
401
- )
402
- finally:
403
- if torch.distributed.is_initialized():
404
- torch.distributed.destroy_process_group()
 
1
+ import os, sys, json, re, gc, time, hashlib, logging, shutil, subprocess, multiprocessing as mp
2
+ from typing import List
 
 
 
 
 
3
  from concurrent.futures import ThreadPoolExecutor, as_completed
4
+
5
+ import fitz # ⇐ PyMuPDF
 
 
 
 
 
 
 
 
6
  import pyarrow as pa
 
7
  import pyarrow.csv as pc
8
+ import pyarrow.dataset as ds
9
+ import pandas as pd
10
+ import torch, gradio as gr, psutil, numpy as np
11
+ from diskcache import Cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # ────────────────────────────── CONSTANTS ────────────────────────────── #
14
+ PERSIST = "/data/hf_cache"
15
+ MODEL_CACHE = os.path.join(PERSIST, "txagent_models")
16
+ TOOL_CACHE = os.path.join(PERSIST, "tool_cache")
17
+ FILE_CACHE = os.path.join(PERSIST, "preprocessed")
18
+ REPORT_DIR = os.path.join(PERSIST, "reports")
19
+ VLLM_CACHEDIR = os.path.join(PERSIST, "vllm_cache")
20
+
21
+ for d in (MODEL_CACHE, TOOL_CACHE, FILE_CACHE, REPORT_DIR, VLLM_CACHEDIR):
22
+ os.makedirs(d, exist_ok=True)
23
+
24
+ os.environ.update(
25
+ HF_HOME = MODEL_CACHE,
26
+ TRANSFORMERS_CACHE = MODEL_CACHE,
27
+ VLLM_CACHE_DIR = VLLM_CACHEDIR,
28
+ TOKENIZERS_PARALLELISM= "false",
29
+ CUDA_LAUNCH_BLOCKING = "1",
30
+ )
31
+
32
+ # put local `src/` first
33
+ ROOT = os.path.dirname(os.path.abspath(__file__))
34
+ sys.path.insert(0, os.path.join(ROOT, "src"))
35
+ from txagent.txagent import TxAgent # noqa: E402
36
+
37
+ # ─────────────────────────────── LOGGING ─────────────────────────────── #
38
+ logging.basicConfig(format="%(asctime)s β€” %(levelname)s β€” %(message)s",
39
+ level=logging.INFO)
40
+ log = logging.getLogger("app")
41
+
42
+ # ──────────────────────────────── CACHE ──────────────────────────────── #
43
+ cache = Cache(FILE_CACHE, size_limit=20 * 1024 ** 3) # 20Β GB
44
+
45
+ # ─────────────────────────────── HELPERS ─────────────────────────────── #
46
+ def md5(path: str) -> str:
47
+ h = hashlib.md5()
48
  with open(path, "rb") as f:
49
+ for chunk in iter(lambda: f.read(1 << 20), b""):
50
+ h.update(chunk)
51
+ return h.hexdigest()
52
+
53
+ # β€”β€”β€” PDF β€”β€”β€” #
54
+ def _extract_pg(args):
55
+ path, pg_no = args
56
+ with fitz.open(path) as doc:
57
+ page = doc.load_page(pg_no)
58
+ text = page.get_text("text")
59
+ return pg_no, f"=== Page {pg_no+1} ===\n{text.strip()}"
60
+
61
+ def pdf_to_txt(path: str, progress=None) -> str:
62
+ doc = fitz.open(path)
63
+ total = doc.page_count
64
+ with mp.Pool() as pool:
65
+ for pg_no, txt in pool.imap_unordered(_extract_pg, [(path, i) for i in range(total)]):
66
+ if progress: progress(pg_no+1, total)
67
+ cache.set((path, "pg", pg_no, os.path.getmtime(path)), txt)
68
+ pages = [cache[(path, "pg", i, os.path.getmtime(path))] for i in range(total)]
69
+ return "\n\n".join(pages)
70
+
71
+ # β€”β€”β€” CSV/XLSX β€”β€”β€” #
72
+ def csv_to_arrow(path: str) -> pa.Table:
73
+ return pc.read_csv(path, read_options=pc.ReadOptions(block_size=1 << 24)) # 16Β MiB
74
+
75
+ def excel_to_arrow(path: str) -> pa.Table:
76
+ # openpyxl is C‑based; fallback to xlrd only for .xls
77
+ df = pd.read_excel(path, engine="openpyxl" if path.endswith("x") else "xlrd", dtype=str)
78
+ return pa.Table.from_pandas(df.fillna(""))
79
+
80
+ def table_to_rows(tbl: pa.Table) -> List[List[str]]:
81
+ cols = [col.to_pylist() for col in tbl.columns]
82
+ return [list(r) for r in zip(*cols)]
83
+
84
+ def load_tabular(path: str) -> List[List[str]]:
85
+ key = (path, os.path.getmtime(path))
86
+ if key in cache:
87
+ return cache[key]
88
+ tbl = csv_to_arrow(path) if path.endswith("csv") else excel_to_arrow(path)
89
+ rows = table_to_rows(tbl)
90
+ cache[key] = rows
91
+ return rows
92
+
93
+ # β€”β€”β€” CLEANERS β€”β€”β€” #
94
+ def strip_tool_noise(txt: str) -> str:
95
+ txt = re.sub(r"\[.*?TOOL.*?]", "", txt, flags=re.S)
96
+ txt = re.sub(r"\s+", " ", txt).strip()
97
+ return txt
98
+
99
+ def summarize(findings: List[str]) -> str:
100
+ uniq = list(dict.fromkeys(findings)) # preserve order, dedupe
101
+ if not uniq:
102
+ return "No missed diagnoses identified."
103
+ if len(uniq) == 1:
104
+ return f"Missed diagnosis: {uniq[0]}."
105
+ return ("Missed diagnoses include " +
106
+ ", ".join(uniq[:-1]) +
107
+ f", and {uniq[-1]}. Please review urgently.")
108
+
109
+ # β€”β€”β€” MONITOR β€”β€”β€” #
110
+ def sys_usage(tag=""):
111
+ cpu = psutil.cpu_percent()
112
+ mem = psutil.virtual_memory()
113
+ log.info("[%s] CPU %.1f%% β€” RAM %.0f/%.0fΒ GB",
114
+ tag, cpu, mem.used/1e9, mem.total/1e9)
115
+
116
+ # ─────────────────────────────── AGENT ──────────────────────────────── #
117
+ def init_agent() -> TxAgent:
118
+ sys_usage("before‑load")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  agent = TxAgent(
120
+ model_name ="mims-harvard/TxAgent-T1-Llama-3.1-8B",
121
+ rag_model_name ="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
122
+ step_rag_num =4,
123
+ force_finish =True,
124
+ enable_checker =False,
125
+ seed =42
 
 
126
  )
127
  agent.init_model()
128
+ sys_usage("after‑load")
 
129
  return agent
130
 
131
+ AGENT = init_agent()
132
+
133
+ # ─────────────────────────────── GRADIO ─────────────────────────────── #
134
+ prompt_tpl = (
135
+ "Analyze the following excerpt (chunkΒ {idx}/{tot}) and list **only** missed diagnoses "
136
+ "with clinical finding + implication in one sentence each.\n\n{chunk}"
137
+ )
138
+
139
+ def analyze(user_msg, chat_hist, files, prog=gr.Progress()):
140
+ chat_hist.append({"role":"user", "content":user_msg})
141
+ yield chat_hist, None, ""
142
+
143
+ # β€”β€”β€” ingest files β€”β€”β€” #
144
+ extracted = ""
145
+ if files:
146
+ for f in files:
147
+ ext = f.name.lower().split(".")[-1]
148
+ if ext == "pdf":
149
+ txt = pdf_to_txt(f.name,
150
+ progress=lambda cur, tot: prog(cur/tot, desc=f"PDF {cur}/{tot}"))
151
+ extracted += txt + "\n"
152
+ elif ext in ("csv", "xls", "xlsx"):
153
+ rows = load_tabular(f.name)
154
+ extracted += "\n".join(",".join(r) for r in rows) + "\n"
155
+ chat_hist.append({"role":"assistant", "content":"βœ…Β Files parsed"})
156
+ yield chat_hist, None, ""
157
+
158
+ # β€”β€”β€” chunk & batch β€”β€”β€” #
159
+ max_tokens = 6000
160
+ chunks = [extracted[i:i+max_tokens] for i in range(0, len(extracted), max_tokens)]
161
+ findings = []
162
+ for i in range(0, len(chunks), 4): # batch of 4
163
+ batch = chunks[i:i+4]
164
+ prompts = [prompt_tpl.format(idx=i+j+1, tot=len(chunks), chunk=c[:4000])
165
+ for j,c in enumerate(batch)]
166
+
167
+ with torch.inference_mode():
168
+ outs = [list(AGENT.run_gradio_chat(p, [], 0.2, 512, 2048, False, []))[-1]
169
+ for p in prompts]
170
+
171
+ for out in outs:
172
+ if out and hasattr(out, "content"):
173
+ clean = strip_tool_noise(out.content)
174
+ if clean and "No missed" not in clean:
175
+ findings.append(clean)
176
+
177
+ prog((i+len(batch))/len(chunks), desc=f"LLM {i+len(batch)}/{len(chunks)}")
178
+
179
+ summary = summarize(findings)
180
+ chat_hist.append({"role":"assistant", "content":summary})
181
+
182
+ # save full
183
+ if files:
184
+ fn_hash = md5(files[0].name)
185
+ p = os.path.join(REPORT_DIR, f"{fn_hash}_report.txt")
186
+ with open(p, "w") as w:
187
+ w.write("\n".join(findings) + "\n\n" + summary)
188
+ yield chat_hist, p, summary
189
+ else:
190
+ yield chat_hist, None, summary
191
+
192
+ def ui():
193
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
194
+ gr.Markdown("<h1 style='text-align:center'>🩺 Clinical Oversight Assistant</h1>")
195
+ chat = gr.Chatbot(height=600, label="Detailed analysis", type="messages")
196
+ summ = gr.Markdown(label="Summary of missed diagnoses")
197
+ files = gr.File(file_types=[".pdf",".csv",".xls",".xlsx"], file_count="multiple")
198
+ txtbox = gr.Textbox(placeholder="Ask about potential oversights…", show_label=False)
199
+ run = gr.Button("Analyze", variant="primary")
200
+ dl = gr.File(label="Download full report")
201
+
202
+ run.click(analyze, [txtbox, gr.State([]), files], [chat, dl, summ])
203
+ txtbox.submit(analyze, [txtbox, gr.State([]), files], [chat, dl, summ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  return demo
205
 
206
  if __name__ == "__main__":
207
+ ui().queue(api_open=False).launch(
208
+ server_name="0.0.0.0", server_port=7860,
209
+ allowed_paths=[REPORT_DIR], show_error=True, share=False)