Ali2206 commited on
Commit
7323cb6
ยท
verified ยท
1 Parent(s): 4b24a59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +254 -88
app.py CHANGED
@@ -1,155 +1,321 @@
1
- # โœ… Fully optimized app.py for Hugging Face Space with persistent 150GB storage
2
-
 
 
 
 
 
3
  from concurrent.futures import ThreadPoolExecutor, as_completed
 
 
 
 
4
  from threading import Thread
 
 
 
 
 
 
 
5
 
6
- # Use /data for persistent HF storage
7
  base_dir = "/data"
 
8
  model_cache_dir = os.path.join(base_dir, "txagent_models")
9
  tool_cache_dir = os.path.join(base_dir, "tool_cache")
10
  file_cache_dir = os.path.join(base_dir, "cache")
11
- report_dir = os.path.join(base_dir, "reports")
12
  vllm_cache_dir = os.path.join(base_dir, "vllm_cache")
13
 
14
- for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
15
- os.makedirs(d, exist_ok=True)
 
 
 
16
 
17
- # Set persistent HF + VLLM cache
18
  os.environ.update({
19
- "HF_HOME": model_cache_dir,
20
  "TRANSFORMERS_CACHE": model_cache_dir,
 
21
  "VLLM_CACHE_DIR": vllm_cache_dir,
22
  "TOKENIZERS_PARALLELISM": "false",
23
  "CUDA_LAUNCH_BLOCKING": "1"
24
  })
25
 
26
- # Force local loading only
27
- LOCAL_TXAGENT_PATH = os.path.join(model_cache_dir, "mims-harvard", "TxAgent-T1-Llama-3.1-8B")
28
- LOCAL_RAG_PATH = os.path.join(model_cache_dir, "mims-harvard", "ToolRAG-T1-GTE-Qwen2-1.5B")
29
-
30
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src")))
31
  from txagent.txagent import TxAgent
32
 
33
- def file_hash(path): return hashlib.md5(open(path, "rb").read()).hexdigest()
34
- def sanitize_utf8(text): return text.encode("utf-8", "ignore").decode("utf-8")
35
- MEDICAL_KEYWORDS = {"diagnosis", "assessment", "plan", "results", "medications", "summary", "findings"}
 
36
 
37
- def extract_priority_pages(file_path, max_pages=20):
 
 
 
 
 
 
 
38
  try:
 
39
  with pdfplumber.open(file_path) as pdf:
40
- pages = []
41
  for i, page in enumerate(pdf.pages[:3]):
42
- pages.append(f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}")
43
  for i, page in enumerate(pdf.pages[3:max_pages], start=4):
44
- text = page.extract_text() or ""
45
- if any(re.search(rf'\\b{kw}\\b', text.lower()) for kw in MEDICAL_KEYWORDS):
46
- pages.append(f"=== Page {i} ===\n{text.strip()}")
47
- return "\n\n".join(pages)
48
  except Exception as e:
49
  return f"PDF processing error: {str(e)}"
50
 
51
- def convert_file_to_json(file_path, file_type):
52
  try:
53
  h = file_hash(file_path)
54
  cache_path = os.path.join(file_cache_dir, f"{h}.json")
55
- if os.path.exists(cache_path): return open(cache_path, "r", encoding="utf-8").read()
 
56
 
57
  if file_type == "pdf":
58
  text = extract_priority_pages(file_path)
59
  result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
60
  Thread(target=full_pdf_processing, args=(file_path, h)).start()
 
61
  elif file_type == "csv":
62
- df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str)
63
- result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna('').astype(str).values.tolist()})
 
 
64
  elif file_type in ["xls", "xlsx"]:
65
- df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
66
- result = json.dumps({"filename": os.path.basename(file_path), "rows": df.fillna('').astype(str).values.tolist()})
 
 
 
 
 
67
  else:
68
  return json.dumps({"error": f"Unsupported file type: {file_type}"})
69
 
70
- with open(cache_path, "w", encoding="utf-8") as f: f.write(result)
 
71
  return result
 
72
  except Exception as e:
73
- return json.dumps({"error": str(e)})
74
 
75
- def full_pdf_processing(file_path, h):
76
  try:
77
- cache_path = os.path.join(file_cache_dir, f"{h}_full.json")
78
- if os.path.exists(cache_path): return
 
79
  with pdfplumber.open(file_path) as pdf:
80
- full_text = "\n".join([f"=== Page {i+1} ===\n{(p.extract_text() or '').strip()}" for i, p in enumerate(pdf.pages)])
81
- with open(cache_path, "w", encoding="utf-8") as f: f.write(json.dumps({"content": full_text}))
82
- except: pass
 
 
 
 
 
83
 
84
  def init_agent():
 
85
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
86
  if not os.path.exists(target_tool_path):
87
- shutil.copy(os.path.abspath("data/new_tool.json"), target_tool_path)
88
 
89
  agent = TxAgent(
90
- model_name=LOCAL_TXAGENT_PATH,
91
- rag_model_name=LOCAL_RAG_PATH,
92
  tool_files_dict={"new_tool": target_tool_path},
93
  force_finish=True,
94
  enable_checker=True,
95
  step_rag_num=8,
96
- seed=100
 
97
  )
98
  agent.init_model()
99
  return agent
100
 
101
- agent_container = {"agent": None}
102
- def get_agent():
103
- if agent_container["agent"] is None:
104
- agent_container["agent"] = init_agent()
105
- return agent_container["agent"]
106
-
107
- def create_ui():
108
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
109
- gr.Markdown("""<h1 style='text-align:center;'>๐Ÿฉบ Clinical Oversight Assistant</h1>""")
110
- chatbot = gr.Chatbot(label="Analysis", height=600)
111
- msg_input = gr.Textbox(placeholder="Ask a question about the patient...")
112
- file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
113
- send_btn = gr.Button("Analyze", variant="primary")
114
- state = gr.State([])
115
-
116
- def analyze(message, history, conversation, files):
117
- try:
118
- extracted, hval = "", ""
119
- if files:
120
- with ThreadPoolExecutor(max_workers=3) as pool:
121
- futures = [pool.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower()) for f in files]
122
- extracted = "\n".join([sanitize_utf8(f.result()) for f in as_completed(futures)])
123
- hval = file_hash(files[0].name)
124
-
125
- prompt = f"""Review these medical records and identify exactly what might have been missed:
126
- 1. Missed diagnoses
127
- 2. Medication conflicts
128
- 3. Incomplete assessments
129
- 4. Abnormal results needing follow-up
130
-
131
- Medical Records:\n{extracted[:15000]}
132
- """
133
- final_response = ""
134
- for chunk in get_agent().run_gradio_chat(prompt, history=[], temperature=0.2, max_new_tokens=1024, max_token=4096, call_agent=False, conversation=conversation):
135
- if isinstance(chunk, str): final_response += chunk
136
- elif isinstance(chunk, list): final_response += "".join([c.content for c in chunk if hasattr(c, 'content')])
137
- cleaned = final_response.replace("[TOOL_CALLS]", "").strip()
138
- updated_history = history + [[message, cleaned]]
139
- return updated_history, None
140
- except Exception as e:
141
- return history + [[message, f"โŒ Error: {str(e)}"]], None
142
-
143
- send_btn.click(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, gr.File()])
144
- msg_input.submit(analyze, inputs=[msg_input, chatbot, state, file_upload], outputs=[chatbot, gr.File()])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  return demo
146
 
147
  if __name__ == "__main__":
148
- ui = create_ui()
149
- ui.queue(api_open=False).launch(
 
 
 
 
 
 
 
150
  server_name="0.0.0.0",
151
  server_port=7860,
152
  show_error=True,
153
  allowed_paths=["/data/reports"],
154
  share=False
155
- )
 
1
+ import sys
2
+ import os
3
+ import pandas as pd
4
+ import pdfplumber
5
+ import json
6
+ import gradio as gr
7
+ from typing import List, Optional
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
+ import hashlib
10
+ import shutil
11
+ import time
12
+ from functools import lru_cache
13
  from threading import Thread
14
+ import re
15
+ import tempfile
16
+
17
+ # Environment setup
18
+ current_dir = os.path.dirname(os.path.abspath(__file__))
19
+ src_path = os.path.abspath(os.path.join(current_dir, "src"))
20
+ sys.path.insert(0, src_path)
21
 
22
+ # Cache directories
23
  base_dir = "/data"
24
+ os.makedirs(base_dir, exist_ok=True)
25
  model_cache_dir = os.path.join(base_dir, "txagent_models")
26
  tool_cache_dir = os.path.join(base_dir, "tool_cache")
27
  file_cache_dir = os.path.join(base_dir, "cache")
28
+ report_dir = "/data/reports"
29
  vllm_cache_dir = os.path.join(base_dir, "vllm_cache")
30
 
31
+ os.makedirs(model_cache_dir, exist_ok=True)
32
+ os.makedirs(tool_cache_dir, exist_ok=True)
33
+ os.makedirs(file_cache_dir, exist_ok=True)
34
+ os.makedirs(report_dir, exist_ok=True)
35
+ os.makedirs(vllm_cache_dir, exist_ok=True)
36
 
 
37
  os.environ.update({
 
38
  "TRANSFORMERS_CACHE": model_cache_dir,
39
+ "HF_HOME": model_cache_dir,
40
  "VLLM_CACHE_DIR": vllm_cache_dir,
41
  "TOKENIZERS_PARALLELISM": "false",
42
  "CUDA_LAUNCH_BLOCKING": "1"
43
  })
44
 
 
 
 
 
 
45
  from txagent.txagent import TxAgent
46
 
47
+ MEDICAL_KEYWORDS = {
48
+ 'diagnosis', 'assessment', 'plan', 'results', 'medications',
49
+ 'allergies', 'summary', 'impression', 'findings', 'recommendations'
50
+ }
51
 
52
+ def sanitize_utf8(text: str) -> str:
53
+ return text.encode("utf-8", "ignore").decode("utf-8")
54
+
55
+ def file_hash(path: str) -> str:
56
+ with open(path, "rb") as f:
57
+ return hashlib.md5(f.read()).hexdigest()
58
+
59
+ def extract_priority_pages(file_path: str, max_pages: int = 20) -> str:
60
  try:
61
+ text_chunks = []
62
  with pdfplumber.open(file_path) as pdf:
 
63
  for i, page in enumerate(pdf.pages[:3]):
64
+ text_chunks.append(f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}")
65
  for i, page in enumerate(pdf.pages[3:max_pages], start=4):
66
+ page_text = page.extract_text() or ""
67
+ if any(re.search(rf'\\b{kw}\\b', page_text.lower()) for kw in MEDICAL_KEYWORDS):
68
+ text_chunks.append(f"=== Page {i} ===\n{page_text.strip()}")
69
+ return "\n\n".join(text_chunks)
70
  except Exception as e:
71
  return f"PDF processing error: {str(e)}"
72
 
73
+ def convert_file_to_json(file_path: str, file_type: str) -> str:
74
  try:
75
  h = file_hash(file_path)
76
  cache_path = os.path.join(file_cache_dir, f"{h}.json")
77
+ if os.path.exists(cache_path):
78
+ return open(cache_path, "r", encoding="utf-8").read()
79
 
80
  if file_type == "pdf":
81
  text = extract_priority_pages(file_path)
82
  result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
83
  Thread(target=full_pdf_processing, args=(file_path, h)).start()
84
+
85
  elif file_type == "csv":
86
+ df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str, skip_blank_lines=False, on_bad_lines="skip")
87
+ content = df.fillna("").astype(str).values.tolist()
88
+ result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
89
+
90
  elif file_type in ["xls", "xlsx"]:
91
+ try:
92
+ df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
93
+ except:
94
+ df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
95
+ content = df.fillna("").astype(str).values.tolist()
96
+ result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
97
+
98
  else:
99
  return json.dumps({"error": f"Unsupported file type: {file_type}"})
100
 
101
+ with open(cache_path, "w", encoding="utf-8") as f:
102
+ f.write(result)
103
  return result
104
+
105
  except Exception as e:
106
+ return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
107
 
108
+ def full_pdf_processing(file_path: str, file_hash: str):
109
  try:
110
+ cache_path = os.path.join(file_cache_dir, f"{file_hash}_full.json")
111
+ if os.path.exists(cache_path):
112
+ return
113
  with pdfplumber.open(file_path) as pdf:
114
+ full_text = "\n".join([f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}" for i, page in enumerate(pdf.pages)])
115
+ result = json.dumps({"filename": os.path.basename(file_path), "content": full_text, "status": "complete"})
116
+ with open(cache_path, "w", encoding="utf-8") as f:
117
+ f.write(result)
118
+ with open(os.path.join(report_dir, f"{file_hash}_report.txt"), "w", encoding="utf-8") as out:
119
+ out.write(full_text)
120
+ except Exception as e:
121
+ print(f"Background processing failed: {str(e)}")
122
 
123
  def init_agent():
124
+ default_tool_path = os.path.abspath("data/new_tool.json")
125
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
126
  if not os.path.exists(target_tool_path):
127
+ shutil.copy(default_tool_path, target_tool_path)
128
 
129
  agent = TxAgent(
130
+ model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
131
+ rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
132
  tool_files_dict={"new_tool": target_tool_path},
133
  force_finish=True,
134
  enable_checker=True,
135
  step_rag_num=8,
136
+ seed=100,
137
+ additional_default_tools=[],
138
  )
139
  agent.init_model()
140
  return agent
141
 
142
+ def format_response(response: str) -> str:
143
+ """Clean and format the response for display"""
144
+ # Remove all tool call artifacts
145
+ response = response.replace("[TOOL_CALLS]", "").strip()
146
+
147
+ # Remove duplicate sections if they exist
148
+ if "Based on the medical records provided" in response:
149
+ parts = response.split("Based on the medical records provided")
150
+ if len(parts) > 1:
151
+ response = "Based on the medical records provided" + parts[-1]
152
+
153
+ # Format sections with Markdown
154
+ formatted = response.replace("1. **Missed Diagnoses**:", "### ๐Ÿ” Missed Diagnoses")
155
+ formatted = formatted.replace("2. **Medication Conflicts**:", "\n### ๐Ÿ’Š Medication Conflicts")
156
+ formatted = formatted.replace("3. **Incomplete Assessments**:", "\n### ๐Ÿ“‹ Incomplete Assessments")
157
+ formatted = formatted.replace("4. **Abnormal Results Needing Follow-up**:", "\n### โš ๏ธ Abnormal Results Needing Follow-up")
158
+ formatted = formatted.replace("Overall, the patient's medical records", "\n### ๐Ÿ“ Overall Assessment")
159
+
160
+ return formatted
161
+
162
+ def analyze_potential_oversights(message: str, history: list, conversation: list, files: list):
163
+ start_time = time.time()
164
+ try:
165
+ # Initial loading message
166
+ history = history + [
167
+ {"role": "user", "content": message},
168
+ {"role": "assistant", "content": "โณ Analyzing records for potential oversights..."}
169
+ ]
170
+ yield history, None
171
+
172
+ # Process uploaded files
173
+ extracted_data = ""
174
+ file_hash_value = ""
175
+ if files and isinstance(files, list):
176
+ with ThreadPoolExecutor(max_workers=4) as executor:
177
+ futures = [executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower())
178
+ for f in files if hasattr(f, 'name')]
179
+ extracted_data = "\n".join([sanitize_utf8(f.result()) for f in as_completed(futures)])
180
+ file_hash_value = file_hash(files[0].name) if files else ""
181
+
182
+ # Prepare the analysis prompt
183
+ analysis_prompt = f"""Review these medical records and identify EXACTLY what might have been missed:
184
+ 1. List potential missed diagnoses
185
+ 2. Flag any medication conflicts
186
+ 3. Note incomplete assessments
187
+ 4. Highlight abnormal results needing follow-up
188
+
189
+ Medical Records:\n{extracted_data[:15000]}
190
+
191
+ ### Potential Oversights:\n"""
192
+
193
+ # Process the response from the agent
194
+ full_response = ""
195
+ for chunk in agent.run_gradio_chat(
196
+ message=analysis_prompt,
197
+ history=[],
198
+ temperature=0.2,
199
+ max_new_tokens=1024,
200
+ max_token=4096,
201
+ call_agent=False,
202
+ conversation=conversation
203
+ ):
204
+ if isinstance(chunk, str):
205
+ full_response += chunk
206
+ elif isinstance(chunk, list):
207
+ full_response += "".join([c.content for c in chunk if hasattr(c, 'content')])
208
+
209
+ # Format and display the partial response
210
+ formatted = format_response(full_response)
211
+ if formatted.strip():
212
+ history = history[:-1] + [{"role": "assistant", "content": formatted}]
213
+ yield history, None
214
+
215
+ # Final formatting and cleanup
216
+ final_output = format_response(full_response)
217
+ if not final_output.strip():
218
+ final_output = "No clear oversights identified. Recommend comprehensive review."
219
+
220
+ # Prepare report download if available
221
+ report_path = None
222
+ if file_hash_value:
223
+ possible_report = os.path.join(report_dir, f"{file_hash_value}_report.txt")
224
+ if os.path.exists(possible_report):
225
+ report_path = possible_report
226
+
227
+ # Update history with final response
228
+ history = history[:-1] + [{"role": "assistant", "content": final_output}]
229
+ yield history, report_path
230
+
231
+ except Exception as e:
232
+ history.append({"role": "assistant", "content": f"โŒ Analysis failed: {str(e)}"})
233
+ yield history, None
234
+
235
+ def create_ui(agent: TxAgent):
236
+ with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 900px !important}") as demo:
237
+ gr.Markdown("""
238
+ <div style='text-align: center;'>
239
+ <h1>๐Ÿฉบ Clinical Oversight Assistant</h1>
240
+ <h3>Identify potential oversights in patient care</h3>
241
+ <p>Upload medical records to analyze for missed diagnoses, medication conflicts, and other potential issues.</p>
242
+ </div>
243
+ """)
244
+
245
+ with gr.Row():
246
+ with gr.Column(scale=2):
247
+ file_upload = gr.File(
248
+ label="Upload Medical Records",
249
+ file_types=[".pdf", ".csv", ".xls", ".xlsx"],
250
+ file_count="multiple",
251
+ height=100
252
+ )
253
+ msg_input = gr.Textbox(
254
+ placeholder="Ask about potential oversights...",
255
+ show_label=False,
256
+ lines=3,
257
+ max_lines=6
258
+ )
259
+ send_btn = gr.Button("Analyze", variant="primary", size="lg")
260
+
261
+ gr.Examples(
262
+ examples=[
263
+ ["What might have been missed in this patient's treatment?"],
264
+ ["Are there any medication conflicts in these records?"],
265
+ ["What abnormal results require follow-up?"],
266
+ ["Identify any incomplete assessments in these records"]
267
+ ],
268
+ inputs=msg_input,
269
+ label="Example Queries"
270
+ )
271
+
272
+ with gr.Column(scale=3):
273
+ chatbot = gr.Chatbot(
274
+ label="Analysis Results",
275
+ height=600,
276
+ bubble_full_width=False,
277
+ show_copy_button=True,
278
+ avatar_images=(
279
+ "assets/user.png",
280
+ "assets/doctor.png"
281
+ )
282
+ )
283
+ download_output = gr.File(
284
+ label="Download Full Report",
285
+ visible=False
286
+ )
287
+
288
+ conversation_state = gr.State([])
289
+
290
+ inputs = [msg_input, chatbot, conversation_state, file_upload]
291
+ outputs = [chatbot, download_output]
292
+
293
+ send_btn.click(
294
+ analyze_potential_oversights,
295
+ inputs=inputs,
296
+ outputs=outputs
297
+ )
298
+ msg_input.submit(
299
+ analyze_potential_oversights,
300
+ inputs=inputs,
301
+ outputs=outputs
302
+ )
303
+
304
  return demo
305
 
306
  if __name__ == "__main__":
307
+ print("Initializing medical analysis agent...")
308
+ agent = init_agent()
309
+
310
+ print("Launching interface...")
311
+ demo = create_ui(agent)
312
+ demo.queue(
313
+ concurrency_count=3,
314
+ api_open=False
315
+ ).launch(
316
  server_name="0.0.0.0",
317
  server_port=7860,
318
  show_error=True,
319
  allowed_paths=["/data/reports"],
320
  share=False
321
+ )