Ali2206 commited on
Commit
12ddaba
·
verified ·
1 Parent(s): 63d0c23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +389 -148
app.py CHANGED
@@ -27,14 +27,12 @@ vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache")
27
  for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
28
  os.makedirs(directory, exist_ok=True)
29
 
30
- # Environment variables
31
  os.environ["HF_HOME"] = model_cache_dir
32
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
33
  os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
34
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
35
  os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
36
 
37
- # Add src to path
38
  current_dir = os.path.dirname(os.path.abspath(__file__))
39
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
40
  sys.path.insert(0, src_path)
@@ -48,14 +46,11 @@ MEDICAL_KEYWORDS = {
48
  'conclusion', 'history', 'examination', 'progress', 'discharge'
49
  }
50
  TOKENIZER = "cl100k_base"
51
- # Increase max model length to support larger contexts
52
- MAX_MODEL_LEN = 4096
53
- # Default chunk target tokens
54
- TARGET_CHUNK_TOKENS = 1200
55
- PROMPT_RESERVE = 100
56
  MEDICAL_SECTION_HEADER = "=== MEDICAL SECTION ==="
57
 
58
-
59
  def log_system_usage(tag=""):
60
  try:
61
  cpu = psutil.cpu_percent(interval=1)
@@ -71,47 +66,52 @@ def log_system_usage(tag=""):
71
  except Exception as e:
72
  print(f"[{tag}] GPU/CPU monitor failed: {e}")
73
 
74
-
75
  def sanitize_utf8(text: str) -> str:
76
  return text.encode("utf-8", "ignore").decode("utf-8")
77
 
78
-
79
  def file_hash(path: str) -> str:
80
  with open(path, "rb") as f:
81
  return hashlib.md5(f.read()).hexdigest()
82
 
83
-
84
  def count_tokens(text: str) -> int:
85
  encoding = tiktoken.get_encoding(TOKENIZER)
86
  return len(encoding.encode(text))
87
 
88
-
89
  def extract_all_pages_with_token_count(file_path: str) -> Tuple[str, int, int]:
90
  try:
91
  text_chunks = []
92
  total_pages = 0
93
  total_tokens = 0
 
94
  with pdfplumber.open(file_path) as pdf:
95
  total_pages = len(pdf.pages)
 
96
  for i, page in enumerate(pdf.pages):
97
  page_text = page.extract_text() or ""
98
  lower_text = page_text.lower()
99
- header = f"\n{MEDICAL_SECTION_HEADER} (Page {i+1})\n" if any(
100
- re.search(rf'\b{kw}\b', lower_text) for kw in MEDICAL_KEYWORDS
101
- ) else f"\n=== Page {i+1} ===\n"
102
- text_chunks.append(header + page_text.strip())
103
- total_tokens += count_tokens(header) + count_tokens(page_text)
 
 
 
 
 
104
  return "\n".join(text_chunks), total_pages, total_tokens
105
  except Exception as e:
106
  return f"PDF processing error: {str(e)}", 0, 0
107
 
108
-
109
  def convert_file_to_json(file_path: str, file_type: str) -> str:
110
  try:
111
  h = file_hash(file_path)
112
  cache_path = os.path.join(file_cache_dir, f"{h}.json")
 
113
  if os.path.exists(cache_path):
114
- return open(cache_path, "r", encoding="utf-8").read()
 
 
115
  if file_type == "pdf":
116
  text, total_pages, total_tokens = extract_all_pages_with_token_count(file_path)
117
  result = json.dumps({
@@ -123,12 +123,10 @@ def convert_file_to_json(file_path: str, file_type: str) -> str:
123
  })
124
  elif file_type == "csv":
125
  chunks = []
126
- for chunk in pd.read_csv(
127
- file_path, encoding_errors="replace", header=None, dtype=str,
128
- skip_blank_lines=False, on_bad_lines="skip", chunksize=1000
129
- ):
130
  chunks.append(chunk.fillna("").astype(str).values.tolist())
131
- content = [item for sub in chunks for item in sub]
132
  result = json.dumps({
133
  "filename": os.path.basename(file_path),
134
  "rows": content,
@@ -137,9 +135,9 @@ def convert_file_to_json(file_path: str, file_type: str) -> str:
137
  elif file_type in ["xls", "xlsx"]:
138
  try:
139
  df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
140
- except:
141
  df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
142
- content = df.fillna("" ).astype(str).values.tolist()
143
  result = json.dumps({
144
  "filename": os.path.basename(file_path),
145
  "rows": content,
@@ -147,97 +145,129 @@ def convert_file_to_json(file_path: str, file_type: str) -> str:
147
  })
148
  else:
149
  result = json.dumps({"error": f"Unsupported file type: {file_type}"})
 
150
  with open(cache_path, "w", encoding="utf-8") as f:
151
  f.write(result)
152
  return result
153
  except Exception as e:
154
  return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
155
 
156
-
157
  def clean_response(text: str) -> str:
158
  text = sanitize_utf8(text)
159
- patterns = [
160
- r"\[TOOL_CALLS\].*",
161
- r"\['get_[^\]]+\']\n?",
162
- r"\{'meta':\s*\{.*?\}\s*,\s*'results':\s*\[.*?\]\}\n?",
163
- r"To analyze the medical records for clinical oversights.*?\n"
164
- ]
165
- for pat in patterns:
166
- text = re.sub(pat, "", text, flags=re.DOTALL)
167
- return re.sub(r"\n{3,}", "\n\n", text).strip()
168
-
169
 
170
  def format_final_report(analysis_results: List[str], filename: str) -> str:
171
- report = [
172
- "COMPREHENSIVE CLINICAL OVERSIGHT ANALYSIS",
173
- f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
174
- f"File: {filename}",
175
- "=" * 80
176
- ]
177
- sections = {s: [] for s in [
178
- "CRITICAL FINDINGS", "MISSED DIAGNOSES", "MEDICATION ISSUES",
179
- "ASSESSMENT GAPS", "FOLLOW-UP RECOMMENDATIONS"
180
- ]}
181
- for res in analysis_results:
182
- for sec in sections:
183
- m = re.search(
184
- rf"{re.escape(sec)}:?\s*
185
- (.+?)(?=
186
- \*|
187
-
188
- |$)",
189
- res, re.IGNORECASE | re.DOTALL
 
190
  )
191
- if m:
192
- content = m.group(1).strip()
193
- if content and content not in sections[sec]:
194
- sections[sec].append(content)
 
195
  if sections["CRITICAL FINDINGS"]:
196
  report.append("\n🚨 **CRITICAL FINDINGS** 🚨")
197
- report.extend(f"\n{c}" for c in sections["CRITICAL FINDINGS"])
198
- for sec, conts in sections.items():
199
- if sec != "CRITICAL FINDINGS" and conts:
200
- report.append(f"\n**{sec}**")
201
- report.extend(f"\n{c}" for c in conts)
 
 
 
 
202
  if not any(sections.values()):
203
  report.append("\nNo significant clinical oversights identified.")
204
- report.append("\n" + "="*80)
 
205
  report.append("END OF REPORT")
 
206
  return "\n".join(report)
207
 
208
-
209
- def split_content_by_tokens(content: str, max_tokens: int) -> List[str]:
210
  paragraphs = re.split(r"\n\s*\n", content)
211
- chunks, current, curr_toks = [], [], 0
 
 
 
212
  for para in paragraphs:
213
- toks = count_tokens(para)
214
- if toks > max_tokens:
215
- for sent in re.split(r'(?<=[.!?])\s+', para):
216
- sent_toks = count_tokens(sent)
217
- if curr_toks + sent_toks > max_tokens:
218
- chunks.append("\n\n".join(current))
219
- current, curr_toks = [sent], sent_toks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  else:
221
- current.append(sent)
222
- curr_toks += sent_toks
223
- elif curr_toks + toks > max_tokens:
224
- chunks.append("\n\n".join(current))
225
- current, curr_toks = [para], toks
 
226
  else:
227
- current.append(para)
228
- curr_toks += toks
229
- if current:
230
- chunks.append("\n\n".join(current))
 
 
231
  return chunks
232
 
233
-
234
  def init_agent():
235
  print("🔁 Initializing model...")
236
  log_system_usage("Before Load")
 
237
  default_tool_path = os.path.abspath("data/new_tool.json")
238
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
239
  if not os.path.exists(target_tool_path):
240
  shutil.copy(default_tool_path, target_tool_path)
 
241
  agent = TxAgent(
242
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
243
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
@@ -246,88 +276,294 @@ def init_agent():
246
  enable_checker=True,
247
  step_rag_num=2,
248
  seed=100,
249
- additional_default_tools=[]
250
  )
251
  agent.init_model()
252
  log_system_usage("After Load")
253
  print("✅ Agent Ready")
254
  return agent
255
 
256
-
257
  def analyze_complete_document(content: str, filename: str, agent: TxAgent, temperature: float = 0.3) -> str:
258
- base_prompt = (
259
- "Analyze for:\n1. Critical\n2. Missed DX\n3. Med issues\n4. Gaps\n5. Follow-up\n\nContent:\n"
260
- )
261
- prompt_toks = count_tokens(base_prompt)
262
- max_chunk_toks = MAX_MODEL_LEN - prompt_toks - PROMPT_RESERVE
263
- chunks = split_content_by_tokens(content, max_chunk_toks)
264
- results = []
265
  for i, chunk in enumerate(chunks):
266
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  prompt = base_prompt + chunk
 
 
 
 
 
 
 
 
 
268
  response = ""
269
- for out in agent.run_gradio_chat(
270
  message=prompt,
271
  history=[],
272
  temperature=temperature,
273
- max_new_tokens=300,
274
  max_token=MAX_MODEL_LEN,
275
  call_agent=False,
276
- conversation=[]
277
  ):
278
- if out:
279
- if isinstance(out, list):
280
- for m in out:
281
- response += clean_response(m.content if hasattr(m, 'content') else str(m))
282
- else:
283
- response += clean_response(str(out))
 
 
284
  if response:
285
- results.append(response)
286
  except Exception as e:
287
- print(f"Error processing chunk {i}: {e}")
288
- return format_final_report(results, filename)
289
-
 
290
 
291
  def create_ui(agent):
292
- with gr.Blocks(title="Clinical Oversight Assistant") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  gr.Markdown("""
294
- # 🩺 Clinical Oversight Assistant
295
- Analyze medical records for potential oversights and generate comprehensive reports
 
 
 
 
296
  """)
297
- with gr.Row():
298
- with gr.Column():
299
- file_upload = gr.File(label="Upload Medical Records", file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
300
- msg_input = gr.Textbox(label="Analysis Focus (optional)")
301
- temperature = gr.Slider(0.1, 1.0, value=0.3, label="Analysis Strictness")
302
- send_btn = gr.Button("Analyze Documents", variant="primary")
303
- clear_btn = gr.Button("Clear All")
304
- status = gr.Textbox(label="Status", interactive=False)
305
- with gr.Column():
306
- report_output = gr.Textbox(label="Report", lines=20, interactive=False)
307
- data_preview = gr.Dataframe(headers=["File", "Snippet"], interactive=False)
308
- download_output = gr.File(label="Download Report")
309
- def analyze(files, msg, temp):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  if not files:
311
- yield "", None, "⚠️ Please upload files.", None
312
- return
313
- yield "", None, "⏳ Processing...", None
314
- previews = []
315
- contents = []
316
- for f in files:
317
- res = json.loads(sanitize_utf8(convert_file_to_json(f.name, os.path.splitext(f.name)[1][1:].lower())))
318
- if "content" in res:
319
- previews.append([res["filename"], res["content"][:200] + "..."])
320
- contents.append(res["content"])
321
- yield "", None, f"🔍 Analyzing {len(contents)} docs...", previews
322
- combined = "\n".join(contents)
323
- report = analyze_complete_document(combined, "+".join([os.path.basename(f.name) for f in files]), agent, temp)
324
- file_hash_val = hashlib.md5(combined.encode()).hexdigest()
325
- path = os.path.join(report_dir, f"{file_hash_val}_report.txt")
326
- with open(path, "w", encoding="utf-8") as rd:
327
- rd.write(report)
328
- yield report, path, "✅ Analysis complete!", previews
329
- send_btn.click(analyze, [file_upload, msg_input, temperature], [report_output, download_output, status, data_preview])
330
- clear_btn.click(lambda: (None, None, "", None), None, [report_output, download_output, status, data_preview])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  return demo
332
 
333
  if __name__ == "__main__":
@@ -335,13 +571,18 @@ if __name__ == "__main__":
335
  try:
336
  import tiktoken
337
  except ImportError:
 
338
  subprocess.run([sys.executable, "-m", "pip", "install", "tiktoken"])
 
339
  agent = init_agent()
340
  demo = create_ui(agent)
341
- demo.queue(api_open=False, max_size=20).launch(
 
 
 
342
  server_name="0.0.0.0",
343
  server_port=7860,
344
  show_error=True,
345
- share=False,
346
- allowed_paths=[report_dir]
347
- )
 
27
  for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
28
  os.makedirs(directory, exist_ok=True)
29
 
 
30
  os.environ["HF_HOME"] = model_cache_dir
31
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
32
  os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
33
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
34
  os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
35
 
 
36
  current_dir = os.path.dirname(os.path.abspath(__file__))
37
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
38
  sys.path.insert(0, src_path)
 
46
  'conclusion', 'history', 'examination', 'progress', 'discharge'
47
  }
48
  TOKENIZER = "cl100k_base"
49
+ MAX_MODEL_LEN = 2048
50
+ TARGET_CHUNK_TOKENS = 1000 # Reduced from 1200 to be more conservative
51
+ PROMPT_RESERVE = 400 # Increased buffer for prompt + response
 
 
52
  MEDICAL_SECTION_HEADER = "=== MEDICAL SECTION ==="
53
 
 
54
  def log_system_usage(tag=""):
55
  try:
56
  cpu = psutil.cpu_percent(interval=1)
 
66
  except Exception as e:
67
  print(f"[{tag}] GPU/CPU monitor failed: {e}")
68
 
 
69
  def sanitize_utf8(text: str) -> str:
70
  return text.encode("utf-8", "ignore").decode("utf-8")
71
 
 
72
  def file_hash(path: str) -> str:
73
  with open(path, "rb") as f:
74
  return hashlib.md5(f.read()).hexdigest()
75
 
 
76
  def count_tokens(text: str) -> int:
77
  encoding = tiktoken.get_encoding(TOKENIZER)
78
  return len(encoding.encode(text))
79
 
 
80
  def extract_all_pages_with_token_count(file_path: str) -> Tuple[str, int, int]:
81
  try:
82
  text_chunks = []
83
  total_pages = 0
84
  total_tokens = 0
85
+
86
  with pdfplumber.open(file_path) as pdf:
87
  total_pages = len(pdf.pages)
88
+
89
  for i, page in enumerate(pdf.pages):
90
  page_text = page.extract_text() or ""
91
  lower_text = page_text.lower()
92
+
93
+ if any(re.search(rf'\b{kw}\b', lower_text) for kw in MEDICAL_KEYWORDS):
94
+ section_header = f"\n{MEDICAL_SECTION_HEADER} (Page {i+1})\n"
95
+ text_chunks.append(section_header + page_text.strip())
96
+ total_tokens += count_tokens(section_header)
97
+ else:
98
+ text_chunks.append(f"\n=== Page {i+1} ===\n{page_text.strip()}")
99
+
100
+ total_tokens += count_tokens(page_text)
101
+
102
  return "\n".join(text_chunks), total_pages, total_tokens
103
  except Exception as e:
104
  return f"PDF processing error: {str(e)}", 0, 0
105
 
 
106
  def convert_file_to_json(file_path: str, file_type: str) -> str:
107
  try:
108
  h = file_hash(file_path)
109
  cache_path = os.path.join(file_cache_dir, f"{h}.json")
110
+
111
  if os.path.exists(cache_path):
112
+ with open(cache_path, "r", encoding="utf-8") as f:
113
+ return f.read()
114
+
115
  if file_type == "pdf":
116
  text, total_pages, total_tokens = extract_all_pages_with_token_count(file_path)
117
  result = json.dumps({
 
123
  })
124
  elif file_type == "csv":
125
  chunks = []
126
+ for chunk in pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str,
127
+ skip_blank_lines=False, on_bad_lines="skip", chunksize=1000):
 
 
128
  chunks.append(chunk.fillna("").astype(str).values.tolist())
129
+ content = [item for sublist in chunks for item in sublist]
130
  result = json.dumps({
131
  "filename": os.path.basename(file_path),
132
  "rows": content,
 
135
  elif file_type in ["xls", "xlsx"]:
136
  try:
137
  df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
138
+ except Exception:
139
  df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
140
+ content = df.fillna("").astype(str).values.tolist()
141
  result = json.dumps({
142
  "filename": os.path.basename(file_path),
143
  "rows": content,
 
145
  })
146
  else:
147
  result = json.dumps({"error": f"Unsupported file type: {file_type}"})
148
+
149
  with open(cache_path, "w", encoding="utf-8") as f:
150
  f.write(result)
151
  return result
152
  except Exception as e:
153
  return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
154
 
 
155
  def clean_response(text: str) -> str:
156
  text = sanitize_utf8(text)
157
+ text = re.sub(r"\[TOOL_CALLS\].*", "", text, flags=re.DOTALL)
158
+ text = re.sub(r"\['get_[^\]]+\']\n?", "", text)
159
+ text = re.sub(r"\{'meta':\s*\{.*?\}\s*,\s*'results':\s*\[.*?\]\}\n?", "", text, flags=re.DOTALL)
160
+ text = re.sub(r"To analyze the medical records for clinical oversights.*?begin by reviewing.*?\n", "", text, flags=re.DOTALL)
161
+ text = re.sub(r"\n{3,}", "\n\n", text).strip()
162
+ return text
 
 
 
 
163
 
164
  def format_final_report(analysis_results: List[str], filename: str) -> str:
165
+ report = []
166
+ report.append(f"COMPREHENSIVE CLINICAL OVERSIGHT ANALYSIS")
167
+ report.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
168
+ report.append(f"File: {filename}")
169
+ report.append("=" * 80)
170
+
171
+ sections = {
172
+ "CRITICAL FINDINGS": [],
173
+ "MISSED DIAGNOSES": [],
174
+ "MEDICATION ISSUES": [],
175
+ "ASSESSMENT GAPS": [],
176
+ "FOLLOW-UP RECOMMENDATIONS": []
177
+ }
178
+
179
+ for result in analysis_results:
180
+ for section in sections:
181
+ section_match = re.search(
182
+ rf"{re.escape(section)}:?\s*\n([^*]+?)(?=\n\*|\n\n|$)",
183
+ result,
184
+ re.IGNORECASE | re.DOTALL
185
  )
186
+ if section_match:
187
+ content = section_match.group(1).strip()
188
+ if content and content not in sections[section]:
189
+ sections[section].append(content)
190
+
191
  if sections["CRITICAL FINDINGS"]:
192
  report.append("\n🚨 **CRITICAL FINDINGS** 🚨")
193
+ for content in sections["CRITICAL FINDINGS"]:
194
+ report.append(f"\n{content}")
195
+
196
+ for section, contents in sections.items():
197
+ if section != "CRITICAL FINDINGS" and contents:
198
+ report.append(f"\n**{section.upper()}**")
199
+ for content in contents:
200
+ report.append(f"\n{content}")
201
+
202
  if not any(sections.values()):
203
  report.append("\nNo significant clinical oversights identified.")
204
+
205
+ report.append("\n" + "=" * 80)
206
  report.append("END OF REPORT")
207
+
208
  return "\n".join(report)
209
 
210
+ def split_content_by_tokens(content: str, max_tokens: int = TARGET_CHUNK_TOKENS) -> List[str]:
211
+ """More conservative splitting that ensures we stay well under token limits"""
212
  paragraphs = re.split(r"\n\s*\n", content)
213
+ chunks = []
214
+ current_chunk = []
215
+ current_tokens = 0
216
+
217
  for para in paragraphs:
218
+ para_tokens = count_tokens(para)
219
+
220
+ # If paragraph is too big, split into sentences
221
+ if para_tokens > max_tokens * 0.8: # Don't allow paragraphs that take up most of the chunk
222
+ sentences = re.split(r'(?<=[.!?])\s+', para)
223
+ for sent in sentences:
224
+ sent_tokens = count_tokens(sent)
225
+ if current_tokens + sent_tokens > max_tokens * 0.9: # Leave 10% buffer
226
+ if current_chunk: # Only add if we have content
227
+ chunks.append("\n\n".join(current_chunk))
228
+ current_chunk = []
229
+ current_tokens = 0
230
+
231
+ # If single sentence is too long, split into words
232
+ if sent_tokens > max_tokens * 0.8:
233
+ words = sent.split()
234
+ for word in words:
235
+ word_tokens = count_tokens(word)
236
+ if current_tokens + word_tokens > max_tokens * 0.9:
237
+ if current_chunk:
238
+ chunks.append("\n\n".join(current_chunk))
239
+ current_chunk = []
240
+ current_tokens = 0
241
+ current_chunk.append(word)
242
+ current_tokens += word_tokens
243
+ else:
244
+ current_chunk.append(sent)
245
+ current_tokens += sent_tokens
246
  else:
247
+ current_chunk.append(sent)
248
+ current_tokens += sent_tokens
249
+ elif current_tokens + para_tokens > max_tokens * 0.9:
250
+ chunks.append("\n\n".join(current_chunk))
251
+ current_chunk = [para]
252
+ current_tokens = para_tokens
253
  else:
254
+ current_chunk.append(para)
255
+ current_tokens += para_tokens
256
+
257
+ if current_chunk:
258
+ chunks.append("\n\n".join(current_chunk))
259
+
260
  return chunks
261
 
 
262
  def init_agent():
263
  print("🔁 Initializing model...")
264
  log_system_usage("Before Load")
265
+
266
  default_tool_path = os.path.abspath("data/new_tool.json")
267
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
268
  if not os.path.exists(target_tool_path):
269
  shutil.copy(default_tool_path, target_tool_path)
270
+
271
  agent = TxAgent(
272
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
273
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
 
276
  enable_checker=True,
277
  step_rag_num=2,
278
  seed=100,
279
+ additional_default_tools=[],
280
  )
281
  agent.init_model()
282
  log_system_usage("After Load")
283
  print("✅ Agent Ready")
284
  return agent
285
 
 
286
  def analyze_complete_document(content: str, filename: str, agent: TxAgent, temperature: float = 0.3) -> str:
287
+ """Analyze complete document with strict token management"""
288
+ chunks = split_content_by_tokens(content)
289
+ analysis_results = []
290
+
 
 
 
291
  for i, chunk in enumerate(chunks):
292
  try:
293
+ # Minimal prompt template
294
+ base_prompt = """Analyze this medical content for:
295
+ 1. Critical findings needing immediate attention
296
+ 2. Potential missed diagnoses
297
+ 3. Medication issues
298
+ 4. Assessment gaps
299
+ 5. Follow-up recommendations
300
+
301
+ Content:\n"""
302
+
303
+ # Calculate available space
304
+ prompt_tokens = count_tokens(base_prompt)
305
+ max_content_tokens = MAX_MODEL_LEN - prompt_tokens - 300 # 300 tokens for response
306
+
307
+ # Ensure chunk fits
308
+ chunk_tokens = count_tokens(chunk)
309
+ if chunk_tokens > max_content_tokens:
310
+ # If still too big after splitting, truncate
311
+ encoding = tiktoken.get_encoding(TOKENIZER)
312
+ tokens = encoding.encode(chunk)
313
+ chunk = encoding.decode(tokens[:max_content_tokens])
314
+ print(f"Warning: Truncated chunk {i} from {chunk_tokens} to {max_content_tokens} tokens")
315
+
316
  prompt = base_prompt + chunk
317
+
318
+ # Final verification
319
+ total_tokens = count_tokens(prompt)
320
+ if total_tokens > MAX_MODEL_LEN - 200:
321
+ encoding = tiktoken.get_encoding(TOKENIZER)
322
+ tokens = encoding.encode(prompt)
323
+ prompt = encoding.decode(tokens[:MAX_MODEL_LEN - 200])
324
+ print(f"Warning: Truncated final prompt from {total_tokens} tokens")
325
+
326
  response = ""
327
+ for output in agent.run_gradio_chat(
328
  message=prompt,
329
  history=[],
330
  temperature=temperature,
331
+ max_new_tokens=200, # Conservative response length
332
  max_token=MAX_MODEL_LEN,
333
  call_agent=False,
334
+ conversation=[],
335
  ):
336
+ if output:
337
+ if isinstance(output, list):
338
+ for m in output:
339
+ if hasattr(m, 'content'):
340
+ response += clean_response(m.content)
341
+ elif isinstance(output, str):
342
+ response += clean_response(output)
343
+
344
  if response:
345
+ analysis_results.append(response)
346
  except Exception as e:
347
+ print(f"Error processing chunk {i}: {str(e)}")
348
+ continue
349
+
350
+ return format_final_report(analysis_results, filename)
351
 
352
  def create_ui(agent):
353
+ with gr.Blocks(
354
+ theme=gr.themes.Soft(
355
+ primary_hue="indigo",
356
+ secondary_hue="blue",
357
+ neutral_hue="slate",
358
+ spacing_size="md",
359
+ radius_size="md"
360
+ ),
361
+ title="Clinical Oversight Assistant",
362
+ css="""
363
+ .report-box {
364
+ border: 1px solid #e0e0e0;
365
+ border-radius: 8px;
366
+ padding: 16px;
367
+ background-color: #f9f9f9;
368
+ }
369
+ .file-upload {
370
+ background-color: #f5f7fa;
371
+ padding: 16px;
372
+ border-radius: 8px;
373
+ }
374
+ .analysis-btn {
375
+ width: 100%;
376
+ }
377
+ .critical-finding {
378
+ color: #d32f2f;
379
+ font-weight: bold;
380
+ }
381
+ .dataframe-container {
382
+ height: 600px;
383
+ overflow-y: auto;
384
+ }
385
+ """
386
+ ) as demo:
387
  gr.Markdown("""
388
+ <div style='text-align: center; margin-bottom: 20px;'>
389
+ <h1 style='color: #2b3a67; margin-bottom: 8px;'>🩺 Clinical Oversight Assistant</h1>
390
+ <p style='color: #5a6a8a; font-size: 16px;'>
391
+ Analyze medical records for potential oversights and generate comprehensive reports
392
+ </p>
393
+ </div>
394
  """)
395
+
396
+ with gr.Row(equal_height=False):
397
+ with gr.Column(scale=1, min_width=400):
398
+ with gr.Group(elem_classes="file-upload"):
399
+ file_upload = gr.File(
400
+ file_types=[".pdf", ".csv", ".xls", ".xlsx"],
401
+ file_count="multiple",
402
+ label="Upload Medical Records",
403
+ elem_id="file-upload"
404
+ )
405
+ with gr.Row():
406
+ clear_btn = gr.Button("Clear All", size="sm")
407
+ send_btn = gr.Button(
408
+ "Analyze Documents",
409
+ variant="primary",
410
+ elem_classes="analysis-btn"
411
+ )
412
+
413
+ with gr.Accordion("Additional Options", open=False):
414
+ msg_input = gr.Textbox(
415
+ placeholder="Enter specific focus areas or questions...",
416
+ label="Analysis Focus",
417
+ lines=3
418
+ )
419
+ temperature = gr.Slider(
420
+ minimum=0.1,
421
+ maximum=1.0,
422
+ value=0.3,
423
+ step=0.1,
424
+ label="Analysis Strictness"
425
+ )
426
+
427
+ status = gr.Textbox(
428
+ label="Processing Status",
429
+ interactive=False,
430
+ visible=True
431
+ )
432
+
433
+ with gr.Column(scale=2, min_width=600):
434
+ with gr.Tabs():
435
+ with gr.TabItem("Analysis Report", id="report"):
436
+ report_output = gr.Textbox(
437
+ label="Clinical Oversight Findings",
438
+ lines=25,
439
+ max_lines=50,
440
+ interactive=False,
441
+ elem_classes="report-box"
442
+ )
443
+
444
+ with gr.TabItem("Raw Data Preview", id="preview"):
445
+ with gr.Column(elem_classes="dataframe-container"):
446
+ data_preview = gr.Dataframe(
447
+ headers=["Page", "Content"],
448
+ datatype=["str", "str"],
449
+ interactive=False
450
+ )
451
+
452
+ with gr.Row():
453
+ download_output = gr.File(
454
+ label="Download Full Report",
455
+ visible=True,
456
+ interactive=False
457
+ )
458
+ gr.Button("Save to EHR", visible=False)
459
+
460
+ def analyze(files: List, message: str, temp: float):
461
  if not files:
462
+ return (
463
+ {"value": "", "visible": True},
464
+ None,
465
+ {"value": "⚠️ Please upload at least one file to analyze.", "visible": True},
466
+ {"value": None, "visible": True}
467
+ )
468
+
469
+ yield (
470
+ {"value": "", "visible": True},
471
+ None,
472
+ {"value": " Processing documents...", "visible": True},
473
+ {"value": None, "visible": True}
474
+ )
475
+
476
+ file_contents = []
477
+ filenames = []
478
+ preview_data = []
479
+
480
+ with ThreadPoolExecutor(max_workers=4) as executor:
481
+ futures = []
482
+ for f in files:
483
+ file_path = f.name
484
+ futures.append(executor.submit(
485
+ convert_file_to_json,
486
+ file_path,
487
+ os.path.splitext(file_path)[1][1:].lower()
488
+ ))
489
+ filenames.append(os.path.basename(file_path))
490
+
491
+ results = []
492
+ for future in as_completed(futures):
493
+ result = sanitize_utf8(future.result())
494
+ try:
495
+ data = json.loads(result)
496
+ results.append(data)
497
+ if "content" in data:
498
+ preview_data.append([data["filename"], data["content"][:500] + "..."])
499
+ except Exception as e:
500
+ print(f"Error processing result: {e}")
501
+ continue
502
+
503
+ yield (
504
+ {"value": "", "visible": True},
505
+ None,
506
+ {"value": f"🔍 Analyzing {len(files)} documents...", "visible": True},
507
+ {"value": preview_data[:20], "visible": True}
508
+ )
509
+
510
+ try:
511
+ combined_content = "\n".join([
512
+ item.get("content", "") if isinstance(item, dict) and "content" in item
513
+ else str(item.get("rows", "")) if isinstance(item, dict)
514
+ else str(item)
515
+ for item in results
516
+ ])
517
+
518
+ full_report = analyze_complete_document(
519
+ combined_content,
520
+ " + ".join(filenames),
521
+ agent,
522
+ temperature=temp
523
+ )
524
+
525
+ file_hash_value = hashlib.md5(combined_content.encode()).hexdigest()
526
+ report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt")
527
+ with open(report_path, "w", encoding="utf-8") as f:
528
+ f.write(full_report)
529
+
530
+ yield (
531
+ {"value": full_report, "visible": True},
532
+ report_path if os.path.exists(report_path) else None,
533
+ {"value": "✅ Analysis complete!", "visible": True},
534
+ {"value": preview_data[:20], "visible": True}
535
+ )
536
+
537
+ except Exception as e:
538
+ error_msg = f"❌ Error during analysis: {str(e)}"
539
+ print(error_msg)
540
+ yield (
541
+ {"value": "", "visible": True},
542
+ None,
543
+ {"value": error_msg, "visible": True},
544
+ {"value": None, "visible": True}
545
+ )
546
+
547
+ send_btn.click(
548
+ fn=analyze,
549
+ inputs=[file_upload, msg_input, temperature],
550
+ outputs=[report_output, download_output, status, data_preview],
551
+ api_name="analyze"
552
+ )
553
+
554
+ clear_btn.click(
555
+ fn=lambda: (
556
+ None,
557
+ None,
558
+ "",
559
+ None,
560
+ {"value": 0.3},
561
+ {"value": ""}
562
+ ),
563
+ inputs=None,
564
+ outputs=[file_upload, download_output, status, data_preview, temperature, msg_input]
565
+ )
566
+
567
  return demo
568
 
569
  if __name__ == "__main__":
 
571
  try:
572
  import tiktoken
573
  except ImportError:
574
+ print("Installing tiktoken...")
575
  subprocess.run([sys.executable, "-m", "pip", "install", "tiktoken"])
576
+
577
  agent = init_agent()
578
  demo = create_ui(agent)
579
+ demo.queue(
580
+ api_open=False,
581
+ max_size=20
582
+ ).launch(
583
  server_name="0.0.0.0",
584
  server_port=7860,
585
  show_error=True,
586
+ allowed_paths=[report_dir],
587
+ share=False
588
+ )