Ali2206 commited on
Commit
1da2cfd
·
verified ·
1 Parent(s): 3b1f183

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -130
app.py CHANGED
@@ -10,15 +10,17 @@ import hashlib
10
  import shutil
11
  import time
12
  from functools import lru_cache
 
 
13
 
14
- # Environment and path setup
15
  current_dir = os.path.dirname(os.path.abspath(__file__))
16
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
17
- print(f"Adding to path: {src_path}")
18
  sys.path.insert(0, src_path)
19
 
20
- # Configure cache directories
21
  base_dir = "/data"
 
22
  model_cache_dir = os.path.join(base_dir, "txagent_models")
23
  tool_cache_dir = os.path.join(base_dir, "tool_cache")
24
  file_cache_dir = os.path.join(base_dir, "cache")
@@ -27,13 +29,21 @@ os.makedirs(model_cache_dir, exist_ok=True)
27
  os.makedirs(tool_cache_dir, exist_ok=True)
28
  os.makedirs(file_cache_dir, exist_ok=True)
29
 
30
- os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
31
- os.environ["HF_HOME"] = model_cache_dir
32
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
33
- os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
 
 
34
 
35
  from txagent.txagent import TxAgent
36
 
 
 
 
 
 
 
37
  def sanitize_utf8(text: str) -> str:
38
  return text.encode("utf-8", "ignore").decode("utf-8")
39
 
@@ -41,193 +51,217 @@ def file_hash(path: str) -> str:
41
  with open(path, "rb") as f:
42
  return hashlib.md5(f.read()).hexdigest()
43
 
44
- @lru_cache(maxsize=100)
45
- def get_cached_response(prompt: str, file_hash: str) -> Optional[str]:
46
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  def convert_file_to_json(file_path: str, file_type: str) -> str:
 
49
  try:
50
  h = file_hash(file_path)
51
  cache_path = os.path.join(file_cache_dir, f"{h}.json")
52
-
53
  if os.path.exists(cache_path):
54
  return open(cache_path, "r", encoding="utf-8").read()
55
 
56
- if file_type == "csv":
57
- df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str, skip_blank_lines=False, on_bad_lines="skip")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  elif file_type in ["xls", "xlsx"]:
59
  try:
60
  df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
61
  except:
62
  df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
63
- elif file_type == "pdf":
64
- with pdfplumber.open(file_path) as pdf:
65
- text = "\n".join([page.extract_text() or "" for page in pdf.pages])
66
- result = json.dumps({"filename": os.path.basename(file_path), "content": text.strip()})
67
- with open(cache_path, "w", encoding="utf-8") as f:
68
- f.write(result)
69
- return result
70
  else:
71
  return json.dumps({"error": f"Unsupported file type: {file_type}"})
72
 
73
- if df is None or df.empty:
74
- return json.dumps({"warning": f"No data extracted from: {file_path}"})
75
-
76
- df = df.fillna("")
77
- content = df.astype(str).values.tolist()
78
- result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
79
  with open(cache_path, "w", encoding="utf-8") as f:
80
  f.write(result)
81
  return result
 
82
  except Exception as e:
83
- return json.dumps({"error": f"Error reading {os.path.basename(file_path)}: {str(e)}"})
84
-
85
- def convert_files_to_json_parallel(uploaded_files: list) -> str:
86
- extracted_text = []
87
- with ThreadPoolExecutor(max_workers=4) as executor:
88
- futures = []
89
- for file in uploaded_files:
90
- if not hasattr(file, 'name'):
91
- continue
92
- path = file.name
93
- ext = path.split(".")[-1].lower()
94
- futures.append(executor.submit(convert_file_to_json, path, ext))
95
-
96
- for future in as_completed(futures):
97
- extracted_text.append(sanitize_utf8(future.result()))
98
- return "\n".join(extracted_text)
 
 
 
 
 
 
 
99
 
100
  def init_agent():
 
101
  default_tool_path = os.path.abspath("data/new_tool.json")
102
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
 
103
  if not os.path.exists(target_tool_path):
104
  shutil.copy(default_tool_path, target_tool_path)
105
 
106
- model_name = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
107
- rag_model_name = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
108
-
109
  agent = TxAgent(
110
- model_name=model_name,
111
- rag_model_name=rag_model_name,
112
  tool_files_dict={"new_tool": target_tool_path},
113
  force_finish=True,
114
  enable_checker=True,
115
  step_rag_num=8,
116
  seed=100,
117
- additional_default_tools=[]
 
118
  )
119
  agent.init_model()
120
  return agent
121
 
122
  def create_ui(agent: TxAgent):
123
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
124
- gr.Markdown("<h1 style='text-align: center;'>📋 CPS: Clinical Patient Support System</h1>")
 
125
 
126
- chatbot = gr.Chatbot(label="CPS Assistant", height=600, type="messages")
127
  file_upload = gr.File(
128
- label="Upload Medical File",
129
- file_types=[".pdf", ".txt", ".docx", ".jpg", ".png", ".csv", ".xls", ".xlsx"],
130
  file_count="multiple"
131
  )
132
- message_input = gr.Textbox(placeholder="Ask a biomedical question or just upload the files...", show_label=False)
133
- send_button = gr.Button("Send", variant="primary")
134
  conversation_state = gr.State([])
135
 
136
- def handle_chat(message: str, history: list, conversation: list, uploaded_files: list, progress=gr.Progress()):
137
  start_time = time.time()
138
  try:
139
- history.append({"role": "user", "content": message})
140
- history.append({"role": "assistant", "content": "⏳ Processing your request..."})
141
  yield history
142
-
143
- file_process_time = time.time()
144
- extracted_text = ""
145
- if uploaded_files and isinstance(uploaded_files, list):
146
- extracted_text = convert_files_to_json_parallel(uploaded_files)
147
- print(f"File processing took: {time.time() - file_process_time:.2f}s")
148
-
149
- context = (
150
- "You are an expert clinical AI assistant. Review this patient's history, "
151
- "medications, and notes, and ONLY provide a final answer summarizing "
152
- "what the doctor might have missed."
153
- )
154
- chunked_prompt = f"{context}\n\n--- Patient Record ---\n{extracted_text}\n\n[Final Analysis]"
155
-
156
- model_start = time.time()
157
- generator = agent.run_gradio_chat(
158
- message=chunked_prompt,
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  history=[],
160
- temperature=0.3,
161
- max_new_tokens=768,
162
  max_token=4096,
163
  call_agent=False,
164
- conversation=conversation,
165
- uploaded_files=uploaded_files,
166
- max_round=10
167
- )
168
-
169
- final_response = ""
170
- for update in generator:
171
- if not update:
172
- continue
173
- if isinstance(update, list):
174
- for msg in update:
175
- if hasattr(msg, 'content'):
176
- final_response += msg.content
177
- elif isinstance(update, str):
178
- final_response += update
179
-
180
- cleaned = final_response.strip().replace("[TOOL_CALLS]", "")
181
- history[-1] = {"role": "assistant", "content": cleaned or "❌ No response."}
182
- yield history
183
-
184
- print("Final model response:\n", final_response)
185
- history[-1] = {"role": "assistant", "content": final_response.strip() or "❌ No response."}
186
- print(f"Model processing took: {time.time() - model_start:.2f}s")
187
  yield history
188
 
189
- except Exception as chat_error:
190
- print(f"Chat handling error: {chat_error}")
191
- history[-1] = {"role": "assistant", "content": "❌ An error occurred while processing your request."}
192
  yield history
193
- finally:
194
- print(f"Total request time: {time.time() - start_time:.2f}s")
195
 
196
- inputs = [message_input, chatbot, conversation_state, file_upload]
197
- send_button.click(fn=handle_chat, inputs=inputs, outputs=chatbot)
198
- message_input.submit(fn=handle_chat, inputs=inputs, outputs=chatbot)
 
199
 
200
  gr.Examples([
201
- ["Upload your medical form and ask what the doctor might've missed."],
202
- ["This patient was treated with antibiotics for UTI. What else should we check?"],
203
- ["Is there anything abnormal in the attached blood work report?"]
204
- ], inputs=message_input)
205
 
206
  return demo
207
 
208
  if __name__ == "__main__":
209
- print("Initializing agent...")
210
  agent = init_agent()
211
-
212
- print("Performing warm-up call...")
213
- try:
214
- warm_up = agent.run_gradio_chat(
215
- message="Warm up",
216
- history=[],
217
- temperature=0.1,
218
- max_new_tokens=10,
219
- max_token=100,
220
- call_agent=False,
221
- conversation=[]
222
- )
223
- for _ in warm_up:
224
- pass
225
- except:
226
- pass
227
-
228
  print("Launching interface...")
229
  demo = create_ui(agent)
230
- demo.queue().launch(
231
  server_name="0.0.0.0",
232
  server_port=7860,
233
  show_error=True,
 
10
  import shutil
11
  import time
12
  from functools import lru_cache
13
+ from threading import Thread
14
+ import re
15
 
16
+ # Environment setup
17
  current_dir = os.path.dirname(os.path.abspath(__file__))
18
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
 
19
  sys.path.insert(0, src_path)
20
 
21
+ # Cache directories
22
  base_dir = "/data"
23
+ os.makedirs(base_dir, exist_ok=True)
24
  model_cache_dir = os.path.join(base_dir, "txagent_models")
25
  tool_cache_dir = os.path.join(base_dir, "tool_cache")
26
  file_cache_dir = os.path.join(base_dir, "cache")
 
29
  os.makedirs(tool_cache_dir, exist_ok=True)
30
  os.makedirs(file_cache_dir, exist_ok=True)
31
 
32
+ os.environ.update({
33
+ "TRANSFORMERS_CACHE": model_cache_dir,
34
+ "HF_HOME": model_cache_dir,
35
+ "TOKENIZERS_PARALLELISM": "false",
36
+ "CUDA_LAUNCH_BLOCKING": "1"
37
+ })
38
 
39
  from txagent.txagent import TxAgent
40
 
41
+ # Medical keywords for priority detection
42
+ MEDICAL_KEYWORDS = {
43
+ 'diagnosis', 'assessment', 'plan', 'results', 'medications',
44
+ 'allergies', 'summary', 'impression', 'findings', 'recommendations'
45
+ }
46
+
47
  def sanitize_utf8(text: str) -> str:
48
  return text.encode("utf-8", "ignore").decode("utf-8")
49
 
 
51
  with open(path, "rb") as f:
52
  return hashlib.md5(f.read()).hexdigest()
53
 
54
+ def extract_priority_pages(file_path: str, max_pages: int = 20) -> str:
55
+ """Fast extraction of first pages and medically relevant sections"""
56
+ try:
57
+ text_chunks = []
58
+ with pdfplumber.open(file_path) as pdf:
59
+ # Always process first 3 pages
60
+ for i, page in enumerate(pdf.pages[:3]):
61
+ text_chunks.append(f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}")
62
+
63
+ # Scan subsequent pages for medical keywords
64
+ for i, page in enumerate(pdf.pages[3:max_pages], start=4):
65
+ page_text = page.extract_text() or ""
66
+ if any(re.search(rf'\b{kw}\b', page_text.lower()) for kw in MEDICAL_KEYWORDS):
67
+ text_chunks.append(f"=== Page {i} ===\n{page_text.strip()}")
68
+
69
+ return "\n\n".join(text_chunks)
70
+ except Exception as e:
71
+ return f"PDF processing error: {str(e)}"
72
 
73
  def convert_file_to_json(file_path: str, file_type: str) -> str:
74
+ """Optimized file conversion with medical focus"""
75
  try:
76
  h = file_hash(file_path)
77
  cache_path = os.path.join(file_cache_dir, f"{h}.json")
78
+
79
  if os.path.exists(cache_path):
80
  return open(cache_path, "r", encoding="utf-8").read()
81
 
82
+ if file_type == "pdf":
83
+ # Fast initial processing
84
+ text = extract_priority_pages(file_path)
85
+ result = json.dumps({
86
+ "filename": os.path.basename(file_path),
87
+ "content": text,
88
+ "status": "initial"
89
+ })
90
+
91
+ # Start background full processing
92
+ Thread(target=full_pdf_processing, args=(file_path, h)).start()
93
+
94
+ elif file_type == "csv":
95
+ df = pd.read_csv(file_path, encoding_errors="replace", header=None,
96
+ dtype=str, skip_blank_lines=False, on_bad_lines="skip")
97
+ content = df.fillna("").astype(str).values.tolist()
98
+ result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
99
+
100
  elif file_type in ["xls", "xlsx"]:
101
  try:
102
  df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
103
  except:
104
  df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
105
+ content = df.fillna("").astype(str).values.tolist()
106
+ result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
107
+
 
 
 
 
108
  else:
109
  return json.dumps({"error": f"Unsupported file type: {file_type}"})
110
 
 
 
 
 
 
 
111
  with open(cache_path, "w", encoding="utf-8") as f:
112
  f.write(result)
113
  return result
114
+
115
  except Exception as e:
116
+ return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
117
+
118
+ def full_pdf_processing(file_path: str, file_hash: str):
119
+ """Background full PDF processing"""
120
+ try:
121
+ cache_path = os.path.join(file_cache_dir, f"{file_hash}_full.json")
122
+ if os.path.exists(cache_path):
123
+ return
124
+
125
+ with pdfplumber.open(file_path) as pdf:
126
+ full_text = "\n".join([f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}"
127
+ for i, page in enumerate(pdf.pages)])
128
+
129
+ result = json.dumps({
130
+ "filename": os.path.basename(file_path),
131
+ "content": full_text,
132
+ "status": "complete"
133
+ })
134
+
135
+ with open(cache_path, "w", encoding="utf-8") as f:
136
+ f.write(result)
137
+ except Exception as e:
138
+ print(f"Background processing failed: {str(e)}")
139
 
140
  def init_agent():
141
+ """Initialize TxAgent with medical analysis focus"""
142
  default_tool_path = os.path.abspath("data/new_tool.json")
143
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
144
+
145
  if not os.path.exists(target_tool_path):
146
  shutil.copy(default_tool_path, target_tool_path)
147
 
 
 
 
148
  agent = TxAgent(
149
+ model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
150
+ rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
151
  tool_files_dict={"new_tool": target_tool_path},
152
  force_finish=True,
153
  enable_checker=True,
154
  step_rag_num=8,
155
  seed=100,
156
+ additional_default_tools=[],
157
+ device_map="auto"
158
  )
159
  agent.init_model()
160
  return agent
161
 
162
  def create_ui(agent: TxAgent):
163
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
164
+ gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
165
+ gr.Markdown("<h3 style='text-align: center;'>Identify potential oversights in patient care</h3>")
166
 
167
+ chatbot = gr.Chatbot(label="Analysis", height=600)
168
  file_upload = gr.File(
169
+ label="Upload Medical Records",
170
+ file_types=[".pdf", ".csv", ".xls", ".xlsx"],
171
  file_count="multiple"
172
  )
173
+ msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
174
+ send_btn = gr.Button("Analyze", variant="primary")
175
  conversation_state = gr.State([])
176
 
177
+ def analyze_potential_oversights(message: str, history: list, conversation: list, files: list):
178
  start_time = time.time()
179
  try:
180
+ history.append((message, "Analyzing records for potential oversights..."))
 
181
  yield history
182
+
183
+ # Process files
184
+ extracted_data = ""
185
+ if files:
186
+ with ThreadPoolExecutor(max_workers=4) as executor:
187
+ futures = [executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower())
188
+ for f in files if hasattr(f, 'name')]
189
+ extracted_data = "\n".join([sanitize_utf8(f.result()) for f in as_completed(futures)])
190
+
191
+ # Medical oversight analysis prompt
192
+ analysis_prompt = """Review these medical records and identify EXACTLY what might have been missed:
193
+ 1. List potential missed diagnoses
194
+ 2. Flag any medication conflicts
195
+ 3. Note incomplete assessments
196
+ 4. Highlight abnormal results needing follow-up
197
+
198
+ Medical Records:
199
+ {records}
200
+
201
+ Provide ONLY the potential oversights in this format:
202
+
203
+ ### Potential Oversights:
204
+ 1. [Missed diagnosis] - [Evidence from records]
205
+ 2. [Medication issue] - [Supporting data]
206
+ 3. [Assessment gap] - [Relevant findings]""".format(records=extracted_data[:15000]) # Limit input size
207
+
208
+ # Generate analysis
209
+ response = []
210
+ for chunk in agent.run_gradio_chat(
211
+ message=analysis_prompt,
212
  history=[],
213
+ temperature=0.2, # More deterministic
214
+ max_new_tokens=1024,
215
  max_token=4096,
216
  call_agent=False,
217
+ conversation=conversation
218
+ ):
219
+ if isinstance(chunk, str):
220
+ response.append(chunk)
221
+ elif isinstance(chunk, list):
222
+ response.extend([c.content for c in chunk if hasattr(c, 'content')])
223
+
224
+ if len(response) % 3 == 0: # Update every 3 chunks
225
+ history[-1] = (message, "".join(response).strip())
226
+ yield history
227
+
228
+ # Finalize output
229
+ final_output = "".join(response).strip()
230
+ if not final_output:
231
+ final_output = "No clear oversights identified. Recommend comprehensive review."
232
+
233
+ # Format as bullet points if not already
234
+ if not final_output.startswith(("1.", "-", "*", "#")):
235
+ final_output = "• " + final_output.replace("\n", "\n• ")
236
+
237
+ history[-1] = (message, f"### Potential Clinical Oversights:\n{final_output}")
238
+ print(f"Analysis completed in {time.time()-start_time:.2f}s")
 
239
  yield history
240
 
241
+ except Exception as e:
242
+ history.append((message, f" Analysis failed: {str(e)}"))
 
243
  yield history
 
 
244
 
245
+ # UI event handlers
246
+ inputs = [msg_input, chatbot, conversation_state, file_upload]
247
+ send_btn.click(analyze_potential_oversights, inputs=inputs, outputs=chatbot)
248
+ msg_input.submit(analyze_potential_oversights, inputs=inputs, outputs=chatbot)
249
 
250
  gr.Examples([
251
+ ["What might have been missed in this patient's treatment?"],
252
+ ["Are there any medication conflicts in these records?"],
253
+ ["What abnormal results require follow-up?"]
254
+ ], inputs=msg_input)
255
 
256
  return demo
257
 
258
  if __name__ == "__main__":
259
+ print("Initializing medical analysis agent...")
260
  agent = init_agent()
261
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  print("Launching interface...")
263
  demo = create_ui(agent)
264
+ demo.queue(concurrency_count=2).launch(
265
  server_name="0.0.0.0",
266
  server_port=7860,
267
  show_error=True,