Ali2206 commited on
Commit
cf765da
·
verified ·
1 Parent(s): abc0e81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -55
app.py CHANGED
@@ -46,13 +46,12 @@ MEDICAL_KEYWORDS = {
46
  'conclusion', 'history', 'examination', 'progress', 'discharge'
47
  }
48
  TOKENIZER = "cl100k_base"
49
- MAX_MODEL_LEN = 2048 # Matches your model's actual limit
50
- TARGET_CHUNK_TOKENS = 1200 # Leaves room for prompt and response
51
- PROMPT_RESERVE = 300 # Tokens reserved for prompt structure
52
  MEDICAL_SECTION_HEADER = "=== MEDICAL SECTION ==="
53
 
54
  def log_system_usage(tag=""):
55
- """Log system resource usage."""
56
  try:
57
  cpu = psutil.cpu_percent(interval=1)
58
  mem = psutil.virtual_memory()
@@ -68,24 +67,17 @@ def log_system_usage(tag=""):
68
  print(f"[{tag}] GPU/CPU monitor failed: {e}")
69
 
70
  def sanitize_utf8(text: str) -> str:
71
- """Ensure text is UTF-8 clean."""
72
  return text.encode("utf-8", "ignore").decode("utf-8")
73
 
74
  def file_hash(path: str) -> str:
75
- """Generate MD5 hash of file content."""
76
  with open(path, "rb") as f:
77
  return hashlib.md5(f.read()).hexdigest()
78
 
79
  def count_tokens(text: str) -> int:
80
- """Count tokens using the same method as the model"""
81
  encoding = tiktoken.get_encoding(TOKENIZER)
82
  return len(encoding.encode(text))
83
 
84
  def extract_all_pages_with_token_count(file_path: str) -> Tuple[str, int, int]:
85
- """
86
- Extract all pages from PDF with token counting.
87
- Returns (extracted_text, total_pages, total_tokens)
88
- """
89
  try:
90
  text_chunks = []
91
  total_pages = 0
@@ -112,7 +104,6 @@ def extract_all_pages_with_token_count(file_path: str) -> Tuple[str, int, int]:
112
  return f"PDF processing error: {str(e)}", 0, 0
113
 
114
  def convert_file_to_json(file_path: str, file_type: str) -> str:
115
- """Convert file to JSON format with caching and token counting."""
116
  try:
117
  h = file_hash(file_path)
118
  cache_path = os.path.join(file_cache_dir, f"{h}.json")
@@ -162,7 +153,6 @@ def convert_file_to_json(file_path: str, file_type: str) -> str:
162
  return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
163
 
164
  def clean_response(text: str) -> str:
165
- """Clean and format the model response."""
166
  text = sanitize_utf8(text)
167
  text = re.sub(r"\[TOOL_CALLS\].*", "", text, flags=re.DOTALL)
168
  text = re.sub(r"\['get_[^\]]+\']\n?", "", text)
@@ -172,7 +162,6 @@ def clean_response(text: str) -> str:
172
  return text
173
 
174
  def format_final_report(analysis_results: List[str], filename: str) -> str:
175
- """Combine all analysis chunks into a well-formatted final report."""
176
  report = []
177
  report.append(f"COMPREHENSIVE CLINICAL OVERSIGHT ANALYSIS")
178
  report.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
@@ -219,7 +208,6 @@ def format_final_report(analysis_results: List[str], filename: str) -> str:
219
  return "\n".join(report)
220
 
221
  def split_content_by_tokens(content: str, max_tokens: int = TARGET_CHUNK_TOKENS) -> List[str]:
222
- """Split content into chunks that fit within token limits"""
223
  paragraphs = re.split(r"\n\s*\n", content)
224
  chunks = []
225
  current_chunk = []
@@ -252,7 +240,6 @@ def split_content_by_tokens(content: str, max_tokens: int = TARGET_CHUNK_TOKENS)
252
  return chunks
253
 
254
  def init_agent():
255
- """Initialize the TxAgent with proper configuration."""
256
  print("🔁 Initializing model...")
257
  log_system_usage("Before Load")
258
 
@@ -277,23 +264,18 @@ def init_agent():
277
  return agent
278
 
279
  def analyze_complete_document(content: str, filename: str, agent: TxAgent, temperature: float = 0.3) -> str:
280
- """Analyze complete document with strict token management"""
281
  chunks = split_content_by_tokens(content)
282
  analysis_results = []
283
 
284
  for i, chunk in enumerate(chunks):
285
  try:
286
- # Ultra-minimal prompt to maximize content space
287
  base_prompt = "Analyze for:\n1. Critical\n2. Missed DX\n3. Med issues\n4. Gaps\n5. Follow-up\n\nContent:\n"
288
 
289
- # Calculate available space for content
290
  prompt_tokens = count_tokens(base_prompt)
291
- max_content_tokens = MAX_MODEL_LEN - prompt_tokens - 100 # Response buffer
292
 
293
- # Ensure chunk fits
294
  chunk_tokens = count_tokens(chunk)
295
  if chunk_tokens > max_content_tokens:
296
- # Find last paragraph that fits
297
  adjusted_chunk = ""
298
  tokens_used = 0
299
  paragraphs = re.split(r"\n\s*\n", chunk)
@@ -307,7 +289,6 @@ def analyze_complete_document(content: str, filename: str, agent: TxAgent, tempe
307
  break
308
 
309
  if not adjusted_chunk:
310
- # If even one paragraph is too big, split sentences
311
  sentences = re.split(r'(?<=[.!?])\s+', chunk)
312
  for sent in sentences:
313
  sent_tokens = count_tokens(sent)
@@ -326,7 +307,7 @@ def analyze_complete_document(content: str, filename: str, agent: TxAgent, tempe
326
  message=prompt,
327
  history=[],
328
  temperature=temperature,
329
- max_new_tokens=300, # Keep responses very concise
330
  max_token=MAX_MODEL_LEN,
331
  call_agent=False,
332
  conversation=[],
@@ -348,7 +329,6 @@ def analyze_complete_document(content: str, filename: str, agent: TxAgent, tempe
348
  return format_final_report(analysis_results, filename)
349
 
350
  def create_ui(agent):
351
- """Create the Gradio interface with enhanced design."""
352
  with gr.Blocks(
353
  theme=gr.themes.Soft(
354
  primary_hue="indigo",
@@ -383,7 +363,6 @@ def create_ui(agent):
383
  }
384
  """
385
  ) as demo:
386
- # Header Section
387
  gr.Markdown("""
388
  <div style='text-align: center; margin-bottom: 20px;'>
389
  <h1 style='color: #2b3a67; margin-bottom: 8px;'>🩺 Clinical Oversight Assistant</h1>
@@ -394,7 +373,6 @@ def create_ui(agent):
394
  """)
395
 
396
  with gr.Row(equal_height=False):
397
- # Left Column - Inputs
398
  with gr.Column(scale=1, min_width=400):
399
  with gr.Group(elem_classes="file-upload"):
400
  file_upload = gr.File(
@@ -431,7 +409,6 @@ def create_ui(agent):
431
  visible=True
432
  )
433
 
434
- # Right Column - Outputs
435
  with gr.Column(scale=2, min_width=600):
436
  with gr.Tabs():
437
  with gr.TabItem("Analysis Report", id="report"):
@@ -459,24 +436,22 @@ def create_ui(agent):
459
  )
460
  gr.Button("Save to EHR", visible=False)
461
 
462
- # Analysis function with UI updates
463
  def analyze(files: List, message: str, temp: float):
464
  if not files:
465
  return (
466
- {"value": "", "visible": True}, # report_output
467
- {"value": None, "visible": False}, # download_output
468
- {"value": "⚠️ Please upload at least one file to analyze.", "visible": True}, # status
469
- {"value": None, "visible": True} # data_preview
470
  )
471
 
472
  yield (
473
  {"value": "", "visible": True},
474
- {"value": None, "visible": False},
475
  {"value": "⏳ Processing documents...", "visible": True},
476
  {"value": None, "visible": True}
477
  )
478
 
479
- # Process files
480
  file_contents = []
481
  filenames = []
482
  preview_data = []
@@ -484,36 +459,39 @@ def create_ui(agent):
484
  with ThreadPoolExecutor(max_workers=4) as executor:
485
  futures = []
486
  for f in files:
 
487
  futures.append(executor.submit(
488
  convert_file_to_json,
489
- f.name,
490
- f.name.split(".")[-1].lower()
491
  ))
492
- filenames.append(os.path.basename(f.name))
493
 
494
  results = []
495
  for future in as_completed(futures):
496
  result = sanitize_utf8(future.result())
497
  try:
498
  data = json.loads(result)
499
- results.append(result)
500
  if "content" in data:
501
  preview_data.append([data["filename"], data["content"][:500] + "..."])
502
- except:
503
- pass
 
504
 
505
  yield (
506
  {"value": "", "visible": True},
507
- {"value": None, "visible": False},
508
  {"value": f"🔍 Analyzing {len(files)} documents...", "visible": True},
509
  {"value": preview_data[:20], "visible": True}
510
  )
511
 
512
  try:
513
  combined_content = "\n".join([
514
- json.loads(fc).get("content", "") if "content" in json.loads(fc)
515
- else str(json.loads(fc).get("rows", ""))
516
- for fc in results
 
517
  ])
518
 
519
  full_report = analyze_complete_document(
@@ -530,7 +508,7 @@ def create_ui(agent):
530
 
531
  yield (
532
  {"value": full_report, "visible": True},
533
- {"value": report_path if os.path.exists(report_path) else None, "visible": True},
534
  {"value": "✅ Analysis complete!", "visible": True},
535
  {"value": preview_data[:20], "visible": True}
536
  )
@@ -540,12 +518,11 @@ def create_ui(agent):
540
  print(error_msg)
541
  yield (
542
  {"value": "", "visible": True},
543
- {"value": None, "visible": False},
544
  {"value": error_msg, "visible": True},
545
  {"value": None, "visible": True}
546
  )
547
 
548
- # Event handlers
549
  send_btn.click(
550
  fn=analyze,
551
  inputs=[file_upload, msg_input, temperature],
@@ -555,12 +532,12 @@ def create_ui(agent):
555
 
556
  clear_btn.click(
557
  fn=lambda: (
558
- None, # file_upload
559
- None, # download_output
560
- "", # status
561
- None, # data_preview
562
- {"value": 0.3}, # temperature
563
- {"value": ""} # msg_input
564
  ),
565
  inputs=None,
566
  outputs=[file_upload, download_output, status, data_preview, temperature, msg_input]
@@ -570,7 +547,6 @@ def create_ui(agent):
570
 
571
  if __name__ == "__main__":
572
  print("🚀 Launching app...")
573
- # Install tiktoken if not available
574
  try:
575
  import tiktoken
576
  except ImportError:
 
46
  'conclusion', 'history', 'examination', 'progress', 'discharge'
47
  }
48
  TOKENIZER = "cl100k_base"
49
+ MAX_MODEL_LEN = 2048
50
+ TARGET_CHUNK_TOKENS = 1200
51
+ PROMPT_RESERVE = 300
52
  MEDICAL_SECTION_HEADER = "=== MEDICAL SECTION ==="
53
 
54
  def log_system_usage(tag=""):
 
55
  try:
56
  cpu = psutil.cpu_percent(interval=1)
57
  mem = psutil.virtual_memory()
 
67
  print(f"[{tag}] GPU/CPU monitor failed: {e}")
68
 
69
  def sanitize_utf8(text: str) -> str:
 
70
  return text.encode("utf-8", "ignore").decode("utf-8")
71
 
72
  def file_hash(path: str) -> str:
 
73
  with open(path, "rb") as f:
74
  return hashlib.md5(f.read()).hexdigest()
75
 
76
  def count_tokens(text: str) -> int:
 
77
  encoding = tiktoken.get_encoding(TOKENIZER)
78
  return len(encoding.encode(text))
79
 
80
  def extract_all_pages_with_token_count(file_path: str) -> Tuple[str, int, int]:
 
 
 
 
81
  try:
82
  text_chunks = []
83
  total_pages = 0
 
104
  return f"PDF processing error: {str(e)}", 0, 0
105
 
106
  def convert_file_to_json(file_path: str, file_type: str) -> str:
 
107
  try:
108
  h = file_hash(file_path)
109
  cache_path = os.path.join(file_cache_dir, f"{h}.json")
 
153
  return json.dumps({"error": f"Error processing {os.path.basename(file_path)}: {str(e)}"})
154
 
155
  def clean_response(text: str) -> str:
 
156
  text = sanitize_utf8(text)
157
  text = re.sub(r"\[TOOL_CALLS\].*", "", text, flags=re.DOTALL)
158
  text = re.sub(r"\['get_[^\]]+\']\n?", "", text)
 
162
  return text
163
 
164
  def format_final_report(analysis_results: List[str], filename: str) -> str:
 
165
  report = []
166
  report.append(f"COMPREHENSIVE CLINICAL OVERSIGHT ANALYSIS")
167
  report.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
 
208
  return "\n".join(report)
209
 
210
  def split_content_by_tokens(content: str, max_tokens: int = TARGET_CHUNK_TOKENS) -> List[str]:
 
211
  paragraphs = re.split(r"\n\s*\n", content)
212
  chunks = []
213
  current_chunk = []
 
240
  return chunks
241
 
242
  def init_agent():
 
243
  print("🔁 Initializing model...")
244
  log_system_usage("Before Load")
245
 
 
264
  return agent
265
 
266
  def analyze_complete_document(content: str, filename: str, agent: TxAgent, temperature: float = 0.3) -> str:
 
267
  chunks = split_content_by_tokens(content)
268
  analysis_results = []
269
 
270
  for i, chunk in enumerate(chunks):
271
  try:
 
272
  base_prompt = "Analyze for:\n1. Critical\n2. Missed DX\n3. Med issues\n4. Gaps\n5. Follow-up\n\nContent:\n"
273
 
 
274
  prompt_tokens = count_tokens(base_prompt)
275
+ max_content_tokens = MAX_MODEL_LEN - prompt_tokens - 100
276
 
 
277
  chunk_tokens = count_tokens(chunk)
278
  if chunk_tokens > max_content_tokens:
 
279
  adjusted_chunk = ""
280
  tokens_used = 0
281
  paragraphs = re.split(r"\n\s*\n", chunk)
 
289
  break
290
 
291
  if not adjusted_chunk:
 
292
  sentences = re.split(r'(?<=[.!?])\s+', chunk)
293
  for sent in sentences:
294
  sent_tokens = count_tokens(sent)
 
307
  message=prompt,
308
  history=[],
309
  temperature=temperature,
310
+ max_new_tokens=300,
311
  max_token=MAX_MODEL_LEN,
312
  call_agent=False,
313
  conversation=[],
 
329
  return format_final_report(analysis_results, filename)
330
 
331
  def create_ui(agent):
 
332
  with gr.Blocks(
333
  theme=gr.themes.Soft(
334
  primary_hue="indigo",
 
363
  }
364
  """
365
  ) as demo:
 
366
  gr.Markdown("""
367
  <div style='text-align: center; margin-bottom: 20px;'>
368
  <h1 style='color: #2b3a67; margin-bottom: 8px;'>🩺 Clinical Oversight Assistant</h1>
 
373
  """)
374
 
375
  with gr.Row(equal_height=False):
 
376
  with gr.Column(scale=1, min_width=400):
377
  with gr.Group(elem_classes="file-upload"):
378
  file_upload = gr.File(
 
409
  visible=True
410
  )
411
 
 
412
  with gr.Column(scale=2, min_width=600):
413
  with gr.Tabs():
414
  with gr.TabItem("Analysis Report", id="report"):
 
436
  )
437
  gr.Button("Save to EHR", visible=False)
438
 
 
439
  def analyze(files: List, message: str, temp: float):
440
  if not files:
441
  return (
442
+ {"value": "", "visible": True},
443
+ None,
444
+ {"value": "⚠️ Please upload at least one file to analyze.", "visible": True},
445
+ {"value": None, "visible": True}
446
  )
447
 
448
  yield (
449
  {"value": "", "visible": True},
450
+ None,
451
  {"value": "⏳ Processing documents...", "visible": True},
452
  {"value": None, "visible": True}
453
  )
454
 
 
455
  file_contents = []
456
  filenames = []
457
  preview_data = []
 
459
  with ThreadPoolExecutor(max_workers=4) as executor:
460
  futures = []
461
  for f in files:
462
+ file_path = f.name
463
  futures.append(executor.submit(
464
  convert_file_to_json,
465
+ file_path,
466
+ os.path.splitext(file_path)[1][1:].lower()
467
  ))
468
+ filenames.append(os.path.basename(file_path))
469
 
470
  results = []
471
  for future in as_completed(futures):
472
  result = sanitize_utf8(future.result())
473
  try:
474
  data = json.loads(result)
475
+ results.append(data)
476
  if "content" in data:
477
  preview_data.append([data["filename"], data["content"][:500] + "..."])
478
+ except Exception as e:
479
+ print(f"Error processing result: {e}")
480
+ continue
481
 
482
  yield (
483
  {"value": "", "visible": True},
484
+ None,
485
  {"value": f"🔍 Analyzing {len(files)} documents...", "visible": True},
486
  {"value": preview_data[:20], "visible": True}
487
  )
488
 
489
  try:
490
  combined_content = "\n".join([
491
+ item.get("content", "") if isinstance(item, dict) and "content" in item
492
+ else str(item.get("rows", "")) if isinstance(item, dict)
493
+ else str(item)
494
+ for item in results
495
  ])
496
 
497
  full_report = analyze_complete_document(
 
508
 
509
  yield (
510
  {"value": full_report, "visible": True},
511
+ report_path if os.path.exists(report_path) else None,
512
  {"value": "✅ Analysis complete!", "visible": True},
513
  {"value": preview_data[:20], "visible": True}
514
  )
 
518
  print(error_msg)
519
  yield (
520
  {"value": "", "visible": True},
521
+ None,
522
  {"value": error_msg, "visible": True},
523
  {"value": None, "visible": True}
524
  )
525
 
 
526
  send_btn.click(
527
  fn=analyze,
528
  inputs=[file_upload, msg_input, temperature],
 
532
 
533
  clear_btn.click(
534
  fn=lambda: (
535
+ None,
536
+ None,
537
+ "",
538
+ None,
539
+ {"value": 0.3},
540
+ {"value": ""}
541
  ),
542
  inputs=None,
543
  outputs=[file_upload, download_output, status, data_preview, temperature, msg_input]
 
547
 
548
  if __name__ == "__main__":
549
  print("🚀 Launching app...")
 
550
  try:
551
  import tiktoken
552
  except ImportError: