Ali2206 commited on
Commit
d88209d
·
verified ·
1 Parent(s): 8a7f6db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +389 -205
app.py CHANGED
@@ -4,7 +4,7 @@ import pandas as pd
4
  import pdfplumber
5
  import json
6
  import gradio as gr
7
- from typing import List
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
  import hashlib
10
  import shutil
@@ -16,12 +16,22 @@ import torch
16
  import gc
17
  from diskcache import Cache
18
  import time
 
 
 
 
19
 
20
  # Configure logging
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
24
- # Persistent directory
 
 
 
 
 
 
25
  persistent_dir = "/data/hf_cache"
26
  os.makedirs(persistent_dir, exist_ok=True)
27
 
@@ -34,11 +44,13 @@ vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache")
34
  for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
35
  os.makedirs(directory, exist_ok=True)
36
 
37
- os.environ["HF_HOME"] = model_cache_dir
38
- os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
39
- os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
40
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
41
- os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
 
 
42
 
43
  current_dir = os.path.dirname(os.path.abspath(__file__))
44
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
@@ -49,169 +61,275 @@ from txagent.txagent import TxAgent
49
  # Initialize cache with 10GB limit
50
  cache = Cache(file_cache_dir, size_limit=10 * 1024**3)
51
 
 
 
 
 
 
52
  def sanitize_utf8(text: str) -> str:
 
53
  return text.encode("utf-8", "ignore").decode("utf-8")
54
 
55
  def file_hash(path: str) -> str:
 
 
56
  with open(path, "rb") as f:
57
- return hashlib.md5(f.read()).hexdigest()
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  def extract_all_pages(file_path: str, progress_callback=None) -> str:
 
60
  try:
61
  with pdfplumber.open(file_path) as pdf:
62
  total_pages = len(pdf.pages)
63
  if total_pages == 0:
64
  return ""
65
 
66
- batch_size = 10
67
- batches = [(i, min(i + batch_size, total_pages)) for i in range(0, total_pages, batch_size)]
68
- text_chunks = [""] * total_pages
69
- processed_pages = 0
70
-
71
- def extract_batch(start: int, end: int) -> List[tuple]:
72
- results = []
73
  with pdfplumber.open(file_path) as pdf:
74
- for page in pdf.pages[start:end]:
75
- page_num = start + pdf.pages.index(page)
76
- page_text = page.extract_text() or ""
77
- results.append((page_num, f"=== Page {page_num + 1} ===\n{page_text.strip()}"))
78
- return results
79
-
80
- with ThreadPoolExecutor(max_workers=6) as executor:
81
- futures = [executor.submit(extract_batch, start, end) for start, end in batches]
82
- for future in as_completed(futures):
83
- for page_num, text in future.result():
84
- text_chunks[page_num] = text
85
- processed_pages += batch_size
86
- if progress_callback:
87
- progress_callback(min(processed_pages, total_pages), total_pages)
88
-
89
- return "\n\n".join(filter(None, text_chunks))
90
  except Exception as e:
91
- logger.error("PDF processing error: %s", e)
92
  return f"PDF processing error: {str(e)}"
93
 
94
- def convert_file_to_json(file_path: str, file_type: str, progress_callback=None) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  try:
96
- file_h = file_hash(file_path)
97
- cache_key = f"{file_h}_{file_type}"
98
- if cache_key in cache:
99
- return cache[cache_key]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
 
 
 
 
101
  if file_type == "pdf":
102
- text = extract_all_pages(file_path, progress_callback)
103
- result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
104
- elif file_type == "csv":
105
- df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str,
106
- skip_blank_lines=False, on_bad_lines="skip")
107
- content = df.fillna("").astype(str).values.tolist()
108
- result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
109
  elif file_type in ["xls", "xlsx"]:
110
- try:
111
- df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
112
- except Exception:
113
- df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
114
- content = df.fillna("").astype(str).values.tolist()
115
- result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
116
  else:
117
- result = json.dumps({"error": f"Unsupported file type: {file_type}"})
118
-
119
- cache[cache_key] = result
120
- return result
121
  except Exception as e:
122
- logger.error("Error processing %s: %s", os.path.basename(file_path), e)
123
- return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
 
 
 
 
 
 
 
 
 
124
 
125
  def log_system_usage(tag=""):
 
126
  try:
127
- cpu = psutil.cpu_percent(interval=1)
128
  mem = psutil.virtual_memory()
