Ali2206 commited on
Commit
5bfcdc0
·
verified ·
1 Parent(s): d88209d

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +109 -87
src/txagent/txagent.py CHANGED
@@ -14,6 +14,8 @@ from .toolrag import ToolRAGModel
14
  import torch
15
  import logging
16
  from difflib import SequenceMatcher
 
 
17
 
18
  logger = logging.getLogger(__name__)
19
  logging.basicConfig(level=logging.INFO)
@@ -102,7 +104,6 @@ class TxAgent:
102
 
103
  def initialize_tools_prompt(self, call_agent, call_agent_level, message):
104
  picked_tools_prompt = []
105
- # Only add Finish tool unless prompt explicitly requires Tool_RAG or CallAgent
106
  if "use external tools" not in message.lower():
107
  picked_tools_prompt = self.add_special_tools(picked_tools_prompt, call_agent=False)
108
  else:
@@ -319,7 +320,6 @@ class TxAgent:
319
  if self.enable_checker:
320
  checker = ReasoningTraceChecker(message, conversation)
321
 
322
- # Check if message contains clinical findings
323
  clinical_keywords = ['medication', 'symptom', 'evaluation', 'diagnosis']
324
  has_clinical_data = any(keyword in message.lower() for keyword in clinical_keywords)
325
 
@@ -355,7 +355,6 @@ class TxAgent:
355
  logger.warning("Checker error: %s", wrong_info)
356
  break
357
 
358
- # Skip tool calls if clinical data is present
359
  tools = [] if has_clinical_data else picked_tools_prompt
360
  last_outputs = []
361
  last_outputs_str, token_overflow = self.llm_infer(
@@ -382,7 +381,6 @@ class TxAgent:
382
  m['content'] for m in messages[-3:] if m['role'] == 'assistant'
383
  ][:2]
384
  forbidden_ids = [tokenizer.encode(msg, add_special_tokens=False) for msg in assistant_messages]
385
- # Enhance deduplication with similarity check
386
  unique_sentences = []
387
  for msg in assistant_messages:
388
  sentences = msg.split('. ')
@@ -397,7 +395,7 @@ class TxAgent:
397
  if is_unique:
398
  unique_sentences.append(s)
399
  forbidden_ids = [tokenizer.encode(s, add_special_tokens=False) for s in unique_sentences]
400
- return [NoRepeatSentenceProcessor(forbidden_ids, 10)] # Increased penalty
401
  return None
402
 
403
  def llm_infer(self, messages, temperature=0.1, tools=None, output_begin_string=None,
@@ -435,6 +433,28 @@ class TxAgent:
435
  return output, False
436
  return output
437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
  def run_self_agent(self, message: str, temperature: float, max_new_tokens: int, max_token: int):
439
  logger.debug("Starting self agent")
440
  conversation = self.set_system_prompt([], self.self_prompt)
@@ -565,30 +585,19 @@ Summarize the function responses in one sentence with all necessary information.
565
  logger.debug("Updated parameters: %s", updated_attributes)
566
  return updated_attributes
567
 
568
- def run_gradio_chat(self, message: str, history: list, temperature: float,
569
- max_new_tokens: int, max_token: int, call_agent: bool,
570
- conversation: gr.State, max_round: int = 3, seed: int = None,
571
- call_agent_level: int = 0, sub_agent_task: str = None,
572
- uploaded_files: list = None):
573
- logger.debug("Chat started, message: %s", message[:100])
574
- if not message or len(message.strip()) < 5:
575
- yield "Please provide a valid message or upload files to analyze."
576
- return
577
-
578
- if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
579
- return
580
-
581
- # Check if message contains clinical findings
582
- clinical_keywords = ['medication', 'symptom', 'evaluation', 'diagnosis']
583
- has_clinical_data = any(keyword in message.lower() for keyword in clinical_keywords)
584
- call_agent = call_agent and not has_clinical_data # Disable CallAgent for clinical data
585
-
586
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
587
  call_agent, call_agent_level, message)
588
- conversation = self.initialize_conversation(
589
- message, conversation, history)
590
- history = []
591
-
592
  next_round = True
593
  current_round = 0
594
  enable_summary = False
@@ -603,24 +612,17 @@ Summarize the function responses in one sentence with all necessary information.
603
  current_round += 1
604
  last_outputs = []
605
  if last_outputs:
606
- function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
607
  last_outputs, return_message=True, existing_tools_prompt=picked_tools_prompt,
608
  message_for_call_agent=message, call_agent=call_agent,
609
- call_agent_level=call_agent_level, temperature=temperature)
610
- history.extend(current_gradio_history)
611
-
612
  if special_tool_call == 'Finish':
613
- yield history
614
  next_round = False
615
  conversation.extend(function_call_messages)
616
- return function_call_messages[0]['content']
617
-
618
- if special_tool_call in ['RequireClarification', 'DirectResponse']:
619
- last_msg = history[-1] if history else ChatMessage(role="assistant", content="Response needed.")
620
- history.append(ChatMessage(role="assistant", content=last_msg.content))
621
- yield history
622
- next_round = False
623
- return last_msg.content
624
 
625
  if (self.enable_summary or token_overflow) and not call_agent:
626
  enable_summary = True
@@ -629,10 +631,11 @@ Summarize the function responses in one sentence with all necessary information.
629
 
630
  if function_call_messages:
631
  conversation.extend(function_call_messages)
632
- yield history
633
  else:
634
  next_round = False
635
- return ''.join(last_outputs).replace("</s>", "")
 
636
 
637
  if self.enable_checker:
638
  good_status, wrong_info = checker.check_conversation()
@@ -640,8 +643,7 @@ Summarize the function responses in one sentence with all necessary information.
640
  logger.warning("Checker error: %s", wrong_info)
641
  break
642
 
643
- # Skip tool calls if clinical data is present
644
- tools = [] if has_clinical_data else picked_tools_prompt
645
  last_outputs_str, token_overflow = self.llm_infer(
646
  messages=conversation, temperature=temperature, tools=tools,
647
  max_new_tokens=max_new_tokens, max_token=max_token, seed=seed, check_token_status=True)
@@ -650,54 +652,74 @@ Summarize the function responses in one sentence with all necessary information.
650
  if self.force_finish:
651
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
652
  conversation, temperature, max_new_tokens, max_token)
653
- history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
654
- yield history
655
- return last_outputs_str
656
- error_msg = "Token limit exceeded."
657
- history.append(ChatMessage(role="assistant", content=error_msg))
658
- yield history
659
- return error_msg
660
-
661
- last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
662
- for msg in history:
663
- if msg.metadata:
664
- msg.metadata['status'] = 'done'
665
-
666
- if '[FinalAnswer]' in last_thought:
667
- parts = last_thought.split('[FinalAnswer]', 1)
668
- final_thought, final_answer = parts if len(parts) == 2 else (last_thought, "")
669
- history.append(ChatMessage(role="assistant", content=final_thought.strip()))
670
- yield history
671
- history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
672
- yield history
673
- else:
674
- history.append(ChatMessage(role="assistant", content=last_thought))
675
- yield history
676
 
 
677
  last_outputs.append(last_outputs_str)
678
 
679
  if next_round and self.force_finish:
680
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
681
  conversation, temperature, max_new_tokens, max_token)
682
- parts = last_outputs_str.split('[FinalAnswer]', 1)
683
- final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
684
- history.append(ChatMessage(role="assistant", content=final_thought.strip()))
685
- yield history
686
- history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
687
- yield history
 
 
 
688
 
689
  except Exception as e:
690
- logger.error("Exception in run_gradio_chat: %s", e, exc_info=True)
691
- error_msg = f"Error: {e}"
692
- history.append(ChatMessage(role="assistant", content=error_msg))
693
- yield history
694
- if self.force_finish:
695
- last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
696
- conversation, temperature, max_new_tokens, max_token)
697
- parts = last_outputs_str.split('[FinalAnswer]', 1)
698
- final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
699
- history.append(ChatMessage(role="assistant", content=final_thought.strip()))
700
- yield history
701
- history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
702
- yield history
703
- return error_msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  import torch
15
  import logging
16
  from difflib import SequenceMatcher
17
+ import asyncio
18
+ import threading
19
 
20
  logger = logging.getLogger(__name__)
21
  logging.basicConfig(level=logging.INFO)
 
104
 
105
  def initialize_tools_prompt(self, call_agent, call_agent_level, message):
106
  picked_tools_prompt = []
 
107
  if "use external tools" not in message.lower():
