Ali2206 commited on
Commit
9ec5ec4
·
verified ·
1 Parent(s): 9d7bfdd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -77
app.py CHANGED
@@ -1,16 +1,21 @@
1
- import ThreadPoolExecutor, as_completed
 
 
 
 
 
 
 
2
  import hashlib
3
  import shutil
4
  import time
5
- from threading import Thread, Lock
6
  import re
7
- import tempfile
8
- import threading
9
 
10
  # ---------------------------------------------------------------------------------------
11
- # Setup persistent directories for Hugging Face Spaces
12
  # ---------------------------------------------------------------------------------------
13
- # Use a persistent cache directory (adjust the path as needed based on your HF Space settings)
14
  persistent_dir = "/data/hf_cache"
15
  os.makedirs(persistent_dir, exist_ok=True)
16
 
@@ -23,25 +28,23 @@ vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache")
23
  for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
24
  os.makedirs(directory, exist_ok=True)
25
 
26
- # Set environment variables so that model and transformers caches point to persistent storage.
27
  os.environ["HF_HOME"] = model_cache_dir
28
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
29
  os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
30
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
31
  os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
32
 
33
- # Append the local source path if needed
 
 
34
  current_dir = os.path.dirname(os.path.abspath(__file__))
35
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
36
  sys.path.insert(0, src_path)
37
 
38
- # ---------------------------------------------------------------------------------------
39
- # Import the TxAgent from your tool package
40
- # ---------------------------------------------------------------------------------------
41
  from txagent.txagent import TxAgent
42
 
43
  # ---------------------------------------------------------------------------------------
44
- # Define constants and helper functions
45
  # ---------------------------------------------------------------------------------------
46
  MEDICAL_KEYWORDS = {
47
  'diagnosis', 'assessment', 'plan', 'results', 'medications',
@@ -59,11 +62,9 @@ def extract_priority_pages(file_path: str, max_pages: int = 20) -> str:
59
  try:
60
  text_chunks = []
61
  with pdfplumber.open(file_path) as pdf:
62
- # Process first three pages always
63
  for i, page in enumerate(pdf.pages[:3]):
64
  text = page.extract_text() or ""
65
  text_chunks.append(f"=== Page {i+1} ===\n{text.strip()}")
66
- # Process subsequent pages only if they contain key medical keywords
67
  for i, page in enumerate(pdf.pages[3:max_pages], start=4):
68
  page_text = page.extract_text() or ""
69
  if any(re.search(rf'\b{kw}\b', page_text.lower()) for kw in MEDICAL_KEYWORDS):
@@ -83,7 +84,6 @@ def convert_file_to_json(file_path: str, file_type: str) -> str:
83
  if file_type == "pdf":
84
  text = extract_priority_pages(file_path)
85
  result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
86
- Thread(target=full_pdf_processing, args=(file_path, h)).start()
87
  elif file_type == "csv":
88
  df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str,
89
  skip_blank_lines=False, on_bad_lines="skip")
@@ -104,34 +104,34 @@ def convert_file_to_json(file_path: str, file_type: str) -> str:
104
  except Exception as e:
105
  return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
106
 
107
- def full_pdf_processing(file_path: str, file_hash_value: str):
108
  try:
109
- cache_path = os.path.join(file_cache_dir, f"{file_hash_value}_full.json")
110
- if os.path.exists(cache_path):
111
- return
112
- with pdfplumber.open(file_path) as pdf:
113
- full_text = "\n".join([f"=== Page {i+1} ===\n{(page.extract_text() or '').strip()}"
114
- for i, page in enumerate(pdf.pages)])
115
- result = json.dumps({"filename": os.path.basename(file_path), "content": full_text, "status": "complete"})
116
- with open(cache_path, "w", encoding="utf-8") as f:
117
- f.write(result)
118
- with open(os.path.join(report_dir, f"{file_hash_value}_report.txt"), "w", encoding="utf-8") as out:
119
- out.write(full_text)
 
 
120
  except Exception as e:
121
- print(f"Background processing failed: {str(e)}")
122
-
123
- # ---------------------------------------------------------------------------------------
124
- # Global agent variable and thread-safe lock for background model loading
125
- # ---------------------------------------------------------------------------------------
126
- agent = None
127
- agent_lock = Lock()
128
 
129
  def init_agent():
 
 
 
130
  default_tool_path = os.path.abspath("data/new_tool.json")
131
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
132
  if not os.path.exists(target_tool_path):
133
  shutil.copy(default_tool_path, target_tool_path)
134
- new_agent = TxAgent(
 
135
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
136
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
137
  tool_files_dict={"new_tool": target_tool_path},
@@ -141,24 +141,21 @@ def init_agent():
141
  seed=100,
142
  additional_default_tools=[],
143
  )
144
- new_agent.init_model()
145
- return new_agent
146
 
147
- def load_agent_in_background():
148
- global agent
149
- with agent_lock:
150
- if agent is None:
151
- print("Initializing agent in background (this may take a while)...")
152
- agent = init_agent()
153
- print("Agent initialization complete.")
154
 
155
- # Start background agent loading at startup
156
- threading.Thread(target=load_agent_in_background, daemon=True).start()
157
 
158
  # ---------------------------------------------------------------------------------------
159
- # Define the Gradio UI
160
  # ---------------------------------------------------------------------------------------
161
- def create_ui():
162
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
163
  gr.Markdown("""
164
  <h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>
@@ -173,20 +170,10 @@ def create_ui():
173
  download_output = gr.File(label="Download Full Report")
174
 
175
  def analyze_potential_oversights(message: str, history: list, files: list):
176
- global agent
177
- # Append user and interim assistant message
178
- history = history + [
179
- {"role": "user", "content": message},
180
- {"role": "assistant", "content": "⏳ Analyzing records for potential oversights..."}
181
- ]
182
  yield history, None
183
 
184
- if agent is None:
185
- history.append({"role": "assistant",
186
- "content": "🕒 The model is still loading. Please wait a moment and try again."})
187
- yield history, None
188
- return
189
-
190
  extracted_data = ""
191
  file_hash_value = ""
192
  if files and isinstance(files, list):
@@ -195,13 +182,10 @@ def create_ui():
195
  executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower())
196
  for f in files if hasattr(f, 'name')
197
  ]
198
- results = []
199
- for future in as_completed(futures):
200
- results.append(sanitize_utf8(future.result()))
201
  extracted_data = "\n".join(results)
202
  file_hash_value = file_hash(files[0].name) if hasattr(files[0], 'name') else ""
203
 
204
- # Truncate extracted data to avoid token overflow
205
  max_extracted_chars = 12000
206
  truncated_data = extracted_data[:max_extracted_chars]
207
 
@@ -216,10 +200,8 @@ Medical Records:
216
 
217
  ### Potential Oversights:
218
  """
219
-
220
  response = ""
221
  try:
222
- # Stream agent responses and update the last message in the conversation with each chunk.
223
  for chunk in agent.run_gradio_chat(
224
  message=analysis_prompt,
225
  history=[],
@@ -229,18 +211,16 @@ Medical Records:
229
  call_agent=False,
230
  conversation=[]
231
  ):
232
- if chunk is None:
233
- continue
234
  if isinstance(chunk, str):
235
  response += chunk
236
  elif isinstance(chunk, list):
237
  response += "".join([c.content for c in chunk if hasattr(c, 'content')])
238
  cleaned = response.replace("[TOOL_CALLS]", "").strip()
239
- # Update the assistant message (last item in history) with the latest accumulated answer
240
  history[-1] = {"role": "assistant", "content": cleaned}
241
  yield history, None
242
  except Exception as agent_error:
243
- history[-1] = {"role": "assistant", "content": f"❌ Analysis failed during processing: {str(agent_error)}"}
244
  yield history, None
245
  return
246
 
@@ -248,7 +228,6 @@ Medical Records:
248
  if not final_output:
249
  final_output = "No clear oversights identified. Recommend comprehensive review."
250
 
251
- # Update the assistant's message with the final output
252
  history[-1] = {"role": "assistant", "content": final_output}
253
 
254
  report_path = None
@@ -265,15 +244,20 @@ Medical Records:
265
  msg_input.submit(analyze_potential_oversights,
266
  inputs=[msg_input, gr.State([]), file_upload],
267
  outputs=[chatbot, download_output])
268
- gr.Examples([["What might have been missed in this patient's treatment?"],
269
- ["Are there any medication conflicts in these records?"],
270
- ["What abnormal results require follow-up?"]],
271
- inputs=msg_input)
 
272
  return demo
273
 
 
 
 
274
  if __name__ == "__main__":
275
- print("Launching interface...")
276
- demo = create_ui()
 
277
  demo.queue(api_open=False).launch(
278
  server_name="0.0.0.0",
279
  server_port=7860,
 
1
+ import sys
2
+ import os
3
+ import pandas as pd
4
+ import pdfplumber
5
+ import json
6
+ import gradio as gr
7
+ from typing import List
8
+ from concurrent.futures import ThreadPoolExecutor, as_completed
9
  import hashlib
10
  import shutil
11
  import time
 
12
  import re
13
+ import psutil
14
+ import subprocess
15
 
16
  # ---------------------------------------------------------------------------------------
17
+ # Persistent directory for Hugging Face Spaces
18
  # ---------------------------------------------------------------------------------------
 
19
  persistent_dir = "/data/hf_cache"
20
  os.makedirs(persistent_dir, exist_ok=True)
21
 
 
28
  for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
29
  os.makedirs(directory, exist_ok=True)
30
 
 
31
  os.environ["HF_HOME"] = model_cache_dir
32
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
33
  os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
34
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
35
  os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
36
 
37
+ # ---------------------------------------------------------------------------------------
38
+ # Add src to path
39
+ # ---------------------------------------------------------------------------------------
40
  current_dir = os.path.dirname(os.path.abspath(__file__))
41
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
42
  sys.path.insert(0, src_path)
43
 
 
 
 
44
  from txagent.txagent import TxAgent
45
 
46
  # ---------------------------------------------------------------------------------------
47
+ # Helper functions
48
  # ---------------------------------------------------------------------------------------
49
  MEDICAL_KEYWORDS = {
50
  'diagnosis', 'assessment', 'plan', 'results', 'medications',
 
62
  try:
63
  text_chunks = []
64
  with pdfplumber.open(file_path) as pdf:
 
65
  for i, page in enumerate(pdf.pages[:3]):
66
  text = page.extract_text() or ""
67
  text_chunks.append(f"=== Page {i+1} ===\n{text.strip()}")
 
68
  for i, page in enumerate(pdf.pages[3:max_pages], start=4):
69
  page_text = page.extract_text() or ""
70
  if any(re.search(rf'\b{kw}\b', page_text.lower()) for kw in MEDICAL_KEYWORDS):
 
84
  if file_type == "pdf":
85
  text = extract_priority_pages(file_path)
86
  result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
 
87
  elif file_type == "csv":
88
  df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str,
89
  skip_blank_lines=False, on_bad_lines="skip")
 
104
  except Exception as e:
105
  return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
106
 
107
+ def log_system_usage(tag=""):
108
  try:
109
+ cpu_percent = psutil.cpu_percent(interval=1)
110
+ mem = psutil.virtual_memory()
111
+ print(f"[{tag}] 🧠 CPU: {cpu_percent}% | RAM: {mem.used // (1024**2)}MB / {mem.total // (1024**2)}MB")
112
+ result = subprocess.run(
113
+ ["nvidia-smi", "--query-gpu=memory.used,memory.total,utilization.gpu", "--format=csv,nounits,noheader"],
114
+ capture_output=True,
115
+ text=True,
116
+ )
117
+ if result.returncode == 0:
118
+ mem_used, mem_total, util = result.stdout.strip().split(", ")
119
+ print(f"[{tag}] ⚡ GPU: {mem_used}MB / {mem_total}MB | Utilization: {util}%")
120
+ else:
121
+ print(f"[{tag}] ⚡ GPU info not available.")
122
  except Exception as e:
123
+ print(f"[{tag}] ⚠️ Failed to log system usage: {e}")
 
 
 
 
 
 
124
 
125
  def init_agent():
126
+ print("🔁 Initializing TxAgent...")
127
+ log_system_usage("Before Model Load")
128
+
129
  default_tool_path = os.path.abspath("data/new_tool.json")
130
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
131
  if not os.path.exists(target_tool_path):
132
  shutil.copy(default_tool_path, target_tool_path)
133
+
134
+ agent = TxAgent(
135
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
136
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
137
  tool_files_dict={"new_tool": target_tool_path},
 
141
  seed=100,
142
  additional_default_tools=[],
143
  )
144
+ agent.init_model()
145
+ log_system_usage("After Model Load")
146
 
147
+ print("✅ TxAgent is ready.")
148
+ print("📦 Cached model files:")
149
+ for root, _, files in os.walk(model_cache_dir):
150
+ for file in files:
151
+ print(os.path.join(root, file))
 
 
152
 
153
+ return agent
 
154
 
155
  # ---------------------------------------------------------------------------------------
156
+ # Gradio UI
157
  # ---------------------------------------------------------------------------------------
158
+ def create_ui(agent):
159
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
160
  gr.Markdown("""
161
  <h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>
 
170
  download_output = gr.File(label="Download Full Report")
171
 
172
  def analyze_potential_oversights(message: str, history: list, files: list):
173
+ history = history + [{"role": "user", "content": message},
174
+ {"role": "assistant", "content": "⏳ Analyzing records for potential oversights..."}]
 
 
 
 
175
  yield history, None
176
 
 
 
 
 
 
 
177
  extracted_data = ""
178
  file_hash_value = ""
179
  if files and isinstance(files, list):
 
182
  executor.submit(convert_file_to_json, f.name, f.name.split(".")[-1].lower())
183
  for f in files if hasattr(f, 'name')
184
  ]
185
+ results = [sanitize_utf8(f.result()) for f in as_completed(futures)]
 
 
186
  extracted_data = "\n".join(results)
187
  file_hash_value = file_hash(files[0].name) if hasattr(files[0], 'name') else ""
188
 
 
189
  max_extracted_chars = 12000
190
  truncated_data = extracted_data[:max_extracted_chars]
191
 
 
200
 
201
  ### Potential Oversights:
202
  """
 
203
  response = ""
204
  try:
 
205
  for chunk in agent.run_gradio_chat(
206
  message=analysis_prompt,
207
  history=[],
 
211
  call_agent=False,
212
  conversation=[]
213
  ):
214
+ if chunk is None: continue
 
215
  if isinstance(chunk, str):
216
  response += chunk
217
  elif isinstance(chunk, list):
218
  response += "".join([c.content for c in chunk if hasattr(c, 'content')])
219
  cleaned = response.replace("[TOOL_CALLS]", "").strip()
 
220
  history[-1] = {"role": "assistant", "content": cleaned}
221
  yield history, None
222
  except Exception as agent_error:
223
+ history[-1] = {"role": "assistant", "content": f"❌ Analysis failed: {str(agent_error)}"}
224
  yield history, None
225
  return
226
 
 
228
  if not final_output:
229
  final_output = "No clear oversights identified. Recommend comprehensive review."
230
 
 
231
  history[-1] = {"role": "assistant", "content": final_output}
232
 
233
  report_path = None
 
244
  msg_input.submit(analyze_potential_oversights,
245
  inputs=[msg_input, gr.State([]), file_upload],
246
  outputs=[chatbot, download_output])
247
+ gr.Examples([
248
+ ["What might have been missed in this patient's treatment?"],
249
+ ["Are there any medication conflicts in these records?"],
250
+ ["What abnormal results require follow-up?"]],
251
+ inputs=msg_input)
252
  return demo
253
 
254
+ # ---------------------------------------------------------------------------------------
255
+ # Launch
256
+ # ---------------------------------------------------------------------------------------
257
  if __name__ == "__main__":
258
+ print("🚀 Starting TxAgent App...")
259
+ agent = init_agent()
260
+ demo = create_ui(agent)
261
  demo.queue(api_open=False).launch(
262
  server_name="0.0.0.0",
263
  server_port=7860,