129
- logger.info("[%s] CPU: %.1f%% | RAM: %dMB / %dMB", tag, cpu, mem.used // (1024**2), mem.total // (1024**2))
130
- result = subprocess.run(
131
- ["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"],
132
- capture_output=True, text=True
133
- )
134
- if result.returncode == 0:
135
- used, total, util = result.stdout.strip().split(", ")
136
- logger.info("[%s] GPU: %sMB / %sMB | Utilization: %s%%", tag, used, total, util)
 
 
 
 
 
 
137
  except Exception as e:
138
- logger.error("[%s] GPU/CPU monitor failed: %s", tag, e)
139
 
140
  def clean_response(text: str) -> str:
141
- text = sanitize_utf8(text)
142
- # Remove unwanted patterns and tool call artifacts
143
- text = re.sub(r"\[.*?\]|\bNone\b|To analyze the patient record excerpt.*?medications\.|Since the previous attempts.*?\.|I need to.*?medications\.|Retrieving tools.*?\.", "", text, flags=re.DOTALL)
144
- # Extract only missed diagnoses, ignoring other categories
145
- diagnoses = []
146
- lines = text.splitlines()
147
- in_diagnoses_section = False
148
- for line in lines:
149
- line = line.strip()
150
- if not line:
151
- continue
152
- if re.match(r"###\s*Missed Diagnoses", line):
153
- in_diagnoses_section = True
154
- continue
155
- if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line):
156
- in_diagnoses_section = False
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  continue
158
- if in_diagnoses_section and re.match(r"-\s*.+", line):
159
- diagnosis = re.sub(r"^\-\s*", "", line).strip()
160
- if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE):
161
- diagnoses.append(diagnosis)
162
- # Join diagnoses into a plain text paragraph
163
- text = " ".join(diagnoses)
164
- # Clean up extra whitespace and punctuation
165
- text = re.sub(r"\s+", " ", text).strip()
166
- text = re.sub(r"[^\w\s\.\,\(\)\-]", "", text)
167
- return text if text else ""
 
 
168
 
169
  def summarize_findings(combined_response: str) -> str:
170
- # Split response by chunk analyses
171
- chunks = combined_response.split("--- Analysis for Chunk")
 
 
 
 
 
 
172
  diagnoses = []
173
- for chunk in chunks:
174
- chunk = chunk.strip()
175
- if not chunk or "No oversights identified" in chunk:
 
 
176
  continue
177
- # Extract missed diagnoses from chunk
178
- lines = chunk.splitlines()
179
- in_diagnoses_section = False
180
- for line in lines:
181
- line = line.strip()
182
- if not line:
183
- continue
184
- if re.match(r"###\s*Missed Diagnoses", line):
185
- in_diagnoses_section = True
186
- continue
187
- if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line):
188
- in_diagnoses_section = False
189
- continue
190
- if in_diagnoses_section and re.match(r"-\s*.+", line):
191
- diagnosis = re.sub(r"^\-\s*", "", line).strip()
192
- if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE):
193
  diagnoses.append(diagnosis)
194
-
195
- # Remove duplicates while preserving order
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  seen = set()
197
  unique_diagnoses = [d for d in diagnoses if not (d in seen or seen.add(d))]
198
 
199
- if not unique_diagnoses:
200
- return "No missed diagnoses were identified in the provided records."
201
-
202
- # Combine into a single paragraph
203
- summary = "Missed diagnoses include " + ", ".join(unique_diagnoses[:-1])
204
- if len(unique_diagnoses) > 1:
205
- summary += f", and {unique_diagnoses[-1]}"
206
- elif len(unique_diagnoses) == 1:
207
- summary = "Missed diagnoses include " + unique_diagnoses[0]
208
- summary += ", all of which require urgent clinical review to prevent potential adverse outcomes."
209
 
210
- return summary.strip()
211
 
 
212
  def init_agent():
 
213
  logger.info("Initializing model...")
214
  log_system_usage("Before Load")
 
215
  default_tool_path = os.path.abspath("data/new_tool.json")
216
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
217
  if not os.path.exists(target_tool_path):
@@ -228,122 +346,188 @@ def init_agent():
228
  additional_default_tools=[],
229
  )
230
  agent.init_model()
 
231
  log_system_usage("After Load")
232
  logger.info("Agent Ready")
233
  return agent
234
 
235
  def create_ui(agent):
