Update app.py
Browse files
app.py
CHANGED
@@ -1,261 +1,275 @@
|
|
|
|
1 |
import os
|
2 |
import pandas as pd
|
3 |
import pdfplumber
|
4 |
-
import
|
5 |
import gradio as gr
|
6 |
-
from typing import List
|
7 |
-
from concurrent.futures import ThreadPoolExecutor
|
8 |
import hashlib
|
9 |
-
import
|
10 |
-
|
11 |
-
import
|
|
|
12 |
import time
|
13 |
|
14 |
-
#
|
15 |
-
logging.getLogger("pdfplumber").setLevel(logging.ERROR)
|
16 |
-
|
17 |
-
# Persistent directories
|
18 |
persistent_dir = "/data/hf_cache"
|
19 |
os.makedirs(persistent_dir, exist_ok=True)
|
|
|
|
|
|
|
20 |
file_cache_dir = os.path.join(persistent_dir, "cache")
|
21 |
report_dir = os.path.join(persistent_dir, "reports")
|
22 |
-
|
|
|
|
|
23 |
os.makedirs(directory, exist_ok=True)
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
def sanitize_utf8(text: str) -> str:
|
26 |
-
"""Sanitize text to handle UTF-8 encoding issues."""
|
27 |
return text.encode("utf-8", "ignore").decode("utf-8")
|
28 |
|
29 |
def file_hash(path: str) -> str:
|
30 |
-
"""Generate MD5 hash of a file."""
|
31 |
with open(path, "rb") as f:
|
32 |
return hashlib.md5(f.read()).hexdigest()
|
33 |
|
34 |
-
def
|
35 |
-
"""Extract text from a range of PDF pages."""
|
36 |
try:
|
37 |
text_chunks = []
|
38 |
-
with pdfplumber.open(file_path) as pdf:
|
39 |
-
for page in pdf.pages[start_page:end_page]:
|
40 |
-
page_text = page.extract_text() or ""
|
41 |
-
text_chunks.append(page_text.strip())
|
42 |
-
return "\n".join(text_chunks)
|
43 |
-
except Exception:
|
44 |
-
return ""
|
45 |
-
|
46 |
-
def extract_all_pages(file_path: str, progress_callback=None) -> str:
|
47 |
-
"""Extract text from all pages of a PDF using parallel processing."""
|
48 |
-
try:
|
49 |
with pdfplumber.open(file_path) as pdf:
|
50 |
total_pages = len(pdf.pages)
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
pages_per_process = max(1, total_pages // num_processes)
|
58 |
-
|
59 |
-
# Create page ranges for parallel processing
|
60 |
-
ranges = [(i * pages_per_process, min((i + 1) * pages_per_process, total_pages))
|
61 |
-
for i in range(num_processes)]
|
62 |
-
if ranges[-1][1] != total_pages:
|
63 |
-
ranges[-1] = (ranges[-1][0], total_pages)
|
64 |
-
|
65 |
-
# Process page ranges in parallel
|
66 |
-
with multiprocessing.Pool(processes=num_processes) as pool:
|
67 |
-
extract_func = partial(extract_page_range, file_path)
|
68 |
-
results = []
|
69 |
-
for idx, result in enumerate(pool.starmap(extract_func, ranges)):
|
70 |
-
results.append(result)
|
71 |
if progress_callback:
|
72 |
-
processed_pages = min((idx + 1) * pages_per_process, total_pages)
|
73 |
progress_callback(processed_pages, total_pages)
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
return ""
|
78 |
|
79 |
-
def
|
80 |
-
"""Convert supported file types to text, caching results."""
|
81 |
try:
|
82 |
h = file_hash(file_path)
|
83 |
-
cache_path = os.path.join(file_cache_dir, f"{h}.
|
84 |
if os.path.exists(cache_path):
|
85 |
with open(cache_path, "r", encoding="utf-8") as f:
|
86 |
return f.read()
|
87 |
|
88 |
if file_type == "pdf":
|
89 |
-
text =
|
|
|
90 |
elif file_type == "csv":
|
91 |
df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str,
|
92 |
-
skip_blank_lines=
|
93 |
-
|
|
|
94 |
elif file_type in ["xls", "xlsx"]:
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
97 |
else:
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
item_pattern = re.compile(r"^- .+$", re.MULTILINE)
|
120 |
-
|
121 |
-
for line in raw_response.splitlines():
|
122 |
-
line = line.strip()
|
123 |
-
if not line:
|
124 |
-
continue
|
125 |
-
if section_pattern.match(line):
|
126 |
-
current_section = line[:-1]
|
127 |
-
elif current_section and item_pattern.match(line):
|
128 |
-
sections[current_section].append(line)
|
129 |
-
|
130 |
-
return sections
|
131 |
-
|
132 |
-
def analyze_medical_records(extracted_text: str, progress_callback=None) -> str:
|
133 |
-
"""Analyze medical records and return generalized structured response."""
|
134 |
-
# Split text into chunks to handle large inputs
|
135 |
-
chunk_size = 10000
|
136 |
-
chunks = [extracted_text[i:i + chunk_size] for i in range(0, len(extracted_text), chunk_size)]
|
137 |
-
|
138 |
-
# Generalized analysis template (replace with model or rule-based logic)
|
139 |
-
raw_response_template = """
|
140 |
-
Missed Diagnoses:
|
141 |
-
- Chronic conditions potentially missed due to inconsistent monitoring of vital signs or symptoms. This may occur when patient visits are infrequent or records lack longitudinal tracking, leading to undetected trends. Undiagnosed conditions can progress, increasing risks of complications like organ damage. Recommended action: Implement regular screening protocols and trend analysis for key indicators (e.g., blood pressure, glucose levels).
|
142 |
-
- Risk factors for hereditary or lifestyle-related diseases not screened despite documented family history or patient demographics. Screening oversights often stem from time constraints or lack of standardized protocols. Delayed diagnosis may lead to preventable disease progression. Recommended action: Establish routine risk assessments based on family history and clinical guidelines.
|
143 |
-
|
144 |
-
Medication Conflicts:
|
145 |
-
- Potential interactions from polypharmacy or untracked over-the-counter medications. Conflicts may arise when multiple prescribers are involved or patients self-medicate, increasing risks of adverse events like bleeding or toxicity. Recommended action: Conduct comprehensive medication reconciliation at each visit and educate patients on reporting all medications.
|
146 |
-
|
147 |
-
Incomplete Assessments:
|
148 |
-
- Symptoms reported but not fully evaluated due to incomplete documentation or failure to follow clinical guidelines. This can occur in busy clinical settings where time limits prioritize acute issues over thorough investigation. Unaddressed symptoms may mask serious conditions, delaying treatment. Recommended action: Standardize symptom evaluation protocols and ensure adequate time for comprehensive assessments.
|
149 |
-
|
150 |
-
Urgent Follow-up:
|
151 |
-
- Critical findings requiring specialist referral or additional testing delayed due to communication gaps or scheduling issues. Delays often result from fragmented care coordination or underestimation of findings' severity. Untreated critical issues can lead to rapid deterioration. Recommended action: Establish clear referral pathways and prioritize urgent findings with defined timelines.
|
152 |
-
"""
|
153 |
-
|
154 |
-
# Aggregate findings across chunks
|
155 |
-
all_sections = {
|
156 |
-
"Missed Diagnoses": set(),
|
157 |
-
"Medication Conflicts": set(),
|
158 |
-
"Incomplete Assessments": set(),
|
159 |
-
"Urgent Follow-up": set()
|
160 |
-
}
|
161 |
-
|
162 |
-
for chunk_idx, chunk in enumerate(chunks, 1):
|
163 |
-
# Simulate analysis per chunk (replace with real logic)
|
164 |
-
raw_response = raw_response_template
|
165 |
-
parsed = parse_analysis_response(raw_response)
|
166 |
-
for section, items in parsed.items():
|
167 |
-
all_sections[section].update(items)
|
168 |
-
if progress_callback:
|
169 |
-
progress_callback(chunk_idx, len(chunks))
|
170 |
-
|
171 |
-
# Format generalized response
|
172 |
-
response = ["### Clinical Oversight Analysis\n"]
|
173 |
-
response.append("This analysis reviews patient records to identify common reasons for potential oversights that could impact clinical outcomes. Findings highlight systemic or procedural gaps, associated risks, and actionable recommendations applicable across various patient records.\n")
|
174 |
-
has_findings = False
|
175 |
-
for section, items in all_sections.items():
|
176 |
-
response.append(f"#### {section}")
|
177 |
-
if items:
|
178 |
-
response.extend(sorted(items))
|
179 |
-
has_findings = True
|
180 |
-
else:
|
181 |
-
response.append("- No issues identified in this category.")
|
182 |
-
response.append("")
|
183 |
-
|
184 |
-
response.append("### Summary")
|
185 |
-
if has_findings:
|
186 |
-
summary = ("The analysis identified common procedural and systemic gaps that may lead to oversights in diagnosis, medication management, assessments, and follow-up care. These gaps, such as inconsistent monitoring, incomplete documentation, or communication delays, pose risks of disease progression, adverse events, or delayed treatment. Recommended actions include standardizing screening and assessment protocols, improving medication reconciliation, and establishing clear referral pathways. Implementing these measures can enhance patient safety and care quality across diverse clinical scenarios.")
|
187 |
-
else:
|
188 |
-
summary = ("No significant oversights were identified in the provided records. Current practices appear aligned with general clinical standards. To maintain care quality, continue regular monitoring, ensure comprehensive documentation, and adhere to guideline-based screening and follow-up protocols.")
|
189 |
-
response.append(summary)
|
190 |
-
|
191 |
-
return "\n".join(response)
|
192 |
-
|
193 |
-
def create_ui():
|
194 |
-
"""Create Gradio UI for clinical oversight analysis."""
|
195 |
-
def analyze(message: str, history: List[dict], files: List):
|
196 |
-
"""Handle analysis with animated progress updates."""
|
197 |
-
history.append({"role": "user", "content": message})
|
198 |
-
yield history, None
|
199 |
-
|
200 |
-
extracted_text = ""
|
201 |
-
file_hash_value = ""
|
202 |
-
if files:
|
203 |
-
# Progress callback for extraction
|
204 |
-
total_pages = 0
|
205 |
-
processed_pages = 0
|
206 |
-
def update_extraction_progress(current, total):
|
207 |
-
nonlocal processed_pages, total_pages
|
208 |
-
processed_pages = current
|
209 |
-
total_pages = total
|
210 |
-
animation = ["π", "π", "βοΈ", "π"][(int(time.time() * 2) % 4)]
|
211 |
-
history[-1] = {"role": "assistant", "content": f"Extracting text... {animation} Page {processed_pages}/{total_pages}"}
|
212 |
-
return history, None
|
213 |
-
|
214 |
-
with ThreadPoolExecutor(max_workers=4) as executor:
|
215 |
-
futures = [executor.submit(convert_file_to_text, f.name, f.name.split(".")[-1].lower(), update_extraction_progress) for f in files]
|
216 |
-
results = [f.result() for f in futures]
|
217 |
-
extracted_text = "\n".join(sanitize_utf8(r) for r in results if r)
|
218 |
-
file_hash_value = file_hash(files[0].name) if files else ""
|
219 |
-
|
220 |
-
history.append({"role": "assistant", "content": "β
Text extraction complete."})
|
221 |
-
yield history, None
|
222 |
-
|
223 |
-
report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None
|
224 |
-
|
225 |
-
try:
|
226 |
-
# Progress callback for analysis
|
227 |
-
total_chunks = 0
|
228 |
-
processed_chunks = 0
|
229 |
-
def update_analysis_progress(current, total):
|
230 |
-
nonlocal processed_chunks, total_chunks
|
231 |
-
processed_chunks = current
|
232 |
-
total_chunks = total
|
233 |
-
animation = ["π", "π", "π§ ", "π"][(int(time.time() * 2) % 4)]
|
234 |
-
history[-1] = {"role": "assistant", "content": f"Analyzing records... {animation} Chunk {processed_chunks}/{total_chunks}"}
|
235 |
-
return history, None
|
236 |
-
|
237 |
-
history.append({"role": "assistant", "content": "Analyzing records... π"})
|
238 |
-
yield history, None
|
239 |
-
response = analyze_medical_records(extracted_text, update_analysis_progress)
|
240 |
-
|
241 |
-
history.pop() # Remove "Analyzing..."
|
242 |
-
history.append({"role": "assistant", "content": response})
|
243 |
-
if report_path:
|
244 |
-
with open(report_path, "w", encoding="utf-8") as f:
|
245 |
-
f.write(response)
|
246 |
-
yield history, report_path if report_path and os.path.exists(report_path) else None
|
247 |
-
except Exception as e:
|
248 |
-
history.pop() # Remove "Analyzing..."
|
249 |
-
history.append({"role": "assistant", "content": f"β Error: {str(e)}"})
|
250 |
-
yield history, None
|
251 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
253 |
gr.Markdown("<h1 style='text-align: center;'>π©Ί Clinical Oversight Assistant</h1>")
|
254 |
chatbot = gr.Chatbot(label="Analysis", height=600, type="messages")
|
255 |
file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
|
256 |
msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
|
257 |
send_btn = gr.Button("Analyze", variant="primary")
|
258 |
-
download_output = gr.File(label="Download Report")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
|
260 |
send_btn.click(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output])
|
261 |
msg_input.submit(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output])
|
@@ -263,14 +277,12 @@ def create_ui():
|
|
263 |
|
264 |
if __name__ == "__main__":
|
265 |
print("π Launching app...")
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
except Exception as e:
|
276 |
-
print(f"Failed to launch app: {str(e)}")
|
|
|
1 |
+
import sys
|
2 |
import os
|
3 |
import pandas as pd
|
4 |
import pdfplumber
|
5 |
+
import json
|
6 |
import gradio as gr
|
7 |
+
from typing import List
|
8 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
9 |
import hashlib
|
10 |
+
import shutil
|
11 |
+
import re
|
12 |
+
import psutil
|
13 |
+
import subprocess
|
14 |
import time
|
15 |
|
16 |
+
# Persistent directory
|
|
|
|
|
|
|
17 |
persistent_dir = "/data/hf_cache"
|
18 |
os.makedirs(persistent_dir, exist_ok=True)
|
19 |
+
|
20 |
+
model_cache_dir = os.path.join(persistent_dir, "txagent_models")
|
21 |
+
tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
|
22 |
file_cache_dir = os.path.join(persistent_dir, "cache")
|
23 |
report_dir = os.path.join(persistent_dir, "reports")
|
24 |
+
vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache")
|
25 |
+
|
26 |
+
for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
|
27 |
os.makedirs(directory, exist_ok=True)
|
28 |
|
29 |
+
os.environ["HF_HOME"] = model_cache_dir
|
30 |
+
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
|
31 |
+
os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
|
32 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
33 |
+
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
34 |
+
|
35 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
36 |
+
src_path = os.path.abspath(os.path.join(current_dir, "src"))
|
37 |
+
sys.path.insert(0, src_path)
|
38 |
+
|
39 |
+
from txagent.txagent import TxAgent
|
40 |
+
|
41 |
+
MEDICAL_KEYWORDS = {'diagnosis', 'assessment', 'plan', 'results', 'medications',
|
42 |
+
'allergies', 'summary', 'impression', 'findings', 'recommendations'}
|
43 |
+
|
44 |
def sanitize_utf8(text: str) -> str:
|
|
|
45 |
return text.encode("utf-8", "ignore").decode("utf-8")
|
46 |
|
47 |
def file_hash(path: str) -> str:
|
|
|
48 |
with open(path, "rb") as f:
|
49 |
return hashlib.md5(f.read()).hexdigest()
|
50 |
|
51 |
+
def extract_priority_pages(file_path: str, progress_callback=None) -> str:
|
|
|
52 |
try:
|
53 |
text_chunks = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
with pdfplumber.open(file_path) as pdf:
|
55 |
total_pages = len(pdf.pages)
|
56 |
+
processed_pages = 0
|
57 |
+
for i, page in enumerate(pdf.pages):
|
58 |
+
page_text = page.extract_text() or ""
|
59 |
+
if i < 3 or any(re.search(rf'\b{kw}\b', page_text.lower()) for kw in MEDICAL_KEYWORDS):
|
60 |
+
text_chunks.append(f"=== Page {i+1} ===\n{page_text.strip()}")
|
61 |
+
processed_pages += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
if progress_callback:
|
|
|
63 |
progress_callback(processed_pages, total_pages)
|
64 |
+
return "\n\n".join(text_chunks)
|
65 |
+
except Exception as e:
|
66 |
+
return f"PDF processing error: {str(e)}"
|
|
|
67 |
|
68 |
+
def convert_file_to_json(file_path: str, file_type: str, progress_callback=None) -> str:
|
|
|
69 |
try:
|
70 |
h = file_hash(file_path)
|
71 |
+
cache_path = os.path.join(file_cache_dir, f"{h}.json")
|
72 |
if os.path.exists(cache_path):
|
73 |
with open(cache_path, "r", encoding="utf-8") as f:
|
74 |
return f.read()
|
75 |
|
76 |
if file_type == "pdf":
|
77 |
+
text = extract_priority_pages(file_path, progress_callback)
|
78 |
+
result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
|
79 |
elif file_type == "csv":
|
80 |
df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str,
|
81 |
+
skip_blank_lines=False, on_bad_lines="skip")
|
82 |
+
content = df.fillna("").astype(str).values.tolist()
|
83 |
+
result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
|
84 |
elif file_type in ["xls", "xlsx"]:
|
85 |
+
try:
|
86 |
+
df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
|
87 |
+
except Exception:
|
88 |
+
df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
|
89 |
+
content = df.fillna("").astype(str).values.tolist()
|
90 |
+
result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
|
91 |
else:
|
92 |
+
result = json.dumps({"error": f"Unsupported file type: {file_type}"})
|
93 |
+
with open(cache_path, "w", encoding="utf-8") as f:
|
94 |
+
f.write(result)
|
95 |
+
return result
|
96 |
+
except Exception as e:
|
97 |
+
return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
|
98 |
+
|
99 |
+
def log_system_usage(tag=""):
|
100 |
+
try:
|
101 |
+
cpu = psutil.cpu_percent(interval=1)
|
102 |
+
mem = psutil.virtual_memory()
|
103 |
+
print(f"[{tag}] CPU: {cpu}% | RAM: {mem.used // (1024**2)}MB / {mem.total // (1024**2)}MB")
|
104 |
+
result = subprocess.run(
|
105 |
+
["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"],
|
106 |
+
capture_output=True, text=True
|
107 |
+
)
|
108 |
+
if result.returncode == 0:
|
109 |
+
used, total, util = result.stdout.strip().split(", ")
|
110 |
+
print(f"[{tag}] GPU: {used}MB / {total}MB | Utilization: {util}%")
|
111 |
+
except Exception as e:
|
112 |
+
print(f"[{tag}] GPU/CPU monitor failed: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
+
def clean_response(text: str) -> str:
|
115 |
+
text = sanitize_utf8(text)
|
116 |
+
text = re.sub(r"\[TOOL_CALLS\].*", "", text, flags=re.DOTALL)
|
117 |
+
text = re.sub(r"\n{3,}", "\n\n", text).strip()
|
118 |
+
return text
|
119 |
+
|
120 |
+
def init_agent():
|
121 |
+
print("π Initializing model...")
|
122 |
+
log_system_usage("Before Load")
|
123 |
+
default_tool_path = os.path.abspath("data/new_tool.json")
|
124 |
+
target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
|
125 |
+
if not os.path.exists(target_tool_path):
|
126 |
+
shutil.copy(default_tool_path, target_tool_path)
|
127 |
+
|
128 |
+
agent = TxAgent(
|
129 |
+
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
|
130 |
+
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
|
131 |
+
tool_files_dict={"new_tool": target_tool_path},
|
132 |
+
force_finish=True,
|
133 |
+
enable_checker=True,
|
134 |
+
step_rag_num=4,
|
135 |
+
seed=100,
|
136 |
+
additional_default_tools=[],
|
137 |
+
)
|
138 |
+
agent.init_model()
|
139 |
+
log_system_usage("After Load")
|
140 |
+
print("β
Agent Ready")
|
141 |
+
return agent
|
142 |
+
|
143 |
+
def create_ui(agent):
|
144 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
145 |
gr.Markdown("<h1 style='text-align: center;'>π©Ί Clinical Oversight Assistant</h1>")
|
146 |
chatbot = gr.Chatbot(label="Analysis", height=600, type="messages")
|
147 |
file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
|
148 |
msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
|
149 |
send_btn = gr.Button("Analyze", variant="primary")
|
150 |
+
download_output = gr.File(label="Download Full Report")
|
151 |
+
|
152 |
+
def analyze(message: str, history: List[dict], files: List):
|
153 |
+
history.append({"role": "user", "content": message})
|
154 |
+
history.append({"role": "assistant", "content": "β³ Extracting text from files..."})
|
155 |
+
yield history, None
|
156 |
+
|
157 |
+
extracted = ""
|
158 |
+
file_hash_value = ""
|
159 |
+
if files:
|
160 |
+
# Progress callback for extraction
|
161 |
+
total_pages = 0
|
162 |
+
processed_pages = 0
|
163 |
+
def update_extraction_progress(current, total):
|
164 |
+
nonlocal processed_pages, total_pages
|
165 |
+
processed_pages = current
|
166 |
+
total_pages = total
|
167 |
+
animation = ["π", "π", "βοΈ", "π"][(int(time.time() * 2) % 4)]
|
168 |
+
history[-1] = {"role": "assistant", "content": f"Extracting text... {animation} Page {processed_pages}/{total_pages}"}
|
169 |
+
return history, None
|
170 |
+
|
171 |
+
with ThreadPoolExecutor(max_workers=6) as executor:
|
172 |
+
futures = [executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower(), update_extraction_progress) for f in files]
|
173 |
+
results = [sanitize_utf8(f.result()) for f in as_completed(futures)]
|
174 |
+
extracted = "\n".join(results)
|
175 |
+
file_hash_value = file_hash(files[0].name) if files else ""
|
176 |
+
|
177 |
+
history.pop() # Remove extraction message
|
178 |
+
history.append({"role": "assistant", "content": "β
Text extraction complete."})
|
179 |
+
yield history, None
|
180 |
+
|
181 |
+
# Split extracted text into chunks of ~6,000 characters
|
182 |
+
chunk_size = 6000
|
183 |
+
chunks = [extracted[i:i + chunk_size] for i in range(0, len(extracted), chunk_size)]
|
184 |
+
combined_response = ""
|
185 |
+
|
186 |
+
prompt_template = """
|
187 |
+
You are a medical analysis assistant. Analyze the following patient record excerpt for clinical oversights. Provide a concise, evidence-based summary in markdown format under these headings: Missed Diagnoses, Medication Conflicts, Incomplete Assessments, and Urgent Follow-up. For each finding, include:
|
188 |
+
- Clinical context (why the issue was missed or relevant details from the record).
|
189 |
+
- Potential risks if unaddressed (e.g., disease progression, adverse events).
|
190 |
+
- Actionable recommendations (e.g., tests, referrals, medication adjustments).
|
191 |
+
If no issues are found in a section, state "No issues identified." Ensure the output is specific to the provided text, formatted as markdown with bullet points under each heading, and avoids generic or static responses.
|
192 |
+
|
193 |
+
Patient Record Excerpt (Chunk {0} of {1}):
|
194 |
+
{chunk}
|
195 |
+
|
196 |
+
### Missed Diagnoses
|
197 |
+
- ...
|
198 |
+
|
199 |
+
### Medication Conflicts
|
200 |
+
- ...
|
201 |
+
|
202 |
+
### Incomplete Assessments
|
203 |
+
- ...
|
204 |
+
|
205 |
+
### Urgent Follow-up
|
206 |
+
- ...
|
207 |
+
"""
|
208 |
+
|
209 |
+
try:
|
210 |
+
# Process each chunk and stream results in real-time
|
211 |
+
for chunk_idx, chunk in enumerate(chunks, 1):
|
212 |
+
# Update UI with chunk progress
|
213 |
+
animation = ["π", "π", "π§ ", "π"][(int(time.time() * 2) % 4)]
|
214 |
+
history.append({"role": "assistant", "content": f"Analyzing records... {animation} Chunk {chunk_idx}/{len(chunks)}"})
|
215 |
+
yield history, None
|
216 |
+
|
217 |
+
prompt = prompt_template.format(chunk_idx, len(chunks), chunk=chunk[:4000]) # Truncate to avoid token limits
|
218 |
+
chunk_response = ""
|
219 |
+
for chunk_output in agent.run_gradio_chat(
|
220 |
+
message=prompt,
|
221 |
+
history=[],
|
222 |
+
temperature=0.2,
|
223 |
+
max_new_tokens=1024,
|
224 |
+
max_token=4096,
|
225 |
+
call_agent=False,
|
226 |
+
conversation=[],
|
227 |
+
):
|
228 |
+
if chunk_output is None:
|
229 |
+
continue
|
230 |
+
if isinstance(chunk_output, list):
|
231 |
+
for m in chunk_output:
|
232 |
+
if hasattr(m, 'content') and m.content:
|
233 |
+
cleaned = clean_response(m.content)
|
234 |
+
if cleaned:
|
235 |
+
chunk_response += cleaned + "\n"
|
236 |
+
# Update UI with partial response
|
237 |
+
if history[-1]["content"].startswith("Analyzing"):
|
238 |
+
history[-1] = {"role": "assistant", "content": f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response.strip()}"}
|
239 |
+
else:
|
240 |
+
history[-1]["content"] = f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response.strip()}"
|
241 |
+
yield history, None
|
242 |
+
elif isinstance(chunk_output, str) and chunk_output.strip():
|
243 |
+
cleaned = clean_response(chunk_output)
|
244 |
+
if cleaned:
|
245 |
+
chunk_response += cleaned + "\n"
|
246 |
+
# Update UI with partial response
|
247 |
+
if history[-1]["content"].startswith("Analyzing"):
|
248 |
+
history[-1] = {"role": "assistant", "content": f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response.strip()}"}
|
249 |
+
else:
|
250 |
+
history[-1]["content"] = f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response.strip()}"
|
251 |
+
yield history, None
|
252 |
+
|
253 |
+
# Append completed chunk response to combined response
|
254 |
+
combined_response += f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response}\n"
|
255 |
+
|
256 |
+
# Finalize UI with complete response
|
257 |
+
if combined_response:
|
258 |
+
history[-1]["content"] = combined_response.strip()
|
259 |
+
else:
|
260 |
+
history.append({"role": "assistant", "content": "No oversights identified."})
|
261 |
+
|
262 |
+
# Generate report file
|
263 |
+
report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None
|
264 |
+
if report_path:
|
265 |
+
with open(report_path, "w", encoding="utf-8") as f:
|
266 |
+
f.write(combined_response)
|
267 |
+
yield history, report_path if report_path and os.path.exists(report_path) else None
|
268 |
+
|
269 |
+
except Exception as e:
|
270 |
+
print("π¨ ERROR:", e)
|
271 |
+
history.append({"role": "assistant", "content": f"β Error occurred: {str(e)}"})
|
272 |
+
yield history, None
|
273 |
|
274 |
send_btn.click(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output])
|
275 |
msg_input.submit(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output])
|
|
|
277 |
|
278 |
if __name__ == "__main__":
|
279 |
print("π Launching app...")
|
280 |
+
agent = init_agent()
|
281 |
+
demo = create_ui(agent)
|
282 |
+
demo.queue(api_open=False).launch(
|
283 |
+
server_name="0.0.0.0",
|
284 |
+
server_port=7860,
|
285 |
+
show_error=True,
|
286 |
+
allowed_paths=[report_dir],
|
287 |
+
share=False
|
288 |
+
)
|
|
|
|