Update src/txagent/txagent.py
Browse files- src/txagent/txagent.py +109 -87
src/txagent/txagent.py
CHANGED
@@ -14,6 +14,8 @@ from .toolrag import ToolRAGModel
|
|
14 |
import torch
|
15 |
import logging
|
16 |
from difflib import SequenceMatcher
|
|
|
|
|
17 |
|
18 |
logger = logging.getLogger(__name__)
|
19 |
logging.basicConfig(level=logging.INFO)
|
@@ -102,7 +104,6 @@ class TxAgent:
|
|
102 |
|
103 |
def initialize_tools_prompt(self, call_agent, call_agent_level, message):
|
104 |
picked_tools_prompt = []
|
105 |
-
# Only add Finish tool unless prompt explicitly requires Tool_RAG or CallAgent
|
106 |
if "use external tools" not in message.lower():
|
107 |
picked_tools_prompt = self.add_special_tools(picked_tools_prompt, call_agent=False)
|
108 |
else:
|
@@ -319,7 +320,6 @@ class TxAgent:
|
|
319 |
if self.enable_checker:
|
320 |
checker = ReasoningTraceChecker(message, conversation)
|
321 |
|
322 |
-
# Check if message contains clinical findings
|
323 |
clinical_keywords = ['medication', 'symptom', 'evaluation', 'diagnosis']
|
324 |
has_clinical_data = any(keyword in message.lower() for keyword in clinical_keywords)
|
325 |
|
@@ -355,7 +355,6 @@ class TxAgent:
|
|
355 |
logger.warning("Checker error: %s", wrong_info)
|
356 |
break
|
357 |
|
358 |
-
# Skip tool calls if clinical data is present
|
359 |
tools = [] if has_clinical_data else picked_tools_prompt
|
360 |
last_outputs = []
|
361 |
last_outputs_str, token_overflow = self.llm_infer(
|
@@ -382,7 +381,6 @@ class TxAgent:
|
|
382 |
m['content'] for m in messages[-3:] if m['role'] == 'assistant'
|
383 |
][:2]
|
384 |
forbidden_ids = [tokenizer.encode(msg, add_special_tokens=False) for msg in assistant_messages]
|
385 |
-
# Enhance deduplication with similarity check
|
386 |
unique_sentences = []
|
387 |
for msg in assistant_messages:
|
388 |
sentences = msg.split('. ')
|
@@ -397,7 +395,7 @@ class TxAgent:
|
|
397 |
if is_unique:
|
398 |
unique_sentences.append(s)
|
399 |
forbidden_ids = [tokenizer.encode(s, add_special_tokens=False) for s in unique_sentences]
|
400 |
-
return [NoRepeatSentenceProcessor(forbidden_ids,
|
401 |
return None
|
402 |
|
403 |
def llm_infer(self, messages, temperature=0.1, tools=None, output_begin_string=None,
|
@@ -435,6 +433,28 @@ class TxAgent:
|
|
435 |
return output, False
|
436 |
return output
|
437 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
438 |
def run_self_agent(self, message: str, temperature: float, max_new_tokens: int, max_token: int):
|
439 |
logger.debug("Starting self agent")
|
440 |
conversation = self.set_system_prompt([], self.self_prompt)
|
@@ -565,30 +585,19 @@ Summarize the function responses in one sentence with all necessary information.
|
|
565 |
logger.debug("Updated parameters: %s", updated_attributes)
|
566 |
return updated_attributes
|
567 |
|
568 |
-
def
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
logger.debug("
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
|
579 |
-
return
|
580 |
-
|
581 |
-
# Check if message contains clinical findings
|
582 |
-
clinical_keywords = ['medication', 'symptom', 'evaluation', 'diagnosis']
|
583 |
-
has_clinical_data = any(keyword in message.lower() for keyword in clinical_keywords)
|
584 |
-
call_agent = call_agent and not has_clinical_data # Disable CallAgent for clinical data
|
585 |
-
|
586 |
picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
|
587 |
call_agent, call_agent_level, message)
|
588 |
-
conversation = self.initialize_conversation(
|
589 |
-
|
590 |
-
history = []
|
591 |
-
|
592 |
next_round = True
|
593 |
current_round = 0
|
594 |
enable_summary = False
|
@@ -603,24 +612,17 @@ Summarize the function responses in one sentence with all necessary information.
|
|
603 |
current_round += 1
|
604 |
last_outputs = []
|
605 |
if last_outputs:
|
606 |
-
function_call_messages, picked_tools_prompt, special_tool_call,
|
607 |
last_outputs, return_message=True, existing_tools_prompt=picked_tools_prompt,
|
608 |
message_for_call_agent=message, call_agent=call_agent,
|
609 |
-
call_agent_level=call_agent_level, temperature=temperature
|
610 |
-
|
611 |
-
|
612 |
if special_tool_call == 'Finish':
|
613 |
-
yield history
|
614 |
next_round = False
|
615 |
conversation.extend(function_call_messages)
|
616 |
-
|
617 |
-
|
618 |
-
if special_tool_call in ['RequireClarification', 'DirectResponse']:
|
619 |
-
last_msg = history[-1] if history else ChatMessage(role="assistant", content="Response needed.")
|
620 |
-
history.append(ChatMessage(role="assistant", content=last_msg.content))
|
621 |
-
yield history
|
622 |
-
next_round = False
|
623 |
-
return last_msg.content
|
624 |
|
625 |
if (self.enable_summary or token_overflow) and not call_agent:
|
626 |
enable_summary = True
|
@@ -629,10 +631,11 @@ Summarize the function responses in one sentence with all necessary information.
|
|
629 |
|
630 |
if function_call_messages:
|
631 |
conversation.extend(function_call_messages)
|
632 |
-
|
633 |
else:
|
634 |
next_round = False
|
635 |
-
|
|
|
636 |
|
637 |
if self.enable_checker:
|
638 |
good_status, wrong_info = checker.check_conversation()
|
@@ -640,8 +643,7 @@ Summarize the function responses in one sentence with all necessary information.
|
|
640 |
logger.warning("Checker error: %s", wrong_info)
|
641 |
break
|
642 |
|
643 |
-
|
644 |
-
tools = [] if has_clinical_data else picked_tools_prompt
|
645 |
last_outputs_str, token_overflow = self.llm_infer(
|
646 |
messages=conversation, temperature=temperature, tools=tools,
|
647 |
max_new_tokens=max_new_tokens, max_token=max_token, seed=seed, check_token_status=True)
|
@@ -650,54 +652,74 @@ Summarize the function responses in one sentence with all necessary information.
|
|
650 |
if self.force_finish:
|
651 |
last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
|
652 |
conversation, temperature, max_new_tokens, max_token)
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
history.append(ChatMessage(role="assistant", content=error_msg))
|
658 |
-
yield history
|
659 |
-
return error_msg
|
660 |
-
|
661 |
-
last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
|
662 |
-
for msg in history:
|
663 |
-
if msg.metadata:
|
664 |
-
msg.metadata['status'] = 'done'
|
665 |
-
|
666 |
-
if '[FinalAnswer]' in last_thought:
|
667 |
-
parts = last_thought.split('[FinalAnswer]', 1)
|
668 |
-
final_thought, final_answer = parts if len(parts) == 2 else (last_thought, "")
|
669 |
-
history.append(ChatMessage(role="assistant", content=final_thought.strip()))
|
670 |
-
yield history
|
671 |
-
history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
|
672 |
-
yield history
|
673 |
-
else:
|
674 |
-
history.append(ChatMessage(role="assistant", content=last_thought))
|
675 |
-
yield history
|
676 |
|
|
|
677 |
last_outputs.append(last_outputs_str)
|
678 |
|
679 |
if next_round and self.force_finish:
|
680 |
last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
|
681 |
conversation, temperature, max_new_tokens, max_token)
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
|
|
|
|
|
|
688 |
|
689 |
except Exception as e:
|
690 |
-
logger.error("
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
import torch
|
15 |
import logging
|
16 |
from difflib import SequenceMatcher
|
17 |
+
import asyncio
|
18 |
+
import threading
|
19 |
|
20 |
logger = logging.getLogger(__name__)
|
21 |
logging.basicConfig(level=logging.INFO)
|
|
|
104 |
|
105 |
def initialize_tools_prompt(self, call_agent, call_agent_level, message):
|
106 |
picked_tools_prompt = []
|
|
|
107 |
if "use external tools" not in message.lower():
|
108 |
picked_tools_prompt = self.add_special_tools(picked_tools_prompt, call_agent=False)
|
109 |
else:
|
|
|
320 |
if self.enable_checker:
|
321 |
checker = ReasoningTraceChecker(message, conversation)
|
322 |
|
|
|
323 |
clinical_keywords = ['medication', 'symptom', 'evaluation', 'diagnosis']
|
324 |
has_clinical_data = any(keyword in message.lower() for keyword in clinical_keywords)
|
325 |
|
|
|
355 |
logger.warning("Checker error: %s", wrong_info)
|
356 |
break
|
357 |
|
|
|
358 |
tools = [] if has_clinical_data else picked_tools_prompt
|
359 |
last_outputs = []
|
360 |
last_outputs_str, token_overflow = self.llm_infer(
|
|
|
381 |
m['content'] for m in messages[-3:] if m['role'] == 'assistant'
|
382 |
][:2]
|
383 |
forbidden_ids = [tokenizer.encode(msg, add_special_tokens=False) for msg in assistant_messages]
|
|
|
384 |
unique_sentences = []
|
385 |
for msg in assistant_messages:
|
386 |
sentences = msg.split('. ')
|
|
|
395 |
if is_unique:
|
396 |
unique_sentences.append(s)
|
397 |
forbidden_ids = [tokenizer.encode(s, add_special_tokens=False) for s in unique_sentences]
|
398 |
+
return [NoRepeatSentenceProcessor(forbidden_ids, 15)] # Increased penalty
|
399 |
return None
|
400 |
|
401 |
def llm_infer(self, messages, temperature=0.1, tools=None, output_begin_string=None,
|
|
|
433 |
return output, False
|
434 |
return output
|
435 |
|
436 |
+
def run_quick_summary(self, message: str, temperature: float = 0.1, max_new_tokens: int = 256, max_token: int = 1024):
|
437 |
+
"""Generate a fast, concise summary of potential missed diagnoses without tool calls"""
|
438 |
+
logger.debug("Starting quick summary for message: %s", message[:100])
|
439 |
+
prompt = """
|
440 |
+
Analyze the patient record excerpt for missed diagnoses, focusing ONLY on clinical findings such as symptoms, medications, or evaluation results. Provide a concise summary in ONE paragraph without headings or bullet points. ALWAYS treat medications or psychiatric evaluations as potential missed diagnoses, specifying their implications (e.g., 'use of Seroquel may indicate untreated psychosis'). Recommend urgent review for identified findings. Do NOT use external tools or repeat non-clinical data (e.g., name, date of birth). If no clinical findings are present, state 'No missed diagnoses identified' in ONE sentence.
|
441 |
+
Patient Record Excerpt:
|
442 |
+
{chunk}
|
443 |
+
"""
|
444 |
+
conversation = self.set_system_prompt([], prompt.format(chunk=message))
|
445 |
+
conversation.append({"role": "user", "content": message})
|
446 |
+
output = self.llm_infer(
|
447 |
+
messages=conversation,
|
448 |
+
temperature=temperature,
|
449 |
+
max_new_tokens=max_new_tokens,
|
450 |
+
max_token=max_token,
|
451 |
+
tools=[] # No tools
|
452 |
+
)
|
453 |
+
if '[FinalAnswer]' in output:
|
454 |
+
output = output.split('[FinalAnswer]')[-1].strip()
|
455 |
+
logger.debug("Quick summary output: %s", output[:100])
|
456 |
+
return output
|
457 |
+
|
458 |
def run_self_agent(self, message: str, temperature: float, max_new_tokens: int, max_token: int):
|
459 |
logger.debug("Starting self agent")
|
460 |
conversation = self.set_system_prompt([], self.self_prompt)
|
|
|
585 |
logger.debug("Updated parameters: %s", updated_attributes)
|
586 |
return updated_attributes
|
587 |
|
588 |
+
async def run_background_report(self, message: str, history: list, temperature: float,
|
589 |
+
max_new_tokens: int, max_token: int, call_agent: bool,
|
590 |
+
conversation: gr.State, max_round: int, seed: int,
|
591 |
+
call_agent_level: int, report_path: str):
|
592 |
+
"""Run detailed report generation in the background and save to file"""
|
593 |
+
logger.debug("Starting background report for message: %s", message[:100])
|
594 |
+
combined_response = ""
|
595 |
+
history_copy = history.copy()
|
596 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
597 |
picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
|
598 |
call_agent, call_agent_level, message)
|
599 |
+
conversation = self.initialize_conversation(message, conversation, history_copy)
|
600 |
+
|
|
|
|
|
601 |
next_round = True
|
602 |
current_round = 0
|
603 |
enable_summary = False
|
|
|
612 |
current_round += 1
|
613 |
last_outputs = []
|
614 |
if last_outputs:
|
615 |
+
function_call_messages, picked_tools_prompt, special_tool_call, _ = yield from self.run_function_call_stream(
|
616 |
last_outputs, return_message=True, existing_tools_prompt=picked_tools_prompt,
|
617 |
message_for_call_agent=message, call_agent=call_agent,
|
618 |
+
call_agent_level=call_agent_level, temperature=temperature,
|
619 |
+
return_gradio_history=False)
|
620 |
+
|
621 |
if special_tool_call == 'Finish':
|
|
|
622 |
next_round = False
|
623 |
conversation.extend(function_call_messages)
|
624 |
+
combined_response += function_call_messages[0]['content'] + "\n"
|
625 |
+
break
|
|
|
|
|
|
|
|
|
|
|
|
|
626 |
|
627 |
if (self.enable_summary or token_overflow) and not call_agent:
|
628 |
enable_summary = True
|
|
|
631 |
|
632 |
if function_call_messages:
|
633 |
conversation.extend(function_call_messages)
|
634 |
+
combined_response += tool_result_format(function_call_messages) + "\n"
|
635 |
else:
|
636 |
next_round = False
|
637 |
+
combined_response += ''.join(last_outputs).replace("</s>", "") + "\n"
|
638 |
+
break
|
639 |
|
640 |
if self.enable_checker:
|
641 |
good_status, wrong_info = checker.check_conversation()
|
|
|
643 |
logger.warning("Checker error: %s", wrong_info)
|
644 |
break
|
645 |
|
646 |
+
tools = picked_tools_prompt
|
|
|
647 |
last_outputs_str, token_overflow = self.llm_infer(
|
648 |
messages=conversation, temperature=temperature, tools=tools,
|
649 |
max_new_tokens=max_new_tokens, max_token=max_token, seed=seed, check_token_status=True)
|
|
|
652 |
if self.force_finish:
|
653 |
last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
|
654 |
conversation, temperature, max_new_tokens, max_token)
|
655 |
+
combined_response += last_outputs_str + "\n"
|
656 |
+
break
|
657 |
+
combined_response += "Token limit exceeded.\n"
|
658 |
+
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
659 |
|
660 |
+
combined_response += last_outputs_str + "\n"
|
661 |
last_outputs.append(last_outputs_str)
|
662 |
|
663 |
if next_round and self.force_finish:
|
664 |
last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
|
665 |
conversation, temperature, max_new_tokens, max_token)
|
666 |
+
combined_response += last_outputs_str + "\n"
|
667 |
+
|
668 |
+
# Save report
|
669 |
+
try:
|
670 |
+
with open(report_path, "w", encoding="utf-8") as f:
|
671 |
+
f.write(combined_response)
|
672 |
+
logger.info("Detailed report saved to %s", report_path)
|
673 |
+
except Exception as e:
|
674 |
+
logger.error("Failed to save report: %s", e)
|
675 |
|
676 |
except Exception as e:
|
677 |
+
logger.error("Background report error: %s", e)
|
678 |
+
combined_response += f"Error: {e}\n"
|
679 |
+
with open(report_path, "w", encoding="utf-8") as f:
|
680 |
+
f.write(combined_response)
|
681 |
+
|
682 |
+
finally:
|
683 |
+
torch.cuda.empty_cache()
|
684 |
+
gc.collect()
|
685 |
+
|
686 |
+
def run_gradio_chat(self, message: str, history: list, temperature: float,
|
687 |
+
max_new_tokens: int, max_token: int, call_agent: bool,
|
688 |
+
conversation: gr.State, max_round: int = 3, seed: int = None,
|
689 |
+
call_agent_level: int = 0, sub_agent_task: str = None,
|
690 |
+
uploaded_files: list = None, report_path: str = None):
|
691 |
+
logger.debug("Chat started, message: %s", message[:100])
|
692 |
+
if not message or len(message.strip()) < 5:
|
693 |
+
yield "Please provide a valid message or upload files to analyze."
|
694 |
+
return
|
695 |
+
|
696 |
+
if message.startswith("[\U0001f9f0 Tool_RAG") or message.startswith("⚒️"):
|
697 |
+
return
|
698 |
+
|
699 |
+
clinical_keywords = ['medication', 'symptom', 'evaluation', 'diagnosis']
|
700 |
+
has_clinical_data = any(keyword in message.lower() for keyword in clinical_keywords)
|
701 |
+
call_agent = call_agent and not has_clinical_data
|
702 |
+
|
703 |
+
# Generate quick summary
|
704 |
+
quick_summary = self.run_quick_summary(
|
705 |
+
message, temperature=temperature, max_new_tokens=256, max_token=1024)
|
706 |
+
history.append(ChatMessage(role="assistant", content=f"**Quick Summary:**\n{quick_summary}"))
|
707 |
+
yield history
|
708 |
+
|
709 |
+
# Start background report generation
|
710 |
+
if report_path:
|
711 |
+
loop = asyncio.get_event_loop()
|
712 |
+
threading.Thread(
|
713 |
+
target=lambda: loop.run_until_complete(
|
714 |
+
self.run_background_report(
|
715 |
+
message, history, temperature, max_new_tokens, max_token, call_agent,
|
716 |
+
conversation, max_round, seed, call_agent_level, report_path
|
717 |
+
)
|
718 |
+
),
|
719 |
+
daemon=True
|
720 |
+
).start()
|
721 |
+
history.append(ChatMessage(
|
722 |
+
role="assistant",
|
723 |
+
content="Generating detailed report in the background. Download will be available when ready."
|
724 |
+
))
|
725 |
+
yield history
|