Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import sys
|
2 |
import os
|
3 |
-
import
|
4 |
import pdfplumber
|
5 |
import json
|
6 |
import gradio as gr
|
@@ -14,8 +14,10 @@ import subprocess
|
|
14 |
import logging
|
15 |
import torch
|
16 |
import gc
|
17 |
-
from
|
18 |
import time
|
|
|
|
|
19 |
|
20 |
# Configure logging
|
21 |
logging.basicConfig(level=logging.INFO)
|
@@ -47,7 +49,7 @@ sys.path.insert(0, src_path)
|
|
47 |
from txagent.txagent import TxAgent
|
48 |
|
49 |
# Initialize cache with 10GB limit
|
50 |
-
cache =
|
51 |
|
52 |
def sanitize_utf8(text: str) -> str:
|
53 |
return text.encode("utf-8", "ignore").decode("utf-8")
|
@@ -91,10 +93,10 @@ def extract_all_pages(file_path: str, progress_callback=None) -> str:
|
|
91 |
logger.error("PDF processing error: %s", e)
|
92 |
return f"PDF processing error: {str(e)}"
|
93 |
|
94 |
-
def convert_file_to_json(file_path: str, file_type: str, progress_callback=None) -> str:
|
95 |
-
try
|
96 |
-
|
97 |
-
cache_key = f"{
|
98 |
if cache_key in cache:
|
99 |
return cache[cache_key]
|
100 |
|
@@ -102,17 +104,23 @@ def convert_file_to_json(file_path: str, file_type: str, progress_callback=None)
|
|
102 |
text = extract_all_pages(file_path, progress_callback)
|
103 |
result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
|
104 |
elif file_type == "csv":
|
105 |
-
df =
|
106 |
-
|
107 |
-
content = df.fillna("").astype(str).values.tolist()
|
108 |
result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
|
109 |
elif file_type in ["xls", "xlsx"]:
|
110 |
-
|
111 |
-
df =
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
else:
|
117 |
result = json.dumps({"error": f"Unsupported file type: {file_type}"})
|
118 |
|
@@ -139,9 +147,7 @@ def log_system_usage(tag=""):
|
|
139 |
|
140 |
def clean_response(text: str) -> str:
|
141 |
text = sanitize_utf8(text)
|
142 |
-
# Remove unwanted patterns and tool call artifacts
|
143 |
text = re.sub(r"\[.*?\]|\bNone\b|To analyze the patient record excerpt.*?medications\.|Since the previous attempts.*?\.|I need to.*?medications\.|Retrieving tools.*?\.", "", text, flags=re.DOTALL)
|
144 |
-
# Extract only missed diagnoses, ignoring other categories
|
145 |
diagnoses = []
|
146 |
lines = text.splitlines()
|
147 |
in_diagnoses_section = False
|
@@ -159,22 +165,18 @@ def clean_response(text: str) -> str:
|
|
159 |
diagnosis = re.sub(r"^\-\s*", "", line).strip()
|
160 |
if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE):
|
161 |
diagnoses.append(diagnosis)
|
162 |
-
# Join diagnoses into a plain text paragraph
|
163 |
text = " ".join(diagnoses)
|
164 |
-
# Clean up extra whitespace and punctuation
|
165 |
text = re.sub(r"\s+", " ", text).strip()
|
166 |
text = re.sub(r"[^\w\s\.\,\(\)\-]", "", text)
|
167 |
return text if text else ""
|
168 |
|
169 |
def summarize_findings(combined_response: str) -> str:
|
170 |
-
# Split response by chunk analyses
|
171 |
chunks = combined_response.split("--- Analysis for Chunk")
|
172 |
diagnoses = []
|
173 |
for chunk in chunks:
|
174 |
chunk = chunk.strip()
|
175 |
if not chunk or "No oversights identified" in chunk:
|
176 |
continue
|
177 |
-
# Extract missed diagnoses from chunk
|
178 |
lines = chunk.splitlines()
|
179 |
in_diagnoses_section = False
|
180 |
for line in lines:
|
@@ -191,22 +193,16 @@ def summarize_findings(combined_response: str) -> str:
|
|
191 |
diagnosis = re.sub(r"^\-\s*", "", line).strip()
|
192 |
if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE):
|
193 |
diagnoses.append(diagnosis)
|
194 |
-
|
195 |
-
# Remove duplicates while preserving order
|
196 |
seen = set()
|
197 |
unique_diagnoses = [d for d in diagnoses if not (d in seen or seen.add(d))]
|
198 |
-
|
199 |
if not unique_diagnoses:
|
200 |
return "No missed diagnoses were identified in the provided records."
|
201 |
-
|
202 |
-
# Combine into a single paragraph
|
203 |
summary = "Missed diagnoses include " + ", ".join(unique_diagnoses[:-1])
|
204 |
if len(unique_diagnoses) > 1:
|
205 |
summary += f", and {unique_diagnoses[-1]}"
|
206 |
elif len(unique_diagnoses) == 1:
|
207 |
summary = "Missed diagnoses include " + unique_diagnoses[0]
|
208 |
summary += ", all of which require urgent clinical review to prevent potential adverse outcomes."
|
209 |
-
|
210 |
return summary.strip()
|
211 |
|
212 |
def init_agent():
|
@@ -232,7 +228,7 @@ def init_agent():
|
|
232 |
logger.info("Agent Ready")
|
233 |
return agent
|
234 |
|
235 |
-
def create_ui(agent):
|
236 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
237 |
gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
|
238 |
chatbot = gr.Chatbot(label="Detailed Analysis", height=600, type="messages")
|
@@ -249,7 +245,7 @@ Patient Record Excerpt (Chunk {0} of {1}):
|
|
249 |
{chunk}
|
250 |
"""
|
251 |
|
252 |
-
def analyze(message: str, history: List[dict], files: List, progress=gr.Progress()):
|
253 |
history.append({"role": "user", "content": message})
|
254 |
yield history, None, ""
|
255 |
|
@@ -260,11 +256,10 @@ Patient Record Excerpt (Chunk {0} of {1}):
|
|
260 |
progress(current / total, desc=f"Extracting text... Page {current}/{total}")
|
261 |
return history, None, ""
|
262 |
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
file_hash_value = file_hash(files[0].name) if files else ""
|
268 |
|
269 |
history.append({"role": "assistant", "content": "✅ Text extraction complete."})
|
270 |
yield history, None, ""
|
@@ -319,8 +314,8 @@ Patient Record Excerpt (Chunk {0} of {1}):
|
|
319 |
summary = summarize_findings(combined_response)
|
320 |
report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None
|
321 |
if report_path:
|
322 |
-
with open(report_path, "w", encoding="utf-8") as f:
|
323 |
-
f.write(combined_response + "\n\n" + summary)
|
324 |
yield history, report_path if report_path and os.path.exists(report_path) else None, summary
|
325 |
|
326 |
except Exception as e:
|
@@ -336,7 +331,7 @@ if __name__ == "__main__":
|
|
336 |
try:
|
337 |
logger.info("Launching app...")
|
338 |
agent = init_agent()
|
339 |
-
demo = create_ui(agent)
|
340 |
demo.queue(api_open=False).launch(
|
341 |
server_name="0.0.0.0",
|
342 |
server_port=7860,
|
|
|
1 |
import sys
|
2 |
import os
|
3 |
+
import polars as pl
|
4 |
import pdfplumber
|
5 |
import json
|
6 |
import gradio as gr
|
|
|
14 |
import logging
|
15 |
import torch
|
16 |
import gc
|
17 |
+
from cachetools import LFUCache
|
18 |
import time
|
19 |
+
import asyncio
|
20 |
+
import aiofiles
|
21 |
|
22 |
# Configure logging
|
23 |
logging.basicConfig(level=logging.INFO)
|
|
|
49 |
from txagent.txagent import TxAgent
|
50 |
|
51 |
# Initialize cache with 10GB limit
|
52 |
+
cache = LFUCache(maxsize=1000) # Adjust maxsize based on memory constraints
|
53 |
|
54 |
def sanitize_utf8(text: str) -> str:
|
55 |
return text.encode("utf-8", "ignore").decode("utf-8")
|
|
|
93 |
logger.error("PDF processing error: %s", e)
|
94 |
return f"PDF processing error: {str(e)}"
|
95 |
|
96 |
+
async def convert_file_to_json(file_path: str, file_type: str, progress_callback=None) -> str:
|
97 |
+
try.:
|
98 |
+
|
99 |
+
cache_key = f"{os.path.basename(file_path)}_{file_type}"
|
100 |
if cache_key in cache:
|
101 |
return cache[cache_key]
|
102 |
|
|
|
104 |
text = extract_all_pages(file_path, progress_callback)
|
105 |
result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
|
106 |
elif file_type == "csv":
|
107 |
+
df = pl.read_csv(file_path, encoding="utf8-lossy", has_header=False, infer_schema_length=0)
|
108 |
+
content = df.fill_null("").to_dicts()
|
|
|
109 |
result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
|
110 |
elif file_type in ["xls", "xlsx"]:
|
111 |
+
async def stream_excel_to_json():
|
112 |
+
df = pl.read_excel(file_path, read_csv_options={"infer_schema_length": 0})
|
113 |
+
chunk_size = 1000
|
114 |
+
rows = []
|
115 |
+
for i in range(0, len(df), chunk_size):
|
116 |
+
chunk = df[i:i + chunk_size].fill_null("").to_dicts()
|
117 |
+
rows.extend(chunk)
|
118 |
+
if progress_callback:
|
119 |
+
progress_callback(min(i + chunk_size, len(df)), len(df))
|
120 |
+
await asyncio.sleep(0) # Yield control to event loop
|
121 |
+
return json.dumps({"filename": os.path.basename(file_path), "rows": rows})
|
122 |
+
|
123 |
+
result = await stream_excel_to_json()
|
124 |
else:
|
125 |
result = json.dumps({"error": f"Unsupported file type: {file_type}"})
|
126 |
|
|
|
147 |
|
148 |
def clean_response(text: str) -> str:
|
149 |
text = sanitize_utf8(text)
|
|
|
150 |
text = re.sub(r"\[.*?\]|\bNone\b|To analyze the patient record excerpt.*?medications\.|Since the previous attempts.*?\.|I need to.*?medications\.|Retrieving tools.*?\.", "", text, flags=re.DOTALL)
|
|
|
151 |
diagnoses = []
|
152 |
lines = text.splitlines()
|
153 |
in_diagnoses_section = False
|
|
|
165 |
diagnosis = re.sub(r"^\-\s*", "", line).strip()
|
166 |
if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE):
|
167 |
diagnoses.append(diagnosis)
|
|
|
168 |
text = " ".join(diagnoses)
|
|
|
169 |
text = re.sub(r"\s+", " ", text).strip()
|
170 |
text = re.sub(r"[^\w\s\.\,\(\)\-]", "", text)
|
171 |
return text if text else ""
|
172 |
|
173 |
def summarize_findings(combined_response: str) -> str:
|
|
|
174 |
chunks = combined_response.split("--- Analysis for Chunk")
|
175 |
diagnoses = []
|
176 |
for chunk in chunks:
|
177 |
chunk = chunk.strip()
|
178 |
if not chunk or "No oversights identified" in chunk:
|
179 |
continue
|
|
|
180 |
lines = chunk.splitlines()
|
181 |
in_diagnoses_section = False
|
182 |
for line in lines:
|
|
|
193 |
diagnosis = re.sub(r"^\-\s*", "", line).strip()
|
194 |
if diagnosis and not re.match(r"No issues identified", diagnosis, re.IGNORECASE):
|
195 |
diagnoses.append(diagnosis)
|
|
|
|
|
196 |
seen = set()
|
197 |
unique_diagnoses = [d for d in diagnoses if not (d in seen or seen.add(d))]
|
|
|
198 |
if not unique_diagnoses:
|
199 |
return "No missed diagnoses were identified in the provided records."
|
|
|
|
|
200 |
summary = "Missed diagnoses include " + ", ".join(unique_diagnoses[:-1])
|
201 |
if len(unique_diagnoses) > 1:
|
202 |
summary += f", and {unique_diagnoses[-1]}"
|
203 |
elif len(unique_diagnoses) == 1:
|
204 |
summary = "Missed diagnoses include " + unique_diagnoses[0]
|
205 |
summary += ", all of which require urgent clinical review to prevent potential adverse outcomes."
|
|
|
206 |
return summary.strip()
|
207 |
|
208 |
def init_agent():
|
|
|
228 |
logger.info("Agent Ready")
|
229 |
return agent
|
230 |
|
231 |
+
async def create_ui(agent):
|
232 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
233 |
gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
|
234 |
chatbot = gr.Chatbot(label="Detailed Analysis", height=600, type="messages")
|
|
|
245 |
{chunk}
|
246 |
"""
|
247 |
|
248 |
+
async def analyze(message: str, history: List[dict], files: List, progress=gr.Progress()):
|
249 |
history.append({"role": "user", "content": message})
|
250 |
yield history, None, ""
|
251 |
|
|
|
256 |
progress(current / total, desc=f"Extracting text... Page {current}/{total}")
|
257 |
return history, None, ""
|
258 |
|
259 |
+
tasks = [convert_file_to_json(f.name, f.name.split(".")[-1].lower(), update_extraction_progress) for f in files]
|
260 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
261 |
+
extracted = "\n".join([sanitize_utf8(r) for r in results if isinstance(r, str)])
|
262 |
+
file_hash_value = file_hash(files[0].name) if files else ""
|
|
|
263 |
|
264 |
history.append({"role": "assistant", "content": "✅ Text extraction complete."})
|
265 |
yield history, None, ""
|
|
|
314 |
summary = summarize_findings(combined_response)
|
315 |
report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None
|
316 |
if report_path:
|
317 |
+
async with aiofiles.open(report_path, "w", encoding="utf-8") as f:
|
318 |
+
await f.write(combined_response + "\n\n" + summary)
|
319 |
yield history, report_path if report_path and os.path.exists(report_path) else None, summary
|
320 |
|
321 |
except Exception as e:
|
|
|
331 |
try:
|
332 |
logger.info("Launching app...")
|
333 |
agent = init_agent()
|
334 |
+
demo = asyncio.run(create_ui(agent))
|
335 |
demo.queue(api_open=False).launch(
|
336 |
server_name="0.0.0.0",
|
337 |
server_port=7860,
|