Ali2206 commited on
Commit
cc93544
·
verified ·
1 Parent(s): cf765da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -370
app.py CHANGED
@@ -27,12 +27,14 @@ vllm_cache_dir = os.path.join(persistent_dir, "vllm_cache")
27
  for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
28
  os.makedirs(directory, exist_ok=True)
29
 
 
30
  os.environ["HF_HOME"] = model_cache_dir
31
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
32
  os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
33
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
34
  os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
35
 
 
36
  current_dir = os.path.dirname(os.path.abspath(__file__))
37
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
38
  sys.path.insert(0, src_path)
@@ -46,11 +48,14 @@ MEDICAL_KEYWORDS = {
46
  'conclusion', 'history', 'examination', 'progress', 'discharge'
47
  }
48
  TOKENIZER = "cl100k_base"
49
- MAX_MODEL_LEN = 2048
 
 
50
  TARGET_CHUNK_TOKENS = 1200
51
- PROMPT_RESERVE = 300
52
  MEDICAL_SECTION_HEADER = "=== MEDICAL SECTION ==="
53
 
 
54
  def log_system_usage(tag=""):
55
  try:
56
  cpu = psutil.cpu_percent(interval=1)
@@ -66,6 +71,7 @@ def log_system_usage(tag=""):
66
  except Exception as e:
67
  print(f"[{tag}] GPU/CPU monitor failed: {e}")
68
 
 
69
  def sanitize_utf8(text: str) -> str:
70
  return text.encode("utf-8", "ignore").decode("utf-8")
71
 
@@ -77,41 +83,33 @@ def count_tokens(text: str) -> int:
77
  encoding = tiktoken.get_encoding(TOKENIZER)
78
  return len(encoding.encode(text))
79
 
 
80
  def extract_all_pages_with_token_count(file_path: str) -> Tuple[str, int, int]:
81
  try:
82
  text_chunks = []
83
  total_pages = 0
84
  total_tokens = 0
85
-
86
  with pdfplumber.open(file_path) as pdf:
87
  total_pages = len(pdf.pages)
88
-
89
  for i, page in enumerate(pdf.pages):
90
  page_text = page.extract_text() or ""
91
  lower_text = page_text.lower()
92
-
93
- if any(re.search(rf'\b{kw}\b', lower_text) for kw in MEDICAL_KEYWORDS):
94
- section_header = f"\n{MEDICAL_SECTION_HEADER} (Page {i+1})\n"
95
- text_chunks.append(section_header + page_text.strip())
96
- total_tokens += count_tokens(section_header)
97
- else:
98
- text_chunks.append(f"\n=== Page {i+1} ===\n{page_text.strip()}")
99
-
100
- total_tokens += count_tokens(page_text)
101
-
102
  return "\n".join(text_chunks), total_pages, total_tokens
103
  except Exception as e:
104
  return f"PDF processing error: {str(e)}", 0, 0
105
 
 
106
  def convert_file_to_json(file_path: str, file_type: str) -> str:
107
  try:
108
  h = file_hash(file_path)
109
  cache_path = os.path.join(file_cache_dir, f"{h}.json")
110
-
111
  if os.path.exists(cache_path):
112
- with open(cache_path, "r", encoding="utf-8") as f:
113
- return f.read()
114
-
115
  if file_type == "pdf":
116
  text, total_pages, total_tokens = extract_all_pages_with_token_count(file_path)
117
  result = json.dumps({
@@ -123,10 +121,12 @@ def convert_file_to_json(file_path: str, file_type: str) -> str:
123
  })
124
  elif file_type == "csv":
125
  chunks = []
126
- for chunk in pd.read_csv(file_path, encoding_errors="replace", header=None, dtype=str,
127
- skip_blank_lines=False, on_bad_lines="skip", chunksize=1000):
 
 
128
  chunks.append(chunk.fillna("").astype(str).values.tolist())
