Ali2206 commited on
Commit
8a7f6db
·
verified ·
1 Parent(s): e0669ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +203 -405
app.py CHANGED
@@ -4,7 +4,7 @@ import pandas as pd
4
  import pdfplumber
5
  import json
6
  import gradio as gr
7
- from typing import List, Dict, Optional, Generator
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
  import hashlib
10
  import shutil
@@ -16,22 +16,12 @@ import torch
16
  import gc
17
  from diskcache import Cache
18
  import time
19
- from transformers import AutoTokenizer
20
- from functools import lru_cache
21
- import numpy as np
22
- from difflib import SequenceMatcher
23
 
24
  # Configure logging
25
  logging.basicConfig(level=logging.INFO)
26
  logger = logging.getLogger(__name__)
27
 
28
- # Constants
29
- MAX_TOKENS = 1800
30
- BATCH_SIZE = 2
31
- MAX_WORKERS = 4
32
- CHUNK_SIZE = 10 # For PDF processing
33
-
34
- # Persistent directory setup
35
  persistent_dir = "/data/hf_cache"
36
  os.makedirs(persistent_dir, exist_ok=True)
37
 
@@ -44,13 +34,11 @@ vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache")
44
  for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
45
  os.makedirs(directory, exist_ok=True)
46
 
47
- os.environ.update({
48
- "HF_HOME": model_cache_dir,
49
- "TRANSFORMERS_CACHE": model_cache_dir,
50
- "VLLM_CACHE_DIR": vllm_cache_dir,
51
- "TOKENIZERS_PARALLELISM": "false",
52
- "CUDA_LAUNCH_BLOCKING": "1"
53
- })
54
 
55
  current_dir = os.path.dirname(os.path.abspath(__file__))
56
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
@@ -61,294 +49,174 @@ from txagent.txagent import TxAgent
61
  # Initialize cache with 10GB limit
62
  cache = Cache(file_cache_dir, size_limit=10 * 1024**3)
63
 
64
- # Initialize tokenizer for precise chunking (with caching)
65
- @lru_cache(maxsize=1)
66
- def get_tokenizer():
67
- return AutoTokenizer.from_pretrained("mims-harvard/TxAgent-T1-Llama-3.1-8B")
68
-
69
  def sanitize_utf8(text: str) -> str:
70
- """Optimized UTF-8 sanitization"""
71
  return text.encode("utf-8", "ignore").decode("utf-8")
72
 
73
  def file_hash(path: str) -> str:
74
- """Optimized file hashing with buffer reading"""
75
- hash_md5 = hashlib.md5()
76
  with open(path, "rb") as f:
77
- for chunk in iter(lambda: f.read(4096), b""):
78
- hash_md5.update(chunk)
79
- return hash_md5.hexdigest()
80
-
81
- def extract_pdf_page(page) -> str:
82
- """Optimized single page extraction"""
83
- try:
84
- text = page.extract_text() or ""
85
- return f"=== Page {page.page_number} ===\n{text.strip()}"
86
- except Exception as e:
87
- logger.warning(f"Error extracting page {page.page_number}: {str(e)}")
88
- return ""
89
 
90
  def extract_all_pages(file_path: str, progress_callback=None) -> str:
91
- """Optimized PDF extraction with memory management"""
92
  try:
93
  with pdfplumber.open(file_path) as pdf:
94
  total_pages = len(pdf.pages)
95
  if total_pages == 0:
96
  return ""
97
 
98
- # Process in chunks with memory cleanup
99
- results = []
100
- for chunk_start in range(0, total_pages, CHUNK_SIZE):
101
- chunk_end = min(chunk_start + CHUNK_SIZE, total_pages)
102
-
 
 
103
  with pdfplumber.open(file_path) as pdf:
104
- with ThreadPoolExecutor(max_workers=min(CHUNK_SIZE, 4)) as executor:
105
- futures = [executor.submit(extract_pdf_page, pdf.pages[i])
106
- for i in range(chunk_start, chunk_end)]
107
-
108
- for future in as_completed(futures):
109
- results.append(future.result())
110
-
111
- if progress_callback:
112
- progress_callback(min(chunk_end, total_pages), total_pages)
113
-
114
- # Explicit cleanup
115
- del pdf
116
- gc.collect()
117
-
118
- return "\n\n".join(filter(None, results))
 
119
  except Exception as e:
120
- logger.error(f"PDF processing error: {e}")
121
  return f"PDF processing error: {str(e)}"
122
 