236
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
237
- gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
238
- chatbot = gr.Chatbot(label="Detailed Analysis", height=600, type="messages")
239
- final_summary = gr.Markdown(label="Summary of Missed Diagnoses")
240
- file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
241
- msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
242
- send_btn = gr.Button("Analyze", variant="primary")
243
- download_output = gr.File(label="Download Full Report")
244
- progress_bar = gr.Progress()
245
-
246
- prompt_template = """
247
- Analyze the patient record excerpt for missed diagnoses only. Provide a concise, evidence-based summary as a single paragraph without headings or bullet points. Include specific clinical findings (e.g., 'elevated blood pressure (160/95) on page 10'), their potential implications (e.g., 'may indicate untreated hypertension'), and a recommendation for urgent review. Do not include other oversight categories like medication conflicts. If no missed diagnoses are found, state 'No missed diagnoses identified' in a single sentence.
248
  Patient Record Excerpt (Chunk {0} of {1}):
249
  {chunk}
250
  """
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  def analyze(message: str, history: List[dict], files: List, progress=gr.Progress()):
 
253
  history.append({"role": "user", "content": message})
254
  yield history, None, ""
255
 
256
- extracted = ""
257
  file_hash_value = ""
 
258
  if files:
259
- def update_extraction_progress(current, total):
260
- progress(current / total, desc=f"Extracting text... Page {current}/{total}")
261
- return history, None, ""
262
-
263
- with ThreadPoolExecutor(max_workers=6) as executor:
264
- futures = [executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower(), update_extraction_progress) for f in files]
265
- results = [sanitize_utf8(f.result()) for f in as_completed(futures)]
266
- extracted = "\n".join(results)
267
- file_hash_value = file_hash(files[0].name) if files else ""
268
-
269
- history.append({"role": "assistant", "content": "✅ Text extraction complete."})
 
 
270
  yield history, None, ""
271
 
272
- chunk_size = 6000
273
- chunks = [extracted[i:i + chunk_size] for i in range(0, len(extracted), chunk_size)]
274
- combined_response = ""
275
- batch_size = 2
276
 
 
 
 
 
 
 
 
 
277
  try:
278
- for batch_idx in range(0, len(chunks), batch_size):
279
- batch_chunks = chunks[batch_idx:batch_idx + batch_size]
280
- batch_prompts = [prompt_template.format(i + 1, len(chunks), chunk=chunk[:4000]) for i, chunk in enumerate(batch_chunks)]
281
- batch_responses = []
282
-
283
- progress((batch_idx + 1) / len(chunks), desc=f"Analyzing chunks {batch_idx + 1}-{min(batch_idx + batch_size, len(chunks))}/{len(chunks)}")
284
-
285
- with ThreadPoolExecutor(max_workers=len(batch_chunks)) as executor:
286
- futures = [executor.submit(agent.run_gradio_chat, prompt, [], 0.2, 512, 2048, False, []) for prompt in batch_prompts]
287
- for future in as_completed(futures):
288
- chunk_response = ""
289
- for chunk_output in future.result():
290
- if chunk_output is None:
291
- continue
292
- if isinstance(chunk_output, list):
293
- for m in chunk_output:
294
- if hasattr(m, 'content') and m.content:
295
- cleaned = clean_response(m.content)
296
- if cleaned:
297
- chunk_response += cleaned + " "
298
- elif isinstance(chunk_output, str) and chunk_output.strip():
299
- cleaned = clean_response(chunk_output)
300
- if cleaned:
301
- chunk_response += cleaned + " "
302
- batch_responses.append(chunk_response.strip())
303
- torch.cuda.empty_cache()
304
- gc.collect()
305
-
306
- for chunk_idx, chunk_response in enumerate(batch_responses, batch_idx + 1):
307
- if chunk_response:
308
- combined_response += f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response}\n"
309
- else:
310
- combined_response += f"--- Analysis for Chunk {chunk_idx} ---\nNo missed diagnoses identified.\n"
311
- history[-1] = {"role": "assistant", "content": combined_response.strip()}
312
- yield history, None, ""
313
-
314
- if combined_response.strip() and not all("No missed diagnoses identified" in chunk for chunk in combined_response.split("--- Analysis for Chunk")):
315
- history[-1]["content"] = combined_response.strip()
316
- else:
317
- history.append({"role": "assistant", "content": "No missed diagnoses identified in the provided records."})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
 
319
  summary = summarize_findings(combined_response)
320
- report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None
321
- if report_path:
322
- with open(report_path, "w", encoding="utf-8") as f:
323
- f.write(combined_response + "\n\n" + summary)
324
- yield history, report_path if report_path and os.path.exists(report_path) else None, summary
 
 
325
 
326
  except Exception as e:
327
- logger.error("Analysis error: %s", e)
328
  history.append({"role": "assistant", "content": f"❌ Error occurred: {str(e)}"})
329
  yield history, None, f"Error occurred during analysis: {str(e)}"
330
-
331
- send_btn.click(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary])
332
- msg_input.submit(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary])
 
 
 
 
 
 
 
 
 
 
 
 
333
  return demo
334
 
335
  if __name__ == "__main__":
336
  try:
337
- logger.info("Launching app...")
338
  agent = init_agent()
339
  demo = create_ui(agent)
340
- demo.queue(api_open=False).launch(
 
 
 
341
  server_name="0.0.0.0",
342
  server_port=7860,
343
  show_error=True,
344
  allowed_paths=[report_dir],
345
  share=False
346
  )
 
 
 
347
  finally:
348
  if torch.distributed.is_initialized():
349
  torch.distributed.destroy_process_group()
 
4
  import pdfplumber
5
  import json
6
  import gradio as gr
7
+ from typing import List, Dict, Optional, Generator
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
  import hashlib
10
  import shutil
 
16
  import gc
17
  from diskcache import Cache
18
  import time
19
+ from transformers import AutoTokenizer
20
+ from functools import lru_cache
21
+ import numpy as np
22
+ from difflib import SequenceMatcher
23
 
24
  # Configure logging
25
  logging.basicConfig(level=logging.INFO)
26
  logger = logging.getLogger(__name__)
27
 
28
+ # Constants
29
+ MAX_TOKENS = 1800
30
+ BATCH_SIZE = 2
31
+ MAX_WORKERS = 4
32
+ CHUNK_SIZE = 10 # For PDF processing
33
+
34
+ # Persistent directory setup
35
  persistent_dir = "/data/hf_cache"
36
  os.makedirs(persistent_dir, exist_ok=True)
37
 
 
44
  for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
45
  os.makedirs(directory, exist_ok=True)
46
 
47
+ os.environ.update({
48
+ "HF_HOME": model_cache_dir,
49
+ "TRANSFORMERS_CACHE": model_cache_dir,
50
+ "VLLM_CACHE_DIR": vllm_cache_dir,
51
+ "TOKENIZERS_PARALLELISM": "false",
52
+ "CUDA_LAUNCH_BLOCKING": "1"
53
+ })
54
 
55
  current_dir = os.path.dirname(os.path.abspath(__file__))
56
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
 
61
  # Initialize cache with 10GB limit
62
  cache = Cache(file_cache_dir, size_limit=10 * 1024**3)
63
 
64
+ # Initialize tokenizer for precise chunking (with caching)
65
+ @lru_cache(maxsize=1)
66
+ def get_tokenizer():
67
+ return AutoTokenizer.from_pretrained("mims-harvard/TxAgent-T1-Llama-3.1-8B")
68
+
69
  def sanitize_utf8(text: str) -> str:
70
+ """Optimized UTF-8 sanitization"""
71
  return text.encode("utf-8", "ignore").decode("utf-8")
72
 
73
  def file_hash(path: str) -> str:
74
+ """Optimized file hashing with buffer reading"""
75
+ hash_md5 = hashlib.md5()
76
  with open(path, "rb") as f:
77
+ for chunk in iter(lambda: f.read(4096), b""):
78
+ hash_md5.update(chunk)
79
+ return hash_md5.hexdigest()
80
+
81
+ def extract_pdf_page(page) -> str:
82
+ """Optimized single page extraction"""
83
+ try:
84
+ text = page.extract_text() or ""
85
+ return f"=== Page {page.page_number} ===\n{text.strip()}"
86
+ except Exception as e:
87
+ logger.warning(f"Error extracting page {page.page_number}: {str(e)}")
88
+ return ""
89
 
90
  def extract_all_pages(file_path: str, progress_callback=None) -> str:
91
+ """Optimized PDF extraction with memory management"""
92
  try:
93
  with pdfplumber.open(file_path) as pdf:
94
  total_pages = len(pdf.pages)
95
  if total_pages == 0:
96
  return ""
97
 
98
+ results = []
99
+ for chunk_start in range(0, total_pages, CHUNK_SIZE):
100
+ chunk_end = min(chunk_start + CHUNK_SIZE, total_pages)
101
+
 
 
 
102
  with pdfplumber.open(file_path) as pdf:
103
+ with ThreadPoolExecutor(max_workers=min(CHUNK_SIZE, 4)) as executor:
104
+ futures = [executor.submit(extract_pdf_page, pdf.pages[i])
105
+ for i in range(chunk_start, chunk_end)]
106
+
107
+ for future in as_completed(futures):
108
+ results.append(future.result())
109
+
110
+ if progress_callback:
111
+ progress_callback(min(chunk_end, total_pages), total_pages)
112
+
113
+ del pdf
114
+ gc.collect()
115
+
116
+ return "\n\n".join(filter(None, results))
 
 
117
  except Exception as e:
118
+ logger.error(f"PDF processing error: {e}")
119
  return f"PDF processing error: {str(e)}"
120
 
121
+ def excel_to_json(file_path: str) -> List[Dict]:
122
+ """Optimized Excel processing with chunking"""
123
+ try:
124
+ for engine in ['openpyxl', 'xlrd']:
125
+ try:
126
+ df = pd.read_excel(
127
+ file_path,
128
+ engine=engine,
129
+ header=None,
130
+ dtype=str,
131
+ na_filter=False
132
+ )
133
+ return [{
134
+ "filename": os.path.basename(file_path),
135
+ "rows": df.values.tolist(),
136
+ "type": "excel"
137
+ }]
138
+ except Exception:
139
+ continue
140
+ raise Exception("No suitable Excel engine found")
141
+ except Exception as e:
142
+ logger.error(f"Excel processing error: {e}")
143
+ return [{"error": f"Excel processing error: {str(e)}"}]
144
+
145
+ def csv_to_json(file_path: str) -> List[Dict]:
146
+ """Optimized CSV processing with chunking"""
147
  try:
148
+ chunks = []
149
+ for chunk in pd.read_csv(
150
+ file_path,
151
+ header=None,
152
+ dtype=str,
153
+ encoding_errors='replace',
154
+ on_bad_lines='skip',
155
+ chunksize=10000,
156
+ na_filter=False
157
+ ):
158
+ chunks.append(chunk)
159
+
160
+ df = pd.concat(chunks) if chunks else pd.DataFrame()
161
+ return [{
162
+ "filename": os.path.basename(file_path),
163
+ "rows": df.values.tolist(),
164
+ "type": "csv"
165
+ }]
166
+ except Exception as e:
167
+ logger.error(f"CSV processing error: {e}")
168
+ return [{"error": f"CSV processing error: {str(e)}"}]
169
 
170
+ @lru_cache(maxsize=100)
171
+ def process_file_cached(file_path: str, file_type: str) -> List[Dict]:
172
+ """Cached file processing with memory optimization"""
173
+ try:
174
  if file_type == "pdf":
175
+ text = extract_all_pages(file_path)
176
+ return [{
177
+ "filename": os.path.basename(file_path),
178
+ "content": text,
179
+ "status": "initial",
180
+ "type": "pdf"
181
+ }]
182
  elif file_type in ["xls", "xlsx"]:
183
+ return excel_to_json(file_path)
184
+ elif file_type == "csv":
185
+ return csv_to_json(file_path)
 
 
 
186
  else:
187
+ return [{"error": f"Unsupported file type: {file_type}"}]
 
 
 
188
  except Exception as e:
189
+ logger.error(f"Error processing {os.path.basename(file_path)}: {e}")
190
+ return [{"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"}]
191
+
192
+ def tokenize_and_chunk(text: str, max_tokens: int = MAX_TOKENS) -> List[str]:
193
+ """Optimized tokenization and chunking"""
194
+ tokenizer = get_tokenizer()
195
+ tokens = tokenizer.encode(text, add_special_tokens=False)
196
+ return [
197
+ tokenizer.decode(tokens[i:i + max_tokens])
198
+ for i in range(0, len(tokens), max_tokens)
199
+ ]
200
 
201
  def log_system_usage(tag=""):
202
+ """Optimized system monitoring"""
203
  try:
204
+ cpu = psutil.cpu_percent(interval=0.5)
205
  mem = psutil.virtual_memory()
206
+ logger.info(f"[{tag}] CPU: {cpu:.1f}% | RAM: {mem.used // (1024**2)}MB / {mem.total // (1024**2)}MB")
207
+
208
+ try:
209
+ result = subprocess.run(
210
+ ["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"],
211
+ capture_output=True,
212
+ text=True,
213
+ timeout=2
214
+ )
215
+ if result.returncode == 0:
216
+ used, total, util = result.stdout.strip().split(", ")
217
+ logger.info(f"[{tag}] GPU: {used}MB / {total}MB | Utilization: {util}%")
218
+ except subprocess.TimeoutExpired:
219
+ logger.warning(f"[{tag}] GPU monitoring timed out")
220
  except Exception as e:
221
+ logger.error(f"[{tag}] Monitor failed: {e}")
222
 
223
  def clean_response(text: str) -> str:
224
+ """Enhanced response cleaning with aggressive deduplication"""
225
+ if not text:
226
+ return ""
227
+
228
+ patterns = [
229
+ (re.compile(r"\[.*?\]|\bNone\b", re.IGNORECASE), ""),
230
+ (re.compile(r"(The patient record excerpt provides|Patient record excerpt contains).*?(John Doe|general information).*?\.", re.IGNORECASE), ""),
231
+ (re.compile(r"To (analyze|proceed).*?medications\.", re.IGNORECASE), ""),
232
+ (re.compile(r"Since the previous attempts.*?\.", re.IGNORECASE), ""),
233
+ (re.compile(r"I need to.*?results\.", re.IGNORECASE), ""),
234
+ (re.compile(r"(Therefore, )?(Retrieving|I will start by retrieving) tools.*?\.", re.IGNORECASE), ""),
235
+ (re.compile(r"This requires reviewing.*?\.", re.IGNORECASE), ""),
236
+ (re.compile(r"Given the context, it is important to review.*?\.", re.IGNORECASE), ""),
237
+ (re.compile(r"Final Analysis\s*", re.IGNORECASE), ""),
238
+ (re.compile(r"Therefore, no missed diagnoses can be identified.*?\.", re.IGNORECASE), ""),
239
+ (re.compile(r"\s+"), " "),
240
+ (re.compile(r"[^\w\s\.\,\(\)\-]"), ""),
241
+ (re.compile(r"(No missed diagnoses identified\.)\s*\1+", re.IGNORECASE), r"\1"),
242
+ ]
243
+
244
+ for pattern, repl in patterns:
245
+ text = pattern.sub(repl, text)
246
+
247
+ sentences = text.split(". ")
248
+ unique_sentences = []
249
+ seen = set()
250
+
251
+ for s in sentences:
252
+ if not s:
253
  continue
254
+ is_unique = True
255
+ for seen_s in seen:
256
+ if SequenceMatcher(None, s.lower(), seen_s.lower()).ratio() > 0.9:
257
+ is_unique = False
258
+ break
259
+ if is_unique:
260
+ unique_sentences.append(s)
261
+ seen.add(s)
262
+
263
+ text = ". ".join(unique_sentences).strip()
264
+
265
+ return text if text else "No missed diagnoses identified."
266
 
267
  def summarize_findings(combined_response: str) -> str:
268
+ """Enhanced findings summarization for a single, concise paragraph"""
269
+ if not combined_response:
270
+ return "No missed diagnoses were identified in the provided records."
271
+
272
+ diagnosis_pattern = re.compile(r"-\s*(.+)$")
273
+ section_pattern = re.compile(r"###\s*(Missed Diagnoses|Medication Conflicts|Incomplete Assessments|Urgent Follow-up)")
274
+ no_issues_pattern = re.compile(r"No issues identified|No missed diagnoses identified", re.IGNORECASE)
275
+
276
  diagnoses = []
277
+ current_section = None
278
+
279
+ for line in combined_response.splitlines():
280
+ line = line.strip()
281
+ if not line:
282
  continue
283
+
284
+ section_match = section_pattern.match(line)
285
+ if section_match:
286
+ current_section = "diagnoses" if section_match.group(1) == "Missed Diagnoses" else None
287
+ continue
288
+
289
+ if current_section == "diagnoses":
290
+ diagnosis_match = diagnosis_pattern.match(line)
291
+ if diagnosis_match and not no_issues_pattern.search(line):
292
+ diagnosis = diagnosis_match.group(1).strip()
293
+ if diagnosis:
 
 
 
 
 
294
  diagnoses.append(diagnosis)
295
+
296
+ medication_pattern = re.compile(r"medications includ(?:e|ing|ed) ([^\.]+)", re.IGNORECASE)
297
+ evaluation_pattern = re.compile(r"psychiatric evaluation.*?mention of ([^\.]+)", re.IGNORECASE)
298
+
299
+ for line in combined_response.splitlines():
300
+ line = line.strip()
301
+ if not line or no_issues_pattern.search(line):
302
+ continue
303
+
304
+ med_match = medication_pattern.search(line)
305
+ if med_match:
306
+ meds = med_match.group(1).strip()
307
+ diagnoses.append(f"use of medications ({meds}), suggesting an undiagnosed psychiatric condition requiring urgent review")
308
+
309
+ eval_match = evaluation_pattern.search(line)
310
+ if eval_match:
311
+ details = eval_match.group(1).strip()
312
+ diagnoses.append(f"psychiatric evaluation noting {details}, indicating a potential missed psychiatric diagnosis requiring urgent review")
313
+
314
+ if not diagnoses:
315
+ return "No missed diagnoses were identified in the provided records."
316
+
317
  seen = set()
318
  unique_diagnoses = [d for d in diagnoses if not (d in seen or seen.add(d))]
319
 
320
+ summary = "The patient record indicates missed diagnoses including "
321
+ summary += ", ".join(unique_diagnoses[:-1])
322
+ summary += f", and {unique_diagnoses[-1]}" if len(unique_diagnoses) > 1 else unique_diagnoses[0]
323
+ summary += ". These findings suggest potential oversights in the patient's medical evaluation and require urgent clinical review to prevent adverse outcomes."
 
 
 
 
 
 
324
 
325
+ return summary
326
 
327
+ @lru_cache(maxsize=1)
328
  def init_agent():
329
+ """Cached agent initialization with memory optimization"""
330
  logger.info("Initializing model...")
331
  log_system_usage("Before Load")
332
+
333
  default_tool_path = os.path.abspath("data/new_tool.json")
334
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
335
  if not os.path.exists(target_tool_path):
 
346
  additional_default_tools=[],
347
  )
348
  agent.init_model()
349
+
350
  log_system_usage("After Load")
351
  logger.info("Agent Ready")
352
  return agent
353
 
354
  def create_ui(agent):
355
+ """Optimized UI creation with pre-compiled templates"""
356
+ PROMPT_TEMPLATE = """
357
+ Analyze the patient record excerpt for missed diagnoses, focusing ONLY on clinical findings such as symptoms, medications, or evaluation results provided in the excerpt. Provide a detailed, evidence-based analysis using all available tools (e.g., Tool_RAG, CallAgent) to identify potential oversights. Include specific findings (e.g., 'elevated blood pressure (160/95)'), their implications (e.g., 'may indicate untreated hypertension'), and recommend urgent review. Treat medications or psychiatric evaluations as potential missed diagnoses. Do NOT repeat non-clinical information (e.g., name, date of birth, allergies). If no clinical findings are present, state 'No missed diagnoses identified' in ONE sentence. Ignore other oversight categories (e.g., medication conflicts).
 
 
 
 
 
 
 
 
 
358
  Patient Record Excerpt (Chunk {0} of {1}):
359
  {chunk}
360
  """
361
 
362
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
363
+ gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
364
+
365
+ with gr.Row():
366
+ with gr.Column(scale=3):
367
+ chatbot = gr.Chatbot(label="Analysis Summary", height=600, type="messages")
368
+ msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
369
+ send_btn = gr.Button("Analyze", variant="primary")
370
+ file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
371
+
372
+ with gr.Column(scale=1):
373
+ final_summary = gr.Markdown(label="Missed Diagnoses Summary")
374
+ download_output = gr.File(label="Download Detailed Report")
375
+ progress_bar = gr.Progress()
376
+
377
  def analyze(message: str, history: List[dict], files: List, progress=gr.Progress()):
378
+ """Optimized analysis pipeline with quick summary and background report"""
379
  history.append({"role": "user", "content": message})
380
  yield history, None, ""
381
 
382
+ extracted = []
383
  file_hash_value = ""
384
+
385
  if files:
386
+ for f in files:
387
+ file_type = f.name.split(".")[-1].lower()
388
+ cache_key = f"{file_hash(f.name)}_{file_type}"
389
+
390
+ if cache_key in cache:
391
+ extracted.extend(cache[cache_key])
392
+ else:
393
+ result = process_file_cached(f.name, file_type)
394
+ cache[cache_key] = result
395
+ extracted.extend(result)
396
+
397
+ file_hash_value = file_hash(files[0].name) if files else ""
398
+ history.append({"role": "assistant", "content": "✅ File processing complete"})
399
  yield history, None, ""
400
 
401
+ text_content = "\n".join(json.dumps(item, ensure_ascii=False) for item in extracted)
402
+ del extracted
403
+ gc.collect()
 
404
 
405
+ chunks = tokenize_and_chunk(text_content)
406
+ del text_content
407
+ gc.collect()
408
+
409
+ combined_response = ""
410
+ report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None
411
+ seen_responses = set()
412
+
413
  try:
414
+ for batch_idx in range(0, len(chunks), BATCH_SIZE):
415
+ batch_chunks = chunks[batch_idx:batch_idx + BATCH_SIZE]
416
+
417
+ batch_prompts = [
418
+ PROMPT_TEMPLATE.format(
419
+ batch_idx + i + 1,
420
+ len(chunks),
421
+ chunk=chunk[:1800]
422
+ )
423
+ for i, chunk in enumerate(batch_chunks)
424
+ ]
425
+
426
+ progress(batch_idx / len(chunks),
427
+ desc=f"Processing batch {(batch_idx // BATCH_SIZE) + 1}/{(len(chunks) + BATCH_SIZE - 1) // BATCH_SIZE}")
428
+
429
+ with ThreadPoolExecutor(max_workers=min(BATCH_SIZE, MAX_WORKERS)) as executor:
430
+ quick_futures = {
431
+ executor.submit(
432
+ agent.run_quick_summary,
433
+ chunk, 0.2, 256, 1024
434
+ ): idx
435
+ for idx, chunk in enumerate(batch_chunks)
436
+ }
437
+
438
+ for future in as_completed(quick_futures):
439
+ chunk_idx = quick_futures[future]
440
+ try:
441
+ quick_response = clean_response(future.result())
442
+ if quick_response and quick_response != "No missed diagnoses identified.":
443
+ is_unique = True
444
+ for seen_response in seen_responses:
445
+ if SequenceMatcher(None, quick_response.lower(), seen_response.lower()).ratio() > 0.9:
446
+ is_unique = False
447
+ break
448
+ if is_unique:
449
+ combined_response += f"--- Quick Analysis for Chunk {batch_idx + chunk_idx + 1} ---\n{quick_response}\n"
450
+ seen_responses.add(quick_response)
451
+ history[-1] = {"role": "assistant", "content": combined_response.strip()}
452
+ yield history, None, ""
453
+ finally:
454
+ del future
455
+ torch.cuda.empty_cache()
456
+ gc.collect()
457
+
458
+ # Start background detailed analysis
459
+ with ThreadPoolExecutor(max_workers=min(BATCH_SIZE, MAX_WORKERS)) as executor:
460
+ detailed_futures = {
461
+ executor.submit(
462
+ agent.run_gradio_chat,
463
+ prompt, [], 0.2, 512, 2048, False, None, 3, None, 0, None, report_path
464
+ ): idx
465
+ for idx, prompt in enumerate(batch_prompts)
466
+ }
467
+
468
+ for future in as_completed(detailed_futures):
469
+ chunk_idx = detailed_futures[future]
470
+ try:
471
+ for chunk_output in future.result():
472
+ if isinstance(chunk_output, list):
473
+ for msg in chunk_output:
474
+ if isinstance(msg, ChatMessage) and msg.content:
475
+ combined_response += clean_response(msg.content) + "\n"
476
+ history[-1] = {"role": "assistant", "content": combined_response.strip()}
477
+ yield history, report_path, ""
478
+ finally:
479
+ del future
480
+ torch.cuda.empty_cache()
481
+ gc.collect()
482
 
483
  summary = summarize_findings(combined_response)
484
+
485
+ if report_path and os.path.exists(report_path):
486
+ history.append({"role": "assistant", "content": "Detailed report ready for download."})
487
+ yield history, report_path, summary
488
+ else:
489
+ history.append({"role": "assistant", "content": "Detailed report still processing."})
490
+ yield history, None, summary
491
 
492
  except Exception as e:
493
+ logger.error(f"Analysis error: {e}")
494
  history.append({"role": "assistant", "content": f"❌ Error occurred: {str(e)}"})
495
  yield history, None, f"Error occurred during analysis: {str(e)}"
496
+ finally:
497
+ torch.cuda.empty_cache()
498
+ gc.collect()
499
+
500
+ send_btn.click(
501
+ analyze,
502
+ inputs=[msg_input, gr.State([]), file_upload],
503
+ outputs=[chatbot, download_output, final_summary]
504
+ )
505
+ msg_input.submit(
506
+ analyze,
507
+ inputs=[msg_input, gr.State([]), file_upload],
508
+ outputs=[chatbot, download_output, final_summary]
509
+ )
510
+
511
  return demo
512
 
513
  if __name__ == "__main__":
514
  try:
515
+ logger.info("Launching optimized app...")
516
  agent = init_agent()
517
  demo = create_ui(agent)
518
+ demo.queue(
519
+ api_open=False,
520
+ max_size=20
521
+ ).launch(
522
  server_name="0.0.0.0",
523
  server_port=7860,
524
  show_error=True,
525
  allowed_paths=[report_dir],
526
  share=False
527
  )
528
+ except Exception as e:
529
+ logger.error(f"Fatal error: {e}")
530
+ raise
531
  finally:
532
  if torch.distributed.is_initialized():
533
  torch.distributed.destroy_process_group()