Update app.py
Browse files
app.py
CHANGED
@@ -1,21 +1,10 @@
|
|
1 |
-
|
2 |
-
import os
|
3 |
-
import pandas as pd
|
4 |
-
import pdfplumber
|
5 |
-
import json
|
6 |
-
import gradio as gr
|
7 |
-
from typing import List
|
8 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
9 |
-
import hashlib
|
10 |
-
import shutil
|
11 |
-
import time
|
12 |
from threading import Thread
|
13 |
-
import re
|
14 |
-
import tempfile
|
15 |
|
16 |
-
# Setup
|
17 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
18 |
-
src_path = os.path.
|
19 |
sys.path.insert(0, src_path)
|
20 |
|
21 |
base_dir = "/data"
|
@@ -28,9 +17,10 @@ vllm_cache_dir = os.path.join(base_dir, "vllm_cache")
|
|
28 |
for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
|
29 |
os.makedirs(d, exist_ok=True)
|
30 |
|
|
|
31 |
os.environ.update({
|
32 |
-
"TRANSFORMERS_CACHE": model_cache_dir,
|
33 |
"HF_HOME": model_cache_dir,
|
|
|
34 |
"VLLM_CACHE_DIR": vllm_cache_dir,
|
35 |
"TOKENIZERS_PARALLELISM": "false",
|
36 |
"CUDA_LAUNCH_BLOCKING": "1"
|
@@ -38,38 +28,31 @@ os.environ.update({
|
|
38 |
|
39 |
from txagent.txagent import TxAgent
|
40 |
|
41 |
-
MEDICAL_KEYWORDS = {
|
42 |
-
|
43 |
-
'allergies', 'summary', 'impression', 'findings', 'recommendations'
|
44 |
-
}
|
45 |
-
|
46 |
-
def sanitize_utf8(text: str) -> str:
|
47 |
-
return text.encode("utf-8", "ignore").decode("utf-8")
|
48 |
|
49 |
-
def
|
50 |
-
|
51 |
-
return hashlib.md5(f.read()).hexdigest()
|
52 |
|
53 |
-
def extract_priority_pages(file_path
|
54 |
try:
|
55 |
-
text_chunks = []
|
56 |
with pdfplumber.open(file_path) as pdf:
|
|
|
57 |
for i, page in enumerate(pdf.pages[:3]):
|
58 |
-
|
59 |
for i, page in enumerate(pdf.pages[3:max_pages], start=4):
|
60 |
-
|
61 |
-
if any(re.search(rf'\b{kw}\b',
|
62 |
-
|
63 |
-
|
64 |
except Exception as e:
|
65 |
return f"PDF processing error: {str(e)}"
|
66 |
|
67 |
-
def convert_file_to_json(file_path
|
68 |
try:
|
69 |
h = file_hash(file_path)
|
70 |
cache_path = os.path.join(file_cache_dir, f"{h}.json")
|
71 |
-
if os.path.exists(cache_path):
|
72 |
-
return open(cache_path, "r", encoding="utf-8").read()
|
73 |
|
74 |
if file_type == "pdf":
|
75 |
text = extract_priority_pages(file_path)
|
@@ -77,39 +60,32 @@ def convert_file_to_json(file_path: str, file_type: str) -> str:
|
|
77 |
Thread(target=full_pdf_processing, args=(file_path, h)).start()
|
78 |
elif file_type == "csv":
|
79 |
df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str, skip_blank_lines=False, on_bad_lines="skip")
|
80 |
-
|
81 |
-
result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
|
82 |
elif file_type in ["xls", "xlsx"]:
|
83 |
try:
|
84 |
df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
|
85 |
except:
|
86 |
df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
|
87 |
-
|
88 |
-
result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
|
89 |
else:
|
90 |
return json.dumps({"error": f"Unsupported file type: {file_type}"})
|
91 |
|
92 |
-
with open(cache_path, "w", encoding="utf-8") as f:
|
93 |
-
f.write(result)
|
94 |
return result
|
95 |
-
|
96 |
except Exception as e:
|
97 |
return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
|
98 |
|
99 |
-
def full_pdf_processing(file_path
|
100 |
try:
|
101 |
-
cache_path = os.path.join(file_cache_dir, f"{
|
102 |
-
if os.path.exists(cache_path):
|
103 |
-
return
|
104 |
with pdfplumber.open(file_path) as pdf:
|
105 |
full_text = "\n".join([f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}" for i, page in enumerate(pdf.pages)])
|
106 |
result = json.dumps({"filename": os.path.basename(file_path), "content": full_text, "status": "complete"})
|
107 |
-
with open(cache_path, "w", encoding="utf-8") as f:
|
108 |
-
|
109 |
-
with open(os.path.join(report_dir, f"{file_hash}_report.txt"), "w", encoding="utf-8") as out:
|
110 |
-
out.write(full_text)
|
111 |
except Exception as e:
|
112 |
-
print(
|
113 |
|
114 |
def init_agent():
|
115 |
default_tool_path = os.path.abspath("data/new_tool.json")
|
@@ -124,36 +100,37 @@ def init_agent():
|
|
124 |
force_finish=True,
|
125 |
enable_checker=True,
|
126 |
step_rag_num=8,
|
127 |
-
seed=100
|
128 |
-
additional_default_tools=[],
|
129 |
)
|
130 |
agent.init_model()
|
131 |
return agent
|
132 |
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
135 |
-
gr.Markdown(""
|
136 |
-
<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>
|
137 |
-
<h3 style='text-align: center;'>Identify potential oversights in patient care</h3>
|
138 |
-
""")
|
139 |
|
140 |
chatbot = gr.Chatbot(label="Analysis", height=600, type="messages")
|
141 |
-
file_upload = gr.File(
|
142 |
-
msg_input = gr.Textbox(placeholder="Ask about potential oversights..."
|
143 |
send_btn = gr.Button("Analyze", variant="primary")
|
144 |
-
|
145 |
-
download_output = gr.File(label="Download
|
146 |
|
147 |
-
def
|
148 |
try:
|
149 |
-
extracted_data = ""
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
with ThreadPoolExecutor(max_workers=4) as executor:
|
154 |
-
futures = [executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower()) for f in files if hasattr(f, 'name')]
|
155 |
extracted_data = "\n".join([sanitize_utf8(f.result()) for f in as_completed(futures)])
|
156 |
-
file_hash_value = file_hash(files[0].name)
|
157 |
|
158 |
prompt = f"""Review these medical records and identify EXACTLY what might have been missed:
|
159 |
1. List potential missed diagnoses
|
@@ -165,8 +142,8 @@ Medical Records:\n{extracted_data[:15000]}
|
|
165 |
|
166 |
### Potential Oversights:\n"""
|
167 |
|
168 |
-
|
169 |
-
for chunk in
|
170 |
message=prompt,
|
171 |
history=[],
|
172 |
temperature=0.2,
|
@@ -176,52 +153,31 @@ Medical Records:\n{extracted_data[:15000]}
|
|
176 |
conversation=conversation
|
177 |
):
|
178 |
if isinstance(chunk, str):
|
179 |
-
|
180 |
elif isinstance(chunk, list):
|
181 |
-
|
182 |
|
183 |
-
cleaned =
|
184 |
if not cleaned:
|
185 |
-
cleaned = "No
|
186 |
-
|
187 |
-
updated_history = history + [
|
188 |
-
{"role": "user", "content": message},
|
189 |
-
{"role": "assistant", "content": cleaned}
|
190 |
-
]
|
191 |
|
192 |
-
|
193 |
-
if file_hash_value:
|
194 |
-
possible_report = os.path.join(report_dir, f"{file_hash_value}_report.txt")
|
195 |
-
if os.path.exists(possible_report):
|
196 |
-
report_path = possible_report
|
197 |
|
|
|
198 |
yield updated_history, report_path
|
199 |
-
|
200 |
except Exception as e:
|
201 |
-
updated_history = history + [{"role": "user", "content": message},
|
202 |
-
{"role": "assistant", "content": f"❌ Analysis failed: {str(e)}"}]
|
203 |
yield updated_history, None
|
204 |
|
205 |
-
inputs
|
206 |
-
|
207 |
-
send_btn.click(analyze_potential_oversights, inputs=inputs, outputs=outputs)
|
208 |
-
msg_input.submit(analyze_potential_oversights, inputs=inputs, outputs=outputs)
|
209 |
-
|
210 |
-
gr.Examples([
|
211 |
-
["What might have been missed in this patient's treatment?"],
|
212 |
-
["Are there any medication conflicts in these records?"],
|
213 |
-
["What abnormal results require follow-up?"]
|
214 |
-
], inputs=msg_input)
|
215 |
|
216 |
return demo
|
217 |
|
218 |
if __name__ == "__main__":
|
219 |
-
print("Initializing medical analysis agent...")
|
220 |
-
agent = init_agent()
|
221 |
-
|
222 |
print("Launching interface...")
|
223 |
-
|
224 |
-
|
225 |
server_name="0.0.0.0",
|
226 |
server_port=7860,
|
227 |
show_error=True,
|
|
|
1 |
+
import sys, os, json, gradio as gr, pandas as pd, pdfplumber, hashlib, shutil, re, time
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
|
|
|
3 |
from threading import Thread
|
|
|
|
|
4 |
|
5 |
+
# Setup
|
6 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
7 |
+
src_path = os.path.join(current_dir, "src")
|
8 |
sys.path.insert(0, src_path)
|
9 |
|
10 |
base_dir = "/data"
|
|
|
17 |
for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
|
18 |
os.makedirs(d, exist_ok=True)
|
19 |
|
20 |
+
# Hugging Face & Transformers cache
|
21 |
os.environ.update({
|
|
|
22 |
"HF_HOME": model_cache_dir,
|
23 |
+
"TRANSFORMERS_CACHE": model_cache_dir,
|
24 |
"VLLM_CACHE_DIR": vllm_cache_dir,
|
25 |
"TOKENIZERS_PARALLELISM": "false",
|
26 |
"CUDA_LAUNCH_BLOCKING": "1"
|
|
|
28 |
|
29 |
from txagent.txagent import TxAgent
|
30 |
|
31 |
+
MEDICAL_KEYWORDS = {'diagnosis', 'assessment', 'plan', 'results', 'medications',
|
32 |
+
'allergies', 'summary', 'impression', 'findings', 'recommendations'}
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
+
def sanitize_utf8(text): return text.encode("utf-8", "ignore").decode("utf-8")
|
35 |
+
def file_hash(path): return hashlib.md5(open(path, "rb").read()).hexdigest()
|
|
|
36 |
|
37 |
+
def extract_priority_pages(file_path, max_pages=20):
|
38 |
try:
|
|
|
39 |
with pdfplumber.open(file_path) as pdf:
|
40 |
+
pages = []
|
41 |
for i, page in enumerate(pdf.pages[:3]):
|
42 |
+
pages.append(f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}")
|
43 |
for i, page in enumerate(pdf.pages[3:max_pages], start=4):
|
44 |
+
text = page.extract_text() or ""
|
45 |
+
if any(re.search(rf'\b{kw}\b', text.lower()) for kw in MEDICAL_KEYWORDS):
|
46 |
+
pages.append(f"=== Page {i} ===\n{text.strip()}")
|
47 |
+
return "\n\n".join(pages)
|
48 |
except Exception as e:
|
49 |
return f"PDF processing error: {str(e)}"
|
50 |
|
51 |
+
def convert_file_to_json(file_path, file_type):
|
52 |
try:
|
53 |
h = file_hash(file_path)
|
54 |
cache_path = os.path.join(file_cache_dir, f"{h}.json")
|
55 |
+
if os.path.exists(cache_path): return open(cache_path, "r", encoding="utf-8").read()
|
|
|
56 |
|
57 |
if file_type == "pdf":
|
58 |
text = extract_priority_pages(file_path)
|
|
|
60 |
Thread(target=full_pdf_processing, args=(file_path, h)).start()
|
61 |
elif file_type == "csv":
|
62 |
df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str, skip_blank_lines=False, on_bad_lines="skip")
|
63 |
+
result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna("").astype(str).values.tolist()})
|
|
|
64 |
elif file_type in ["xls", "xlsx"]:
|
65 |
try:
|
66 |
df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
|
67 |
except:
|
68 |
df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
|
69 |
+
result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna("").astype(str).values.tolist()})
|
|
|
70 |
else:
|
71 |
return json.dumps({"error": f"Unsupported file type: {file_type}"})
|
72 |
|
73 |
+
with open(cache_path, "w", encoding="utf-8") as f: f.write(result)
|
|
|
74 |
return result
|
|
|
75 |
except Exception as e:
|
76 |
return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
|
77 |
|
78 |
+
def full_pdf_processing(file_path, file_hash_value):
|
79 |
try:
|
80 |
+
cache_path = os.path.join(file_cache_dir, f"{file_hash_value}_full.json")
|
81 |
+
if os.path.exists(cache_path): return
|
|
|
82 |
with pdfplumber.open(file_path) as pdf:
|
83 |
full_text = "\n".join([f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}" for i, page in enumerate(pdf.pages)])
|
84 |
result = json.dumps({"filename": os.path.basename(file_path), "content": full_text, "status": "complete"})
|
85 |
+
with open(cache_path, "w", encoding="utf-8") as f: f.write(result)
|
86 |
+
with open(os.path.join(report_dir, f"{file_hash_value}_report.txt"), "w", encoding="utf-8") as out: out.write(full_text)
|
|
|
|
|
87 |
except Exception as e:
|
88 |
+
print("PDF processing error:", e)
|
89 |
|
90 |
def init_agent():
|
91 |
default_tool_path = os.path.abspath("data/new_tool.json")
|
|
|
100 |
force_finish=True,
|
101 |
enable_checker=True,
|
102 |
step_rag_num=8,
|
103 |
+
seed=100
|
|
|
104 |
)
|
105 |
agent.init_model()
|
106 |
return agent
|
107 |
|
108 |
+
# Lazy load agent only on first use
|
109 |
+
agent_container = {"agent": None}
|
110 |
+
def get_agent():
|
111 |
+
if agent_container["agent"] is None:
|
112 |
+
agent_container["agent"] = init_agent()
|
113 |
+
return agent_container["agent"]
|
114 |
+
|
115 |
+
def create_ui(get_agent_func):
|
116 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
117 |
+
gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1><h3 style='text-align: center;'>Identify potential oversights in patient care</h3>")
|
|
|
|
|
|
|
118 |
|
119 |
chatbot = gr.Chatbot(label="Analysis", height=600, type="messages")
|
120 |
+
file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
|
121 |
+
msg_input = gr.Textbox(placeholder="Ask about potential oversights...")
|
122 |
send_btn = gr.Button("Analyze", variant="primary")
|
123 |
+
state = gr.State([])
|
124 |
+
download_output = gr.File(label="Download Report")
|
125 |
|
126 |
+
def analyze(message, history, conversation, files):
|
127 |
try:
|
128 |
+
extracted_data, file_hash_value = "", ""
|
129 |
+
if files:
|
130 |
+
with ThreadPoolExecutor(max_workers=4) as pool:
|
131 |
+
futures = [pool.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower()) for f in files]
|
|
|
|
|
132 |
extracted_data = "\n".join([sanitize_utf8(f.result()) for f in as_completed(futures)])
|
133 |
+
file_hash_value = file_hash(files[0].name)
|
134 |
|
135 |
prompt = f"""Review these medical records and identify EXACTLY what might have been missed:
|
136 |
1. List potential missed diagnoses
|
|
|
142 |
|
143 |
### Potential Oversights:\n"""
|
144 |
|
145 |
+
final_response = ""
|
146 |
+
for chunk in get_agent_func().run_gradio_chat(
|
147 |
message=prompt,
|
148 |
history=[],
|
149 |
temperature=0.2,
|
|
|
153 |
conversation=conversation
|
154 |
):
|
155 |
if isinstance(chunk, str):
|
156 |
+
final_response += chunk
|
157 |
elif isinstance(chunk, list):
|
158 |
+
final_response += "".join([c.content for c in chunk if hasattr(c, "content")])
|
159 |
|
160 |
+
cleaned = final_response.replace("[TOOL_CALLS]", "").strip()
|
161 |
if not cleaned:
|
162 |
+
cleaned = "No oversights found. Consider further review."
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
+
updated_history = history + [{"role": "user", "content": message}, {"role": "assistant", "content": cleaned}]
|
|
|
|
|
|
|
|
|
165 |
|
166 |
+
report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value and os.path.exists(os.path.join(report_dir, f"{file_hash_value}_report.txt")) else None
|
167 |
yield updated_history, report_path
|
|
|
168 |
except Exception as e:
|
169 |
+
updated_history = history + [{"role": "user", "content": message}, {"role": "assistant", "content": f"❌ Error: {str(e)}"}]
|
|
|
170 |
yield updated_history, None
|
171 |
|
172 |
+
send_btn.click(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, download_output])
|
173 |
+
msg_input.submit(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, download_output])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
return demo
|
176 |
|
177 |
if __name__ == "__main__":
|
|
|
|
|
|
|
178 |
print("Launching interface...")
|
179 |
+
ui = create_ui(get_agent)
|
180 |
+
ui.queue(api_open=False).launch(
|
181 |
server_name="0.0.0.0",
|
182 |
server_port=7860,
|
183 |
show_error=True,
|