123
- def excel_to_json(file_path: str) -> List[Dict]:
124
- """Optimized Excel processing with chunking"""
125
- try:
126
- # Try fastest engines first
127
- for engine in ['openpyxl', 'xlrd']:
128
- try:
129
- df = pd.read_excel(
130
- file_path,
131
- engine=engine,
132
- header=None,
133
- dtype=str,
134
- na_filter=False
135
- )
136
- return [{
137
- "filename": os.path.basename(file_path),
138
- "rows": df.values.tolist(),
139
- "type": "excel"
140
- }]
141
- except Exception:
142
- continue
143
- raise Exception("No suitable Excel engine found")
144
- except Exception as e:
145
- logger.error(f"Excel processing error: {e}")
146
- return [{"error": f"Excel processing error: {str(e)}"}]
147
-
148
- def csv_to_json(file_path: str) -> List[Dict]:
149
- """Optimized CSV processing with chunking"""
150
  try:
151
- chunks = []
152
- for chunk in pd.read_csv(
153
- file_path,
154
- header=None,
155
- dtype=str,
156
- encoding_errors='replace',
157
- on_bad_lines='skip',
158
- chunksize=10000,
159
- na_filter=False
160
- ):
161
- chunks.append(chunk)
162
-
163
- df = pd.concat(chunks) if chunks else pd.DataFrame()
164
- return [{
165
- "filename": os.path.basename(file_path),
166
- "rows": df.values.tolist(),
167
- "type": "csv"
168
- }]
169
- except Exception as e:
170
- logger.error(f"CSV processing error: {e}")
171
- return [{"error": f"CSV processing error: {str(e)}"}]
172
 
173
- @lru_cache(maxsize=100)
174
- def process_file_cached(file_path: str, file_type: str) -> List[Dict]:
175
- """Cached file processing with memory optimization"""
176
- try:
177
  if file_type == "pdf":
178
- text = extract_all_pages(file_path)
179
- return [{
180
- "filename": os.path.basename(file_path),
181
- "content": text,
182
- "status": "initial",
183
- "type": "pdf"
184
- }]
185
- elif file_type in ["xls", "xlsx"]:
186
- return excel_to_json(file_path)
187
  elif file_type == "csv":
188
- return csv_to_json(file_path)
 
 
 
 
 
 
 
 
 
 
189
  else:
190
- return [{"error": f"Unsupported file type: {file_type}"}]
 
 
 
191
  except Exception as e:
192
- logger.error(f"Error processing {os.path.basename(file_path)}: {e}")
193
- return [{"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"}]
194
-
195
- def tokenize_and_chunk(text: str, max_tokens: int = MAX_TOKENS) -> List[str]:
196
- """Optimized tokenization and chunking"""
197
- tokenizer = get_tokenizer()
198
- tokens = tokenizer.encode(text, add_special_tokens=False)
199
- return [
200
- tokenizer.decode(tokens[i:i + max_tokens])
201
- for i in range(0, len(tokens), max_tokens)
202
- ]
203
 
204
  def log_system_usage(tag=""):
205
- """Optimized system monitoring"""
206
  try:
207
- cpu = psutil.cpu_percent(interval=0.5)
208
  mem = psutil.virtual_memory()
209
- logger.info(f"[{tag}] CPU: {cpu:.1f}% | RAM: {mem.used // (1024**2)}MB / {mem.total // (1024**2)}MB")
210
-
211
- # GPU monitoring with timeout
212
- try:
213
- result = subprocess.run(
214
- ["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"],
215
- capture_output=True,
216
- text=True,
217
- timeout=2
218
- )
219
- if result.returncode == 0:
220
- used, total, util = result.stdout.strip().split(", ")
221
- logger.info(f"[{tag}] GPU: {used}MB / {total}MB | Utilization: {util}%")
222
- except subprocess.TimeoutExpired:
223
- logger.warning(f"[{tag}] GPU monitoring timed out")
224
  except Exception as e:
225
- logger.error(f"[{tag}] Monitor failed: {e}")
226
 
227
  def clean_response(text: str) -> str:
228
- """Enhanced response cleaning with aggressive deduplication"""
229
- if not text:
230
- return ""
231
-
232
- # Pre-compiled regex patterns for cleaning
233
- patterns = [
234
- (re.compile(r"\[.*?\]|\bNone\b", re.IGNORECASE), ""),
235
- (re.compile(r"(The patient record excerpt provides|Patient record excerpt contains).*?(John Doe|general information).*?\.", re.IGNORECASE), ""),
236
- (re.compile(r"To (analyze|proceed).*?medications\.", re.IGNORECASE), ""),
237
- (re.compile(r"Since the previous attempts.*?\.", re.IGNORECASE), ""),
238
- (re.compile(r"I need to.*?results\.", re.IGNORECASE), ""),
239
- (re.compile(r"(Therefore, )?(Retrieving|I will start by retrieving) tools.*?\.", re.IGNORECASE), ""),
240
- (re.compile(r"This requires reviewing.*?\.", re.IGNORECASE), ""),
241
- (re.compile(r"Given the context, it is important to review.*?\.", re.IGNORECASE), ""),
242
- (re.compile(r"Final Analysis\s*", re.IGNORECASE), ""),
243
- (re.compile(r"\s+"), " "),
244
- (re.compile(r"[^\w\s\.\,\(\)\-]"), ""),
245
- (re.compile(r"(No missed diagnoses identified\.)\s*\1+", re.IGNORECASE), r"\1"),
246
- ]
247
-
248
- for pattern, repl in patterns:
249
- text = pattern.sub(repl, text)
250
-
251
- # Deduplicate near-identical sentences using similarity threshold
252
- sentences = text.split(". ")
253
- unique_sentences = []
254
- seen = set()
255
-
256
- for s in sentences:
257
- if not s:
258
- continue
259
- # Check similarity with existing sentences
260
- is_unique = True
261
- for seen_s in seen:
262
- if SequenceMatcher(None, s.lower(), seen_s.lower()).ratio() > 0.9:
263
- is_unique = False
264
- break
265
- if is_unique:
266
- unique_sentences.append(s)
267
- seen.add(s)
268
-
269
- text = ". ".join(unique_sentences).strip()
270
-
271
- return text if text else "No missed diagnoses identified."
272
-
273
- def summarize_findings(combined_response: str) -> str:
274
- """Enhanced findings summarization for a single, concise paragraph"""
275
- if not combined_response:
276
- return "No missed diagnoses were identified in the provided records."
277
-
278
- # Pre-compiled regex patterns
279
- diagnosis_pattern = re.compile(r"-\s*(.+)$")
280
- section_pattern = re.compile(r"###\s*(Missed Diagnoses|Medication Conflicts|Incomplete Assessments|Urgent Follow-up)")
281
- no_issues_pattern = re.compile(r"No issues identified|No missed diagnoses identified", re.IGNORECASE)
282
-
283
  diagnoses = []
284
- current_section = None
285
-
286
- for line in combined_response.splitlines():
287
  line = line.strip()
288
  if not line:
289
  continue
290
-
291
- # Check section headers
292
- section_match = section_pattern.match(line)
293
- if section_match:
294
- current_section = "diagnoses" if section_match.group(1) == "Missed Diagnoses" else None
295
  continue
296
-
297
- # Process diagnosis lines in the correct section
298
- if current_section == "diagnoses":
299
- diagnosis_match = diagnosis_pattern.match(line)
300
- if diagnosis_match and not no_issues_pattern.search(line):
301
- diagnosis = diagnosis_match.group(1).strip()
302
- if diagnosis:
303
- diagnoses.append(diagnosis)
304
-
305
- # Extract findings from non-sectioned text
306
- medication_pattern = re.compile(r"medications includ(?:e|ing|ed) ([^\.]+)", re.IGNORECASE)
307
- evaluation_pattern = re.compile(r"psychiatric evaluation.*?mention of ([^\.]+)", re.IGNORECASE)
308
-
309
- for line in combined_response.splitlines():
310
- line = line.strip()
311
- if not line or no_issues_pattern.search(line):
312
  continue
313
-
314
- med_match = medication_pattern.search(line)
315
- if med_match:
316
- meds = med_match.group(1).strip()
317
- diagnoses.append(f"use of medications ({meds}), suggesting an undiagnosed psychiatric condition requiring urgent review")
318
-
319
- eval_match = evaluation_pattern.search(line)
320
- if eval_match:
321
- details = eval_match.group(1).strip()
322
- diagnoses.append(f"psychiatric evaluation noting {details}, indicating a potential missed psychiatric diagnosis requiring urgent review")
323
-
324
- if not diagnoses:
325
- return "No missed diagnoses were identified in the provided records."
326
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
  # Remove duplicates while preserving order
328
  seen = set()
329
  unique_diagnoses = [d for d in diagnoses if not (d in seen or seen.add(d))]
330
 
331
- # Create a single paragraph
332
- summary = "The patient record indicates missed diagnoses including "
333
- summary += ", ".join(unique_diagnoses[:-1])
334
- summary += f", and {unique_diagnoses[-1]}" if len(unique_diagnoses) > 1 else unique_diagnoses[0]
335
- summary += ". These findings suggest potential oversights in the patient's medical evaluation and require urgent clinical review to prevent adverse outcomes."
 
 
 
 
 
336
 
337
- return summary
338
 
339
- @lru_cache(maxsize=1)
340
  def init_agent():
341
- """Cached agent initialization with memory optimization"""
342
  logger.info("Initializing model...")
343
  log_system_usage("Before Load")
344
-
345
- # Tool setup
346
  default_tool_path = os.path.abspath("data/new_tool.json")
347
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
348
  if not os.path.exists(target_tool_path):
349
  shutil.copy(default_tool_path, target_tool_path)
350
 
351
- # Initialize with optimized settings
352
  agent = TxAgent(
353
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
354
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
@@ -360,192 +228,122 @@ def init_agent():
360
  additional_default_tools=[],
361
  )
362
  agent.init_model()
363
-
364
  log_system_usage("After Load")
365
  logger.info("Agent Ready")
366
  return agent
367
 
368
  def create_ui(agent):
369
- """Optimized UI creation with pre-compiled templates"""
370
- PROMPT_TEMPLATE = """
371
- Analyze the patient record excerpt for missed diagnoses, focusing ONLY on clinical findings such as symptoms, medications, or evaluation results provided in the excerpt. Provide a concise, evidence-based summary in ONE paragraph without headings, bullet points, or repeating non-clinical data (e.g., name, date of birth, allergies). Include specific findings (e.g., 'elevated blood pressure (160/95)'), their implications (e.g., 'may indicate untreated hypertension'), and recommend urgent review. Treat medications or psychiatric evaluations as potential missed diagnoses. Do NOT use external tools, retrieve additional data, or summarize non-clinical information. If no clinical findings are present, state 'No missed diagnoses identified' in ONE sentence. Ignore other oversight categories (e.g., medication conflicts).
 
 
 
 
 
 
 
 
 
372
  Patient Record Excerpt (Chunk {0} of {1}):
373
  {chunk}
374
  """
375
 
376
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
377
- gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
378
-
379
- with gr.Row():
380
- with gr.Column(scale=3):
381
- chatbot = gr.Chatbot(label="Detailed Analysis", height=600, type="messages")
382
- msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
383
- send_btn = gr.Button("Analyze", variant="primary")
384
- file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
385
-
386
- with gr.Column(scale=1):
387
- final_summary = gr.Markdown(label="Summary of Missed Diagnoses")
388
- download_output = gr.File(label="Download Full Report")
389
- progress_bar = gr.Progress()
390
-
391
  def analyze(message: str, history: List[dict], files: List, progress=gr.Progress()):
392
- """Optimized analysis pipeline with memory management"""
393
  history.append({"role": "user", "content": message})
394
  yield history, None, ""
395
 
396
- # Process files with caching
397
- extracted = []
398
  file_hash_value = ""
399
-
400
  if files:
401
- for f in files:
402
- file_type = f.name.split(".")[-1].lower()
403
- cache_key = f"{file_hash(f.name)}_{file_type}"
404
-
405
- if cache_key in cache:
406
- extracted.extend(cache[cache_key])
407
- else:
408
- result = process_file_cached(f.name, file_type)
409
- cache[key] = result
410
- extracted.extend(result)
411
-
412
- file_hash_value = file_hash(files[0].name) if files else ""
413
- history.append({"role": "assistant", "content": "✅ File processing complete"})
414
- yield history, None, ""
415
 
416
- # Convert to text with memory efficiency
417
- text_content = "\n".join(json.dumps(item, ensure_ascii=False) for item in extracted)
418
- del extracted
419
- gc.collect()
420
 
421
- # Tokenize and chunk
422
- chunks = tokenize_and_chunk(text_content)
423
- del text_content
424
- gc.collect()
425
-
426
  combined_response = ""
427
- report_path = None
428
- seen_responses = set() # Track unique responses to avoid repetition
429
-
430
  try:
431
- # Process in optimized batches
432
- for batch_idx in range(0, len(chunks), BATCH_SIZE):
433
- batch_chunks = chunks[batch_idx:batch_idx + BATCH_SIZE]
434
-
435
- # Prepare prompts
436
- batch_prompts = [
437
- PROMPT_TEMPLATE.format(
438
- batch_idx + i + 1,
439
- len(chunks),
440
- chunk=chunk[:1800] # Conservative size
441
- )
442
- for i, chunk in enumerate(batch_chunks)
443
- ]
444
-
445
- progress(batch_idx / len(chunks),
446
- desc=f"Analyzing batch {(batch_idx // BATCH_SIZE) + 1}/{(len(chunks) + BATCH_SIZE - 1) // BATCH_SIZE}")
447
-
448
- # Process batch
449
- with ThreadPoolExecutor(max_workers=min(BATCH_SIZE, MAX_WORKERS)) as executor:
450
- futures = {
451
- executor.submit(
452
- agent.run_gradio_chat,
453
- prompt, [], 0.2, 512, 2048, False, []
454
- ): idx
455
- for idx, prompt in enumerate(batch_prompts)
456
- }
457
-
458
  for future in as_completed(futures):
459
- chunk_idx = futures[future]
460
  chunk_response = ""
461
-
462
- try:
463
- for chunk_output in future.result():
464
- if isinstance(chunk_output, (list, str)):
465
- content = ""
466
- if isinstance(chunk_output, list):
467
- content = " ".join(
468
- clean_response(m.content)
469
- for m in chunk_output
470
- if hasattr(m, 'content') and m.content
471
- )
472
- elif isinstance(chunk_output, str):
473
- content = clean_response(chunk_output)
474
-
475
- if content and content != "No missed diagnoses identified.":
476
- # Check for near-duplicate responses
477
- is_unique = True
478
- for seen_response in seen_responses:
479
- if SequenceMatcher(None, content.lower(), seen_response.lower()).ratio() > 0.9:
480
- is_unique = False
481
- break
482
- if is_unique:
483
- chunk_response += content + " "
484
- seen_responses.add(content)
485
-
486
- if chunk_response:
487
- combined_response += f"--- Analysis for Chunk {batch_idx + chunk_idx + 1} ---\n{chunk_response.strip()}\n"
488
- history[-1] = {"role": "assistant", "content": combined_response.strip()}
489
- yield history, None, ""
490
- finally:
491
- del future
492
- torch.cuda.empty_cache()
493
- gc.collect()
494
-
495
- # Generate final outputs
496
  summary = summarize_findings(combined_response)
497
-
498
- if file_hash_value:
499
- report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt")
500
- try:
501
- with open(report_path, "w", encoding="utf-8") as f:
502
- f.write(combined_response + "\n\n" + summary)
503
- except Exception as e:
504
- logger.error(f"Report save failed: {e}")
505
- report_path = None
506
-
507
- yield history, report_path, summary
508
 
509
  except Exception as e:
510
- logger.error(f"Analysis error: {e}")
511
  history.append({"role": "assistant", "content": f"❌ Error occurred: {str(e)}"})
512
  yield history, None, f"Error occurred during analysis: {str(e)}"
513
- finally:
514
- torch.cuda.empty_cache()
515
- gc.collect()
516
-
517
- # Event handlers
518
- send_btn.click(
519
- analyze,
520
- inputs=[msg_input, gr.State([]), file_upload],
521
- outputs=[chatbot, download_output, final_summary]
522
- )
523
- msg_input.submit(
524
- analyze,
525
- inputs=[msg_input, gr.State([]), file_upload],
526
- outputs=[chatbot, download_output, final_summary]
527
- )
528
-
529
  return demo
530
 
531
  if __name__ == "__main__":
532
  try:
533
- logger.info("Launching optimized app...")
534
  agent = init_agent()
535
  demo = create_ui(agent)
536
- demo.queue(
537
- api_open=False,
538
- max_size=20
539
- ).launch(
540
  server_name="0.0.0.0",
541
  server_port=7860,
542
  show_error=True,
543
  allowed_paths=[report_dir],
544
  share=False
545
  )
546
- except Exception as e:
547
- logger.error(f"Fatal error: {e}")
548
- raise
549
  finally:
550
  if torch.distributed.is_initialized():
551
  torch.distributed.destroy_process_group()
 
4
  import pdfplumber
5
  import json
6
  import gradio as gr
7
+ from typing import List
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
  import hashlib
10
  import shutil
 
16
  import gc
17
  from diskcache import Cache
18
  import time
 
 
 
 
19
 
20
  # Configure logging
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
24
+ # Persistent directory
 
 
 
 
 
 
25
  persistent_dir = "/data/hf_cache"
26
  os.makedirs(persistent_dir, exist_ok=True)
27
 
 
34
  for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
35
  os.makedirs(directory, exist_ok=True)
36
 
37
+ os.environ["HF_HOME"] = model_cache_dir
38
+ os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
39
+ os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
40
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
41
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
 
 
42
 
43
  current_dir = os.path.dirname(os.path.abspath(__file__))
44
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
 
49
  # Initialize cache with 10GB limit
50
  cache = Cache(file_cache_dir, size_limit=10 * 1024**3)
51
 
 
 
 
 
 
52
  def sanitize_utf8(text: str) -> str:
 
53
  return text.encode("utf-8", "ignore").decode("utf-8")
54
 
55
  def file_hash(path: str) -> str:
 
 
56
  with open(path, "rb") as f:
57
+ return hashlib.md5(f.read()).hexdigest()
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  def extract_all_pages(file_path: str, progress_callback=None) -> str:
 
60
  try:
61
  with pdfplumber.open(file_path) as pdf:
62
  total_pages = len(pdf.pages)
63
  if total_pages == 0:
64
  return ""
65
 
66
+ batch_size = 10
67
+ batches = [(i, min(i + batch_size, total_pages)) for i in range(0, total_pages, batch_size)]
68
+ text_chunks = [""] * total_pages
69
+ processed_pages = 0
70
+
71
+ def extract_batch(start: int, end: int) -> List[tuple]:
72
+ results = []
73
  with pdfplumber.open(file_path) as pdf:
74
+ for page in pdf.pages[start:end]:
75
+ page_num = start + pdf.pages.index(page)
76
+ page_text = page.extract_text() or ""
77
+ results.append((page_num, f"=== Page {page_num + 1} ===\n{page_text.strip()}"))
78
+ return results
79
+
80
+ with ThreadPoolExecutor(max_workers=6) as executor:
81
+ futures = [executor.submit(extract_batch, start, end) for start, end in batches]
82
+ for future in as_completed(futures):
83
+ for page_num, text in future.result():
84
+ text_chunks[page_num] = text
85
+ processed_pages += batch_size
86
+ if progress_callback:
87
+ progress_callback(min(processed_pages, total_pages), total_pages)
88
+
89
+ return "\n\n".join(filter(None, text_chunks))
90
  except Exception as e:
91
+ logger.error("PDF processing error: %s", e)
92
  return f"PDF processing error: {str(e)}"
93
 
94
+ def convert_file_to_json(file_path: str, file_type: str, progress_callback=None) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  try:
96
+ file_h = file_hash(file_path)
97
+ cache_key = f"{file_h}_{file_type}"
98
+ if cache_key in cache:
99
+ return cache[cache_key]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
 
 
 
 
101
  if file_type == "pdf":
102
+ text = extract_all_pages(file_path, progress_callback)
103
+ result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
 
 
 
 
 
 
 
104
  elif file_type == "csv":
105
+ df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str,
106
+ skip_blank_lines=False, on_bad_lines="skip")
107
+ content = df.fillna("").astype(str).values.tolist()
108
+ result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
109
+ elif file_type in ["xls", "xlsx"]:
110
+ try:
111
+ df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
112
+ except Exception:
113
+ df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
114
+ content = df.fillna("").astype(str).values.tolist()
115
+ result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
116
  else:
117
+ result = json.dumps({"error": f"Unsupported file type: {file_type}"})
118
+
119
+ cache[cache_key] = result
120
+ return result
121
  except Exception as e:
122
+ logger.error("Error processing %s: %s", os.path.basename(file_path), e)
123
+ return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
 
 
 
 
 
 
 
 
 
124
 
125
  def log_system_usage(tag=""):
 
126
  try:
127
+ cpu = psutil.cpu_percent(interval=1)
128
  mem = psutil.virtual_memory()
129
+ logger.info("[%s] CPU: %.1f%% | RAM: %dMB / %dMB", tag, cpu, mem.used // (1024**2), mem.total // (1024**2))
130
+ result = subprocess.run(
131
+ ["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"],
132
+ capture_output=True, text=True
133
+ )
134
+ if result.returncode == 0:
135
+ used, total, util = result.stdout.strip().split(", ")
136
+ logger.info("[%s] GPU: %sMB / %sMB | Utilization: %s%%", tag, used, total, util)
 
 
 
 
 
 
 
137
  except Exception as e:
138
+ logger.error("[%s] GPU/CPU monitor failed: %s", tag, e)
139
 
140
  def clean_response(text: str) -> str:
141
+ text = sanitize_utf8(text)
142
+ # Remove unwanted patterns and tool call artifacts
143
+ text = re.sub(r"\[.*?\]|\bNone\b|To analyze the patient record excerpt.*?medications\.|Since the previous attempts.*?\.|I need to.*?medications\.|Retrieving tools.*?\.", "", text, flags=re.DOTALL)
144
+ # Extract only missed diagnoses, ignoring other categories
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  diagnoses = []
146
+ lines = text.splitlines()
147
+ in_diagnoses_section = False
148
+ for line in lines:
149
  line = line.strip()
150
  if not line:
151
  continue
152
+ if re.match(r"###\s*Missed Diagnoses", line):
153
+ in_diagnoses_section = True
 
 
 
154
  continue
155
+ if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line):
156
+ in_diagnoses_section = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  continue
158
+ if in_diagnoses_section and re.match(r"-\s*.+", line):
159
+ diagnosis = re.sub(r"^\-\s*", "", line).strip()
160
+ if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE):
161
+ diagnoses.append(diagnosis)
162
+ # Join diagnoses into a plain text paragraph
163
+ text = " ".join(diagnoses)
164
+ # Clean up extra whitespace and punctuation
165
+ text = re.sub(r"\s+", " ", text).strip()
166
+ text = re.sub(r"[^\w\s\.\,\(\)\-]", "", text)
167
+ return text if text else ""
168
+
169
+ def summarize_findings(combined_response: str) -> str:
170
+ # Split response by chunk analyses
171
+ chunks = combined_response.split("--- Analysis for Chunk")
172
+ diagnoses = []
173
+ for chunk in chunks:
174
+ chunk = chunk.strip()
175
+ if not chunk or "No oversights identified" in chunk:
176
+ continue
177
+ # Extract missed diagnoses from chunk
178
+ lines = chunk.splitlines()
179
+ in_diagnoses_section = False
180
+ for line in lines:
181
+ line = line.strip()
182
+ if not line:
183
+ continue
184
+ if re.match(r"###\s*Missed Diagnoses", line):
185
+ in_diagnoses_section = True
186
+ continue
187
+ if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line):
188
+ in_diagnoses_section = False
189
+ continue
190
+ if in_diagnoses_section and re.match(r"-\s*.+", line):
191
+ diagnosis = re.sub(r"^\-\s*", "", line).strip()
192
+ if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE):
193
+ diagnoses.append(diagnosis)
194
+
195
  # Remove duplicates while preserving order
196
  seen = set()
197
  unique_diagnoses = [d for d in diagnoses if not (d in seen or seen.add(d))]
198
 
199
+ if not unique_diagnoses:
200
+ return "No missed diagnoses were identified in the provided records."
201
+
202
+ # Combine into a single paragraph
203
+ summary = "Missed diagnoses include " + ", ".join(unique_diagnoses[:-1])
204
+ if len(unique_diagnoses) > 1:
205
+ summary += f", and {unique_diagnoses[-1]}"
206
+ elif len(unique_diagnoses) == 1:
207
+ summary = "Missed diagnoses include " + unique_diagnoses[0]
208
+ summary += ", all of which require urgent clinical review to prevent potential adverse outcomes."
209
 
210
+ return summary.strip()
211
 
 
212
  def init_agent():
 
213
  logger.info("Initializing model...")
214
  log_system_usage("Before Load")
 
 
215
  default_tool_path = os.path.abspath("data/new_tool.json")
216
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
217
  if not os.path.exists(target_tool_path):
218
  shutil.copy(default_tool_path, target_tool_path)
219
 
 
220
  agent = TxAgent(
221
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
222
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
 
228
  additional_default_tools=[],
229
  )
230
  agent.init_model()
 
231
  log_system_usage("After Load")
232
  logger.info("Agent Ready")
233
  return agent
234
 
235
  def create_ui(agent):
236
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
237
+ gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
238
+ chatbot = gr.Chatbot(label="Detailed Analysis", height=600, type="messages")
239
+ final_summary = gr.Markdown(label="Summary of Missed Diagnoses")
240
+ file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
241
+ msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
242
+ send_btn = gr.Button("Analyze", variant="primary")
243
+ download_output = gr.File(label="Download Full Report")
244
+ progress_bar = gr.Progress()
245
+
246
+ prompt_template = """
247
+ Analyze the patient record excerpt for missed diagnoses only. Provide a concise, evidence-based summary as a single paragraph without headings or bullet points. Include specific clinical findings (e.g., 'elevated blood pressure (160/95) on page 10'), their potential implications (e.g., 'may indicate untreated hypertension'), and a recommendation for urgent review. Do not include other oversight categories like medication conflicts. If no missed diagnoses are found, state 'No missed diagnoses identified' in a single sentence.
248
  Patient Record Excerpt (Chunk {0} of {1}):
249
  {chunk}
250
  """
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  def analyze(message: str, history: List[dict], files: List, progress=gr.Progress()):
 
253
  history.append({"role": "user", "content": message})
254
  yield history, None, ""
255
 
256
+ extracted = ""
 
