Ali2206 commited on
Commit
2416301
·
verified ·
1 Parent(s): b33bf6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -91
app.py CHANGED
@@ -46,9 +46,9 @@ MEDICAL_KEYWORDS = {
46
  'conclusion', 'history', 'examination', 'progress', 'discharge'
47
  }
48
  TOKENIZER = "cl100k_base"
49
- MAX_MODEL_LEN = 2048
50
- TARGET_CHUNK_TOKENS = 1200
51
- PROMPT_RESERVE = 300
52
  MEDICAL_SECTION_HEADER = "=== MEDICAL SECTION ==="
53
 
54
  def log_system_usage(tag=""):
@@ -251,20 +251,49 @@ def split_content_by_tokens(content: str, max_tokens: int = TARGET_CHUNK_TOKENS)
251
 
252
  return chunks
253
 
254
- def analyze_complete_document(content: str, filename: str, agent: TxAgent) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  """Analyze complete document with strict token management"""
256
  chunks = split_content_by_tokens(content)
257
  analysis_results = []
258
 
259
  for i, chunk in enumerate(chunks):
260
  try:
 
261
  base_prompt = "Analyze for:\n1. Critical\n2. Missed DX\n3. Med issues\n4. Gaps\n5. Follow-up\n\nContent:\n"
262
 
 
263
  prompt_tokens = count_tokens(base_prompt)
264
- max_content_tokens = MAX_MODEL_LEN - prompt_tokens - 100
265
 
 
266
  chunk_tokens = count_tokens(chunk)
267
  if chunk_tokens > max_content_tokens:
 
268
  adjusted_chunk = ""
269
  tokens_used = 0
270
  paragraphs = re.split(r"\n\s*\n", chunk)
@@ -278,6 +307,7 @@ def analyze_complete_document(content: str, filename: str, agent: TxAgent) -> st
278
  break
279
 
280
  if not adjusted_chunk:
 
281
  sentences = re.split(r'(?<=[.!?])\s+', chunk)
282
  for sent in sentences:
283
  sent_tokens = count_tokens(sent)
@@ -295,8 +325,8 @@ def analyze_complete_document(content: str, filename: str, agent: TxAgent) -> st
295
  for output in agent.run_gradio_chat(
296
  message=prompt,
297
  history=[],
298
- temperature=0.1,
299
- max_new_tokens=300,
300
  max_token=MAX_MODEL_LEN,
301
  call_agent=False,
302
  conversation=[],
@@ -317,78 +347,137 @@ def analyze_complete_document(content: str, filename: str, agent: TxAgent) -> st
317
 
318
  return format_final_report(analysis_results, filename)
319
 
320
- def init_agent():
321
- """Initialize the TxAgent with proper configuration."""
322
- print("🔁 Initializing model...")
323
- log_system_usage("Before Load")
324
-
325
- default_tool_path = os.path.abspath("data/new_tool.json")
326
- target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
327
- if not os.path.exists(target_tool_path):
328
- shutil.copy(default_tool_path, target_tool_path)
329
-
330
- agent = TxAgent(
331
- model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
332
- rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
333
- tool_files_dict={"new_tool": target_tool_path},
334
- force_finish=True,
335
- enable_checker=True,
336
- step_rag_num=2,
337
- seed=100,
338
- additional_default_tools=[],
339
- )
340
- agent.init_model()
341
- log_system_usage("After Load")
342
- print("✅ Agent Ready")
343
- return agent
344
-
345
  def create_ui(agent):
346
- """Create the Gradio interface."""
347
- with gr.Blocks(theme=gr.themes.Soft(), title="Clinical Oversight Assistant") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  gr.Markdown("""
349
- <h1 style='text-align: center;'>🩺 Comprehensive Clinical Oversight Assistant</h1>
350
- <p style='text-align: center;'>Analyze complete medical records for potential oversights</p>
 
 
 
 
351
  """)
352
-
353
- with gr.Row():
354
- with gr.Column(scale=3):
355
- file_upload = gr.File(
356
- file_types=[".pdf", ".csv", ".xls", ".xlsx"],
357
- file_count="multiple",
358
- label="Upload Medical Records"
359
- )
360
- msg_input = gr.Textbox(
361
- placeholder="Optional: Add specific focus areas or questions...",
362
- label="Analysis Focus"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  with gr.Row():
365
- send_btn = gr.Button("Analyze Complete Documents", variant="primary")
366
- clear_btn = gr.Button("Clear")
367
- status = gr.Textbox(label="Status", interactive=False)
368
-
369
- with gr.Column(scale=7):
370
- report_output = gr.Textbox(
371
- label="Clinical Oversight Report",
372
- lines=20,
373
- max_lines=50,
374
- interactive=False
375
- )
376
- download_output = gr.File(
377
- label="Download Full Report",
378
- visible=False
379
- )
380
-
381
- def analyze(files: List, message: str):
382
- """Process files and generate analysis."""
383
  if not files:
384
- yield "", None, "⚠️ Please upload at least one file to analyze."
 
 
 
 
 
385
  return
386
 
387
- yield "", None, "⏳ Processing documents (this may take several minutes for large files)..."
 
 
 
 
 
 
388
 
 
389
  file_contents = []
390
  filenames = []
391
- total_tokens = 0
392
 
393
  with ThreadPoolExecutor(max_workers=4) as executor:
394
  futures = []
@@ -403,30 +492,34 @@ def create_ui(agent):
403
  results = []
404
  for future in as_completed(futures):
405
  result = sanitize_utf8(future.result())
406
- results.append(result)
407
  try:
408
  data = json.loads(result)
409
- if "total_tokens" in data:
410
- total_tokens += data["total_tokens"]
 
411
  except:
412
  pass
413
-
414
- file_contents = results
415
 
416
- combined_filename = " + ".join(filenames)
417
- combined_content = "\n".join([
418
- json.loads(fc).get("content", "") if "content" in json.loads(fc)
419
- else str(json.loads(fc).get("rows", ""))
420
- for fc in file_contents
421
- ])
422
-
423
- yield "", None, f"🔍 Analyzing content ({total_tokens//1000}k tokens)..."
424
 
425
  try:
 
 
 
 
 
 
426
  full_report = analyze_complete_document(
427
  combined_content,
428
- combined_filename,
429
- agent
 
430
  )
431
 
432
  file_hash_value = hashlib.md5(combined_content.encode()).hexdigest()
@@ -434,30 +527,46 @@ def create_ui(agent):
434
  with open(report_path, "w", encoding="utf-8") as f:
435
  f.write(full_report)
436
 
437
- yield full_report, report_path if os.path.exists(report_path) else None, "✅ Analysis complete!"
 
 
 
 
 
438
 
439
  except Exception as e:
440
  error_msg = f"❌ Error during analysis: {str(e)}"
441
  print(error_msg)
442
- yield "", None, error_msg
443
-
 
 
 
 
 
 
444
  send_btn.click(
445
  fn=analyze,
446
- inputs=[file_upload, msg_input],
447
- outputs=[report_output, download_output, status],
448
  api_name="analyze"
449
  )
450
 
451
  clear_btn.click(
452
- fn=lambda: ("", None, ""),
 
 
 
 
453
  inputs=None,
454
- outputs=[report_output, download_output, status]
455
  )
456
-
457
  return demo
458
 
459
  if __name__ == "__main__":
460
  print("🚀 Launching app...")
 
461
  try:
462
  import tiktoken
463
  except ImportError:
 
46
  'conclusion', 'history', 'examination', 'progress', 'discharge'
47
  }
48
  TOKENIZER = "cl100k_base"
49
+ MAX_MODEL_LEN = 2048 # Matches your model's actual limit
50
+ TARGET_CHUNK_TOKENS = 1200 # Leaves room for prompt and response
51
+ PROMPT_RESERVE = 300 # Tokens reserved for prompt structure
52
  MEDICAL_SECTION_HEADER = "=== MEDICAL SECTION ==="
53
 
54
  def log_system_usage(tag=""):
 
251
 
252
  return chunks
253
 
254
+ def init_agent():
255
+ """Initialize the TxAgent with proper configuration."""
256
+ print("🔁 Initializing model...")
257
+ log_system_usage("Before Load")
258
+
259
+ default_tool_path = os.path.abspath("data/new_tool.json")
260
+ target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
261
+ if not os.path.exists(target_tool_path):
262
+ shutil.copy(default_tool_path, target_tool_path)
263
+
264
+ agent = TxAgent(
265
+ model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
266
+ rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
267
+ tool_files_dict={"new_tool": target_tool_path},
268
+ force_finish=True,
269
+ enable_checker=True,
270
+ step_rag_num=2,
271
+ seed=100,
272
+ additional_default_tools=[],
273
+ )
274
+ agent.init_model()
275
+ log_system_usage("After Load")
276
+ print("✅ Agent Ready")
277
+ return agent
278
+
279
+ def analyze_complete_document(content: str, filename: str, agent: TxAgent, temperature: float = 0.3) -> str:
280
  """Analyze complete document with strict token management"""
281
  chunks = split_content_by_tokens(content)
282
  analysis_results = []
283
 
284
  for i, chunk in enumerate(chunks):
285
  try:
286
+ # Ultra-minimal prompt to maximize content space
287
  base_prompt = "Analyze for:\n1. Critical\n2. Missed DX\n3. Med issues\n4. Gaps\n5. Follow-up\n\nContent:\n"
288
 
289
+ # Calculate available space for content
290
  prompt_tokens = count_tokens(base_prompt)
291
+ max_content_tokens = MAX_MODEL_LEN - prompt_tokens - 100 # Response buffer
292
 
293
+ # Ensure chunk fits
294
  chunk_tokens = count_tokens(chunk)
295
  if chunk_tokens > max_content_tokens:
296
+ # Find last paragraph that fits
297
  adjusted_chunk = ""
298
  tokens_used = 0
299
  paragraphs = re.split(r"\n\s*\n", chunk)
 
307
  break
308
 
309
  if not adjusted_chunk:
310
+ # If even one paragraph is too big, split sentences
311
  sentences = re.split(r'(?<=[.!?])\s+', chunk)
312
  for sent in sentences:
313
  sent_tokens = count_tokens(sent)
 
325
  for output in agent.run_gradio_chat(
326
  message=prompt,
327
  history=[],
328
+ temperature=temperature,
329
+ max_new_tokens=300, # Keep responses very concise
330
  max_token=MAX_MODEL_LEN,
331
  call_agent=False,
332
  conversation=[],
 
347
 
348
  return format_final_report(analysis_results, filename)
349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  def create_ui(agent):
351
+ """Create the Gradio interface with enhanced design."""
352
+ with gr.Blocks(
353
+ theme=gr.themes.Soft(
354
+ primary_hue="indigo",
355
+ secondary_hue="blue",
356
+ neutral_hue="slate",
357
+ spacing_size="md",
358
+ radius_size="md"
359
+ ),
360
+ title="Clinical Oversight Assistant",
361
+ css="""
362
+ .report-box {
363
+ border: 1px solid #e0e0e0;
364
+ border-radius: 8px;
365
+ padding: 16px;
366
+ background-color: #f9f9f9;
367
+ }
368
+ .file-upload {
369
+ background-color: #f5f7fa;
370
+ padding: 16px;
371
+ border-radius: 8px;
372
+ }
373
+ .analysis-btn {
374
+ width: 100%;
375
+ }
376
+ .critical-finding {
377
+ color: #d32f2f;
378
+ font-weight: bold;
379
+ }
380
+ """
381
+ ) as demo:
382
+ # Header Section
383
  gr.Markdown("""
384
+ <div style='text-align: center; margin-bottom: 20px;'>
385
+ <h1 style='color: #2b3a67; margin-bottom: 8px;'>🩺 Clinical Oversight Assistant</h1>
386
+ <p style='color: #5a6a8a; font-size: 16px;'>
387
+ Analyze medical records for potential oversights and generate comprehensive reports
388
+ </p>
389
+ </div>
390
  """)
391
+
392
+ with gr.Row(equal_height=False):
393
+ # Left Column - Inputs
394
+ with gr.Column(scale=1, min_width=400):
395
+ with gr.Group(label="Document Upload", elem_classes="file-upload"):
396
+ file_upload = gr.File(
397
+ file_types=[".pdf", ".csv", ".xls", ".xlsx"],
398
+ file_count="multiple",
399
+ label="Upload Medical Records",
400
+ elem_id="file-upload"
401
+ )
402
+ with gr.Row():
403
+ clear_btn = gr.Button("Clear All", size="sm")
404
+ send_btn = gr.Button(
405
+ "Analyze Documents",
406
+ variant="primary",
407
+ elem_classes="analysis-btn"
408
+ )
409
+
410
+ with gr.Accordion("Additional Options", open=False):
411
+ msg_input = gr.Textbox(
412
+ placeholder="Enter specific focus areas or questions...",
413
+ label="Analysis Focus",
414
+ lines=3
415
+ )
416
+ temperature = gr.Slider(
417
+ minimum=0.1,
418
+ maximum=1.0,
419
+ value=0.3,
420
+ step=0.1,
421
+ label="Analysis Strictness"
422
+ )
423
+
424
+ status = gr.Textbox(
425
+ label="Processing Status",
426
+ interactive=False,
427
+ visible=True
428
  )
429
+
430
+ # Right Column - Outputs
431
+ with gr.Column(scale=2, min_width=600):
432
+ with gr.Tabs():
433
+ with gr.TabItem("Analysis Report", id="report"):
434
+ report_output = gr.Textbox(
435
+ label="Clinical Oversight Findings",
436
+ lines=25,
437
+ max_lines=50,
438
+ interactive=False,
439
+ elem_classes="report-box"
440
+ )
441
+
442
+ with gr.TabItem("Raw Data Preview", id="preview"):
443
+ data_preview = gr.Dataframe(
444
+ headers=["Page", "Content"],
445
+ datatype=["str", "str"],
446
+ interactive=False,
447
+ height=600
448
+ )
449
+
450
  with gr.Row():
451
+ download_output = gr.File(
452
+ label="Download Full Report",
453
+ visible=True,
454
+ interactive=False
455
+ )
456
+ gr.Button("Save to EHR", visible=False)
457
+
458
+ # Analysis function with UI updates
459
+ def analyze(files: List, message: str, temp: float):
 
 
 
 
 
 
 
 
 
460
  if not files:
461
+ yield (
462
+ gr.Textbox.update(value="", visible=True),
463
+ gr.File.update(value=None, visible=False),
464
+ gr.Textbox.update(value="⚠️ Please upload at least one file to analyze.", visible=True),
465
+ gr.Dataframe.update(value=None, visible=True)
466
+ )
467
  return
468
 
469
+ # Update UI for processing state
470
+ yield (
471
+ gr.Textbox.update(value="", visible=True),
472
+ gr.File.update(value=None, visible=False),
473
+ gr.Textbox.update(value="⏳ Processing documents...", visible=True),
474
+ gr.Dataframe.update(value=None, visible=True)
475
+ )
476
 
477
+ # Process files
478
  file_contents = []
479
  filenames = []
480
+ preview_data = []
481
 
482
  with ThreadPoolExecutor(max_workers=4) as executor:
483
  futures = []
 
492
  results = []
493
  for future in as_completed(futures):
494
  result = sanitize_utf8(future.result())
 
495
  try:
496
  data = json.loads(result)
497
+ results.append(result)
498
+ if "content" in data:
499
+ preview_data.append([data["filename"], data["content"][:500] + "..."])
500
  except:
501
  pass
 
 
502
 
503
+ # Update UI for analysis state
504
+ yield (
505
+ gr.Textbox.update(value="", visible=True),
506
+ gr.File.update(value=None, visible=False),
507
+ gr.Textbox.update(value=f"🔍 Analyzing {len(files)} documents...", visible=True),
508
+ gr.Dataframe.update(value=preview_data[:20], visible=True)
509
+ )
 
510
 
511
  try:
512
+ combined_content = "\n".join([
513
+ json.loads(fc).get("content", "") if "content" in json.loads(fc)
514
+ else str(json.loads(fc).get("rows", ""))
515
+ for fc in results
516
+ ])
517
+
518
  full_report = analyze_complete_document(
519
  combined_content,
520
+ " + ".join(filenames),
521
+ agent,
522
+ temperature=temp
523
  )
524
 
525
  file_hash_value = hashlib.md5(combined_content.encode()).hexdigest()
 
527
  with open(report_path, "w", encoding="utf-8") as f:
528
  f.write(full_report)
529
 
530
+ yield (
531
+ gr.Textbox.update(value=full_report, visible=True),
532
+ gr.File.update(value=report_path if os.path.exists(report_path) else None, visible=True),
533
+ gr.Textbox.update(value="✅ Analysis complete!", visible=True),
534
+ gr.Dataframe.update(value=preview_data[:20], visible=True)
535
+ )
536
 
537
  except Exception as e:
538
  error_msg = f"❌ Error during analysis: {str(e)}"
539
  print(error_msg)
540
+ yield (
541
+ gr.Textbox.update(value="", visible=True),
542
+ gr.File.update(value=None, visible=False),
543
+ gr.Textbox.update(value=error_msg, visible=True),
544
+ gr.Dataframe.update(value=None, visible=True)
545
+ )
546
+
547
+ # Event handlers
548
  send_btn.click(
549
  fn=analyze,
550
+ inputs=[file_upload, msg_input, temperature],
551
+ outputs=[report_output, download_output, status, data_preview],
552
  api_name="analyze"
553
  )
554
 
555
  clear_btn.click(
556
+ fn=lambda: (
557
+ None, None, "", None,
558
+ gr.Slider.update(value=0.3),
559
+ gr.Textbox.update(value="")
560
+ ),
561
  inputs=None,
562
+ outputs=[file_upload, download_output, status, data_preview, temperature, msg_input]
563
  )
564
+
565
  return demo
566
 
567
  if __name__ == "__main__":
568
  print("🚀 Launching app...")
569
+ # Install tiktoken if not available
570
  try:
571
  import tiktoken
572
  except ImportError: