Ali2206 commited on
Commit
e12aa83
·
verified ·
1 Parent(s): 99e7b0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +326 -249
app.py CHANGED
@@ -5,7 +5,7 @@ import pdfplumber
5
  import json
6
  import gradio as gr
7
  from typing import List, Dict, Optional, Generator
8
- from concurrent.futures import ProcessPoolExecutor, as_completed
9
  import hashlib
10
  import shutil
11
  import re
@@ -17,26 +17,20 @@ import gc
17
  from diskcache import Cache
18
  import time
19
  from transformers import AutoTokenizer
20
- import pyarrow as pa
21
- import pyarrow.csv as pc
22
- import pyarrow.parquet as pq
23
- from vllm import LLM, SamplingParams
24
- import asyncio
25
- import threading
26
 
27
  # Configure logging
28
  logging.basicConfig(level=logging.INFO)
29
  logger = logging.getLogger(__name__)
30
 
31
- # File handler for response logging
32
- response_log_file = os.path.join("/data/hf_cache", "response_log.txt")
33
- response_logger = logging.getLogger("ResponseLogger")
34
- response_handler = logging.FileHandler(response_log_file, mode="a")
35
- response_handler.setFormatter(logging.Formatter("%(asctime)s - %(message)s"))
36
- response_logger.addHandler(response_handler)
37
- response_logger.setLevel(logging.INFO)
38
 
39
- # Persistent directory
40
  persistent_dir = "/data/hf_cache"
41
  os.makedirs(persistent_dir, exist_ok=True)
42
 
@@ -49,113 +43,129 @@ vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache")
49
  for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
50
  os.makedirs(directory, exist_ok=True)
51
 
52
- os.environ["HF_HOME"] = model_cache_dir
53
- os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
54
- os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
55
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
56
- os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
57
-
58
- current_dir = os.path.dirname(os.path.abspath(__file__))
59
- src_path = os.path.abspath(os.path.join(current_dir, "src"))
60
- sys.path.insert(0, src_path)
61
-
62
- from txagent.txagent import TxAgent
63
 
64
  # Initialize cache with 10GB limit
65
  cache = Cache(file_cache_dir, size_limit=10 * 1024**3)
66
 
67
- # Initialize tokenizer for precise chunking
68
- tokenizer = AutoTokenizer.from_pretrained("mims-harvard/TxAgent-T1-Llama-3.1-8B")
 
 
69
 
70
  def sanitize_utf8(text: str) -> str:
 
71
  return text.encode("utf-8", "ignore").decode("utf-8")
72
 
73
  def file_hash(path: str) -> str:
 
 
74
  with open(path, "rb") as f:
75
- return hashlib.md5(f.read()).hexdigest()
 
 
76
 
77
- def extract_all_pages(file_path: str, progress_callback=None) -> str:
78
- cache_key = f"pdf_{file_hash(file_path)}"
79
- if cache_key in cache:
80
- return cache[cache_key]
 
 
 
 
81
 
 
 
82
  try:
83
  with pdfplumber.open(file_path) as pdf:
84
  total_pages = len(pdf.pages)
85
  if total_pages == 0:
86
  return ""
87
 
88
- batch_size = 5
89
- batches = [(i, min(i + batch_size, total_pages)) for i in range(0, total_pages, batch_size)]
90
- text_chunks = [""] * total_pages
91
- processed_pages = 0
92
-
93
- def extract_batch(start: int, end: int) -> List[tuple]:
94
- results = []
95
  with pdfplumber.open(file_path) as pdf:
96
- for page in pdf.pages[start:end]:
97
- page_num = start + pdf.pages.index(page)
98
- page_text = page.extract_text_simple() or ""
99
- results.append((page_num, f"=== Page {page_num + 1} ===\n{page_text.strip()}"))
100
- return results
101
-
102
- with ProcessPoolExecutor(max_workers=4) as executor:
103
- futures = [executor.submit(extract_batch, start, end) for start, end in batches]
104
- for future in as_completed(futures):
105
- for page_num, text in future.result():
106
- text_chunks[page_num] = text
107
- processed_pages += batch_size
108
- if progress_callback:
109
- progress_callback(min(processed_pages, total_pages), total_pages)
110
-
111
- result = "\n\n".join(filter(None, text_chunks))
112
- cache[cache_key] = result
113
- return result
114
  except Exception as e:
115
- logger.error("PDF processing error: %s", e)
116
  return f"PDF processing error: {str(e)}"
117
 
118
  def excel_to_json(file_path: str) -> List[Dict]:
119
- cache_key = f"excel_{file_hash(file_path)}"
120
- if cache_key in cache:
121
- return cache[cache_key]
122
-
123
  try:
124
- table = pq.read_table(file_path)
125
- df = table.to_pandas(use_threads=True, split_blocks=True)
126
- content = df.where(pd.notnull(df), "").astype(str).values.tolist()
127
- result = [{
128
- "filename": os.path.basename(file_path),
129
- "rows": content,
130
- "type": "excel"
131
- }]
132
- cache[cache_key] = result
133
- return result
 
 
 
 
 
 
 
 
134
  except Exception as e:
135
- logger.error(f"Error processing Excel file: {e}")
136
- return [{"error": f"Error processing Excel file: {str(e)}"}]
137
 
138
  def csv_to_json(file_path: str) -> List[Dict]:
139
- cache_key = f"csv_{file_hash(file_path)}"
140
- if cache_key in cache:
141
- return cache[cache_key]
142
-
143
  try:
144
- table = pc.read_csv(file_path, parse_options=pc.ParseOptions(invalid_row_handler=lambda x: "skip"))
145
- df = table.to_pandas(use_threads=True, split_blocks=True)
146
- content = df.where(pd.notnull(df), "").astype(str).values.tolist()
147
- result = [{
 
 
 
 
 
 
 
 
 
 
148
  "filename": os.path.basename(file_path),
149
- "rows": content,
150
  "type": "csv"
151
  }]
152
- cache[cache_key] = result
153
- return result
154
  except Exception as e:
155
- logger.error(f"Error processing CSV file: {e}")
156
- return [{"error": f"Error processing CSV file: {str(e)}"}]
157
 
158
- def process_file(file_path: str, file_type: str) -> List[Dict]:
 
 
159
  try:
160
  if file_type == "pdf":
161
  text = extract_all_pages(file_path)
@@ -172,248 +182,315 @@ def process_file(file_path: str, file_type: str) -> List[Dict]:
172
  else:
173
  return [{"error": f"Unsupported file type: {file_type}"}]
174
  except Exception as e:
175
- logger.error("Error processing %s: %s", os.path.basename(file_path), e)
176
  return [{"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"}]
177
 
178
- def tokenize_and_chunk(text: str, max_tokens: int = 800) -> List[str]:
179
- cache_key = f"tokens_{hashlib.md5(text.encode()).hexdigest()}"
180
- if cache_key in cache:
181
- return cache[cache_key]
182
-
183
  tokens = tokenizer.encode(text, add_special_tokens=False)
184
- chunks = []
185
- for i in range(0, len(tokens), max_tokens):
186
- chunk_tokens = tokens[i:i + max_tokens]
187
- chunks.append(tokenizer.decode(chunk_tokens, skip_special_tokens=True))
188
- cache[cache_key] = chunks
189
- return chunks
190
 
191
  def log_system_usage(tag=""):
 
192
  try:
193
- cpu = psutil.cpu_percent(interval=0.1)
194
  mem = psutil.virtual_memory()
195
- logger.info("[%s] CPU: %.1f%% | RAM: %dMB / %dMB", tag, cpu, mem.used // (1024**2), mem.total // (1024**2))
196
- result = subprocess.run(
197
- ["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"],
198
- capture_output=True, text=True
199
- )
200
- if result.returncode == 0:
201
- used, total, util = result.stdout.strip().split(", ")
202
- logger.info("[%s] GPU: %sMB / %sMB | Utilization: %s%%", tag, used, total, util)
 
 
 
 
 
 
 
203
  except Exception as e:
204
- logger.error("[%s] GPU/CPU monitor failed: %s", tag, e)
205
 
206
  def clean_response(text: str) -> str:
207
- text = sanitize_utf8(text)
208
- text = re.sub(r"\[.*?\]|\bNone\b|To analyze the patient record excerpt.*?medications\.|Since the previous attempts.*?\.|I need to.*?medications\.|Retrieving tools.*?\.", "", text, flags=re.DOTALL)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  diagnoses = []
210
- lines = text.splitlines()
211
- in_diagnoses_section = False
212
- for line in lines:
213
  line = line.strip()
214
  if not line:
215
  continue
216
- if re.match(r"###\s*Missed Diagnoses", line):
217
- in_diagnoses_section = True
218
- continue
219
- if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line):
220
- in_diagnoses_section = False
221
- continue
222
- if in_diagnoses_section and re.match(r"-\s*.+", line):
223
- diagnosis = re.sub(r"^\-\s*", "", line).strip()
224
- if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE):
225
- diagnoses.append(diagnosis)
226
- text = " ".join(diagnoses)
227
- text = re.sub(r"\s+", " ", text).strip()
228
- text = re.sub(r"[^\w\s\.\,\(\)\-]", "", text)
229
- return text if text else ""
230
-
231
- def summarize_findings(combined_response: str) -> str:
232
- chunks = combined_response.split("--- Analysis for Chunk")
233
- diagnoses = []
234
- for chunk in chunks:
235
- chunk = chunk.strip()
236
- if not chunk or "No oversights identified" in chunk:
237
  continue
238
- lines = chunk.splitlines()
239
- in_diagnoses_section = False
240
- for line in lines:
241
- line = line.strip()
242
- if not line:
243
- continue
244
- if re.match(r"###\s*Missed Diagnoses", line):
245
- in_diagnoses_section = True
246
- continue
247
- if re.match(r"###\s*(Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", line):
248
- in_diagnoses_section = False
249
- continue
250
- if in_diagnoses_section and re.match(r"-\s*.+", line):
251
- diagnosis = re.sub(r"^\-\s*", "", line).strip()
252
- if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE):
253
  diagnoses.append(diagnosis)
254
-
 
 
 
 
255
  seen = set()
256
  unique_diagnoses = [d for d in diagnoses if not (d in seen or seen.add(d))]
257
 
258
- if not unique_diagnoses:
259
- return "No missed diagnoses were identified in the provided records."
260
-
261
  summary = "Missed diagnoses include " + ", ".join(unique_diagnoses[:-1])
262
- if len(unique_diagnoses) > 1:
263
- summary += f", and {unique_diagnoses[-1]}"
264
- elif len(unique_diagnoses) == 1:
265
- summary = "Missed diagnoses include " + unique_diagnoses[0]
266
  summary += ", all of which require urgent clinical review to prevent potential adverse outcomes."
267
 
268
- return summary.strip()
269
 
 
270
  def init_agent():
 
271
  logger.info("Initializing model...")
272
  log_system_usage("Before Load")
 
 
273
  default_tool_path = os.path.abspath("data/new_tool.json")
274
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
275
  if not os.path.exists(target_tool_path):
276
  shutil.copy(default_tool_path, target_tool_path)
277
 
278
- llm = LLM(
279
- model="mims-harvard/TxAgent-T1-Llama-3.1-8B",
280
- gpu_memory_utilization=0.8,
281
- max_model_len=2048,
282
- tensor_parallel_size=1,
283
- )
284
- sampling_params = SamplingParams(
285
- temperature=0.2,
286
- max_tokens=256,
287
- stop=["</s>", "[INST]"],
288
  )
 
 
289
  log_system_usage("After Load")
290
  logger.info("Agent Ready")
291
- return llm, sampling_params
292
 
293
- async def create_ui(llm, sampling_params):
294
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
295
- gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
296
- chatbot = gr.Chatbot(label="Detailed Analysis", height=600, type="messages")
297
- final_summary = gr.Markdown(label="Summary of Missed Diagnoses")
298
- file_upload = gr.File(file_types=["pdf", "csv", "xls", "xlsx"], file_count="multiple")
299
- msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
300
- send_btn = gr.Button("Analyze", variant="primary")
301
- download_output = gr.File(label="Download Full Report")
302
- progress_bar = gr.Progress()
303
-
304
- prompt_template = """
305
  Analyze the patient record excerpt for missed diagnoses only. Provide a concise, evidence-based summary as a single paragraph without headings or bullet points. Include specific clinical findings (e.g., 'elevated blood pressure (160/95) on page 10'), their potential implications (e.g., 'may indicate untreated hypertension'), and a recommendation for urgent review. Do not include other oversight categories like medication conflicts. If no missed diagnoses are found, state 'No missed diagnoses identified' in a single sentence.
306
  Patient Record Excerpt (Chunk {0} of {1}):
307
  {chunk}
308
  """
309
 
310
- def log_response_partial(text: str):
311
- response_logger.info(text)
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
- async def analyze(message: str, history: List[dict], files: List, progress=gr.Progress()):
 
314
  history.append({"role": "user", "content": message})
315
  yield history, None, ""
316
 
 
317
  extracted = []
318
  file_hash_value = ""
319
 
320
  if files:
321
- with ProcessPoolExecutor(max_workers=4) as executor:
322
- futures = []
323
- for f in files:
324
- file_type = f.name.split(".")[-1].lower()
325
- futures.append(executor.submit(
326
- process_file,
327
- f.name,
328
- file_type
329
- ))
330
 
331
- for future in as_completed(futures):
332
- try:
333
- extracted.extend(future.result())
334
- except Exception as e:
335
- logger.error(f"File processing error: {e}")
336
- extracted.append({"error": f"Error processing file: {str(e)}"})
337
 
338
  file_hash_value = file_hash(files[0].name) if files else ""
339
  history.append({"role": "assistant", "content": "✅ File processing complete"})
340
  yield history, None, ""
341
 
342
- text_content = "\n".join(json.dumps(item) for item in extracted)
 
 
 
 
 
343
  chunks = tokenize_and_chunk(text_content)
 
 
 
344
  combined_response = ""
345
- batch_size = 1
346
-
347
  try:
348
- for batch_idx in range(0, len(chunks), batch_size):
349
- batch_chunks = chunks[batch_idx:batch_idx + batch_size]
 
 
 
350
  batch_prompts = [
351
- prompt_template.format(
352
  batch_idx + i + 1,
353
  len(chunks),
354
- chunk=chunk[:800]
355
  )
356
  for i, chunk in enumerate(batch_chunks)
357
  ]
358
 
359
- progress((batch_idx) / len(chunks),
360
- desc=f"Analyzing batch {(batch_idx // batch_size) + 1}/{(len(chunks) + batch_size - 1) // batch_size}")
361
 
362
- with torch.no_grad():
363
- for prompt in batch_prompts:
 
 
 
 
 
 
 
 
 
 
364
  chunk_response = ""
365
- current_response = ""
366
- stream = llm.generate([prompt], sampling_params, use_tqdm=False)
367
- for output in stream:
368
- for request_output in output:
369
- new_text = request_output.outputs[0].text[len(current_response):]
370
- if new_text:
371
- current_response += new_text
372
- cleaned = clean_response(current_response)
373
- if cleaned and cleaned != chunk_response:
374
- chunk_response = cleaned
375
- history[-1] = {"role": "assistant", "content": chunk_response}
376
- threading.Thread(target=log_response_partial, args=(chunk_response,)).start()
377
- yield history, None, ""
378
- await asyncio.sleep(0.01)
379
-
380
- if chunk_response:
381
- combined_response += f"--- Analysis for Chunk {batch_idx + 1} ---\n{chunk_response}\n"
382
-
383
- torch.cuda.empty_cache()
384
- gc.collect()
385
-
 
 
 
 
 
 
 
386
  summary = summarize_findings(combined_response)
387
- report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None
388
- if report_path:
389
- with open(report_path, "w", encoding="utf-8") as f:
390
- f.write(combined_response + "\n\n" + summary)
391
- threading.Thread(target=log_response_partial, args=(summary,)).start()
392
 
393
- yield history, report_path if report_path and os.path.exists(report_path) else None, summary
 
 
 
 
 
 
 
 
 
394
 
395
  except Exception as e:
396
- logger.error("Analysis error: %s", e)
397
  history.append({"role": "assistant", "content": f"❌ Error occurred: {str(e)}"})
398
- threading.Thread(target=log_response_partial, args=(f"Error: {str(e)}",)).start()
399
  yield history, None, f"Error occurred during analysis: {str(e)}"
400
-
401
- send_btn.click(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary])
402
- msg_input.submit(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  return demo
404
 
405
  if __name__ == "__main__":
406
  try:
407
- logger.info("Launching app...")
408
- llm, sampling_params = init_agent()
409
- demo = asyncio.run(create_ui(llm, sampling_params))
410
- demo.queue(api_open=False).launch(
 
 
 
 
411
  server_name="0.0.0.0",
412
  server_port=7860,
413
  show_error=True,
414
  allowed_paths=[report_dir],
415
  share=False
416
  )
 
 
 
417
  finally:
418
  if torch.distributed.is_initialized():
419
  torch.distributed.destroy_process_group()
 
5
  import json
6
  import gradio as gr
7
  from typing import List, Dict, Optional, Generator
8
+ from concurrent.futures import ThreadPoolExecutor, as_completed
9
  import hashlib
10
  import shutil
11
  import re
 
17
  from diskcache import Cache
18
  import time
19
  from transformers import AutoTokenizer
20
+ from functools import lru_cache
21
+ import numpy as np
 
 
 
 
22
 
23
  # Configure logging
24
  logging.basicConfig(level=logging.INFO)
25
  logger = logging.getLogger(__name__)
26
 
27
+ # Constants
28
+ MAX_TOKENS = 1800
29
+ BATCH_SIZE = 2
30
+ MAX_WORKERS = 4
31
+ CHUNK_SIZE = 10 # For PDF processing
 
 
32
 
33
+ # Persistent directory setup
34
  persistent_dir = "/data/hf_cache"
35
  os.makedirs(persistent_dir, exist_ok=True)
36
 
 
43
  for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
44
  os.makedirs(directory, exist_ok=True)
45
 
46
+ os.environ.update({
47
+ "HF_HOME": model_cache_dir,
48
+ "TRANSFORMERS_CACHE": model_cache_dir,
49
+ "VLLM_CACHE_DIR": vllm_cache_dir,
50
+ "TOKENIZERS_PARALLELISM": "false",
51
+ "CUDA_LAUNCH_BLOCKING": "1"
52
+ })
 
 
 
 
53
 
54
  # Initialize cache with 10GB limit
55
  cache = Cache(file_cache_dir, size_limit=10 * 1024**3)
56
 
57
+ # Initialize tokenizer for precise chunking (with caching)
58
+ @lru_cache(maxsize=1)
59
+ def get_tokenizer():
60
+ return AutoTokenizer.from_pretrained("mims-harvard/TxAgent-T1-Llama-3.1-8B")
61
 
62
  def sanitize_utf8(text: str) -> str:
63
+ """Optimized UTF-8 sanitization"""
64
  return text.encode("utf-8", "ignore").decode("utf-8")
65
 
66
  def file_hash(path: str) -> str:
67
+ """Optimized file hashing with buffer reading"""
68
+ hash_md5 = hashlib.md5()
69
  with open(path, "rb") as f:
70
+ for chunk in iter(lambda: f.read(4096), b""):
71
+ hash_md5.update(chunk)
72
+ return hash_md5.hexdigest()
73
 
74
+ def extract_pdf_page(page) -> str:
75
+ """Optimized single page extraction"""
76
+ try:
77
+ text = page.extract_text() or ""
78
+ return f"=== Page {page.page_number} ===\n{text.strip()}"
79
+ except Exception as e:
80
+ logger.warning(f"Error extracting page {page.page_number}: {str(e)}")
81
+ return ""
82
 
83
+ def extract_all_pages(file_path: str, progress_callback=None) -> str:
84
+ """Optimized PDF extraction with memory management"""
85
  try:
86
  with pdfplumber.open(file_path) as pdf:
87
  total_pages = len(pdf.pages)
88
  if total_pages == 0:
89
  return ""
90
 
91
+ # Process in chunks with memory cleanup
92
+ results = []
93
+ for chunk_start in range(0, total_pages, CHUNK_SIZE):
94
+ chunk_end = min(chunk_start + CHUNK_SIZE, total_pages)
95
+
 
 
96
  with pdfplumber.open(file_path) as pdf:
97
+ with ThreadPoolExecutor(max_workers=min(CHUNK_SIZE, 4)) as executor:
98
+ futures = [executor.submit(extract_pdf_page, pdf.pages[i])
99
+ for i in range(chunk_start, chunk_end)]
100
+
101
+ for future in as_completed(futures):
102
+ results.append(future.result())
103
+
104
+ if progress_callback:
105
+ progress_callback(min(chunk_end, total_pages), total_pages)
106
+
107
+ # Explicit cleanup
108
+ del pdf
109
+ gc.collect()
110
+
111
+ return "\n\n".join(filter(None, results))
 
 
 
112
  except Exception as e:
113
+ logger.error(f"PDF processing error: {e}")
114
  return f"PDF processing error: {str(e)}"
115
 
116
  def excel_to_json(file_path: str) -> List[Dict]:
117
+ """Optimized Excel processing with chunking"""
 
 
 
118
  try:
119
+ # Try fastest engines first
120
+ for engine in ['openpyxl', 'xlrd']:
121
+ try:
122
+ df = pd.read_excel(
123
+ file_path,
124
+ engine=engine,
125
+ header=None,
126
+ dtype=str,
127
+ na_filter=False
128
+ )
129
+ return [{
130
+ "filename": os.path.basename(file_path),
131
+ "rows": df.values.tolist(),
132
+ "type": "excel"
133
+ }]
134
+ except Exception:
135
+ continue
136
+ raise Exception("No suitable Excel engine found")
137
  except Exception as e:
138
+ logger.error(f"Excel processing error: {e}")
139
+ return [{"error": f"Excel processing error: {str(e)}"}]
140
 
141
  def csv_to_json(file_path: str) -> List[Dict]:
142
+ """Optimized CSV processing with chunking"""
 
 
 
143
  try:
144
+ chunks = []
145
+ for chunk in pd.read_csv(
146
+ file_path,
147
+ header=None,
148
+ dtype=str,
149
+ encoding_errors='replace',
150
+ on_bad_lines='skip',
151
+ chunksize=10000,
152
+ na_filter=False
153
+ ):
154
+ chunks.append(chunk)
155
+
156
+ df = pd.concat(chunks) if chunks else pd.DataFrame()
157
+ return [{
158
  "filename": os.path.basename(file_path),
159
+ "rows": df.values.tolist(),
160
  "type": "csv"
161
  }]
 
 
162
  except Exception as e:
163
+ logger.error(f"CSV processing error: {e}")
164
+ return [{"error": f"CSV processing error: {str(e)}"}]
165
 
166
+ @lru_cache(maxsize=100)
167
+ def process_file_cached(file_path: str, file_type: str) -> List[Dict]:
168
+ """Cached file processing with memory optimization"""
169
  try:
170
  if file_type == "pdf":
171
  text = extract_all_pages(file_path)
 
182
  else:
183
  return [{"error": f"Unsupported file type: {file_type}"}]
184
  except Exception as e:
185
+ logger.error(f"Error processing {os.path.basename(file_path)}: {e}")
186
  return [{"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"}]
187
 
188
+ def tokenize_and_chunk(text: str, max_tokens: int = MAX_TOKENS) -> List[str]:
189
+ """Optimized tokenization and chunking"""
190
+ tokenizer = get_tokenizer()
 
 
191
  tokens = tokenizer.encode(text, add_special_tokens=False)
192
+ return [
193
+ tokenizer.decode(tokens[i:i + max_tokens])
194
+ for i in range(0, len(tokens), max_tokens)
195
+ ]
 
 
196
 
197
  def log_system_usage(tag=""):
198
+ """Optimized system monitoring"""
199
  try:
200
+ cpu = psutil.cpu_percent(interval=0.5)
201
  mem = psutil.virtual_memory()
202
+ logger.info(f"[{tag}] CPU: {cpu:.1f}% | RAM: {mem.used // (1024**2)}MB / {mem.total // (1024**2)}MB")
203
+
204
+ # GPU monitoring with timeout
205
+ try:
206
+ result = subprocess.run(
207
+ ["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"],
208
+ capture_output=True,
209
+ text=True,
210
+ timeout=2
211
+ )
212
+ if result.returncode == 0:
213
+ used, total, util = result.stdout.strip().split(", ")
214
+ logger.info(f"[{tag}] GPU: {used}MB / {total}MB | Utilization: {util}%")
215
+ except subprocess.TimeoutExpired:
216
+ logger.warning(f"[{tag}] GPU monitoring timed out")
217
  except Exception as e:
218
+ logger.error(f"[{tag}] Monitor failed: {e}")
219
 
220
  def clean_response(text: str) -> str:
221
+ """Optimized response cleaning with regex compilation"""
222
+ if not text:
223
+ return ""
224
+
225
+ # Pre-compiled regex patterns
226
+ patterns = [
227
+ (re.compile(r"\[.*?\]|\bNone\b"), ""),
228
+ (re.compile(r"To analyze the patient record excerpt.*?medications\."), ""),
229
+ (re.compile(r"Since the previous attempts.*?\."), ""),
230
+ (re.compile(r"I need to.*?medications\."), ""),
231
+ (re.compile(r"Retrieving tools.*?\."), ""),
232
+ (re.compile(r"\s+"), " "),
233
+ (re.compile(r"[^\w\s\.\,\(\)\-]"), "")
234
+ ]
235
+
236
+ for pattern, repl in patterns:
237
+ text = pattern.sub(repl, text)
238
+
239
+ return text.strip()
240
+
241
+ def summarize_findings(combined_response: str) -> str:
242
+ """Optimized findings summarization"""
243
+ if not combined_response:
244
+ return "No missed diagnoses were identified in the provided records."
245
+
246
+ # Pre-compiled regex patterns
247
+ diagnosis_pattern = re.compile(r"-\s*(.+)$")
248
+ section_pattern = re.compile(r"###\s*(Missed Diagnoses|Medication Conflicts|Incomplete Assessments|Urgent Follow-up)")
249
+ no_issues_pattern = re.compile(r"No issues identified", re.IGNORECASE)
250
+
251
  diagnoses = []
252
+ current_section = None
253
+
254
+ for line in combined_response.splitlines():
255
  line = line.strip()
256
  if not line:
257
  continue
258
+
259
+ # Check section headers
260
+ section_match = section_pattern.match(line)
261
+ if section_match:
262
+ current_section = "diagnoses" if section_match.group(1) == "Missed Diagnoses" else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  continue
264
+
265
+ # Only process diagnosis lines in the correct section
266
+ if current_section == "diagnoses":
267
+ diagnosis_match = diagnosis_pattern.match(line)
268
+ if diagnosis_match and not no_issues_pattern.search(line):
269
+ diagnosis = diagnosis_match.group(1).strip()
270
+ if diagnosis:
 
 
 
 
 
 
 
 
271
  diagnoses.append(diagnosis)
272
+
273
+ if not diagnoses:
274
+ return "No missed diagnoses were identified in the provided records."
275
+
276
+ # Remove duplicates while preserving order
277
  seen = set()
278
  unique_diagnoses = [d for d in diagnoses if not (d in seen or seen.add(d))]
279
 
280
+ if len(unique_diagnoses) == 1:
281
+ return f"Missed diagnoses include {unique_diagnoses[0]}"
282
+
283
  summary = "Missed diagnoses include " + ", ".join(unique_diagnoses[:-1])
284
+ summary += f", and {unique_diagnoses[-1]}" if len(unique_diagnoses) > 1 else ""
 
 
 
285
  summary += ", all of which require urgent clinical review to prevent potential adverse outcomes."
286
 
287
+ return summary
288
 
289
+ @lru_cache(maxsize=1)
290
  def init_agent():
291
+ """Cached agent initialization with memory optimization"""
292
  logger.info("Initializing model...")
293
  log_system_usage("Before Load")
294
+
295
+ # Tool setup
296
  default_tool_path = os.path.abspath("data/new_tool.json")
297
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
298
  if not os.path.exists(target_tool_path):
299
  shutil.copy(default_tool_path, target_tool_path)
300
 
301
+ # Initialize with optimized settings
302
+ agent = TxAgent(
303
+ model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
304
+ rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
305
+ tool_files_dict={"new_tool": target_tool_path},
306
+ force_finish=True,
307
+ enable_checker=False,
308
+ step_rag_num=4,
309
+ seed=100,
310
+ additional_default_tools=[],
311
  )
312
+ agent.init_model()
313
+
314
  log_system_usage("After Load")
315
  logger.info("Agent Ready")
316
+ return agent
317
 
318
+ def create_ui(agent):
319
+ """Optimized UI creation with pre-compiled templates"""
320
+ PROMPT_TEMPLATE = """
 
 
 
 
 
 
 
 
 
321
  Analyze the patient record excerpt for missed diagnoses only. Provide a concise, evidence-based summary as a single paragraph without headings or bullet points. Include specific clinical findings (e.g., 'elevated blood pressure (160/95) on page 10'), their potential implications (e.g., 'may indicate untreated hypertension'), and a recommendation for urgent review. Do not include other oversight categories like medication conflicts. If no missed diagnoses are found, state 'No missed diagnoses identified' in a single sentence.
322
  Patient Record Excerpt (Chunk {0} of {1}):
323
  {chunk}
324
  """
325
 
326
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
327
+ gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
328
+
329
+ with gr.Row():
330
+ with gr.Column(scale=3):
331
+ chatbot = gr.Chatbot(label="Detailed Analysis", height=600)
332
+ msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
333
+ send_btn = gr.Button("Analyze", variant="primary")
334
+ file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
335
+
336
+ with gr.Column(scale=1):
337
+ final_summary = gr.Markdown(label="Summary of Missed Diagnoses")
338
+ download_output = gr.File(label="Download Full Report")
339
+ progress_bar = gr.Progress()
340
 
341
+ def analyze(message: str, history: List[dict], files: List, progress=gr.Progress()):
342
+ """Optimized analysis pipeline with memory management"""
343
  history.append({"role": "user", "content": message})
344
  yield history, None, ""
345
 
346
+ # Process files with caching
347
  extracted = []
348
  file_hash_value = ""
349
 
350
  if files:
351
+ # Use cached results when possible
352
+ for f in files:
353
+ file_type = f.name.split(".")[-1].lower()
354
+ cache_key = f"{file_hash(f.name)}_{file_type}"
 
 
 
 
 
355
 
356
+ if cache_key in cache:
357
+ extracted.extend(cache[cache_key])
358
+ else:
359
+ result = process_file_cached(f.name, file_type)
360
+ cache[cache_key] = result
361
+ extracted.extend(result)
362
 
363
  file_hash_value = file_hash(files[0].name) if files else ""
364
  history.append({"role": "assistant", "content": "✅ File processing complete"})
365
  yield history, None, ""
366
 
367
+ # Convert to text with memory efficiency
368
+ text_content = "\n".join(json.dumps(item, ensure_ascii=False) for item in extracted)
369
+ del extracted
370
+ gc.collect()
371
+
372
+ # Tokenize and chunk
373
  chunks = tokenize_and_chunk(text_content)
374
+ del text_content
375
+ gc.collect()
376
+
377
  combined_response = ""
378
+ report_path = None
379
+
380
  try:
381
+ # Process in optimized batches
382
+ for batch_idx in range(0, len(chunks), BATCH_SIZE):
383
+ batch_chunks = chunks[batch_idx:batch_idx + BATCH_SIZE]
384
+
385
+ # Prepare prompts
386
  batch_prompts = [
387
+ PROMPT_TEMPLATE.format(
388
  batch_idx + i + 1,
389
  len(chunks),
390
+ chunk=chunk[:1800] # Conservative size
391
  )
392
  for i, chunk in enumerate(batch_chunks)
393
  ]
394
 
395
+ progress(batch_idx / len(chunks),
396
+ desc=f"Analyzing batch {(batch_idx // BATCH_SIZE) + 1}/{(len(chunks) + BATCH_SIZE - 1) // BATCH_SIZE}")
397
 
398
+ # Process batch
399
+ with ThreadPoolExecutor(max_workers=min(BATCH_SIZE, MAX_WORKERS)) as executor:
400
+ futures = {
401
+ executor.submit(
402
+ agent.run_gradio_chat,
403
+ prompt, [], 0.2, 512, 2048, False, []
404
+ ): idx
405
+ for idx, prompt in enumerate(batch_prompts)
406
+ }
407
+
408
+ for future in as_completed(futures):
409
+ chunk_idx = futures[future]
410
  chunk_response = ""
411
+
412
+ try:
413
+ for chunk_output in future.result():
414
+ if isinstance(chunk_output, (list, str)):
415
+ content = ""
416
+ if isinstance(chunk_output, list):
417
+ content = " ".join(
418
+ clean_response(m.content)
419
+ for m in chunk_output
420
+ if hasattr(m, 'content') and m.content
421
+ )
422
+ elif isinstance(chunk_output, str):
423
+ content = clean_response(chunk_output)
424
+
425
+ if content:
426
+ chunk_response += content + " "
427
+
428
+ if chunk_response:
429
+ combined_response += f"--- Analysis for Chunk {batch_idx + chunk_idx + 1} ---\n{chunk_response.strip()}\n"
430
+ history[-1] = {"role": "assistant", "content": combined_response.strip()}
431
+ yield history, None, ""
432
+ finally:
433
+ # Ensure cleanup
434
+ del future
435
+ torch.cuda.empty_cache()
436
+ gc.collect()
437
+
438
+ # Generate final outputs
439
  summary = summarize_findings(combined_response)
 
 
 
 
 
440
 
441
+ if file_hash_value:
442
+ report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt")
443
+ try:
444
+ with open(report_path, "w", encoding="utf-8") as f:
445
+ f.write(combined_response + "\n\n" + summary)
446
+ except Exception as e:
447
+ logger.error(f"Report save failed: {e}")
448
+ report_path = None
449
+
450
+ yield history, report_path, summary
451
 
452
  except Exception as e:
453
+ logger.error(f"Analysis error: {e}")
454
  history.append({"role": "assistant", "content": f"❌ Error occurred: {str(e)}"})
 
455
  yield history, None, f"Error occurred during analysis: {str(e)}"
456
+ finally:
457
+ # Final cleanup
458
+ torch.cuda.empty_cache()
459
+ gc.collect()
460
+
461
+ # Event handlers
462
+ send_btn.click(
463
+ analyze,
464
+ inputs=[msg_input, gr.State([]), file_upload],
465
+ outputs=[chatbot, download_output, final_summary]
466
+ )
467
+ msg_input.submit(
468
+ analyze,
469
+ inputs=[msg_input, gr.State([]), file_upload],
470
+ outputs=[chatbot, download_output, final_summary]
471
+ )
472
+
473
  return demo
474
 
475
  if __name__ == "__main__":
476
  try:
477
+ logger.info("Launching optimized app...")
478
+ agent = init_agent()
479
+ demo = create_ui(agent)
480
+ demo.queue(
481
+ api_open=False,
482
+ max_size=20,
483
+ concurrency_count=4
484
+ ).launch(
485
  server_name="0.0.0.0",
486
  server_port=7860,
487
  show_error=True,
488
  allowed_paths=[report_dir],
489
  share=False
490
  )
491
+ except Exception as e:
492
+ logger.error(f"Fatal error: {e}")
493
+ raise
494
  finally:
495
  if torch.distributed.is_initialized():
496
  torch.distributed.destroy_process_group()