Ali2206 commited on
Commit
8547f5e
·
verified ·
1 Parent(s): 543491f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -68
app.py CHANGED
@@ -277,15 +277,57 @@ def init_agent():
277
  return agent
278
 
279
  def create_ui(agent):
280
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
281
  gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
282
- chatbot = gr.Chatbot(label="Detailed Analysis", height=600, type="messages")
283
- final_summary = gr.Markdown(label="Summary of Missed Diagnoses")
284
- file_upload = gr.File(file_types=[".pdf", ".csv", ".xls", ".xlsx"], file_count="multiple")
285
- msg_input = gr.Textbox(placeholder="Ask about potential oversights...", show_label=False)
286
- send_btn = gr.Button("Analyze", variant="primary")
287
- download_output = gr.File(label="Download Full Report")
288
- progress_bar = gr.Progress()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
  prompt_template = """
291
  Analyze the patient record excerpt for missed diagnoses only. Provide a concise, evidence-based summary as a single paragraph without headings or bullet points. Include specific clinical findings (e.g., 'elevated blood pressure (160/95) on page 10'), their potential implications (e.g., 'may indicate untreated hypertension'), and a recommendation for urgent review. Do not include other oversight categories like medication conflicts. If no missed diagnoses are found, state 'No missed diagnoses identified' in a single sentence.
@@ -293,9 +335,37 @@ Patient Record Excerpt (Chunk {0} of {1}):
293
  {chunk}
294
  """
295
 
296
- def analyze(message: str, history: List[dict], files: List, progress=gr.Progress()):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  history.append({"role": "user", "content": message})
298
- yield history, None, ""
 
 
 
 
 
299
 
300
  extracted = []
301
  file_hash_value = ""
@@ -306,11 +376,7 @@ Patient Record Excerpt (Chunk {0} of {1}):
306
  futures = []
307
  for f in files:
308
  file_type = f.name.split(".")[-1].lower()
309
- futures.append(executor.submit(
310
- process_file,
311
- f.name,
312
- file_type
313
- ))
314
 
315
  for future in as_completed(futures):
316
  try:
@@ -321,7 +387,12 @@ Patient Record Excerpt (Chunk {0} of {1}):
321
 
322
  file_hash_value = file_hash(files[0].name) if files else ""
323
  history.append({"role": "assistant", "content": "✅ File processing complete"})
324
- yield history, None, ""
 
 
 
 
 
325
 
326
  # Convert extracted data to JSON text
327
  text_content = "\n".join(json.dumps(item) for item in extracted)
@@ -329,56 +400,45 @@ Patient Record Excerpt (Chunk {0} of {1}):
329
  # Tokenize and chunk the content properly
330
  chunks = tokenize_and_chunk(text_content)
331
  combined_response = ""
332
- batch_size = 2 # Reduced batch size to prevent token overflow
333
 
334
  try:
335
- for batch_idx in range(0, len(chunks), batch_size):
336
- batch_chunks = chunks[batch_idx:batch_idx + batch_size]
337
- batch_prompts = [
338
- prompt_template.format(
339
- batch_idx + i + 1,
340
- len(chunks),
341
- chunk=chunk[:1800] # Conservative chunk size
342
- )
343
- for i, chunk in enumerate(batch_chunks)
344
- ]
345
 
346
- progress((batch_idx) / len(chunks),
347
- desc=f"Analyzing batch {(batch_idx // batch_size) + 1}/{(len(chunks) + batch_size - 1) // batch_size}")
 
 
 
 
 
 
 
 
 
 
348
 
349
- # Process batch in parallel
350
- with ThreadPoolExecutor(max_workers=len(batch_prompts)) as executor:
351
- future_to_prompt = {
352
- executor.submit(
353
- agent.run_gradio_chat,
354
- prompt, [], 0.2, 512, 2048, False, []
355
- ): prompt
356
- for prompt in batch_prompts
 
 
 
 
 
 
357
  }
358
-
359
- for future in as_completed(future_to_prompt):
360
- chunk_response = ""
361
- for chunk_output in future.result():
362
- if chunk_output is None:
363
- continue
364
- if isinstance(chunk_output, list):
365
- for m in chunk_output:
366
- if hasattr(m, 'content') and m.content:
367
- cleaned = clean_response(m.content)
368
- if cleaned:
369
- chunk_response += cleaned + " "
370
- elif isinstance(chunk_output, str) and chunk_output.strip():
371
- cleaned = clean_response(chunk_output)
372
- if cleaned:
373
- chunk_response += cleaned + " "
374
-
375
- combined_response += f"--- Analysis for Chunk {batch_idx + 1} ---\n{chunk_response.strip()}\n"
376
- history[-1] = {"role": "assistant", "content": combined_response.strip()}
377
- yield history, None, ""
378
-
379
- # Clean up memory
380
- torch.cuda.empty_cache()
381
- gc.collect()
382
 
383
  # Generate final summary
384
  summary = summarize_findings(combined_response)
@@ -387,15 +447,53 @@ Patient Record Excerpt (Chunk {0} of {1}):
387
  with open(report_path, "w", encoding="utf-8") as f:
388
  f.write(combined_response + "\n\n" + summary)
389
 
390
- yield history, report_path if report_path and os.path.exists(report_path) else None, summary
 
 
 
 
 
391
 
392
  except Exception as e:
393
  logger.error("Analysis error: %s", e)
394
  history.append({"role": "assistant", "content": f"❌ Error occurred: {str(e)}"})
395
- yield history, None, f"Error occurred during analysis: {str(e)}"
396
-
397
- send_btn.click(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary])
398
- msg_input.submit(analyze, inputs=[msg_input, gr.State([]), file_upload], outputs=[chatbot, download_output, final_summary])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  return demo
400
 
401
  if __name__ == "__main__":
@@ -403,13 +501,20 @@ if __name__ == "__main__":
403
  logger.info("Launching app...")
404
  agent = init_agent()
405
  demo = create_ui(agent)
406
- demo.queue(api_open=False).launch(
 
 
 
407
  server_name="0.0.0.0",
408
  server_port=7860,
409
  show_error=True,
410
  allowed_paths=[report_dir],
411
- share=False
 
412
  )
 
 
 
413
  finally:
414
  if torch.distributed.is_initialized():
415
  torch.distributed.destroy_process_group()
 
277
  return agent
278
 
279
  def create_ui(agent):
280
+ with gr.Blocks(theme=gr.themes.Soft(), title="Clinical Oversight Assistant") as demo:
281
  gr.Markdown("<h1 style='text-align: center;'>🩺 Clinical Oversight Assistant</h1>")
282
+
283
+ with gr.Row():
284
+ with gr.Column(scale=3):
285
+ chatbot = gr.Chatbot(
286
+ label="Analysis Conversation",
287
+ height=600,
288
+ bubble_full_width=False,
289
+ show_copy_button=True,
290
+ avatar_images=(
291
+ "assets/user.png",
292
+ "assets/assistant.png"
293
+ )
294
+ )
295
+ with gr.Column(scale=1):
296
+ final_summary = gr.Markdown(
297
+ label="Summary of Findings",
298
+ value="### Summary will appear here\nAfter analysis completes"
299
+ )
300
+ download_output = gr.File(
301
+ label="Download Full Report",
302
+ visible=False
303
+ )
304
+
305
+ with gr.Row():
306
+ file_upload = gr.File(
307
+ file_types=[".pdf", ".csv", ".xls", ".xlsx"],
308
+ file_count="multiple",
309
+ label="Upload Patient Records"
310
+ )
311
+
312
+ with gr.Row():
313
+ msg_input = gr.Textbox(
314
+ placeholder="Ask about potential oversights...",
315
+ show_label=False,
316
+ container=False,
317
+ scale=7,
318
+ autofocus=True
319
+ )
320
+ send_btn = gr.Button(
321
+ "Analyze",
322
+ variant="primary",
323
+ scale=1,
324
+ min_width=100
325
+ )
326
+
327
+ progress_bar = gr.Progress(
328
+ label="Processing Progress",
329
+ visible=False
330
+ )
331
 
332
  prompt_template = """
333
  Analyze the patient record excerpt for missed diagnoses only. Provide a concise, evidence-based summary as a single paragraph without headings or bullet points. Include specific clinical findings (e.g., 'elevated blood pressure (160/95) on page 10'), their potential implications (e.g., 'may indicate untreated hypertension'), and a recommendation for urgent review. Do not include other oversight categories like medication conflicts. If no missed diagnoses are found, state 'No missed diagnoses identified' in a single sentence.
 
335
  {chunk}
336
  """
337
 
338
+ def process_response_stream(prompt: str, history: List[dict]) -> Generator[dict, None, None]:
339
+ """Process a single prompt and stream the response"""
340
+ full_response = ""
341
+ for chunk_output in agent.run_gradio_chat(prompt, [], 0.2, 512, 2048, False, []):
342
+ if chunk_output is None:
343
+ continue
344
+
345
+ if isinstance(chunk_output, list):
346
+ for m in chunk_output:
347
+ if hasattr(m, 'content') and m.content:
348
+ cleaned = clean_response(m.content)
349
+ if cleaned:
350
+ full_response += cleaned + " "
351
+ yield {"role": "assistant", "content": full_response}
352
+ elif isinstance(chunk_output, str) and chunk_output.strip():
353
+ cleaned = clean_response(chunk_output)
354
+ if cleaned:
355
+ full_response += cleaned + " "
356
+ yield {"role": "assistant", "content": full_response}
357
+
358
+ return full_response
359
+
360
+ def analyze(message: str, history: List[dict], files: List) -> Generator[dict, None, None]:
361
+ # Start with user message
362
  history.append({"role": "user", "content": message})
363
+ yield {
364
+ "chatbot": history,
365
+ "download_output": None,
366
+ "final_summary": "",
367
+ "progress_bar": gr.Progress(visible=True)
368
+ }
369
 
