Ali2206 commited on
Commit
5f7a1a1
·
verified ·
1 Parent(s): 13560d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -248
app.py CHANGED
@@ -1,314 +1,242 @@
1
-
2
- import json
3
  import gradio as gr
4
- from typing import List, Optional
5
- from concurrent.futures import ThreadPoolExecutor, as_completed
6
  import hashlib
7
- import shutil
8
  import time
9
- from functools import lru_cache
10
- from threading import Thread
11
  import re
12
- import tempfile
13
-
14
- # Environment setup
15
- current_dir = os.path.dirname(os.path.abspath(__file__))
16
- src_path = os.path.abspath(os.path.join(current_dir, "src"))
17
- sys.path.insert(0, src_path)
18
-
19
- # Cache directories
20
- base_dir = "/data"
21
- os.makedirs(base_dir, exist_ok=True)
22
- model_cache_dir = os.path.join(base_dir, "txagent_models")
23
- tool_cache_dir = os.path.join(base_dir, "tool_cache")
24
- file_cache_dir = os.path.join(base_dir, "cache")
25
- report_dir = "/data/reports"
26
- vllm_cache_dir = os.path.join(base_dir, "vllm_cache")
27
-
28
- os.makedirs(model_cache_dir, exist_ok=True)
29
- os.makedirs(tool_cache_dir, exist_ok=True)
30
- os.makedirs(file_cache_dir, exist_ok=True)
31
- os.makedirs(report_dir, exist_ok=True)
32
- os.makedirs(vllm_cache_dir, exist_ok=True)
33
 
 
34
  os.environ.update({
35
- "TRANSFORMERS_CACHE": model_cache_dir,
36
- "HF_HOME": model_cache_dir,
37
- "VLLM_CACHE_DIR": vllm_cache_dir,
38
  "TOKENIZERS_PARALLELISM": "false",
39
  "CUDA_LAUNCH_BLOCKING": "1"
40
  })
41
 
42
- from txagent.txagent import TxAgent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- MEDICAL_KEYWORDS = {
45
- 'diagnosis', 'assessment', 'plan', 'results', 'medications',
46
- 'allergies', 'summary', 'impression', 'findings', 'recommendations'
47
- }
 
48
 
49
- def sanitize_utf8(text: str) -> str:
50
- return text.encode("utf-8", "ignore").decode("utf-8")
51
 
 
52
  def file_hash(path: str) -> str:
53
  with open(path, "rb") as f:
54
  return hashlib.md5(f.read()).hexdigest()
55
 
56
- def extract_priority_pages(file_path: str, max_pages: int = 20) -> str:
57
  try:
58
- text_chunks = []
59
  with pdfplumber.open(file_path) as pdf:
60
- for i, page in enumerate(pdf.pages[:3]):
61
- text_chunks.append(f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}")
62
- for i, page in enumerate(pdf.pages[3:max_pages], start=4):
63
- page_text = page.extract_text() or ""
64
- if any(re.search(rf'\b{kw}\b', page_text.lower()) for kw in MEDICAL_KEYWORDS):
65
- text_chunks.append(f"=== Page {i} ===\n{page_text.strip()}")
66
- return "\n\n".join(text_chunks)
67
  except Exception as e:
68
  return f"PDF processing error: {str(e)}"
69
 
70
- def convert_file_to_json(file_path: str, file_type: str) -> str:
71
  try:
72
  h = file_hash(file_path)
73
- cache_path = os.path.join(file_cache_dir, f"{h}.json")
 
74
  if os.path.exists(cache_path):
75
- return open(cache_path, "r", encoding="utf-8").read()
76
-
 
77
  if file_type == "pdf":
78
- text = extract_priority_pages(file_path)
79
- result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
80
- Thread(target=full_pdf_processing, args=(file_path, h)).start()
81
-
82
  elif file_type == "csv":
83
- df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str, skip_blank_lines=False, on_bad_lines="skip")
84
- content = df.fillna("").astype(str).values.tolist()
85
- result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
86
-
87
  elif file_type in ["xls", "xlsx"]:
88
- try:
89
- df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
90
- except:
91
- df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
92
- content = df.fillna("").astype(str).values.tolist()
93
- result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
94
-
95
  else:
96
  return json.dumps({"error": f"Unsupported file type: {file_type}"})
97
 
98
  with open(cache_path, "w", encoding="utf-8") as f:
99
  f.write(result)
100
  return result
101
-
102
- except Exception as e:
103
- return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
104
-
105
- def full_pdf_processing(file_path: str, file_hash: str):
106
- try:
107
- cache_path = os.path.join(file_cache_dir, f"{file_hash}_full.json")
108
- if os.path.exists(cache_path):
109
- return
110
- with pdfplumber.open(file_path) as pdf:
111
- full_text = "\n".join([f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}" for i, page in enumerate(pdf.pages)])
112
- result = json.dumps({"filename": os.path.basename(file_path), "content": full_text, "status": "complete"})
113
- with open(cache_path, "w", encoding="utf-8") as f:
114
- f.write(result)
115
- with open(os.path.join(report_dir, f"{file_hash}_report.txt"), "w", encoding="utf-8") as out:
116
- out.write(full_text)
117
  except Exception as e:
118
- print(f"Background processing failed: {str(e)}")
119
-
120
- def init_agent():
121
- default_tool_path = os.path.abspath("data/new_tool.json")
122
- target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
123
- if not os.path.exists(target_tool_path):
124
- shutil.copy(default_tool_path, target_tool_path)
125
-
126
- agent = TxAgent(
127
- model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
128
- rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
129
- tool_files_dict={"new_tool": target_tool_path},
130
- force_finish=True,
131
- enable_checker=True,
132
- step_rag_num=8,
133
- seed=100,
134
- additional_default_tools=[],
135
- )
136
- agent.init_model()
137
- return agent
138
 
139
  def format_response(response: str) -> str:
140
- """Clean and format the response for display"""
141
- # Remove all tool call artifacts
142
  response = response.replace("[TOOL_CALLS]", "").strip()
143
-
144
- # Remove duplicate sections if they exist
145
  if "Based on the medical records provided" in response:
146
  parts = response.split("Based on the medical records provided")
147
- if len(parts) > 1:
148
- response = "Based on the medical records provided" + parts[-1]
 
 
 
 
 
 
 
149
 
150
- # Format sections with Markdown
151
- formatted = response.replace("1. **Missed Diagnoses**:", "### 🔍 Missed Diagnoses")
152
- formatted = formatted.replace("2. **Medication Conflicts**:", "\n### 💊 Medication Conflicts")
153
- formatted = formatted.replace("3. **Incomplete Assessments**:", "\n### 📋 Incomplete Assessments")
154
- formatted = formatted.replace("4. **Abnormal Results Needing Follow-up**:", "\n### ⚠️ Abnormal Results Needing Follow-up")
155
- formatted = formatted.replace("Overall, the patient's medical records", "\n### 📝 Overall Assessment")
156
 
157
- return formatted
158
 
159
- def analyze_potential_oversights(message: str, history: list, conversation: list, files: list):
160
- start_time = time.time()
161
  try:
162
- # Initial loading message
163
- history = history + [
164
- {"role": "user", "content": message},
165
- {"role": "assistant", "content": "⏳ Analyzing records for potential oversights..."}
166
- ]
 
167
  yield history, None
168
-
169
- # Process uploaded files
170
  extracted_data = ""
171
- file_hash_value = ""
172
- if files and isinstance(files, list):
173
  with ThreadPoolExecutor(max_workers=4) as executor:
174
- futures = [executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower())
175
  for f in files if hasattr(f, 'name')]
176
- extracted_data = "\n".join([sanitize_utf8(f.result()) for f in as_completed(futures)])
177
- file_hash_value = file_hash(files[0].name) if files else ""
178
-
179
- # Prepare the analysis prompt
180
- analysis_prompt = f"""Review these medical records and identify EXACTLY what might have been missed:
181
- 1. List potential missed diagnoses
182
- 2. Flag any medication conflicts
183
- 3. Note incomplete assessments
184
- 4. Highlight abnormal results needing follow-up
185
-
186
- Medical Records:\n{extracted_data[:15000]}
187
 
188
- ### Potential Oversights:\n"""
 
 
 
 
189
 
190
- # Process the response from the agent
191
- full_response = ""
 
192
  for chunk in agent.run_gradio_chat(
193
- message=analysis_prompt,
194
  history=[],
195
  temperature=0.2,
196
- max_new_tokens=1024,
197
- max_token=4096,
198
- call_agent=False,
199
- conversation=conversation
200
  ):
