Ali2206 commited on
Commit
be8f191
·
verified ·
1 Parent(s): 92c6be9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -35
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import sys
2
  import os
3
- import polars as pl
4
  import pdfplumber
5
  import json
6
  import gradio as gr
7
- from typing import List
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
  import hashlib
10
  import shutil
@@ -14,10 +14,12 @@ import subprocess
14
  import logging
15
  import torch
16
  import gc
17
- from cachetools import LFUCache
18
  import time
19
- import asyncio
20
- import aiofiles
 
 
21
 
22
  # Configure logging
23
  logging.basicConfig(level=logging.INFO)
@@ -49,7 +51,7 @@ sys.path.insert(0, src_path)
49
  from txagent.txagent import TxAgent
50
 
51
  # Initialize cache with 10GB limit
52
- cache = LFUCache(maxsize=1000) # Adjust maxsize based on memory constraints
53
 
54
  def sanitize_utf8(text: str) -> str:
55
  return text.encode("utf-8", "ignore").decode("utf-8")
@@ -93,33 +95,79 @@ def extract_all_pages(file_path: str, progress_callback=None) -> str:
93
  logger.error("PDF processing error: %s", e)
94
  return f"PDF processing error: {str(e)}"
95
 
96
- async def convert_file_to_json(file_path: str, file_type: str, progress_callback=None) -> str:
 
97
  try:
98
- cache_key = f"{os.path.basename(file_path)}_{file_type}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  if cache_key in cache:
100
  return cache[cache_key]
101
 
102
  if file_type == "pdf":
103
  text = extract_all_pages(file_path, progress_callback)
104
  result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
105
- elif file_type == "csv":
106
- df = pl.read_csv(file_path, encoding="utf8-lossy", has_header=False, infer_schema_length=0)
107
- content = df.fill_null("").to_dicts()
108
- result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
109
- elif file_type in ["xls", "xlsx"]:
110
- async def stream_excel_to_json():
111
- df = pl.read_excel(file_path, read_csv_options={"infer_schema_length": 0})
112
- chunk_size = 1000
113
- rows = []
114
- for i in range(0, len(df), chunk_size):
115
- chunk = df[i:i + chunk_size].fill_null("").to_dicts()
116
- rows.extend(chunk)
117
- if progress_callback:
118
- progress_callback(min(i + chunk_size, len(df)), len(df))
119
- await asyncio.sleep(0) # Yield control to event loop
120
- return json.dumps({"filename": os.path.basename(file_path), "rows": rows})
121
-
122
- result = await stream_excel_to_json()
 
 
 
 
 
123
  else:
124
  result = json.dumps({"error": f"Unsupported file type: {file_type}"})
125
 
@@ -146,7 +194,9 @@ def log_system_usage(tag=""):
146
 
147
  def clean_response(text: str) -> str:
148
  text = sanitize_utf8(text)
 
149
  text = re.sub(r"\[.*?\]|\bNone\b|To analyze the patient record excerpt.*?medications\.|Since the previous attempts.*?\.|I need to.*?medications\.|Retrieving tools.*?\.", "", text, flags=re.DOTALL)
 
150
  diagnoses = []
151
  lines = text.splitlines()
152
  in_diagnoses_section = False
@@ -164,18 +214,22 @@ def clean_response(text: str) -> str:
164
  diagnosis = re.sub(r"^\-\s*", "", line).strip()
165
  if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE):
166
  diagnoses.append(diagnosis)
 
167
  text = " ".join(diagnoses)
 
168
  text = re.sub(r"\s+", " ", text).strip()
169
  text = re.sub(r"[^\w\s\.\,\(\)\-]", "", text)
170
  return text if text else ""
171
 
172
  def summarize_findings(combined_response: str) -> str:
 
173
  chunks = combined_response.split("--- Analysis for Chunk")
174
  diagnoses = []
175
  for chunk in chunks:
176
  chunk = chunk.strip()
177
  if not chunk or "No oversights identified" in chunk:
178
  continue
 
179
  lines = chunk.splitlines()
180
  in_diagnoses_section = False
181
  for line in lines:
@@ -192,16 +246,22 @@ def summarize_findings(combined_response: str) -> str:
192
  diagnosis = re.sub(r"^\-\s*", "", line).strip()
193
  if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE):
194
  diagnoses.append(diagnosis)
 
 
195
  seen = set()
196
  unique_diagnoses = [d for d in diagnoses if not (d in seen or seen.add(d))]
 
197
  if not unique_diagnoses:
198
  return "No missed diagnoses were identified in the provided records."
 
 
199
  summary = "Missed diagnoses include " + ", ".join(unique_diagnoses[:-1])
200
  if len(unique_diagnoses) > 1:
201
  summary += f", and {unique_diagnoses[-1]}"
202
  elif len(unique_diagnoses) == 1:
203
  summary = "Missed diagnoses include " + unique_diagnoses[0]
204
  summary += ", all of which require urgent clinical review to prevent potential adverse outcomes."
 
205
  return summary.strip()
206
 
207
  def init_agent():
@@ -227,7 +287,7 @@ def init_agent():
227
  logger.info("Agent Ready")