370
  extracted = []
371
  file_hash_value = ""
 
376
  futures = []
377
  for f in files:
378
  file_type = f.name.split(".")[-1].lower()
379
+ futures.append(executor.submit(process_file, f.name, file_type))
 
 
 
 
380
 
381
  for future in as_completed(futures):
382
  try:
 
387
 
388
  file_hash_value = file_hash(files[0].name) if files else ""
389
  history.append({"role": "assistant", "content": "✅ File processing complete"})
390
+ yield {
391
+ "chatbot": history,
392
+ "download_output": None,
393
+ "final_summary": "",
394
+ "progress_bar": gr.Progress(0.2, visible=True, label="Processing files")
395
+ }
396
 
397
  # Convert extracted data to JSON text
398
  text_content = "\n".join(json.dumps(item) for item in extracted)
 
400
  # Tokenize and chunk the content properly
401
  chunks = tokenize_and_chunk(text_content)
402
  combined_response = ""
 
403
 
404
  try:
405
+ for chunk_idx, chunk in enumerate(chunks, 1):
406
+ prompt = prompt_template.format(chunk_idx, len(chunks), chunk=chunk[:1800])
 
 
 
 
 
 
 
 
407
 
408
+ # Create a placeholder message
409
+ history.append({"role": "assistant", "content": ""})
410
+ yield {
411
+ "chatbot": history,
412
+ "download_output": None,
413
+ "final_summary": "",
414
+ "progress_bar": gr.Progress(
415
+ 0.2 + (chunk_idx/len(chunks))*0.7,
416
+ visible=True,
417
+ label=f"Analyzing chunk {chunk_idx}/{len(chunks)}"
418
+ )
419
+ }
420
 
421
+ # Process and stream the response
422
+ chunk_response = ""
423
+ for update in process_response_stream(prompt, history):
424
+ # Update the last message with streaming content
425
+ history[-1] = update
426
+ chunk_response = update["content"]
427
+ yield {
428
+ "chatbot": history,
429
+ "download_output": None,
430
+ "final_summary": "",
431
+ "progress_bar": gr.Progress(
432
+ 0.2 + (chunk_idx/len(chunks))*0.7,
433
+ visible=True
434
+ )
435
  }
436
+
437
+ combined_response += f"--- Analysis for Chunk {chunk_idx} ---\n{chunk_response}\n"
438
+
439
+ # Clean up memory
440
+ torch.cuda.empty_cache()
441
+ gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
 
443
  # Generate final summary
444
  summary = summarize_findings(combined_response)
 
447
  with open(report_path, "w", encoding="utf-8") as f:
448
  f.write(combined_response + "\n\n" + summary)
449
 
450
+ yield {
451
+ "chatbot": history,
452
+ "download_output": gr.File(report_path) if report_path and os.path.exists(report_path) else None,
453
+ "final_summary": summary,
454
+ "progress_bar": gr.Progress(1.0, visible=False)
455
+ }
456
 
457
  except Exception as e:
458
  logger.error("Analysis error: %s", e)
459
  history.append({"role": "assistant", "content": f"❌ Error occurred: {str(e)}"})
460
+ yield {
461
+ "chatbot": history,
462
+ "download_output": None,
463
+ "final_summary": f"Error occurred during analysis: {str(e)}",
464
+ "progress_bar": gr.Progress(visible=False)
465
+ }
466
+
467
+ def clear_and_start():
468
+ return {
469
+ "chatbot": [],
470
+ "download_output": None,
471
+ "final_summary": "",
472
+ "msg_input": "",
473
+ "file_upload": None
474
+ }
475
+
476
+ # Event handlers
477
+ send_btn.click(
478
+ analyze,
479
+ inputs=[msg_input, chatbot, file_upload],
480
+ outputs=[chatbot, download_output, final_summary, progress_bar],
481
+ show_progress="hidden"
482
+ )
483
+
484
+ msg_input.submit(
485
+ analyze,
486
+ inputs=[msg_input, chatbot, file_upload],
487
+ outputs=[chatbot, download_output, final_summary, progress_bar],
488
+ show_progress="hidden"
489
+ )
490
+
491
+ demo.load(
492
+ clear_and_start,
493
+ outputs=[chatbot, download_output, final_summary, msg_input, file_upload],
494
+ queue=False
495
+ )
496
+
497
  return demo
498
 
499
  if __name__ == "__main__":
 
501
  logger.info("Launching app...")
502
  agent = init_agent()
503
  demo = create_ui(agent)
504
+ demo.queue(
505
+ api_open=False,
506
+ max_size=20
507
+ ).launch(
508
  server_name="0.0.0.0",
509
  server_port=7860,
510
  show_error=True,
511
  allowed_paths=[report_dir],
512
+ share=False,
513
+ favicon_path="assets/favicon.ico"
514
  )
515
+ except Exception as e:
516
+ logger.error(f"Failed to launch app: {e}")
517
+ raise
518
  finally:
519
  if torch.distributed.is_initialized():
520
  torch.distributed.destroy_process_group()