Ali2206 commited on
Commit
1ba0100
·
verified ·
1 Parent(s): f640ef8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -117
app.py CHANGED
@@ -1,13 +1,11 @@
1
  import sys
2
  import os
3
- import pandas as pd
4
  import pdfplumber
5
  import json
6
  import gradio as gr
7
  from typing import List
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
  import hashlib
10
- import shutil
11
  import re
12
  import psutil
13
  import subprocess
@@ -17,12 +15,11 @@ persistent_dir = "/data/hf_cache"
17
  os.makedirs(persistent_dir, exist_ok=True)
18
 
19
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
20
- tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
21
  file_cache_dir = os.path.join(persistent_dir, "cache")
22
  report_dir = os.path.join(persistent_dir, "reports")
23
  vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache")
24
 
25
- for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
26
  os.makedirs(directory, exist_ok=True)
27
 
28
  os.environ["HF_HOME"] = model_cache_dir
@@ -47,15 +44,23 @@ def file_hash(path: str) -> str:
47
  with open(path, "rb") as f:
48
  return hashlib.md5(f.read()).hexdigest()
49
 
50
- def extract_priority_pages(file_path: str) -> str:
51
  try:
52
  text_chunks = []
 
53
  with pdfplumber.open(file_path) as pdf:
54
  for i, page in enumerate(pdf.pages):
55
  page_text = page.extract_text() or ""
56
  if i < 3 or any(re.search(rf'\b{kw}\b', page_text.lower()) for kw in MEDICAL_KEYWORDS):
57
- text_chunks.append(f"=== Page {i+1} ===\n{page_text.strip()}")
58
- return "\n\n".join(text_chunks)
 
 
 
 
 
 
 
59
  except Exception as e:
60
  return f"PDF processing error: {str(e)}"
61
 
@@ -70,18 +75,6 @@ def convert_file_to_json(file_path: str, file_type: str) -> str:
70
  if file_type == "pdf":
71
  text = extract_priority_pages(file_path)
72
  result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
73
- elif file_type == "csv":
74
- df = pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str,
75
- skip_blank_lines=False, on_bad_lines="skip")
76
- content = df.fillna("").astype(str).values.tolist()
77
- result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
78
- elif file_type in ["xls", "xlsx"]:
79
- try:
80
- df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
81
- except Exception:
82
- df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
83
- content = df.fillna("").astype(str).values.tolist()
84
- result = json.dumps({"filename": os.path.basename(file_path), "rows": content})
85
  else:
86
  result = json.dumps({"error": f"Unsupported file type: {file_type}"})
87
  with open(cache_path, "w", encoding="utf-8") as f:
@@ -107,34 +100,25 @@ def log_system_usage(tag=""):
107
 
108
  def clean_response(text: str) -> str:
109
  text = sanitize_utf8(text)
110
- # Remove tool calls, JSON data, and repetitive phrases
111
  text = re.sub(r"\[TOOL_CALLS\].*", "", text, flags=re.DOTALL)
112
- text = re.sub(r"\['get_[^\]]+\']\n?", "", text) # Remove tool names
113
- text = re.sub(r"\{'meta':\s*\{.*?\}\s*,\s*'results':\s*\[.*?\]\}\n?", "", text, flags=re.DOTALL) # Remove JSON
114
- text = re.sub(r"To analyze the medical records for clinical oversights.*?begin by reviewing.*?\n", "", text, flags=re.DOTALL)
115
  text = re.sub(r"\n{3,}", "\n\n", text).strip()
116
- # Only keep text under analysis headings or relevant content
117
- if not re.search(r"(Missed Diagnoses|Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", text):
118
  return ""
119
  return text
120
 
121
  def init_agent():
122
  print("🔁 Initializing model...")
123
  log_system_usage("Before Load")
124
- default_tool_path = os.path.abspath("data/new_tool.json")
125
- target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
126
- if not os.path.exists(target_tool_path):
127
- shutil.copy(default_tool_path, target_tool_path)
128
-
129
  agent = TxAgent(
130
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
131
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
132
- tool_files_dict={"new_tool": target_tool_path},
133
  force_finish=True,
134
  enable_checker=True,
135
- step_rag_num=2,
136
  seed=100,
137
- additional_default_tools=[],
138
  )
139
  agent.init_model()
140
  log_system_usage("After Load")
@@ -145,14 +129,13 @@ def create_ui(agent):
145
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
146
  gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
147
  chatbot = gr.Chatbot(label="Analysis", height=600, type="messages")
148
- file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
149
  msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
150
  send_btn = gr.Button("Analyze", variant="primary")
151
- download_output = gr.File(label="Download Full Report")
152
 
153
  def analyze(message: str, history: List[dict], files: List):
154
  history.append({"role": "user", "content": message})
155
- history.append({"role": "assistant", "content": "⏳ Analyzing records for potential oversights..."})
156
  yield history, None
157
 
158
  extracted = ""
@@ -164,101 +147,64 @@ def create_ui(agent):
164
  extracted = "\n".join(results)
165
  file_hash_value = file_hash(files[0].name) if files else ""
166
 
167
- # Split extracted text into chunks of ~4,000 characters
168
- chunk_size = 4000
169
- chunks = [extracted[i:i + chunk_size] for i in range(0, len(extracted), chunk_size)]
170
- combined_response = ""
171
-
172
- prompt_template = f"""
173
- Analyze the medical records for clinical oversights. Provide a concise, evidence-based summary under these headings:
174
-
175
- 1. **Missed Diagnoses**:
176
- - Identify inconsistencies in history, symptoms, or tests.
177
- - Consider psychiatric, neurological, infectious, autoimmune, genetic conditions, family history, trauma, and developmental factors.
178
-
179
- 2. **Medication Conflicts**:
180
- - Check for contraindications, interactions, or unjustified off-label use.
181
- - Assess if medications worsen diagnoses or cause adverse effects.
182
-
183
- 3. **Incomplete Assessments**:
184
- - Note missing or superficial cognitive, psychiatric, social, or family assessments.
185
- - Highlight gaps in medical history, substance use, or lab/imaging documentation.
186
 
187
- 4. **Urgent Follow-up**:
188
- - Flag abnormal lab results, imaging, behaviors, or legal history needing immediate reassessment or referral.
 
 
189
 
190
- Medical Records (Chunk {0} of {1}):
191
- {{chunk}}
192
 
193
- Begin analysis:
194
  """
195
 
196
  try:
197
- if history and history[-1]["content"].startswith(""):
198
- history.pop()
199
-
200
- # Process each chunk and stream cleaned results
201
- for chunk_idx, chunk in enumerate(chunks, 1):
202
- # Update UI with progress
203
- history.append({"role": "assistant", "content": f"🔄 Processing Chunk {chunk_idx} of {len(chunks)}..."})
204
- yield history, None
205
-
206
- prompt = prompt_template.format(chunk_idx, len(chunks), chunk=chunk)
207
- chunk_response = ""
208
- for chunk_output in agent.run_gradio_chat(
209
- message=prompt,
210
- history=[],
211
- temperature=0.2,
212
- max_new_tokens=1024,
213
- max_token=4096,
214
- call_agent=False,
215
- conversation=[],
216
- ):
217
- if chunk_output is None:
218
- continue
219
- if isinstance(chunk_output, list):
220
- for m in chunk_output:
221
- if hasattr(m, 'content') and m.content:
222
- cleaned = clean_response(m.content)
223
- if cleaned:
224
- chunk_response += cleaned + "\n"
225
- # Stream partial response to UI
226
- if history[-1]["content"].startswith("🔄"):
227
- history[-1] = {"role": "assistant", "content": f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response.strip()}"}
228
- else:
229
- history[-1]["content"] = f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response.strip()}"
230
- yield history, None
231
- elif isinstance(chunk_output, str) and chunk_output.strip():
232
- cleaned = clean_response(chunk_output)
233
- if cleaned:
234
- chunk_response += cleaned + "\n"
235
- # Stream partial response to UI
236
- if history[-1]["content"].startswith("🔄"):
237
- history[-1] = {"role": "assistant", "content": f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response.strip()}"}
238
- else:
239
- history[-1]["content"] = f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response.strip()}"
240
- yield history, None
241
-
242
- # Append completed chunk response to combined response
243
- if chunk_response:
244
- combined_response += f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response}\n"
245
 
246
- # Finalize UI with complete response
247
- if combined_response:
248
- history[-1]["content"] = combined_response.strip()
249
- else:
250
- history.append({"role": "assistant", "content": "No oversights identified."})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
- # Generate report file with cleaned response
253
  report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None
254
- if report_path:
255
  with open(report_path, "w", encoding="utf-8") as f:
256
- f.write(combined_response)
257
  yield history, report_path if report_path and os.path.exists(report_path) else None
258
 
259
  except Exception as e:
260
  print("🚨 ERROR:", e)
261
- history.append({"role": "assistant", "content": f"❌ Error occurred: {str(e)}"})
262
  yield history, None
263
 
264
  send_btn.click(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output])
 
1
  import sys
2
  import os
 
3
  import pdfplumber
4
  import json
5
  import gradio as gr
6
  from typing import List
7
  from concurrent.futures import ThreadPoolExecutor, as_completed
8
  import hashlib
 
9
  import re
10
  import psutil
11
  import subprocess
 
15
  os.makedirs(persistent_dir, exist_ok=True)
16
 
17
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
 
18
  file_cache_dir = os.path.join(persistent_dir, "cache")
19
  report_dir = os.path.join(persistent_dir, "reports")
20
  vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache")
21
 
22
+ for directory in [model_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
23
  os.makedirs(directory, exist_ok=True)
24
 
25
  os.environ["HF_HOME"] = model_cache_dir
 
44
  with open(path, "rb") as f:
45
  return hashlib.md5(f.read()).hexdigest()
46
 
47
+ def extract_priority_pages(file_path: str, max_chars: int = 6000) -> str:
48
  try:
49
  text_chunks = []
50
+ total_chars = 0
51
  with pdfplumber.open(file_path) as pdf:
52
  for i, page in enumerate(pdf.pages):
53
  page_text = page.extract_text() or ""
54
  if i < 3 or any(re.search(rf'\b{kw}\b', page_text.lower()) for kw in MEDICAL_KEYWORDS):
55
+ page_chunk = f"=== Page {i+1} ===\n{page_text.strip()}\n"
56
+ if total_chars + len(page_chunk) <= max_chars:
57
+ text_chunks.append(page_chunk)
58
+ total_chars += len(page_chunk)
59
+ else:
60
+ remaining = max_chars - total_chars
61
+ text_chunks.append(page_chunk[:remaining])
62
+ break
63
+ return "".join(text_chunks).strip()
64
  except Exception as e:
65
  return f"PDF processing error: {str(e)}"
66
 
 
75
  if file_type == "pdf":
76
  text = extract_priority_pages(file_path)
77
  result = json.dumps({"filename": os.path.basename(file_path), "content": text, "status": "initial"})
 
 
 
 
 
 
 
 
 
 
 
 
78
  else:
79
  result = json.dumps({"error": f"Unsupported file type: {file_type}"})
80
  with open(cache_path, "w", encoding="utf-8") as f:
 
100
 
101
  def clean_response(text: str) -> str:
102
  text = sanitize_utf8(text)
 
103
  text = re.sub(r"\[TOOL_CALLS\].*", "", text, flags=re.DOTALL)
104
+ text = re.sub(r"\['get_[^\]]+\']\n?", "", text)
105
+ text = re.sub(r"\{'meta':\s*\{.*?\}\s*,\s*'results':\s*\[.*?\]\}\n?", "", text, flags=re.DOTALL)
106
+ text = re.sub(r"(?i)(to analyze|based on|will start|no (drug|clinical|information)).*?\n", "", text, flags=re.DOTALL)
107
  text = re.sub(r"\n{3,}", "\n\n", text).strip()
108
+ if not re.search(r"(Missed Diagnoses|Medication Conflicts|Incomplete Assessments|Urgent Follow-up)", text, re.IGNORECASE):
 
109
  return ""
110
  return text
111
 
112
  def init_agent():
113
  print("🔁 Initializing model...")
114
  log_system_usage("Before Load")
 
 
 
 
 
115
  agent = TxAgent(
116
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
117
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
 
118
  force_finish=True,
119
  enable_checker=True,
120
+ step_rag_num=1,
121
  seed=100,
 
122
  )
123
  agent.init_model()
124
  log_system_usage("After Load")
 
129
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
130
  gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
131
  chatbot = gr.Chatbot(label="Analysis", height=600, type="messages")
132
+ file_upload = gr.File(file_types=[".pdf"], file_count="multiple")
133
  msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
134
  send_btn = gr.Button("Analyze", variant="primary")
135
+ download_output = gr.File(label="Download Report")
136
 
137
  def analyze(message: str, history: List[dict], files: List):
138
  history.append({"role": "user", "content": message})
 
139
  yield history, None
140
 
141
  extracted = ""
 
147
  extracted = "\n".join(results)
148
  file_hash_value = file_hash(files[0].name) if files else ""
149
 
150
+ prompt = f"""
151
+ Analyze the medical records and list potential doctor oversights under these headings only, with brief details:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
+ **Missed Diagnoses**: Inconsistencies or unaddressed conditions.
154
+ **Medication Conflicts**: Contraindications or risky prescriptions.
155
+ **Incomplete Assessments**: Missing or shallow evaluations.
156
+ **Urgent Follow-up**: Issues needing immediate attention.
157
 
158
+ Records:
159
+ {extracted[:6000]}
160
 
161
+ Respond concisely.
162
  """
163
 
164
  try:
165
+ history.append({"role": "assistant", "content": "🔄 Analyzing..."})
166
+ yield history, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
+ response = ""
169
+ for output in agent.run_gradio_chat(
170
+ message=prompt,
171
+ history=[],
172
+ temperature=0.1,
173
+ max_new_tokens=512,
174
+ max_token=4096,
175
+ call_agent=False,
176
+ conversation=[],
177
+ ):
178
+ if output is None:
179
+ continue
180
+ if isinstance(output, list):
181
+ for m in output:
182
+ if hasattr(m, 'content') and m.content:
183
+ cleaned = clean_response(m.content)
184
+ if cleaned:
185
+ response += cleaned + "\n"
186
+ history[-1]["content"] = response.strip()
187
+ yield history, None
188
+ elif isinstance(output, str) and output.strip():
189
+ cleaned = clean_response(output)
190
+ if cleaned:
191
+ response += cleaned + "\n"
192
+ history[-1]["content"] = response.strip()
193
+ yield history, None
194
+
195
+ if not response:
196
+ history[-1]["content"] = "No oversights identified."
197
+ yield history, None
198
 
 
199
  report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt") if file_hash_value else None
200
+ if report_path and response:
201
  with open(report_path, "w", encoding="utf-8") as f:
202
+ f.write(response.strip())
203
  yield history, report_path if report_path and os.path.exists(report_path) else None
204
 
205
  except Exception as e:
206
  print("🚨 ERROR:", e)
207
+ history[-1]["content"] = f"❌ Error: {str(e)}"
208
  yield history, None
209
 
210
  send_btn.click(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output])