201
  if isinstance(chunk, str):
202
- full_response += chunk
203
  elif isinstance(chunk, list):
204
- full_response += "".join([c.content for c in chunk if hasattr(c, 'content')])
205
-
206
- # Format and display the partial response
207
- formatted = format_response(full_response)
208
  if formatted.strip():
209
- history = history[:-1] + [{"role": "assistant", "content": formatted}]
210
  yield history, None
211
-
212
- # Final formatting and cleanup
213
- final_output = format_response(full_response)
214
- if not final_output.strip():
215
- final_output = "No clear oversights identified. Recommend comprehensive review."
216
-
217
- # Prepare report download if available
218
- report_path = None
219
- if file_hash_value:
220
- possible_report = os.path.join(report_dir, f"{file_hash_value}_report.txt")
221
- if os.path.exists(possible_report):
222
- report_path = possible_report
223
-
224
- # Update history with final response
225
- history = history[:-1] + [{"role": "assistant", "content": final_output}]
226
- yield history, report_path
227
-
228
- except Exception as e:
229
- history.append({"role": "assistant", "content": f"❌ Analysis failed: {str(e)}"})
230
  yield history, None
231
-
232
- def create_ui(agent: TxAgent):
233
- with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 900px !important}") as demo:
234
- gr.Markdown("""
235
- <div style='text-align: center;'>
236
- <h1>🩺 Clinical Oversight Assistant</h1>
237
- <h3>Identify potential oversights in patient care</h3>
238
- <p>Upload medical records to analyze for missed diagnoses, medication conflicts, and other potential issues.</p>
239
- </div>
240
- """)
241
-
242
- with gr.Row():
243
- with gr.Column(scale=2):
244
- file_upload = gr.File(
245
- label="Upload Medical Records",
246
- file_types=[".pdf", ".csv", ".xls", ".xlsx"],
247
- file_count="multiple",
248
- height=100
249
- )
250
- msg_input = gr.Textbox(
251
- placeholder="Ask about potential oversights...",
252
- show_label=False,
253
- lines=3,
254
- max_lines=6
255
- )
256
- send_btn = gr.Button("Analyze", variant="primary", size="lg")
257
-
258
- gr.Examples(
259
- examples=[
260
- ["What might have been missed in this patient's treatment?"],
261
- ["Are there any medication conflicts in these records?"],
262
- ["What abnormal results require follow-up?"],
263
- ["Identify any incomplete assessments in these records"]
264
- ],
265
- inputs=msg_input,
266
- label="Example Queries"
267
- )
268
-
269
- with gr.Column(scale=3):
270
- chatbot = gr.Chatbot(
271
- label="Analysis Results",
272
- height=600,
273
- show_copy_button=True,
274
- avatar_images=(
275
- "assets/user.png",
276
- "assets/doctor.png"
277
- )
278
- )
279
- download_output = gr.File(
280
- label="Download Full Report",
281
- visible=False
282
- )
283
-
284
- conversation_state = gr.State([])
285
-
286
- inputs = [msg_input, chatbot, conversation_state, file_upload]
287
- outputs = [chatbot, download_output]
288
 
289
- send_btn.click(
290
- analyze_potential_oversights,
291
- inputs=inputs,
292
- outputs=outputs
293
- )
294
- msg_input.submit(
295
- analyze_potential_oversights,
296
- inputs=inputs,
297
- outputs=outputs
298
- )
299
 
300
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
  if __name__ == "__main__":
303
- print("Initializing medical analysis agent...")
304
- agent = init_agent()
305
-
306
- print("Launching interface...")
307
- demo = create_ui(agent)
308
- demo.queue().launch(
309
  server_name="0.0.0.0",
310
  server_port=7860,
311
- show_error=True,
312
- allowed_paths=["/data/reports"],
313
- share=False
314
  )
 
1
+ import sys
2
+ import os
3
  import gradio as gr
4
+ from typing import List
 
5
  import hashlib
 
6
  import time
7
+ import json
 
8
  import re
9
+ from concurrent.futures import ThreadPoolExecutor, as_completed
10
+ from threading import Thread
11
+ import pandas as pd
12
+ import pdfplumber
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # Optimized environment setup
15
  os.environ.update({
16
+ "HF_HOME": "/data/hf_cache",
17
+ "VLLM_CACHE_DIR": "/data/vllm_cache",
 
18
  "TOKENIZERS_PARALLELISM": "false",
19
  "CUDA_LAUNCH_BLOCKING": "1"
20
  })
21
 
22
+ # Create cache directories if they don't exist
23
+ os.makedirs("/data/hf_cache", exist_ok=True)
24
+ os.makedirs("/data/tool_cache", exist_ok=True)
25
+ os.makedirs("/data/file_cache", exist_ok=True)
26
+ os.makedirs("/data/reports", exist_ok=True)
27
+ os.makedirs("/data/vllm_cache", exist_ok=True)
28
+
29
+ # Lazy loading of heavy dependencies
30
+ def lazy_load_agent():
31
+ from txagent.txagent import TxAgent
32
+
33
+ # Initialize agent with optimized settings
34
+ agent = TxAgent(
35
+ model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
36
+ rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
37
+ tool_files_dict={"new_tool": "/data/tool_cache/new_tool.json"},
38
+ force_finish=True,
39
+ enable_checker=True,
40
+ step_rag_num=8,
41
+ seed=100,
42
+ additional_default_tools=[],
43
+ )
44
+ agent.init_model()
45
+ return agent
46
 
47
+ # Pre-load the agent in a separate thread
48
+ agent = None
49
+ def preload_agent():
50
+ global agent
51
+ agent = lazy_load_agent()
52
 
53
+ Thread(target=preload_agent).start()
 
54
 
55
+ # File processing functions
56
  def file_hash(path: str) -> str:
57
  with open(path, "rb") as f:
58
  return hashlib.md5(f.read()).hexdigest()
59
 
60
+ def extract_priority_pages(file_path: str, max_pages: int = 10) -> str:
61
  try:
 
62
  with pdfplumber.open(file_path) as pdf:
63
+ return "\n\n".join(
64
+ f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}"
65
+ for i, page in enumerate(pdf.pages[:max_pages])
66
+ )
 
 
 
67
  except Exception as e:
68
  return f"PDF processing error: {str(e)}"
69
 
70
+ def process_file(file_path: str, file_type: str) -> str:
71
  try:
72
  h = file_hash(file_path)
73
+ cache_path = f"/data/file_cache/{h}.json"
74
+
75
  if os.path.exists(cache_path):
76
+ with open(cache_path, "r", encoding="utf-8") as f:
77
+ return f.read()
78
+
79
  if file_type == "pdf":
80
+ content = extract_priority_pages(file_path)
81
+ result = json.dumps({"filename": os.path.basename(file_path), "content": content})
 
 
82
  elif file_type == "csv":
83
+ df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str)
84
+ result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna("").values.tolist()})
 
 
85
  elif file_type in ["xls", "xlsx"]:
86
+ df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
87
+ result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna("").values.tolist()})
 
 
 
 
 
88
  else:
89
  return json.dumps({"error": f"Unsupported file type: {file_type}"})
90
 
91
  with open(cache_path, "w", encoding="utf-8") as f:
92
  f.write(result)
93
  return result
94
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  except Exception as e:
96
+ return json.dumps({"error": str(e)})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  def format_response(response: str) -> str:
 
 
99
  response = response.replace("[TOOL_CALLS]", "").strip()
 
 
100
  if "Based on the medical records provided" in response:
101
  parts = response.split("Based on the medical records provided")
102
+ response = "Based on the medical records provided" + parts[-1]
103
+
104
+ replacements = {
105
+ "1. **Missed Diagnoses**:": "### 🔍 Missed Diagnoses",
106
+ "2. **Medication Conflicts**:": "\n### 💊 Medication Conflicts",
107
+ "3. **Incomplete Assessments**:": "\n### 📋 Incomplete Assessments",
108
+ "4. **Abnormal Results Needing Follow-up**:": "\n### ⚠️ Abnormal Results Needing Follow-up",
109
+ "Overall, the patient's medical records": "\n### 📝 Overall Assessment"
110
+ }
111
 
112
+ for old, new in replacements.items():
113
+ response = response.replace(old, new)
 
 
 
 
114
 
115
+ return response
116
 
117
+ def analyze_files(message: str, history: List, files: List):
 
118
  try:
119
+ # Wait for agent to load if not ready
120
+ while agent is None:
121
+ time.sleep(0.1)
122
+
123
+ # Append user message to history in correct format
124
+ history.append([message, None])
125
  yield history, None
126
+
127
+ # Process files in parallel
128
  extracted_data = ""
129
+ if files:
 
130
  with ThreadPoolExecutor(max_workers=4) as executor:
131
+ futures = [executor.submit(process_file, f.name, f.name.split(".")[-1].lower())
132
  for f in files if hasattr(f, 'name')]
133
+ extracted_data = "\n".join(f.result() for f in as_completed(futures))
134
+
135
+ prompt = f"""Review these medical records:
136
+ {extracted_data[:10000]}
 
 
 
 
 
 
 
137
 
138
+ Identify:
139
+ 1. Potential missed diagnoses
140
+ 2. Medication conflicts
141
+ 3. Incomplete assessments
142
+ 4. Abnormal results needing follow-up
143
 
144
+ Analysis:"""
145
+
146
+ response = ""
147
  for chunk in agent.run_gradio_chat(
148
+ message=prompt,
149
  history=[],
150
  temperature=0.2,
151
+ max_new_tokens=800,
152
+ max_token=3000
 
 
153
  ):
154
  if isinstance(chunk, str):
155
+ response += chunk
156
  elif isinstance(chunk, list):
157
+ response += "".join(getattr(c, 'content', '') for c in chunk)
158
+
159
+ formatted = format_response(response)
 
160
  if formatted.strip():
161
+ history[-1][1] = formatted
162
  yield history, None
163
+
164
+ final_output = format_response(response) or "No clear oversights identified."
165
+ history[-1][1] = final_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  yield history, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
+ except Exception as e:
169
+ history[-1][1] = f"❌ Error: {str(e)}"
170
+ yield history, None
 
 
 
 
 
 
 
171
 
172
+ # Create optimized UI with better layout
173
+ with gr.Blocks(title="Clinical Oversight Assistant", css="""
174
+ .gradio-container {
175
+ max-width: 1200px !important;
176
+ margin: auto;
177
+ }
178
+ .container {
179
+ max-width: 1200px !important;
180
+ }
181
+ .chatbot {
182
+ min-height: 500px;
183
+ }
184
+ """) as demo:
185
+ gr.Markdown("""
186
+ <div style='text-align: center; margin-bottom: 20px;'>
187
+ <h1 style='margin-bottom: 10px;'>🩺 Clinical Oversight Assistant</h1>
188
+ <p>Upload medical records to analyze for potential oversights in patient care</p>
189
+ </div>
190
+ """)
191
+
192
+ with gr.Row():
193
+ with gr.Column(scale=1, min_width=400):
194
+ file_upload = gr.File(
195
+ label="Upload Medical Records",
196
+ file_types=[".pdf", ".csv", ".xls", ".xlsx"],
197
+ file_count="multiple",
198
+ height=100
199
+ )
200
+ query = gr.Textbox(
201
+ label="Your Query",
202
+ placeholder="Ask about potential oversights...",
203
+ lines=3
204
+ )
205
+ submit = gr.Button("Analyze", variant="primary")
206
+
207
+ gr.Examples(
208
+ examples=[
209
+ ["What potential diagnoses might have been missed?"],
210
+ ["Are there any medication conflicts I should be aware of?"],
211
+ ["What assessments appear incomplete in these records?"]
212
+ ],
213
+ inputs=query,
214
+ label="Example Queries"
215
+ )
216
+
217
+ with gr.Column(scale=2, min_width=600):
218
+ chatbot = gr.Chatbot(
219
+ label="Analysis Results",
220
+ height=600,
221
+ bubble_full_width=False,
222
+ show_copy_button=True
223
+ )
224
+
225
+ submit.click(
226
+ analyze_files,
227
+ inputs=[query, chatbot, file_upload],
228
+ outputs=[chatbot, gr.File(visible=False)]
229
+ )
230
+
231
+ query.submit(
232
+ analyze_files,
233
+ inputs=[query, chatbot, file_upload],
234
+ outputs=[chatbot, gr.File(visible=False)]
235
+ )
236
 
237
  if __name__ == "__main__":
238
+ demo.queue(concurrency_count=1).launch(
 
 
 
 
 
239
  server_name="0.0.0.0",
240
  server_port=7860,
241
+ show_error=True
 
 
242
  )