Update app.py
Browse files
app.py
CHANGED
@@ -1,155 +1,424 @@
|
|
1 |
-
|
2 |
-
import os
|
3 |
-
|
|
|
|
|
|
|
|
|
4 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
5 |
-
|
6 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
from diskcache import Cache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
#
|
10 |
-
|
11 |
-
|
12 |
-
PROMPT_MAX = 512
|
13 |
-
GPU_UTIL = 0.90 # leave a little headβroom
|
14 |
-
|
15 |
-
PERSIST = "/data/hf_cache"
|
16 |
-
MODEL_CACHE = os.path.join(PERSIST, "txagent_models")
|
17 |
-
TOOL_CACHE = os.path.join(PERSIST, "tool_cache")
|
18 |
-
FILE_CACHE = os.path.join(PERSIST, "preprocessed")
|
19 |
-
REPORT_DIR = os.path.join(PERSIST, "reports")
|
20 |
-
for d in (MODEL_CACHE, TOOL_CACHE, FILE_CACHE, REPORT_DIR):
|
21 |
-
os.makedirs(d, exist_ok=True)
|
22 |
-
|
23 |
-
os.environ.update(
|
24 |
-
HF_HOME = MODEL_CACHE,
|
25 |
-
TRANSFORMERS_CACHE = MODEL_CACHE,
|
26 |
-
VLLM_CACHE_DIR = os.path.join(PERSIST, "vllm_cache"),
|
27 |
-
TOKENIZERS_PARALLELISM = "false",
|
28 |
-
)
|
29 |
-
|
30 |
-
ROOT = os.path.dirname(os.path.abspath(__file__))
|
31 |
-
sys.path.insert(0, os.path.join(ROOT, "src"))
|
32 |
-
|
33 |
-
from txagent.txagent import TxAgent # noqa: E402
|
34 |
-
|
35 |
-
logging.basicConfig(
|
36 |
-
level = logging.INFO,
|
37 |
-
format="%(asctime)s %(levelname)s %(name)s β %(message)s")
|
38 |
-
log = logging.getLogger("app")
|
39 |
-
|
40 |
-
cache = Cache(FILE_CACHE, size_limit=20 * 1024**3) # 20Β GB
|
41 |
-
|
42 |
-
|
43 |
-
# ---------- GPUΒ /Β CPUΒ helpers ----------
|
44 |
-
def _gpu_ok() -> bool:
|
45 |
-
return torch.cuda.is_available() and torch.cuda.device_count() > 0
|
46 |
-
|
47 |
-
def _sys(tag=""):
|
48 |
-
cpu = psutil.cpu_percent()
|
49 |
-
ram = psutil.virtual_memory()
|
50 |
-
log.info("[%s] CPU %.1f%% β RAM %.1fΒ /Β %.1fΒ GB",
|
51 |
-
tag, cpu, ram.used/1e9, ram.total/1e9)
|
52 |
-
|
53 |
-
# ---------- AGENT LOADER ----------
|
54 |
-
def _init_vllm() -> TxAgent:
|
55 |
-
from vllm import LLM # local import avoids importβtime CUDA checks
|
56 |
-
agent = TxAgent(
|
57 |
-
model_name = MODEL_NAME,
|
58 |
-
rag_model_name = RAG_MODEL,
|
59 |
-
step_rag_num = 4,
|
60 |
-
force_finish = True,
|
61 |
-
enable_checker = False,
|
62 |
-
seed = 42,
|
63 |
-
)
|
64 |
-
# monkeyβpatch TxAgent.load_models to use enforced kwargs
|
65 |
-
def _load():
|
66 |
-
agent.model = LLM(
|
67 |
-
model = MODEL_NAME,
|
68 |
-
dtype = "half",
|
69 |
-
gpu_memory_utilization = GPU_UTIL,
|
70 |
-
enforce_eager = True, # avoids CUDAGraph crashes
|
71 |
-
)
|
72 |
-
agent.load_models = _load # type: ignore
|
73 |
-
agent.init_model()
|
74 |
-
return agent
|
75 |
|
|
|
|
|
|
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
torch_dtype = (torch.float16 if _gpu_ok() else torch.float32),
|
83 |
-
device_map = ("auto" if _gpu_ok() else None),
|
84 |
-
)
|
85 |
-
return pipeline("text-generation", model=mdl, tokenizer=tok,
|
86 |
-
max_new_tokens=PROMPT_MAX, device=0 if _gpu_ok() else -1)
|
87 |
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
try:
|
91 |
-
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
except Exception as e:
|
95 |
-
|
96 |
-
|
97 |
-
agent = TxAgent(dummy=True) # bare object; we'll store pipe on it
|
98 |
-
agent.generator = pipe
|
99 |
-
_sys("afterβload")
|
100 |
-
return agent
|
101 |
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
"
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
h.update(chunk)
|
122 |
-
return h.hexdigest()
|
123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
|
|
|
|
|
|
129 |
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
-
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
-
def
|
137 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
138 |
-
gr.Markdown("<h1 style='text-align:center'>π©Ί
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
return demo
|
147 |
|
148 |
if __name__ == "__main__":
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import pandas as pd
|
4 |
+
import pdfplumber
|
5 |
+
import json
|
6 |
+
import gradio as gr
|
7 |
+
from typing import List, Tuple, Optional, Generator
|
8 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
9 |
+
import hashlib
|
10 |
+
import shutil
|
11 |
+
import re
|
12 |
+
import psutil
|
13 |
+
import subprocess
|
14 |
+
import logging
|
15 |
+
import torch
|
16 |
+
import gc
|
17 |
from diskcache import Cache
|
18 |
+
import time
|
19 |
+
import pyarrow as pa
|
20 |
+
import pyarrow.parquet as pq
|
21 |
+
import pyarrow.csv as pc
|
22 |
+
import numpy as np
|
23 |
+
from functools import partial
|
24 |
+
from itertools import islice
|
25 |
+
import io
|
26 |
|
27 |
+
# Configure logging
|
28 |
+
logging.basicConfig(level=logging.INFO)
|
29 |
+
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
# Persistent directory
|
32 |
+
persistent_dir = "/data/hf_cache"
|
33 |
+
os.makedirs(persistent_dir, exist_ok=True)
|
34 |
|
35 |
+
model_cache_dir = os.path.join(persistent_dir, "txagent_models")
|
36 |
+
tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
|
37 |
+
file_cache_dir = os.path.join(persistent_dir, "cache")
|
38 |
+
report_dir = os.path.join(persistent_dir, "reports")
|
39 |
+
vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache")
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
+
for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
|
42 |
+
os.makedirs(directory, exist_ok=True)
|
43 |
+
|
44 |
+
os.environ["HF_HOME"] = model_cache_dir
|
45 |
+
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
|
46 |
+
os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
|
47 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
48 |
+
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
49 |
+
|
50 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
51 |
+
src_path = os.path.abspath(os.path.join(current_dir, "src"))
|
52 |
+
sys.path.insert(0, src_path)
|
53 |
+
|
54 |
+
from txagent.txagent import TxAgent
|
55 |
+
|
56 |
+
# Initialize cache with 10GB limit
|
57 |
+
cache = Cache(file_cache_dir, size_limit=10 * 1024**3)
|
58 |
+
|
59 |
+
def sanitize_utf8(text: str) -> str:
|
60 |
+
return text.encode("utf-8", "ignore").decode("utf-8")
|
61 |
+
|
62 |
+
def file_hash(path: str) -> str:
|
63 |
+
with open(path, "rb") as f:
|
64 |
+
return hashlib.md5(f.read()).hexdigest()
|
65 |
+
|
66 |
+
def extract_all_pages(file_path: str, progress_callback=None) -> str:
|
67 |
+
try:
|
68 |
+
with pdfplumber.open(file_path) as pdf:
|
69 |
+
total_pages = len(pdf.pages)
|
70 |
+
if total_pages == 0:
|
71 |
+
return ""
|
72 |
+
|
73 |
+
batch_size = 10
|
74 |
+
batches = [(i, min(i + batch_size, total_pages)) for i in range(0, total_pages, batch_size)]
|
75 |
+
text_chunks = [""] * total_pages
|
76 |
+
processed_pages = 0
|
77 |
+
|
78 |
+
def extract_batch(start: int, end: int) -> List[tuple]:
|
79 |
+
results = []
|
80 |
+
with pdfplumber.open(file_path) as pdf:
|
81 |
+
for page in pdf.pages[start:end]:
|
82 |
+
page_num = start + pdf.pages.index(page)
|
83 |
+
page_text = page.extract_text() or ""
|
84 |
+
results.append((page_num, f"=== Page {page_num + 1} ===\n{page_text.strip()}"))
|
85 |
+
return results
|
86 |
+
|
87 |
+
with ThreadPoolExecutor(max_workers=6) as executor:
|
88 |
+
futures = [executor.submit(extract_batch, start, end) for start, end in batches]
|
89 |
+
for future in as_completed(futures):
|
90 |
+
for page_num, text in future.result():
|
91 |
+
text_chunks[page_num] = text
|
92 |
+
processed_pages += batch_size
|
93 |
+
if progress_callback:
|
94 |
+
progress_callback(min(processed_pages, total_pages), total_pages)
|
95 |
+
|
96 |
+
return "\n\n".join(filter(None, text_chunks))
|
97 |
+
except Exception as e:
|
98 |
+
logger.error("PDF processing error: %s", e)
|
99 |
+
return f"PDF processing error: {str(e)}"
|
100 |
+
|
101 |
+
def excel_to_ndjson(file_path: str) -> Generator[str, None, None]:
|
102 |
+
"""Stream Excel file as NDJSON for maximum performance"""
|
103 |
try:
|
104 |
+
# Use openpyxl in streaming mode
|
105 |
+
with pd.ExcelFile(file_path, engine='openpyxl') as xls:
|
106 |
+
for sheet_name in xls.sheet_names:
|
107 |
+
for chunk in pd.read_excel(
|
108 |
+
xls,
|
109 |
+
sheet_name=sheet_name,
|
110 |
+
header=None,
|
111 |
+
dtype=str,
|
112 |
+
chunksize=1000
|
113 |
+
):
|
114 |
+
for _, row in chunk.iterrows():
|
115 |
+
yield json.dumps({
|
116 |
+
"sheet": sheet_name,
|
117 |
+
"row": row.fillna("").astype(str).tolist()
|
118 |
+
}) + "\n"
|
119 |
except Exception as e:
|
120 |
+
logger.error(f"Error streaming Excel: {e}")
|
121 |
+
raise
|
|
|
|
|
|
|
|
|
122 |
|
123 |
+
def csv_to_ndjson(file_path: str) -> Generator[str, None, None]:
|
124 |
+
"""Stream CSV file as NDJSON for maximum performance"""
|
125 |
+
try:
|
126 |
+
for chunk in pd.read_csv(
|
127 |
+
file_path,
|
128 |
+
header=None,
|
129 |
+
dtype=str,
|
130 |
+
chunksize=1000,
|
131 |
+
encoding_errors='replace',
|
132 |
+
on_bad_lines='skip'
|
133 |
+
):
|
134 |
+
for _, row in chunk.iterrows():
|
135 |
+
yield json.dumps({
|
136 |
+
"row": row.fillna("").astype(str).tolist()
|
137 |
+
}) + "\n"
|
138 |
+
except Exception as e:
|
139 |
+
logger.error(f"Error streaming CSV: {e}")
|
140 |
+
raise
|
141 |
+
|
142 |
+
def stream_file_to_json(file_path: str, file_type: str) -> Generator[str, None, None]:
|
143 |
+
"""Stream file content as JSON chunks"""
|
144 |
+
try:
|
145 |
+
if file_type == "pdf":
|
146 |
+
text = extract_all_pages(file_path)
|
147 |
+
yield json.dumps({
|
148 |
+
"filename": os.path.basename(file_path),
|
149 |
+
"content": text,
|
150 |
+
"status": "initial"
|
151 |
+
})
|
152 |
+
elif file_type in ["csv", "xls", "xlsx"]:
|
153 |
+
# Stream the file content
|
154 |
+
yield json.dumps({
|
155 |
+
"filename": os.path.basename(file_path),
|
156 |
+
"streaming": True,
|
157 |
+
"type": file_type
|
158 |
+
})
|
159 |
+
|
160 |
+
if file_type == "csv":
|
161 |
+
stream_gen = csv_to_ndjson(file_path)
|
162 |
+
else:
|
163 |
+
stream_gen = excel_to_ndjson(file_path)
|
164 |
+
|
165 |
+
for chunk in stream_gen:
|
166 |
+
yield chunk
|
167 |
+
else:
|
168 |
+
yield json.dumps({"error": f"Unsupported file type: {file_type}"})
|
169 |
+
except Exception as e:
|
170 |
+
logger.error("Error processing %s: %s", os.path.basename(file_path), e)
|
171 |
+
yield json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
|
172 |
|
173 |
+
def log_system_usage(tag=""):
|
174 |
+
try:
|
175 |
+
cpu = psutil.cpu_percent(interval=1)
|
176 |
+
mem = psutil.virtual_memory()
|
177 |
+
logger.info("[%s] CPU: %.1f%% | RAM: %dMB / %dMB", tag, cpu, mem.used // (1024**2), mem.total // (1024**2))
|
178 |
+
result = subprocess.run(
|
179 |
+
["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"],
|
180 |
+
capture_output=True, text=True
|
181 |
+
)
|
182 |
+
if result.returncode == 0:
|
183 |
+
used, total, util = result.stdout.strip().split(", ")
|
184 |
+
logger.info("[%s] GPU: %sMB / %sMB | Utilization: %s%%", tag, used, total, util)
|
185 |
+
except Exception as e:
|
186 |
+
logger.error("[%s] GPU/CPU monitor failed: %s", tag, e)
|
187 |
|
188 |
+
def clean_response(text: str) -> str:
|
189 |
+
text = sanitize_utf8(text)
|
190 |
+
text = re.sub(r"\[.*?\]|\bNone\b|To analyze the patient record excerpt.*?medications\.|Since the previous attempts.*?\.|I need to.*?medications\.|Retrieving tools.*?\.", "", text, flags=re.DOTALL)
|
191 |
+
diagnoses = []
|
192 |
+
lines = text.splitlines()
|
193 |
+
in_diagnoses_section = False
|
194 |
+
for line in lines:
|
195 |
+
line = line.strip()
|
196 |
+
if not line:
|
197 |
+
continue
|
198 |
+
if re.match(r"###\s*Missed Diagnoses", line):
|
199 |
+
in_diagnoses_section = True
|
200 |
+
continue
|
201 |
+
if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line):
|
202 |
+
in_diagnoses_section = False
|
203 |
+
continue
|
204 |
+
if in_diagnoses_section and re.match(r"-\s*.+", line):
|
205 |
+
diagnosis = re.sub(r"^\-\s*", "", line).strip()
|
206 |
+
if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE):
|
207 |
+
diagnoses.append(diagnosis)
|
208 |
+
text = " ".join(diagnoses)
|
209 |
+
text = re.sub(r"\s+", " ", text).strip()
|
210 |
+
text = re.sub(r"[^\w\s\.\,\(\)\-]", "", text)
|
211 |
+
return text if text else ""
|
212 |
|
213 |
+
def summarize_findings(combined_response: str) -> str:
|
214 |
+
chunks = combined_response.split("--- Analysis for Chunk")
|
215 |
+
diagnoses = []
|
216 |
+
for chunk in chunks:
|
217 |
+
chunk = chunk.strip()
|
218 |
+
if not chunk or "No oversights identified" in chunk:
|
219 |
+
continue
|
220 |
+
lines = chunk.splitlines()
|
221 |
+
in_diagnoses_section = False
|
222 |
+
for line in lines:
|
223 |
+
line = line.strip()
|
224 |
+
if not line:
|
225 |
+
continue
|
226 |
+
if re.match(r"###\s*Missed Diagnoses", line):
|
227 |
+
in_diagnoses_section = True
|
228 |
+
continue
|
229 |
+
if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line):
|
230 |
+
in_diagnoses_section = False
|
231 |
+
continue
|
232 |
+
if in_diagnoses_section and re.match(r"-\s*.+", line):
|
233 |
+
diagnosis = re.sub(r"^\-\s*", "", line).strip()
|
234 |
+
if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE):
|
235 |
+
diagnoses.append(diagnosis)
|
236 |
|
237 |
+
seen = set()
|
238 |
+
unique_diagnoses = [d for d in diagnoses if not (d in seen or seen.add(d))]
|
239 |
+
|
240 |
+
if not unique_diagnoses:
|
241 |
+
return "No missed diagnoses were identified in the provided records."
|
|
|
|
|
242 |
|
243 |
+
summary = "Missed diagnoses include " + ", ".join(unique_diagnoses[:-1])
|
244 |
+
if len(unique_diagnoses) > 1:
|
245 |
+
summary += f", and {unique_diagnoses[-1]}"
|
246 |
+
elif len(unique_diagnoses) == 1:
|
247 |
+
summary = "Missed diagnoses include " + unique_diagnoses[0]
|
248 |
+
summary += ", all of which require urgent clinical review to prevent potential adverse outcomes."
|
249 |
+
|
250 |
+
return summary.strip()
|
251 |
|
252 |
+
def init_agent():
|
253 |
+
logger.info("Initializing model...")
|
254 |
+
log_system_usage("Before Load")
|
255 |
+
default_tool_path = os.path.abspath("data/new_tool.json")
|
256 |
+
target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
|
257 |
+
if not os.path.exists(target_tool_path):
|
258 |
+
shutil.copy(default_tool_path, target_tool_path)
|
259 |
|
260 |
+
agent = TxAgent(
|
261 |
+
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
|
262 |
+
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
|
263 |
+
tool_files_dict={"new_tool": target_tool_path},
|
264 |
+
force_finish=True,
|
265 |
+
enable_checker=False,
|
266 |
+
step_rag_num=4,
|
267 |
+
seed=100,
|
268 |
+
additional_default_tools=[],
|
269 |
+
)
|
270 |
+
agent.init_model()
|
271 |
+
log_system_usage("After Load")
|
272 |
+
logger.info("Agent Ready")
|
273 |
+
return agent
|
274 |
|
275 |
+
def batched(iterable, n):
|
276 |
+
"""Batch data into tuples of length n. The last batch may be shorter."""
|
277 |
+
it = iter(iterable)
|
278 |
+
while True:
|
279 |
+
batch = list(islice(it, n))
|
280 |
+
if not batch:
|
281 |
+
return
|
282 |
+
yield batch
|
283 |
|
284 |
+
def create_ui(agent):
|
285 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
286 |
+
gr.Markdown("<h1 style='text-align: center;'>π©Ί Clinical Oversight Assistant</h1>")
|
287 |
+
chatbot = gr.Chatbot(label="Detailed Analysis", height=600, type="messages")
|
288 |
+
final_summary = gr.Markdown(label="Summary of Missed Diagnoses")
|
289 |
+
file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
|
290 |
+
msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
|
291 |
+
send_btn = gr.Button("Analyze", variant="primary")
|
292 |
+
download_output = gr.File(label="Download Full Report")
|
293 |
+
progress_bar = gr.Progress()
|
294 |
+
|
295 |
+
prompt_template = """
|
296 |
+
Analyze the patient record excerpt for missed diagnoses only. Provide a concise, evidence-based summary as a single paragraph without headings or bullet points. Include specific clinical findings (e.g., 'elevated blood pressure (160/95) on page 10'), their potential implications (e.g., 'may indicate untreated hypertension'), and a recommendation for urgent review. Do not include other oversight categories like medication conflicts. If no missed diagnoses are found, state 'No missed diagnoses identified' in a single sentence.
|
297 |
+
Patient Record Excerpt (Chunk {0} of {1}):
|
298 |
+
{chunk}
|
299 |
+
"""
|
300 |
+
|
301 |
+
def analyze(message: str, history: List[dict], files: List, progress=gr.Progress()):
|
302 |
+
history.append({"role": "user", "content": message})
|
303 |
+
yield history, None, ""
|
304 |
+
|
305 |
+
extracted = []
|
306 |
+
file_hash_value = ""
|
307 |
+
|
308 |
+
if files:
|
309 |
+
# Process files in parallel with streaming
|
310 |
+
with ThreadPoolExecutor(max_workers=4) as executor:
|
311 |
+
futures = []
|
312 |
+
for f in files:
|
313 |
+
file_type = f.name.split(".")[-1].lower()
|
314 |
+
futures.append(executor.submit(
|
315 |
+
lambda f: list(stream_file_to_json(f.name, file_type)),
|
316 |
+
f
|
317 |
+
))
|
318 |
+
|
319 |
+
for future in as_completed(futures):
|
320 |
+
try:
|
321 |
+
extracted.extend(future.result())
|
322 |
+
except Exception as e:
|
323 |
+
logger.error(f"File processing error: {e}")
|
324 |
+
extracted.append(json.dumps({
|
325 |
+
"error": f"Error processing file: {str(e)}"
|
326 |
+
}))
|
327 |
+
|
328 |
+
file_hash_value = file_hash(files[0].name) if files else ""
|
329 |
+
history.append({"role": "assistant", "content": "β
File processing complete"})
|
330 |
+
yield history, None, ""
|
331 |
+
|
332 |
+
# Process chunks in parallel with dynamic batching
|
333 |
+
chunk_size = 8000 # Larger chunks reduce overhead
|
334 |
+
combined_response = ""
|
335 |
+
|
336 |
+
try:
|
337 |
+
# Convert extracted data to text chunks
|
338 |
+
text_content = "\n".join(extracted)
|
339 |
+
chunks = [text_content[i:i+chunk_size] for i in range(0, len(text_content), chunk_size)]
|
340 |
+
|
341 |
+
# Process chunks in parallel batches
|
342 |
+
batch_size = 4 # Optimal for most GPUs
|
343 |
+
total_chunks = len(chunks)
|
344 |
+
|
345 |
+
for batch_idx, batch_chunks in enumerate(batched(chunks, batch_size)):
|
346 |
+
batch_prompts = [
|
347 |
+
prompt_template.format(
|
348 |
+
batch_idx * batch_size + i + 1,
|
349 |
+
total_chunks,
|
350 |
+
chunk=chunk[:6000] # Slightly larger context
|
351 |
+
)
|
352 |
+
for i, chunk in enumerate(batch_chunks)
|
353 |
+
]
|
354 |
+
|
355 |
+
progress((batch_idx * batch_size) / total_chunks,
|
356 |
+
desc=f"Analyzing batch {batch_idx + 1}/{(total_chunks + batch_size - 1) // batch_size}")
|
357 |
+
|
358 |
+
# Process batch in parallel
|
359 |
+
with ThreadPoolExecutor(max_workers=len(batch_prompts)) as executor:
|
360 |
+
future_to_prompt = {
|
361 |
+
executor.submit(
|
362 |
+
agent.run_gradio_chat,
|
363 |
+
prompt, [], 0.2, 512, 2048, False, []
|
364 |
+
): prompt
|
365 |
+
for prompt in batch_prompts
|
366 |
+
}
|
367 |
+
|
368 |
+
for future in as_completed(future_to_prompt):
|
369 |
+
chunk_response = ""
|
370 |
+
for chunk_output in future.result():
|
371 |
+
if chunk_output is None:
|
372 |
+
continue
|
373 |
+
if isinstance(chunk_output, list):
|
374 |
+
for m in chunk_output:
|
375 |
+
if hasattr(m, 'content') and m.content:
|
376 |
+
cleaned = clean_response(m.content)
|
377 |
+
if cleaned:
|
378 |
+
chunk_response += cleaned + " "
|
379 |
+
elif isinstance(chunk_output, str) and chunk_output.strip():
|
380 |
+
cleaned = clean_response(chunk_output)
|
381 |
+
if cleaned:
|
382 |
+
chunk_response += cleaned + " "
|
383 |
+
|
384 |
+
combined_response += f"--- Analysis for Chunk {batch_idx * batch_size + 1} ---\n{chunk_response.strip()}\n"
|
385 |
+
history[-1] = {"role": "assistant", "content": combined_response.strip()}
|
386 |
+
yield history, None, ""
|
387 |
+
|
388 |
+
# Clean up memory
|
389 |
+
torch.cuda.empty_cache()
|
390 |
+
gc.collect()
|
391 |
+
|
392 |
+
# Generate final summary
|
393 |
+
summary = summarize_findings(combined_response)
|
394 |
+
report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None
|
395 |
+
if report_path:
|
396 |
+
with open(report_path, "w", encoding="utf-8") as f:
|
397 |
+
f.write(combined_response + "\n\n" + summary)
|
398 |
+
|
399 |
+
yield history, report_path if report_path and os.path.exists(report_path) else None, summary
|
400 |
+
|
401 |
+
except Exception as e:
|
402 |
+
logger.error("Analysis error: %s", e)
|
403 |
+
history.append({"role": "assistant", "content": f"β Error occurred: {str(e)}"})
|
404 |
+
yield history, None, f"Error occurred during analysis: {str(e)}"
|
405 |
+
|
406 |
+
send_btn.click(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary])
|
407 |
+
msg_input.submit(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary])
|
408 |
return demo
|
409 |
|
410 |
if __name__ == "__main__":
|
411 |
+
try:
|
412 |
+
logger.info("Launching app...")
|
413 |
+
agent = init_agent()
|
414 |
+
demo = create_ui(agent)
|
415 |
+
demo.queue(api_open=False).launch(
|
416 |
+
server_name="0.0.0.0",
|
417 |
+
server_port=7860,
|
418 |
+
show_error=True,
|
419 |
+
allowed_paths=[report_dir],
|
420 |
+
share=False
|
421 |
+
)
|
422 |
+
finally:
|
423 |
+
if torch.distributed.is_initialized():
|
424 |
+
torch.distributed.destroy_process_group()
|