108
  picked_tools_prompt = self.add_special_tools(picked_tools_prompt, call_agent=False)
109
  else:
 
320
  if self.enable_checker:
321
  checker = ReasoningTraceChecker(message, conversation)
322
 
 
323
  clinical_keywords = ['medication', 'symptom', 'evaluation', 'diagnosis']
324
  has_clinical_data = any(keyword in message.lower() for keyword in clinical_keywords)
325
 
 
355
  logger.warning("Checker error: %s", wrong_info)
356
  break
357
 
 
358
  tools = [] if has_clinical_data else picked_tools_prompt
359
  last_outputs = []
360
  last_outputs_str, token_overflow = self.llm_infer(
 
381
  m['content'] for m in messages[-3:] if m['role'] == 'assistant'
382
  ][:2]
383
  forbidden_ids = [tokenizer.encode(msg, add_special_tokens=False) for msg in assistant_messages]
 
384
  unique_sentences = []
385
  for msg in assistant_messages:
386
  sentences = msg.split('. ')
 
395
  if is_unique:
396
  unique_sentences.append(s)
397
  forbidden_ids = [tokenizer.encode(s, add_special_tokens=False) for s in unique_sentences]
398
+ return [NoRepeatSentenceProcessor(forbidden_ids, 15)] # Increased penalty
399
  return None
400
 
401
  def llm_infer(self, messages, temperature=0.1, tools=None, output_begin_string=None,
 
433
  return output, False
434
  return output
435
 
436
+ def run_quick_summary(self, message: str, temperature: float = 0.1, max_new_tokens: int = 256, max_token: int = 1024):
437
+ """Generate a fast, concise summary of potential missed diagnoses without tool calls"""
438
+ logger.debug("Starting quick summary for message: %s", message[:100])
439
+ prompt = """
440
+ Analyze the patient record excerpt for missed diagnoses, focusing ONLY on clinical findings such as symptoms, medications, or evaluation results. Provide a concise summary in ONE paragraph without headings or bullet points. ALWAYS treat medications or psychiatric evaluations as potential missed diagnoses, specifying their implications (e.g., 'use of Seroquel may indicate untreated psychosis'). Recommend urgent review for identified findings. Do NOT use external tools or repeat non-clinical data (e.g., name, date of birth). If no clinical findings are present, state 'No missed diagnoses identified' in ONE sentence.
441
+ Patient Record Excerpt:
442
+ {chunk}
443
+ """
444
+ conversation = self.set_system_prompt([], prompt.format(chunk=message))
445
+ conversation.append({"role": "user", "content": message})
446
+ output = self.llm_infer(
447
+ messages=conversation,
448
+ temperature=temperature,
449
+ max_new_tokens=max_new_tokens,
450
+ max_token=max_token,
451
+ tools=[] # No tools
452
+ )
453
+ if '[FinalAnswer]' in output:
454
+ output = output.split('[FinalAnswer]')[-1].strip()
455
+ logger.debug("Quick summary output: %s", output[:100])
456
+ return output
457
+
458
  def run_self_agent(self, message: str, temperature: float, max_new_tokens: int, max_token: int):
459
  logger.debug("Starting self agent")
460
  conversation = self.set_system_prompt([], self.self_prompt)
 
585
  logger.debug("Updated parameters: %s", updated_attributes)
586
  return updated_attributes
587
 
588
+ async def run_background_report(self, message: str, history: list, temperature: float,
589
+ max_new_tokens: int, max_token: int, call_agent: bool,
590
+ conversation: gr.State, max_round: int, seed: int,
591
+ call_agent_level: int, report_path: str):
592
+ """Run detailed report generation in the background and save to file"""
593
+ logger.debug("Starting background report for message: %s", message[:100])
594
+ combined_response = ""
595
+ history_copy = history.copy()
596
+
 
 
 
 
 
 
 
 
 
597
  picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
598
  call_agent, call_agent_level, message)
599
+ conversation = self.initialize_conversation(message, conversation, history_copy)
600
+
 
 
601
  next_round = True
602
  current_round = 0
603
  enable_summary = False
 
612
  current_round += 1
613
  last_outputs = []
614
  if last_outputs:
615
+ function_call_messages, picked_tools_prompt, special_tool_call, _ = yield from self.run_function_call_stream(
616
  last_outputs, return_message=True, existing_tools_prompt=picked_tools_prompt,
617
  message_for_call_agent=message, call_agent=call_agent,
618
+ call_agent_level=call_agent_level, temperature=temperature,
619
+ return_gradio_history=False)
620
+
621
  if special_tool_call == 'Finish':
 
622
  next_round = False
623
  conversation.extend(function_call_messages)
624
+ combined_response += function_call_messages[0]['content'] + "\n"
625
+ break
 
 
 
 
 
 
626
 
627
  if (self.enable_summary or token_overflow) and not call_agent:
628
  enable_summary = True
 
631
 
632
  if function_call_messages:
633
  conversation.extend(function_call_messages)
634
+ combined_response += tool_result_format(function_call_messages) + "\n"
635
  else:
636
  next_round = False
637
+ combined_response += ''.join(last_outputs).replace("</s>", "") + "\n"
638
+ break
639
 
640
  if self.enable_checker:
641
  good_status, wrong_info = checker.check_conversation()
 
643
  logger.warning("Checker error: %s", wrong_info)
644
  break
645
 
646
+ tools = picked_tools_prompt
 
647
  last_outputs_str, token_overflow = self.llm_infer(
648
  messages=conversation, temperature=temperature, tools=tools,
649
  max_new_tokens=max_new_tokens, max_token=max_token, seed=seed, check_token_status=True)
 
652
  if self.force_finish:
653
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
654
  conversation, temperature, max_new_tokens, max_token)
655
+ combined_response += last_outputs_str + "\n"
656
+ break
657
+ combined_response += "Token limit exceeded.\n"
658
+ break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659
 
660
+ combined_response += last_outputs_str + "\n"
661
  last_outputs.append(last_outputs_str)
662
 
663
  if next_round and self.force_finish:
664
  last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
665
  conversation, temperature, max_new_tokens, max_token)
666
+ combined_response += last_outputs_str + "\n"
667
+
668
+ # Save report
669
+ try:
670
+ with open(report_path, "w", encoding="utf-8") as f:
671
+ f.write(combined_response)
672
+ logger.info("Detailed report saved to %s", report_path)
673
+ except Exception as e:
674
+ logger.error("Failed to save report: %s", e)
675
 
676
  except Exception as e:
677
+ logger.error("Background report error: %s", e)
678
+ combined_response += f"Error: {e}\n"
679
+ with open(report_path, "w", encoding="utf-8") as f:
680
+ f.write(combined_response)
681
+
682
+ finally:
683
+ torch.cuda.empty_cache()
684
+ gc.collect()
685
+
686
+ def run_gradio_chat(self, message: str, history: list, temperature: float,
687
+ max_new_tokens: int, max_token: int, call_agent: bool,
688
+ conversation: gr.State, max_round: int = 3, seed: int = None,
689
+ call_agent_level: int = 0, sub_agent_task: str = None,
690
+ uploaded_files: list = None, report_path: str = None):
691
+ logger.debug("Chat started, message: %s", message[:100])
692
+ if not message or len(message.strip()) < 5:
693
+ yield "Please provide a valid message or upload files to analyze."
694
+ return
695
+
696
+ if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
697
+ return
698
+
699
+ clinical_keywords = ['medication', 'symptom', 'evaluation', 'diagnosis']
700
+ has_clinical_data = any(keyword in message.lower() for keyword in clinical_keywords)
701
+ call_agent = call_agent and not has_clinical_data
702
+
703
+ # Generate quick summary
704
+ quick_summary = self.run_quick_summary(
705
+ message, temperature=temperature, max_new_tokens=256, max_token=1024)
706
+ history.append(ChatMessage(role="assistant", content=f"**Quick Summary:**\n{quick_summary}"))
707
+ yield history
708
+
709
+ # Start background report generation
710
+ if report_path:
711
+ loop = asyncio.get_event_loop()
712
+ threading.Thread(
713
+ target=lambda: loop.run_until_complete(
714
+ self.run_background_report(
715
+ message, history, temperature, max_new_tokens, max_token, call_agent,
716
+ conversation, max_round, seed, call_agent_level, report_path
717
+ )
718
+ ),
719
+ daemon=True
720
+ ).start()
721
+ history.append(ChatMessage(
722
+ role="assistant",
723
+ content="Generating detailed report in the background. Download will be available when ready."
724
+ ))
725
+ yield history