257
  file_hash_value = ""
 
258
  if files:
259
+ def update_extraction_progress(current, total):
260
+ progress(current / total, desc=f"Extracting text... Page {current}/{total}")
261
+ return history, None, ""
262
+
263
+ with ThreadPoolExecutor(max_workers=6) as executor:
264
+ futures = [executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower(), update_extraction_progress) for f in files]
265
+ results = [sanitize_utf8(f.result()) for f in as_completed(futures)]
266
+ extracted = "\n".join(results)
267
+ file_hash_value = file_hash(files[0].name) if files else ""
 
 
 
 
 
268
 
269
+ history.append({"role": "assistant", "content": "✅ Text extraction complete."})
270
+ yield history, None, ""
 
 
271
 
272
+ chunk_size = 6000
273
+ chunks = [extracted[i:i + chunk_size] for i in range(0, len(extracted), chunk_size)]
 
 
 
274
  combined_response = ""
275
+ batch_size = 2
276
+
 
277
  try:
278
+ for batch_idx in range(0, len(chunks), batch_size):
279
+ batch_chunks = chunks[batch_idx:batch_idx + batch_size]
280
+ batch_prompts = [prompt_template.format(i + 1, len(chunks), chunk=chunk[:4000]) for i, chunk in enumerate(batch_chunks)]
281
+ batch_responses = []
282
+
283
+ progress((batch_idx + 1) / len(chunks), desc=f"Analyzing chunks {batch_idx + 1}-{min(batch_idx + batch_size, len(chunks))}/{len(chunks)}")
284
+
285
+ with ThreadPoolExecutor(max_workers=len(batch_chunks)) as executor:
286
+ futures = [executor.submit(agent.run_gradio_chat, prompt, [], 0.2, 512, 2048, False, []) for prompt in batch_prompts]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  for future in as_completed(futures):
 
288
  chunk_response = ""
289
+ for chunk_output in future.result():
290
+ if chunk_output is None:
291
+ continue
292
+ if isinstance(chunk_output, list):
293
+ for m in chunk_output:
294
+ if hasattr(m, 'content') and m.content:
295
+ cleaned = clean_response(m.content)
296
+ if cleaned:
297
+ chunk_response += cleaned + " "
298
+ elif isinstance(chunk_output, str) and chunk_output.strip():
299
+ cleaned = clean_response(chunk_output)
300
+ if cleaned:
301
+ chunk_response += cleaned + " "
302
+ batch_responses.append(chunk_response.strip())
303
+ torch.cuda.empty_cache()
304
+ gc.collect()
305
+
306
+ for chunk_idx, chunk_response in enumerate(batch_responses, batch_idx + 1):
307
+ if chunk_response:
308
+ combined_response += f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response}\n"
309
+ else:
310
+ combined_response += f"--- Analysis for Chunk {chunk_idx} ---\nNo missed diagnoses identified.\n"
311
+ history[-1] = {"role": "assistant", "content": combined_response.strip()}
312
+ yield history, None, ""
313
+
314
+ if combined_response.strip() and not all("No missed diagnoses identified" in chunk for chunk in combined_response.split("--- Analysis for Chunk")):
315
+ history[-1]["content"] = combined_response.strip()
316
+ else:
317
+ history.append({"role": "assistant", "content": "No missed diagnoses identified in the provided records."})
318
+
 
 
 
 
 
319
  summary = summarize_findings(combined_response)
320
+ report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None
321
+ if report_path:
322
+ with open(report_path, "w", encoding="utf-8") as f:
323
+ f.write(combined_response + "\n\n" + summary)
324
+ yield history, report_path if report_path and os.path.exists(report_path) else None, summary
 
 
 
 
 
 
325
 
326
  except Exception as e:
327
+ logger.error("Analysis error: %s", e)
328
  history.append({"role": "assistant", "content": f"❌ Error occurred: {str(e)}"})
329
  yield history, None, f"Error occurred during analysis: {str(e)}"
330
+
331
+ send_btn.click(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary])
332
+ msg_input.submit(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary])
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  return demo
334
 
335
  if __name__ == "__main__":
336
  try:
337
+ logger.info("Launching app...")
338
  agent = init_agent()
339
  demo = create_ui(agent)
340
+ demo.queue(api_open=False).launch(
 
 
 
341
  server_name="0.0.0.0",
342
  server_port=7860,
343
  show_error=True,
344
  allowed_paths=[report_dir],
345
  share=False
346
  )
 
 
 
347
  finally:
348
  if torch.distributed.is_initialized():
349
  torch.distributed.destroy_process_group()