129
- content = [item for sublist in chunks for item in sublist]
130
  result = json.dumps({
131
  "filename": os.path.basename(file_path),
132
  "rows": content,
@@ -135,7 +135,7 @@ def convert_file_to_json(file_path: str, file_type: str) -> str:
135
  elif file_type in ["xls", "xlsx"]:
136
  try:
137
  df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
138
- except Exception:
139
  df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
140
  content = df.fillna("").astype(str).values.tolist()
141
  result = json.dumps({
@@ -145,109 +145,91 @@ def convert_file_to_json(file_path: str, file_type: str) -> str:
145
  })
146
  else:
147
  result = json.dumps({"error": f"Unsupported file type: {file_type}"})
148
-
149
  with open(cache_path, "w", encoding="utf-8") as f:
150
  f.write(result)
151
  return result
152
  except Exception as e:
153
  return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
154
 
 
155
  def clean_response(text: str) -> str:
156
  text = sanitize_utf8(text)
157
- text = re.sub(r"\[TOOL_CALLS\].*", "", text, flags=re.DOTALL)
158
- text = re.sub(r"\['get_[^\]]+\']\n?", "", text)
159
- text = re.sub(r"\{'meta':\s*\{.*?\}\s*,\s*'results':\s*\[.*?\]\}\n?", "", text, flags=re.DOTALL)
160
- text = re.sub(r"To analyze the medical records for clinical oversights.*?begin by reviewing.*?\n", "", text, flags=re.DOTALL)
161
- text = re.sub(r"\n{3,}", "\n\n", text).strip()
162
- return text
 
 
163
 
164
  def format_final_report(analysis_results: List[str], filename: str) -> str:
165
- report = []
166
- report.append(f"COMPREHENSIVE CLINICAL OVERSIGHT ANALYSIS")
167
- report.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
168
- report.append(f"File: {filename}")
169
- report.append("=" * 80)
170
-
171
- sections = {
172
- "CRITICAL FINDINGS": [],
173
- "MISSED DIAGNOSES": [],
174
- "MEDICATION ISSUES": [],
175
- "ASSESSMENT GAPS": [],
176
- "FOLLOW-UP RECOMMENDATIONS": []
177
- }
178
-
179
- for result in analysis_results:
180
- for section in sections:
181
- section_match = re.search(
182
- rf"{re.escape(section)}:?\s*\n([^*]+?)(?=\n\*|\n\n|$)",
183
- result,
184
- re.IGNORECASE | re.DOTALL
185
  )
186
- if section_match:
187
- content = section_match.group(1).strip()
188
- if content and content not in sections[section]:
189
- sections[section].append(content)
190
-
191
  if sections["CRITICAL FINDINGS"]:
192
  report.append("\n🚨 **CRITICAL FINDINGS** 🚨")
193
- for content in sections["CRITICAL FINDINGS"]:
194
- report.append(f"\n{content}")
195
-
196
- for section, contents in sections.items():
197
- if section != "CRITICAL FINDINGS" and contents:
198
- report.append(f"\n**{section.upper()}**")
199
- for content in contents:
200
- report.append(f"\n{content}")
201
-
202
  if not any(sections.values()):
203
  report.append("\nNo significant clinical oversights identified.")
204
-
205
- report.append("\n" + "=" * 80)
206
  report.append("END OF REPORT")
207
-
208
  return "\n".join(report)
209
 
210
- def split_content_by_tokens(content: str, max_tokens: int = TARGET_CHUNK_TOKENS) -> List[str]:
 
211
  paragraphs = re.split(r"\n\s*\n", content)
212
- chunks = []
213
- current_chunk = []
214
- current_tokens = 0
215
-
216
  for para in paragraphs:
217
- para_tokens = count_tokens(para)
218
- if para_tokens > max_tokens:
219
- sentences = re.split(r'(?<=[.!?])\s+', para)
220
- for sent in sentences:
221
- sent_tokens = count_tokens(sent)
222
- if current_tokens + sent_tokens > max_tokens:
223
- chunks.append("\n\n".join(current_chunk))
224
- current_chunk = [sent]
225
- current_tokens = sent_tokens
226
  else:
227
- current_chunk.append(sent)
228
- current_tokens += sent_tokens
229
- elif current_tokens + para_tokens > max_tokens:
230
- chunks.append("\n\n".join(current_chunk))
231
- current_chunk = [para]
232
- current_tokens = para_tokens
233
  else:
234
- current_chunk.append(para)
235
- current_tokens += para_tokens
236
-
237
- if current_chunk:
238
- chunks.append("\n\n".join(current_chunk))
239
-
240
  return chunks
241
 
 
242
  def init_agent():
243
  print("🔁 Initializing model...")
244
  log_system_usage("Before Load")
245
-
246
  default_tool_path = os.path.abspath("data/new_tool.json")
247
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
248
  if not os.path.exists(target_tool_path):
249
  shutil.copy(default_tool_path, target_tool_path)
250
-
251
  agent = TxAgent(
252
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
253
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
@@ -256,293 +238,89 @@ def init_agent():
256
  enable_checker=True,
257
  step_rag_num=2,
258
  seed=100,
259
- additional_default_tools=[],
260
  )
261
- agent.init_model()
262
  log_system_usage("After Load")
263
  print("✅ Agent Ready")
264
  return agent
265
 
 
266
  def analyze_complete_document(content: str, filename: str, agent: TxAgent, temperature: float = 0.3) -> str:
267
- chunks = split_content_by_tokens(content)
268
- analysis_results = []
269
-
 
 
 
 
270
  for i, chunk in enumerate(chunks):
271
  try:
272
- base_prompt = "Analyze for:\n1. Critical\n2. Missed DX\n3. Med issues\n4. Gaps\n5. Follow-up\n\nContent:\n"
273
-
274
- prompt_tokens = count_tokens(base_prompt)
275
- max_content_tokens = MAX_MODEL_LEN - prompt_tokens - 100
276
-
277
- chunk_tokens = count_tokens(chunk)
278
- if chunk_tokens > max_content_tokens:
279
- adjusted_chunk = ""
280
- tokens_used = 0
281
- paragraphs = re.split(r"\n\s*\n", chunk)
282
-
283
- for para in paragraphs:
284
- para_tokens = count_tokens(para)
285
- if tokens_used + para_tokens <= max_content_tokens:
286
- adjusted_chunk += "\n\n" + para
287
- tokens_used += para_tokens
288
- else:
289
- break
290
-
291
- if not adjusted_chunk:
292
- sentences = re.split(r'(?<=[.!?])\s+', chunk)
293
- for sent in sentences:
294
- sent_tokens = count_tokens(sent)
295
- if tokens_used + sent_tokens <= max_content_tokens:
296
- adjusted_chunk += " " + sent
297
- tokens_used += sent_tokens
298
- else:
299
- break
300
-
301
- chunk = adjusted_chunk.strip()
302
-
303
  prompt = base_prompt + chunk
304
-
305
  response = ""
306
- for output in agent.run_gradio_chat(
307
  message=prompt,
308
  history=[],
309
  temperature=temperature,
310
  max_new_tokens=300,
311
  max_token=MAX_MODEL_LEN,
312
  call_agent=False,
313
- conversation=[],
314
  ):
315
- if output:
316
- if isinstance(output, list):
317
- for m in output:
318
- if hasattr(m, 'content'):
319
- response += clean_response(m.content)
320
- elif isinstance(output, str):
321
- response += clean_response(output)
322
-
323
  if response:
324
- analysis_results.append(response)
325
  except Exception as e:
326
- print(f"Error processing chunk {i}: {str(e)}")
327
- continue
328
-
329
- return format_final_report(analysis_results, filename)
330
 
331
  def create_ui(agent):
332
- with gr.Blocks(
333
- theme=gr.themes.Soft(
334
- primary_hue="indigo",
335
- secondary_hue="blue",
336
- neutral_hue="slate",
337
- spacing_size="md",
338
- radius_size="md"
339
- ),
340
- title="Clinical Oversight Assistant",
341
- css="""
342
- .report-box {
343
- border: 1px solid #e0e0e0;
344
- border-radius: 8px;
345
- padding: 16px;
346
- background-color: #f9f9f9;
347
- }
348
- .file-upload {
349
- background-color: #f5f7fa;
350
- padding: 16px;
351
- border-radius: 8px;
352
- }
353
- .analysis-btn {
354
- width: 100%;
355
- }
356
- .critical-finding {
357
- color: #d32f2f;
358
- font-weight: bold;
359
- }
360
- .dataframe-container {
361
- height: 600px;
362
- overflow-y: auto;
363
- }
364
- """
365
- ) as demo:
366
  gr.Markdown("""
367
- <div style='text-align: center; margin-bottom: 20px;'>
368
- <h1 style='color: #2b3a67; margin-bottom: 8px;'>🩺 Clinical Oversight Assistant</h1>
369
- <p style='color: #5a6a8a; font-size: 16px;'>
370
- Analyze medical records for potential oversights and generate comprehensive reports
371
- </p>
372
- </div>
373
  """)
374
-
375
- with gr.Row(equal_height=False):
376
- with gr.Column(scale=1, min_width=400):
377
- with gr.Group(elem_classes="file-upload"):
378
- file_upload = gr.File(
379
- file_types=[".pdf", ".csv", ".xls", ".xlsx"],
380
- file_count="multiple",
381
- label="Upload Medical Records",
382
- elem_id="file-upload"
383
- )
384
- with gr.Row():
385
- clear_btn = gr.Button("Clear All", size="sm")
386
- send_btn = gr.Button(
387
- "Analyze Documents",
388
- variant="primary",
389
- elem_classes="analysis-btn"
390
- )
391
-
392
- with gr.Accordion("Additional Options", open=False):
393
- msg_input = gr.Textbox(
394
- placeholder="Enter specific focus areas or questions...",
395
- label="Analysis Focus",
396
- lines=3
397
- )
398
- temperature = gr.Slider(
399
- minimum=0.1,
400
- maximum=1.0,
401
- value=0.3,
402
- step=0.1,
403
- label="Analysis Strictness"
404
- )
405
-
406
- status = gr.Textbox(
407
- label="Processing Status",
408
- interactive=False,
409
- visible=True
410
- )
411
-
412
- with gr.Column(scale=2, min_width=600):
413
- with gr.Tabs():
414
- with gr.TabItem("Analysis Report", id="report"):
415
- report_output = gr.Textbox(
416
- label="Clinical Oversight Findings",
417
- lines=25,
418
- max_lines=50,
419
- interactive=False,
420
- elem_classes="report-box"
421
- )
422
-
423
- with gr.TabItem("Raw Data Preview", id="preview"):
424
- with gr.Column(elem_classes="dataframe-container"):
425
- data_preview = gr.Dataframe(
426
- headers=["Page", "Content"],
427
- datatype=["str", "str"],
428
- interactive=False
429
- )
430
-
431
- with gr.Row():
432
- download_output = gr.File(
433
- label="Download Full Report",
434
- visible=True,
435
- interactive=False
436
- )
437
- gr.Button("Save to EHR", visible=False)
438
-
439
- def analyze(files: List, message: str, temp: float):
440
  if not files:
441
- return (
442
- {"value": "", "visible": True},
443
- None,
444
- {"value": "⚠️ Please upload at least one file to analyze.", "visible": True},
445
- {"value": None, "visible": True}
446
- )
447
-
448
- yield (
449
- {"value": "", "visible": True},
450
- None,
451
- {"value": "⏳ Processing documents...", "visible": True},
452
- {"value": None, "visible": True}
453
- )
454
-
455
- file_contents = []
456
- filenames = []
457
- preview_data = []
458
-
459
- with ThreadPoolExecutor(max_workers=4) as executor:
460
- futures = []
461
- for f in files:
462
- file_path = f.name
463
- futures.append(executor.submit(
464
- convert_file_to_json,
465
- file_path,
466
- os.path.splitext(file_path)[1][1:].lower()
467
- ))
468
- filenames.append(os.path.basename(file_path))
469
-
470
- results = []
471
- for future in as_completed(futures):
472
- result = sanitize_utf8(future.result())
473
- try:
474
- data = json.loads(result)
475
- results.append(data)
476
- if "content" in data:
477
- preview_data.append([data["filename"], data["content"][:500] + "..."])
478
- except Exception as e:
479
- print(f"Error processing result: {e}")
480
- continue
481
-
482
- yield (
483
- {"value": "", "visible": True},
484
- None,
485
- {"value": f"🔍 Analyzing {len(files)} documents...", "visible": True},
486
- {"value": preview_data[:20], "visible": True}
487
- )
488
-
489
- try:
490
- combined_content = "\n".join([
491
- item.get("content", "") if isinstance(item, dict) and "content" in item
492
- else str(item.get("rows", "")) if isinstance(item, dict)
493
- else str(item)
494
- for item in results
495
- ])
496
-
497
- full_report = analyze_complete_document(
498
- combined_content,
499
- " + ".join(filenames),
500
- agent,
501
- temperature=temp
502
- )
503
-
504
- file_hash_value = hashlib.md5(combined_content.encode()).hexdigest()
505
- report_path = os.path.join(report_dir, f"{file_hash_value}_report.txt")
506
- with open(report_path, "w", encoding="utf-8") as f:
507
- f.write(full_report)
508
-
509
- yield (
510
- {"value": full_report, "visible": True},
511
- report_path if os.path.exists(report_path) else None,
512
- {"value": "✅ Analysis complete!", "visible": True},
513
- {"value": preview_data[:20], "visible": True}
514
- )
515
-
516
- except Exception as e:
517
- error_msg = f"❌ Error during analysis: {str(e)}"
518
- print(error_msg)
519
- yield (
520
- {"value": "", "visible": True},
521
- None,
522
- {"value": error_msg, "visible": True},
523
- {"value": None, "visible": True}
524
- )
525
-
526
- send_btn.click(
527
- fn=analyze,
528
- inputs=[file_upload, msg_input, temperature],
529
- outputs=[report_output, download_output, status, data_preview],
530
- api_name="analyze"
531
- )
532
-
533
- clear_btn.click(
534
- fn=lambda: (
535
- None,
536
- None,
537
- "",
538
- None,
539
- {"value": 0.3},
540
- {"value": ""}
541
- ),
542
- inputs=None,
543
- outputs=[file_upload, download_output, status, data_preview, temperature, msg_input]
544
- )
545
-
546
  return demo
547
 
548
  if __name__ == "__main__":
@@ -550,18 +328,7 @@ if __name__ == "__main__":
550
  try:
551
  import tiktoken
552
  except ImportError:
553
- print("Installing tiktoken...")
554
- subprocess.run([sys.executable, "-m", "pip", "install", "tiktoken"])
555
-
556
  agent = init_agent()
557
  demo = create_ui(agent)
558
- demo.queue(
559
- api_open=False,
560
- max_size=20
561
- ).launch(
562
- server_name="0.0.0.0",
563
- server_port=7860,
564
- show_error=True,
565
- allowed_paths=[report_dir],
566
- share=False
567
- )
 
27
  for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir, vllm_cache_dir]:
28
  os.makedirs(directory, exist_ok=True)
29
 
30
+ # Environment variables
31
  os.environ["HF_HOME"] = model_cache_dir
32
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
33
  os.environ["VLLM_CACHE_DIR"] = vllm_cache_dir
34
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
35
  os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
36
 
37
+ # Add src to path
38
  current_dir = os.path.dirname(os.path.abspath(__file__))
39
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
40
  sys.path.insert(0, src_path)
 
48
  'conclusion', 'history', 'examination', 'progress', 'discharge'
49
  }
50
  TOKENIZER = "cl100k_base"
51
+ # Increase max model length to support larger contexts
52
+ MAX_MODEL_LEN = 4096
53
+ # Default chunk target tokens
54
  TARGET_CHUNK_TOKENS = 1200
55
+ PROMPT_RESERVE = 100
56
  MEDICAL_SECTION_HEADER = "=== MEDICAL SECTION ==="
57
 
58
+
59
  def log_system_usage(tag=""):
60
  try:
61
  cpu = psutil.cpu_percent(interval=1)
 
71
  except Exception as e:
72
  print(f"[{tag}] GPU/CPU monitor failed: {e}")
73
 
74
+
75
  def sanitize_utf8(text: str) -> str:
76
  return text.encode("utf-8", "ignore").decode("utf-8")
77
 
 
83
  encoding = tiktoken.get_encoding(TOKENIZER)
84
  return len(encoding.encode(text))
85
 
86
+
87
  def extract_all_pages_with_token_count(file_path: str) -> Tuple[str, int, int]:
88
  try:
89
  text_chunks = []
90
  total_pages = 0
91
  total_tokens = 0
 
92
  with pdfplumber.open(file_path) as pdf:
93
  total_pages = len(pdf.pages)
 
94
  for i, page in enumerate(pdf.pages):
95
  page_text = page.extract_text() or ""
96
  lower_text = page_text.lower()
97
+ header = f"\n{MEDICAL_SECTION_HEADER} (Page {i+1})\n" if any(
98
+ re.search(rf'\b{kw}\b', lower_text) for kw in MEDICAL_KEYWORDS
99
+ ) else f"\n=== Page {i+1} ===\n"
100
+ text_chunks.append(header + page_text.strip())
101
+ total_tokens += count_tokens(header) + count_tokens(page_text)
 
 
 
 
 
102
  return "\n".join(text_chunks), total_pages, total_tokens
103
  except Exception as e:
104
  return f"PDF processing error: {str(e)}", 0, 0
105
 
106
+
107
  def convert_file_to_json(file_path: str, file_type: str) -> str:
108
  try:
109
  h = file_hash(file_path)
110
  cache_path = os.path.join(file_cache_dir, f"{h}.json")
 
111
  if os.path.exists(cache_path):
112
+ return open(cache_path, "r", encoding="utf-8").read()
 
 
113
  if file_type == "pdf":
114
  text, total_pages, total_tokens = extract_all_pages_with_token_count(file_path)
115
  result = json.dumps({
 
121
  })
122
  elif file_type == "csv":
123
  chunks = []
124
+ for chunk in pd.read_csv(
125
+ file_path, encoding_errors="replace", header=None, dtype=str,
126
+ skip_blank_lines=False, on_bad_lines="skip", chunksize=1000
127
+ ):
128
  chunks.append(chunk.fillna("").astype(str).values.tolist())