228
  return agent
229
 
230
- async def create_ui(agent):
231
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
232
  gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
233
  chatbot = gr.Chatbot(label="Detailed Analysis", height=600, type="messages")
@@ -244,7 +304,7 @@ Patient Record Excerpt (Chunk {0} of {1}):
244
  {chunk}
245
  """
246
 
247
- async def analyze(message: str, history: List[dict], files: List, progress=gr.Progress()):
248
  history.append({"role": "user", "content": message})
249
  yield history, None, ""
250
 
@@ -255,10 +315,11 @@ Patient Record Excerpt (Chunk {0} of {1}):
255
  progress(current / total, desc=f"Extracting text... Page {current}/{total}")
256
  return history, None, ""
257
 
258
- tasks = [convert_file_to_json(f.name, f.name.split(".")[-1].lower(), update_extraction_progress) for f in files]
259
- results = await asyncio.gather(*tasks, return_exceptions=True)
260
- extracted = "\n".join([sanitize_utf8(r) for r in results if isinstance(r, str)])
261
- file_hash_value = file_hash(files[0].name) if files else ""
 
262
 
263
  history.append({"role": "assistant", "content": "✅ Text extraction complete."})
264
  yield history, None, ""
@@ -313,8 +374,8 @@ Patient Record Excerpt (Chunk {0} of {1}):
313
  summary = summarize_findings(combined_response)
314
  report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None
315
  if report_path:
316
- async with aiofiles.open(report_path, "w", encoding="utf-8") as f:
317
- await f.write(combined_response + "\n\n" + summary)
318
  yield history, report_path if report_path and os.path.exists(report_path) else None, summary
319
 
320
  except Exception as e:
@@ -330,7 +391,7 @@ if __name__ == "__main__":
330
  try:
331
  logger.info("Launching app...")
332
  agent = init_agent()
333
- demo = asyncio.run(create_ui(agent))
334
  demo.queue(api_open=False).launch(
335
  server_name="0.0.0.0",
336
  server_port=7860,
 
1
  import sys
2
  import os
3
+ import pandas as pd
4
  import pdfplumber
5
  import json
6
  import gradio as gr
7
+ from typing import List, Tuple, Optional
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
  import hashlib
10
  import shutil
 
14
  import logging
15
  import torch
16
  import gc
17
+ from diskcache import Cache
18
  import time
19
+ import pyarrow as pa
20
+ import pyarrow.parquet as pq
21
+ import pyarrow.csv as pc
22
+ import numpy as np
23
 
24
  # Configure logging
25
  logging.basicConfig(level=logging.INFO)
 
51
  from txagent.txagent import TxAgent
52
 
53
  # Initialize cache with 10GB limit
54
+ cache = Cache(file_cache_dir, size_limit=10 * 1024**3)
55
 
56
  def sanitize_utf8(text: str) -> str:
57
  return text.encode("utf-8", "ignore").decode("utf-8")
 
95
  logger.error("PDF processing error: %s", e)
96
  return f"PDF processing error: {str(e)}"
97
 
98
+ def excel_to_arrow(file_path: str) -> pa.Table:
99
+ """Convert Excel file to Arrow table for faster processing"""
100
  try:
101
+ # First try with openpyxl (faster for xlsx)
102
+ try:
103
+ df = pd.read_excel(file_path, engine='openpyxl', header=None, dtype=str)
104
+ except Exception:
105
+ # Fall back to xlrd if needed
106
+ df = pd.read_excel(file_path, engine='xlrd', header=None, dtype=str)
107
+
108
+ # Convert to Arrow table
109
+ table = pa.Table.from_pandas(df.fillna(""))
110
+ return table
111
+ except Exception as e:
112
+ logger.error(f"Error converting Excel to Arrow: {e}")
113
+ raise
114
+
115
+ def csv_to_arrow(file_path: str) -> pa.Table:
116
+ """Convert CSV file to Arrow table for faster processing"""
117
+ try:
118
+ read_options = pc.ReadOptions(
119
+ encoding='utf-8',
120
+ invalid_row_handler=lambda x: None,
121
+ column_names=[str(i) for i in range(1000)] # Generous column count
122
+ )
123
+ convert_options = pc.ConvertOptions(
124
+ strings_can_be_null=True,
125
+ quoted_strings_can_be_null=True,
126
+ include_columns=None
127
+ )
128
+ table = pc.read_csv(
129
+ file_path,
130
+ read_options=read_options,
131
+ convert_options=convert_options
132
+ )
133
+ return table
134
+ except Exception as e:
135
+ logger.error(f"Error converting CSV to Arrow: {e}")
136
+ raise
137
+
138
+ def convert_file_to_json(file_path: str, file_type: str, progress_callback=None) -> str:
139
+ try:
140
+ file_h = file_hash(file_path)
141
+ cache_key = f"{file_h}_{file_type}"
142
  if cache_key in cache:
143
  return cache[cache_key]
144
 
145
  if file_type == "pdf":
146
  text = extract_all_pages(file_path, progress_callback)
147
  result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
148
+ elif file_type in ["csv", "xls", "xlsx"]:
149
+ # Use Arrow for tabular data processing
150
+ start_time = time.time()
151
+
152
+ if file_type == "csv":
153
+ table = csv_to_arrow(file_path)
154
+ else: # Excel files
155
+ table = excel_to_arrow(file_path)
156
+
157
+ # Convert to list of lists efficiently
158
+ content = []
159
+ for col in table.columns:
160
+ content.append([str(x) if x is not None else "" for x in col.to_pylist()])
161
+
162
+ # Transpose to get rows
163
+ rows = list(map(list, zip(*content)))
164
+
165
+ logger.info(f"Processed {len(rows)} rows in {time.time()-start_time:.2f}s")
166
+ result = json.dumps({
167
+ "filename": os.path.basename(file_path),
168
+ "rows": rows,
169
+ "arrow_processed": True # Flag for optimized processing
170
+ })
171
  else:
172
  result = json.dumps({"error": f"Unsupported file type: {file_type}"})
173
 
 
194
 
195
  def clean_response(text: str) -> str:
196
  text = sanitize_utf8(text)
197
+ # Remove unwanted patterns and tool call artifacts
198
  text = re.sub(r"\[.*?\]|\bNone\b|To analyze the patient record excerpt.*?medications\.|Since the previous attempts.*?\.|I need to.*?medications\.|Retrieving tools.*?\.", "", text, flags=re.DOTALL)
199
+ # Extract only missed diagnoses, ignoring other categories
200
  diagnoses = []
201
  lines = text.splitlines()
202
  in_diagnoses_section = False
 
214
  diagnosis = re.sub(r"^\-\s*", "", line).strip()
215
  if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE):
216
  diagnoses.append(diagnosis)
217
+ # Join diagnoses into a plain text paragraph
218
  text = " ".join(diagnoses)
219
+ # Clean up extra whitespace and punctuation
220
  text = re.sub(r"\s+", " ", text).strip()
221
  text = re.sub(r"[^\w\s\.\,\(\)\-]", "", text)
222
  return text if text else ""
223
 
224
  def summarize_findings(combined_response: str) -> str:
225
+ # Split response by chunk analyses
226
  chunks = combined_response.split("--- Analysis for Chunk")
227
  diagnoses = []
228
  for chunk in chunks:
229
  chunk = chunk.strip()
230
  if not chunk or "No oversights identified" in chunk:
231
  continue
232
+ # Extract missed diagnoses from chunk
233
  lines = chunk.splitlines()
234
  in_diagnoses_section = False
235
  for line in lines:
 
246
  diagnosis = re.sub(r"^\-\s*", "", line).strip()
247
  if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE):
248
  diagnoses.append(diagnosis)
249
+
250
+ # Remove duplicates while preserving order
251
  seen = set()
252
  unique_diagnoses = [d for d in diagnoses if not (d in seen or seen.add(d))]
253
+
254
  if not unique_diagnoses:
255
  return "No missed diagnoses were identified in the provided records."
256
+
257
+ # Combine into a single paragraph
258
  summary = "Missed diagnoses include " + ", ".join(unique_diagnoses[:-1])
259
  if len(unique_diagnoses) > 1:
260
  summary += f", and {unique_diagnoses[-1]}"
261
  elif len(unique_diagnoses) == 1:
262
  summary = "Missed diagnoses include " + unique_diagnoses[0]
263
  summary += ", all of which require urgent clinical review to prevent potential adverse outcomes."
264
+
265
  return summary.strip()
266
 
267
  def init_agent():
 
287
  logger.info("Agent Ready")
288
  return agent
289
 
290
+ def create_ui(agent):
291
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
292
  gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
293
  chatbot = gr.Chatbot(label="Detailed Analysis", height=600, type="messages")
 
304
  {chunk}
305
  """
