Ali2206 commited on
Commit
abc4511
·
verified ·
1 Parent(s): 6b734c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +223 -155
app.py CHANGED
@@ -1,192 +1,260 @@
1
  import sys
2
  import os
 
 
 
3
  import gradio as gr
 
 
4
  import hashlib
 
5
  import time
6
- import json
7
- from concurrent.futures import ThreadPoolExecutor, as_completed
8
- import pandas as pd
9
- import pdfplumber
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- # Set up environment
12
  os.environ.update({
13
- "HF_HOME": "/data/hf_cache",
14
- "TOKENIZERS_PARALLELISM": "false"
 
 
 
15
  })
16
 
17
- # Create cache directories
18
- os.makedirs("/data/hf_cache", exist_ok=True)
19
- os.makedirs("/data/file_cache", exist_ok=True)
20
- os.makedirs("/data/reports", exist_ok=True)
21
-
22
- # Import TxAgent after setting up environment
23
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src")))
24
  from txagent.txagent import TxAgent
25
 
26
- # Initialize agent with error handling
27
- try:
28
- agent = TxAgent(
29
- model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
30
- rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
31
- tool_files_dict={"new_tool": "/data/tool_cache/new_tool.json"},
32
- force_finish=True,
33
- enable_checker=True,
34
- step_rag_num=8,
35
- seed=100
36
- )
37
- agent.init_model()
38
- except Exception as e:
39
- print(f"Failed to initialize agent: {str(e)}")
40
- agent = None
41
 
42
  def file_hash(path: str) -> str:
43
  with open(path, "rb") as f:
44
  return hashlib.md5(f.read()).hexdigest()
45
 
46
- def extract_text_from_pdf(file_path: str, max_pages: int = 10) -> str:
47
  try:
 
48
  with pdfplumber.open(file_path) as pdf:
49
- return "\n".join(
50
- f"Page {i+1}:\n{(page.extract_text() or '').strip()}\n"
51
- for i, page in enumerate(pdf.pages[:max_pages])
52
- )
 
 
 
53
  except Exception as e:
54
- return f"PDF error: {str(e)}"
55
 
56
- def process_file(file_path: str, file_type: str) -> str:
57
  try:
58
- cache_path = f"/data/file_cache/{file_hash(file_path)}.json"
 
59
  if os.path.exists(cache_path):
60
- with open(cache_path, "r") as f:
61
- return f.read()
62
-
63
  if file_type == "pdf":
64
- content = extract_text_from_pdf(file_path)
 
 
 
65
  elif file_type == "csv":
66
- df = pd.read_csv(file_path, header=None, dtype=str, on_bad_lines="skip")
67
- content = df.fillna("").to_string()
 
 
68
  elif file_type in ["xls", "xlsx"]:
69
- df = pd.read_excel(file_path, header=None, dtype=str)
70
- content = df.fillna("").to_string()
 
 
 
 
 
71
  else:
72
- return json.dumps({"error": "Unsupported file type"})
73
 
74
- result = json.dumps({"filename": os.path.basename(file_path), "content": content})
75
- with open(cache_path, "w") as f:
76
  f.write(result)
77
  return result
 
78
  except Exception as e:
79
- return json.dumps({"error": str(e)})
80
-
81
- def format_response(response: str) -> str:
82
- response = response.replace("[TOOL_CALLS]", "").strip()
83
- sections = {
84
- "1. **Missed Diagnoses**:": "🔍 Missed Diagnoses",
85
- "2. **Medication Conflicts**:": "💊 Medication Conflicts",
86
- "3. **Incomplete Assessments**:": "📋 Incomplete Assessments",
87
- "4. **Abnormal Results Needing Follow-up**:": "⚠️ Abnormal Results"
88
- }
89
- for old, new in sections.items():
90
- response = response.replace(old, f"\n### {new}\n")
91
- return response
92
-
93
- def analyze(message: str, history: list, files: list):
94
- if agent is None:
95
- yield history + [(message, "Agent initialization failed. Please try again later.")], None
96
- return
97
-
98
- history.append((message, None))
99
- yield history, None
100
-
101
  try:
102
- extracted_data = ""
103
- if files:
104
- with ThreadPoolExecutor() as executor:
105
- futures = [executor.submit(process_file, f.name, f.name.split(".")[-1])
106
- for f in files if hasattr(f, 'name')]
107
- extracted_data = "\n".join(f.result() for f in as_completed(futures))
108
-
109
- prompt = f"""Review these medical records:
110
- {extracted_data[:10000]}
111
-
112
- Identify potential issues:
113
- 1. Missed diagnoses
114
- 2. Medication conflicts
115
- 3. Incomplete assessments
116
- 4. Abnormal results needing follow-up
117
-
118
- Analysis:"""
119
-
120
- response = ""
121
- for chunk in agent.run_gradio_chat(
122
- message=prompt,
123
- history=[],
124
- temperature=0.2,
125
- max_new_tokens=800
126
- ):
127
- if isinstance(chunk, str):
128
- response += chunk
129
- elif isinstance(chunk, list):
130
- response += "".join(getattr(c, 'content', '') for c in chunk)
131
-
132
- history[-1] = (message, format_response(response))
133
- yield history, None
134
-
135
- history[-1] = (message, format_response(response))
136
- yield history, None
137
-
138
  except Exception as e:
139
- history[-1] = (message, f" Error: {str(e)}")
140
- yield history, None
141
-
142
- # Create the interface
143
- with gr.Blocks(
144
- title="Clinical Oversight Assistant",
145
- css="""
146
- .gradio-container {
147
- max-width: 1000px;
148
- margin: auto;
149
- }
150
- .chatbot {
151
- min-height: 500px;
152
- }
153
- """
154
- ) as demo:
155
- gr.Markdown("# 🩺 Clinical Oversight Assistant")
156
-
157
- with gr.Row():
158
- with gr.Column(scale=1):
159
- files = gr.File(
160
- label="Upload Medical Records",
161
- file_types=[".pdf", ".csv", ".xlsx"],
162
- file_count="multiple"
163
- )
164
- query = gr.Textbox(
165
- label="Your Query",
166
- placeholder="Ask about potential oversights..."
167
- )
168
- submit = gr.Button("Analyze", variant="primary")
169
-
170
- with gr.Column(scale=2):
171
- chatbot = gr.Chatbot(
172
- label="Analysis Results",
173
- show_copy_button=True
174
- )
175
-
176
- submit.click(
177
- analyze,
178
- inputs=[query, chatbot, files],
179
- outputs=[chatbot, gr.File(visible=False)]
180
- )
181
- query.submit(
182
- analyze,
183
- inputs=[query, chatbot, files],
184
- outputs=[chatbot, gr.File(visible=False)]
185
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  if __name__ == "__main__":
188
- demo.launch(
 
 
 
 
 
189
  server_name="0.0.0.0",
190
  server_port=7860,
191
- show_error=True
192
- )
 
 
 
1
  import sys
2
  import os
3
+ import pandas as pd
4
+ import pdfplumber
5
+ import json
6
  import gradio as gr
7
+ from typing import List, Optional
8
+ from concurrent.futures import ThreadPoolExecutor, as_completed
9
  import hashlib
10
+ import shutil
11
  import time
12
+ from functools import lru_cache
13
+ from threading import Thread
14
+ import re
15
+ import tempfile
16
+
17
+ # Environment setup
18
+ current_dir = os.path.dirname(os.path.abspath(__file__))
19
+ src_path = os.path.abspath(os.path.join(current_dir, "src"))
20
+ sys.path.insert(0, src_path)
21
+
22
+ # Cache directories
23
+ base_dir = "/data"
24
+ os.makedirs(base_dir, exist_ok=True)
25
+ model_cache_dir = os.path.join(base_dir, "txagent_models")
26
+ tool_cache_dir = os.path.join(base_dir, "tool_cache")
27
+ file_cache_dir = os.path.join(base_dir, "cache")
28
+ report_dir = "/data/reports"
29
+ vllm_cache_dir = os.path.join(base_dir, "vllm_cache")
30
+
31
+ os.makedirs(model_cache_dir, exist_ok=True)
32
+ os.makedirs(tool_cache_dir, exist_ok=True)
33
+ os.makedirs(file_cache_dir, exist_ok=True)
34
+ os.makedirs(report_dir, exist_ok=True)
35
+ os.makedirs(vllm_cache_dir, exist_ok=True)
36
 
 
37
  os.environ.update({
38
+ "TRANSFORMERS_CACHE": model_cache_dir,
39
+ "HF_HOME": model_cache_dir,
40
+ "VLLM_CACHE_DIR": vllm_cache_dir,
41
+ "TOKENIZERS_PARALLELISM": "false",
42
+ "CUDA_LAUNCH_BLOCKING": "1"
43
  })
44
 
 
 
 
 
 
 
 
45
  from txagent.txagent import TxAgent
46
 
47
+ MEDICAL_KEYWORDS = {
48
+ 'diagnosis', 'assessment', 'plan', 'results', 'medications',
49
+ 'allergies', 'summary', 'impression', 'findings', 'recommendations'
50
+ }
51
+
52
+ def sanitize_utf8(text: str) -> str:
53
+ return text.encode("utf-8", "ignore").decode("utf-8")
 
 
 
 
 
 
 
 
54
 
55
  def file_hash(path: str) -> str:
56
  with open(path, "rb") as f:
57
  return hashlib.md5(f.read()).hexdigest()
58
 
59
+ def extract_priority_pages(file_path: str, max_pages: int = 20) -> str:
60
  try:
61
+ text_chunks = []
62
  with pdfplumber.open(file_path) as pdf:
63
+ for i, page in enumerate(pdf.pages[:3]):
64
+ text_chunks.append(f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}")
65
+ for i, page in enumerate(pdf.pages[3:max_pages], start=4):
66
+ page_text = page.extract_text() or ""
67
+ if any(re.search(rf'\b{kw}\b', page_text.lower()) for kw in MEDICAL_KEYWORDS):
68
+ text_chunks.append(f"=== Page {i} ===\n{page_text.strip()}")
69
+ return "\n\n".join(text_chunks)
70
  except Exception as e:
71
+ return f"PDF processing error: {str(e)}"
72
 
73
+ def convert_file_to_json(file_path: str, file_type: str) -> str:
74
  try:
75
+ h = file_hash(file_path)
76
+ cache_path = os.path.join(file_cache_dir, f"{h}.json")
77
  if os.path.exists(cache_path):
78
+ return open(cache_path, "r", encoding="utf-8").read()
79
+
 
80
  if file_type == "pdf":
81
+ text = extract_priority_pages(file_path)
82
+ result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
83
+ Thread(target=full_pdf_processing, args=(file_path, h)).start()
84
+
85
  elif file_type == "csv":
86
+ df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str, skip_blank_lines=False, on_bad_lines="skip")
87
+ content = df.fillna("").astype(str).values.tolist()
88
+ result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
89
+
90
  elif file_type in ["xls", "xlsx"]:
91
+ try:
92
+ df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
93
+ except Exception:
94
+ df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
95
+ content = df.fillna("").astype(str).values.tolist()
96
+ result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
97
+
98
  else:
99
+ return json.dumps({"error": f"Unsupported file type: {file_type}"})
100
 
101
+ with open(cache_path, "w", encoding="utf-8") as f:
 
102
  f.write(result)
103
  return result
104
+
105
  except Exception as e:
106
+ return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
107
+
108
+ def full_pdf_processing(file_path: str, file_hash: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  try:
110
+ cache_path = os.path.join(file_cache_dir, f"{file_hash}_full.json")
111
+ if os.path.exists(cache_path):
112
+ return
113
+ with pdfplumber.open(file_path) as pdf:
114
+ full_text = "\n".join([f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}" for i, page in enumerate(pdf.pages)])
115
+ result = json.dumps({"filename": os.path.basename(file_path), "content": full_text, "status": "complete"})
116
+ with open(cache_path, "w", encoding="utf-8") as f:
117
+ f.write(result)
118
+ with open(os.path.join(report_dir, f"{file_hash}_report.txt"), "w", encoding="utf-8") as out:
119
+ out.write(full_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  except Exception as e:
121
+ print(f"Background processing failed: {str(e)}")
122
+
123
+ def init_agent():
124
+ default_tool_path = os.path.abspath("data/new_tool.json")
125
+ target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
126
+ if not os.path.exists(target_tool_path):
127
+ shutil.copy(default_tool_path, target_tool_path)
128
+
129
+ agent = TxAgent(
130
+ model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
131
+ rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
132
+ tool_files_dict={"new_tool": target_tool_path},
133
+ force_finish=True,
134
+ enable_checker=True,
135
+ step_rag_num=8,
136
+ seed=100,
137
+ additional_default_tools=[],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  )
139
+ agent.init_model()
140
+ return agent
141
+
142
+ def create_ui(agent: TxAgent):
143
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
144
+ gr.Markdown("""
145
+ <h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>
146
+ <h3 style='text-align: center;'>Identify potential oversights in patient care</h3>
147
+ """)
148
+
149
+ chatbot = gr.Chatbot(label="Analysis", height=600, type="messages")
150
+ file_upload = gr.File(label="Upload Medical Records", file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
151
+ msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
152
+ send_btn = gr.Button("Analyze", variant="primary")
153
+ conversation_state = gr.State([])
154
+ download_output = gr.File(label="Download Full Report")
155
+
156
+ def analyze_potential_oversights(message: str, history: list, conversation: list, files: list):
157
+ start_time = time.time()
158
+ try:
159
+ # Add initial user and temporary assistant messages to update UI immediately
160
+ history = history + [
161
+ {"role": "user", "content": message},
162
+ {"role": "assistant", "content": "⏳ Analyzing records for potential oversights..."}
163
+ ]
164
+ yield history, None
165
+
166
+ extracted_data = ""
167
+ file_hash_value = ""
168
+ if files and isinstance(files, list):
169
+ with ThreadPoolExecutor(max_workers=4) as executor:
170
+ futures = [
171
+ executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower())
172
+ for f in files if hasattr(f, 'name')
173
+ ]
174
+ extracted_data = "\n".join([sanitize_utf8(f.result()) for f in as_completed(futures)])
175
+ file_hash_value = file_hash(files[0].name) if hasattr(files[0], 'name') else ""
176
+
177
+ # Truncate extracted data to reduce overall token count (tune the character limit as needed)
178
+ max_extracted_chars = 12000
179
+ truncated_data = extracted_data[:max_extracted_chars]
180
+
181
+ analysis_prompt = f"""Review these medical records and identify EXACTLY what might have been missed:
182
+ 1. List potential missed diagnoses
183
+ 2. Flag any medication conflicts
184
+ 3. Note incomplete assessments
185
+ 4. Highlight abnormal results needing follow-up
186
+
187
+ Medical Records:
188
+ {truncated_data}
189
+
190
+ ### Potential Oversights:
191
+ """
192
+ response = ""
193
+ try:
194
+ # Stream the agent responses; skip any None chunks
195
+ for chunk in agent.run_gradio_chat(
196
+ message=analysis_prompt,
197
+ history=[],
198
+ temperature=0.2,
199
+ max_new_tokens=1024,
200
+ max_token=4096,
201
+ call_agent=False,
202
+ conversation=conversation
203
+ ):
204
+ if chunk is None:
205
+ continue
206
+ if isinstance(chunk, str):
207
+ response += chunk
208
+ elif isinstance(chunk, list):
209
+ response += "".join([c.content for c in chunk if hasattr(c, 'content')])
210
+ # Yield partial response updates
211
+ cleaned = response.replace("[TOOL_CALLS]", "").strip()
212
+ yield history[:-1] + [{"role": "assistant", "content": cleaned}], None
213
+ except Exception as agent_error:
214
+ history.append({"role": "assistant", "content": f"❌ Analysis failed during processing: {str(agent_error)}"})
215
+ yield history, None
216
+ return
217
+
218
+ final_output = response.replace("[TOOL_CALLS]", "").strip()
219
+ if not final_output:
220
+ final_output = "No clear oversights identified. Recommend comprehensive review."
221
+
222
+ report_path = None
223
+ if file_hash_value:
224
+ possible_report = os.path.join(report_dir, f"{file_hash_value}_report.txt")
225
+ if os.path.exists(possible_report):
226
+ report_path = possible_report
227
+
228
+ history = history[:-1] + [{"role": "assistant", "content": final_output}]
229
+ yield history, report_path
230
+
231
+ except Exception as e:
232
+ history.append({"role": "assistant", "content": f"❌ Analysis failed: {str(e)}"})
233
+ yield history, None
234
+
235
+ inputs = [msg_input, chatbot, conversation_state, file_upload]
236
+ outputs = [chatbot, download_output]
237
+ send_btn.click(analyze_potential_oversights, inputs=inputs, outputs=outputs)
238
+ msg_input.submit(analyze_potential_oversights, inputs=inputs, outputs=outputs)
239
+
240
+ gr.Examples([
241
+ ["What might have been missed in this patient's treatment?"],
242
+ ["Are there any medication conflicts in these records?"],
243
+ ["What abnormal results require follow-up?"]
244
+ ], inputs=msg_input)
245
+
246
+ return demo
247
 
248
  if __name__ == "__main__":
249
+ print("Initializing medical analysis agent...")
250
+ agent = init_agent()
251
+
252
+ print("Launching interface...")
253
+ demo = create_ui(agent)
254
+ demo.queue(api_open=False).launch(
255
  server_name="0.0.0.0",
256
  server_port=7860,
257
+ show_error=True,
258
+ allowed_paths=["/data/reports"],
259
+ share=False
260
+ )