129
+ content = [item for sub in chunks for item in sub]
130
  result = json.dumps({
131
  "filename": os.path.basename(file_path),
132
  "rows": content,
 
135
  elif file_type in ["xls", "xlsx"]:
136
  try:
137
  df = pd.read_excel(file_path, engine="openpyxl", header=None, dtype=str)
138
+ except:
139
  df = pd.read_excel(file_path, engine="xlrd", header=None, dtype=str)
140
  content = df.fillna("").astype(str).values.tolist()
141
  result = json.dumps({
 
145
  })
146
  else:
147
  result = json.dumps({"error": f"Unsupported file type: {file_type}"})
 
148
  with open(cache_path, "w", encoding="utf-8") as f:
149
  f.write(result)
150
  return result
151
  except Exception as e:
152
  return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
153
 
154
+
155
  def clean_response(text: str) -> str:
156
  text = sanitize_utf8(text)
157
+ patterns = [
158
+ r"\[TOOL_CALLS\].*", r"\['get_[^\]]+\']\n?", r"\{'meta':\s*\{.*?\}\s*,\s*'results':\s*\[.*?\]\}\n?",
159
+ r"To analyze the medical records for clinical oversights.*?\n" # remove generic prompt
160
+ ]
161
+ for pat in patterns:
162
+ text = re.sub(pat, "", text, flags=re.DOTALL)
163
+ return re.sub(r"\n{3,}", "\n\n", text).strip()
164
+
165
 
166
  def format_final_report(analysis_results: List[str], filename: str) -> str:
167
+ report = [
168
+ "COMPREHENSIVE CLINICAL OVERSIGHT ANALYSIS",
169
+ f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
170
+ f"File: {filename}",
171
+ "=" * 80
172
+ ]
173
+ sections = {s: [] for s in [
174
+ "CRITICAL FINDINGS", "MISSED DIAGNOSES", "MEDICATION ISSUES",
175
+ "ASSESSMENT GAPS", "FOLLOW-UP RECOMMENDATIONS"
176
+ ]}
177
+ for res in analysis_results:
178
+ for sec in sections:
179
+ m = re.search(
180
+ rf"{re.escape(sec)}:?\s*\n(.+?)(?=\n\*|\n\n|$)",
181
+ res, re.IGNORECASE | re.DOTALL
 
 
 
 
 
182
  )
183
+ if m:
184
+ content = m.group(1).strip()
185
+ if content and content not in sections[sec]:
186
+ sections[sec].append(content)
 
187
  if sections["CRITICAL FINDINGS"]:
188
  report.append("\n🚨 **CRITICAL FINDINGS** 🚨")
189
+ report.extend(f"\n{c}" for c in sections["CRITICAL FINDINGS"])
190
+ for sec, conts in sections.items():
191
+ if sec != "CRITICAL FINDINGS" and conts:
192
+ report.append(f"\n**{sec}**")
193
+ report.extend(f"\n{c}" for c in conts)
 
 
 
 
194
  if not any(sections.values()):
195
  report.append("\nNo significant clinical oversights identified.")
196
+ report.append("\n" + "="*80)
 
197
  report.append("END OF REPORT")
 
198
  return "\n".join(report)
199
 
200
+
201
+ def split_content_by_tokens(content: str, max_tokens: int) -> List[str]:
202
  paragraphs = re.split(r"\n\s*\n", content)
203
+ chunks, current, curr_toks = [], [], 0
 
 
 
204
  for para in paragraphs:
205
+ toks = count_tokens(para)
206
+ if toks > max_tokens:
207
+ for sent in re.split(r'(?<=[.!?])\s+', para):
208
+ sent_toks = count_tokens(sent)
209
+ if curr_toks + sent_toks > max_tokens:
210
+ chunks.append("\n\n".join(current))
211
+ current, curr_toks = [sent], sent_toks
 
 
212
  else:
213
+ current.append(sent)
214
+ curr_toks += sent_toks
215
+ elif curr_toks + toks > max_tokens:
216
+ chunks.append("\n\n".join(current))
217
+ current, curr_toks = [para], toks
 
218
  else:
219
+ current.append(para)
220
+ curr_toks += toks
221
+ if current:
222
+ chunks.append("\n\n".join(current))
 
 
223
  return chunks
224
 
225
+
226
  def init_agent():
227
  print("🔁 Initializing model...")
228
  log_system_usage("Before Load")
 
229
  default_tool_path = os.path.abspath("data/new_tool.json")
230
  target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
231
  if not os.path.exists(target_tool_path):
232
  shutil.copy(default_tool_path, target_tool_path)
 
233
  agent = TxAgent(
234
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
235
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
 
238
  enable_checker=True,
239
  step_rag_num=2,
240
  seed=100,
241
+ additional_default_tools=[]
242
  )
243
+ agent.init_model(max_model_len=MAX_MODEL_LEN)
244
  log_system_usage("After Load")
245
  print("✅ Agent Ready")
246
  return agent
247
 
248
+
249
  def analyze_complete_document(content: str, filename: str, agent: TxAgent, temperature: float = 0.3) -> str:
250
+ base_prompt = (
251
+ "Analyze for:\n1. Critical\n2. Missed DX\n3. Med issues\n4. Gaps\n5. Follow-up\n\nContent:\n"
252
+ )
253
+ prompt_toks = count_tokens(base_prompt)
254
+ max_chunk_toks = MAX_MODEL_LEN - prompt_toks - PROMPT_RESERVE
255
+ chunks = split_content_by_tokens(content, max_chunk_toks)
256
+ results = []
257
  for i, chunk in enumerate(chunks):
258
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  prompt = base_prompt + chunk
 
260
  response = ""
261
+ for out in agent.run_gradio_chat(
262
  message=prompt,
263
  history=[],
264
  temperature=temperature,
265
  max_new_tokens=300,
266
  max_token=MAX_MODEL_LEN,
267
  call_agent=False,
268
+ conversation=[]
269
  ):
270
+ if out:
271
+ if isinstance(out, list):
272
+ for m in out:
273
+ response += clean_response(m.content if hasattr(m, 'content') else str(m))
274
+ else:
275
+ response += clean_response(str(out))
 
 
276
  if response:
277
+ results.append(response)
278
  except Exception as e:
279
+ print(f"Error processing chunk {i}: {e}")
280
+ return format_final_report(results, filename)
281
+
 
282
 
283
  def create_ui(agent):
284
+ with gr.Blocks(title="Clinical Oversight Assistant") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  gr.Markdown("""
286
+ # 🩺 Clinical Oversight Assistant
287
+ Analyze medical records for potential oversights and generate comprehensive reports
 
 
 
 
288
  """)
289
+ with gr.Row():
290
+ with gr.Column():
291
+ file_upload = gr.File(label="Upload Medical Records", file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
292
+ msg_input = gr.Textbox(label="Analysis Focus (optional)")
293
+ temperature = gr.Slider(0.1, 1.0, value=0.3, label="Analysis Strictness")
294
+ send_btn = gr.Button("Analyze Documents", variant="primary")
295
+ clear_btn = gr.Button("Clear All")
296
+ status = gr.Textbox(label="Status", interactive=False)
297
+ with gr.Column():
298
+ report_output = gr.Textbox(label="Report", lines=20, interactive=False)
299
+ data_preview = gr.Dataframe(headers=["File", "Snippet"], interactive=False)
300
+ download_output = gr.File(label="Download Report")
301
+ def analyze(files, msg, temp):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  if not files:
303
+ yield "", None, "⚠️ Please upload files.", None
304
+ return
305
+ yield "", None, "⏳ Processing...", None
306
+ # convert files
307
+ previews = []
308
+ contents = []
309
+ for f in files:
310
+ res = json.loads(sanitize_utf8(convert_file_to_json(f.name, os.path.splitext(f.name)[1][1:].lower())))
311
+ if "content" in res:
312
+ previews.append([res["filename"], res["content"][:200] + "..."])
313
+ contents.append(res["content"])
314
+ yield "", None, f"🔍 Analyzing {len(contents)} docs...", previews
315
+ combined = "\n".join(contents)
316
+ report = analyze_complete_document(combined, "+".join([os.path.basename(f.name) for f in files]), agent, temp)
317
+ file_hash_val = hashlib.md5(combined.encode()).hexdigest()
318
+ path = os.path.join(report_dir, f"{file_hash_val}_report.txt")
319
+ with open(path, "w") as rd:
320
+ rd.write(report)
321
+ yield report, path, "✅ Analysis complete!", previews
322
+ send_btn.click(analyze, [file_upload, msg_input, temperature], [report_output, download_output, status, data_preview])
323
+ clear_btn.click(lambda: (None, None, "", None), None, [report_output, download_output, status, data_preview])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  return demo
325
 
326
  if __name__ == "__main__":
 
328
  try:
329
  import tiktoken
330
  except ImportError:
331
+ subprocess.run([sys.executable, "-m", "pip", "install", "tiktoken"] )
 
 
332
  agent = init_agent()
333
  demo = create_ui(agent)
334
+ demo.queue(api_open=False, max_size=20).launch(server_name="0.0.0.0", server_port=7860, show_error=True, share=False)