306
 
307
+ def analyze(message: str, history: List[dict], files: List, progress=gr.Progress()):
308
  history.append({"role": "user", "content": message})
309
  yield history, None, ""
310
 
 
315
  progress(current / total, desc=f"Extracting text... Page {current}/{total}")
316
  return history, None, ""
317
 
318
+ with ThreadPoolExecutor(max_workers=6) as executor:
319
+ futures = [executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower(), update_extraction_progress) for f in files]
320
+ results = [sanitize_utf8(f.result()) for f in as_completed(futures)]
321
+ extracted = "\n".join(results)
322
+ file_hash_value = file_hash(files[0].name) if files else ""
323
 
324
  history.append({"role": "assistant", "content": "✅ Text extraction complete."})
325
  yield history, None, ""
 
374
  summary = summarize_findings(combined_response)
375
  report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None
376
  if report_path:
377
+ with open(report_path, "w", encoding="utf-8") as f:
378
+ f.write(combined_response + "\n\n" + summary)
379
  yield history, report_path if report_path and os.path.exists(report_path) else None, summary
380
 
381
  except Exception as e:
 
391
  try:
392
  logger.info("Launching app...")
393
  agent = init_agent()
394
+ demo = create_ui(agent)
395
  demo.queue(api_open=False).launch(
396
  server_name="0.0.0.0",
397
  server_port=7860,