Update src/txagent/txagent.py
Browse files- src/txagent/txagent.py +39 -75
src/txagent/txagent.py
CHANGED
@@ -12,18 +12,17 @@ from tooluniverse import ToolUniverse
|
|
12 |
from gradio import ChatMessage
|
13 |
from .toolrag import ToolRAGModel
|
14 |
import torch
|
15 |
-
# near the top of txagent.py
|
16 |
import logging
|
|
|
17 |
logger = logging.getLogger(__name__)
|
18 |
-
logging.basicConfig(level=logging.
|
19 |
|
20 |
from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
|
21 |
|
22 |
-
|
23 |
class TxAgent:
|
24 |
def __init__(self, model_name,
|
25 |
rag_model_name,
|
26 |
-
tool_files_dict=None,
|
27 |
enable_finish=True,
|
28 |
enable_rag=True,
|
29 |
enable_summary=False,
|
@@ -47,10 +46,11 @@ class TxAgent:
|
|
47 |
self.model = None
|
48 |
self.rag_model = ToolRAGModel(rag_model_name)
|
49 |
self.tooluniverse = None
|
50 |
-
|
51 |
-
|
|
|
52 |
self.self_prompt = "Strictly follow the instruction."
|
53 |
-
self.chat_prompt = "You are helpful assistant to chat with the user."
|
54 |
self.enable_finish = enable_finish
|
55 |
self.enable_rag = enable_rag
|
56 |
self.enable_summary = enable_summary
|
@@ -145,7 +145,7 @@ class TxAgent:
|
|
145 |
existing_tools_prompt=[],
|
146 |
rag_num=5,
|
147 |
return_call_result=False):
|
148 |
-
extra_factor = 30
|
149 |
if picked_tool_names is None:
|
150 |
assert picked_tool_names is not None or message is not None
|
151 |
picked_tool_names = self.rag_infer(
|
@@ -270,7 +270,6 @@ class TxAgent:
|
|
270 |
"tool_calls": json.dumps(function_call_json)
|
271 |
}] + call_results
|
272 |
|
273 |
-
# Yield the final result.
|
274 |
return revised_messages, existing_tools_prompt, special_tool_call
|
275 |
|
276 |
def run_function_call_stream(self, fcall_str,
|
@@ -364,11 +363,10 @@ class TxAgent:
|
|
364 |
else:
|
365 |
return revised_messages, existing_tools_prompt, special_tool_call
|
366 |
|
367 |
-
|
368 |
def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
|
369 |
-
if conversation[-1]['role'] == '
|
370 |
conversation.append(
|
371 |
-
{'role': 'tool', 'content': 'Errors
|
372 |
finish_tools_prompt = self.add_finish_tools([])
|
373 |
|
374 |
last_outputs_str = self.llm_infer(messages=conversation,
|
@@ -387,15 +385,6 @@ class TxAgent:
|
|
387 |
max_round: int = 20,
|
388 |
call_agent=False,
|
389 |
call_agent_level=0) -> str:
|
390 |
-
"""
|
391 |
-
Generate a streaming response using the llama3-8b model.
|
392 |
-
Args:
|
393 |
-
message (str): The input message.
|
394 |
-
temperature (float): The temperature for generating the response.
|
395 |
-
max_new_tokens (int): The maximum number of new tokens to generate.
|
396 |
-
Returns:
|
397 |
-
str: The generated response.
|
398 |
-
"""
|
399 |
print("\033[1;32;40mstart\033[0m")
|
400 |
picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
|
401 |
call_agent, call_agent_level, message)
|
@@ -454,7 +443,6 @@ class TxAgent:
|
|
454 |
if self.enable_checker:
|
455 |
good_status, wrong_info = checker.check_conversation()
|
456 |
if not good_status:
|
457 |
-
next_round = False
|
458 |
print(
|
459 |
"Internal error in reasoning: " + wrong_info)
|
460 |
break
|
@@ -489,7 +477,6 @@ class TxAgent:
|
|
489 |
return None
|
490 |
|
491 |
def build_logits_processor(self, messages, llm):
|
492 |
-
# Use the tokenizer from the LLM instance.
|
493 |
tokenizer = llm.get_tokenizer()
|
494 |
if self.avoid_repeat and len(messages) > 2:
|
495 |
assistant_messages = []
|
@@ -516,7 +503,6 @@ class TxAgent:
|
|
516 |
sampling_params = SamplingParams(
|
517 |
temperature=temperature,
|
518 |
max_tokens=max_new_tokens,
|
519 |
-
|
520 |
seed=seed if seed is not None else self.seed,
|
521 |
)
|
522 |
|
@@ -527,18 +513,23 @@ class TxAgent:
|
|
527 |
|
528 |
if check_token_status and max_token is not None:
|
529 |
token_overflow = False
|
530 |
-
|
531 |
-
|
532 |
-
if
|
533 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
534 |
torch.cuda.empty_cache()
|
535 |
gc.collect()
|
536 |
-
print("Number of input tokens before inference:",
|
537 |
-
num_input_tokens)
|
538 |
-
logger.info(
|
539 |
-
"The number of tokens exceeds the maximum limit!!!!")
|
540 |
-
token_overflow = True
|
541 |
return None, token_overflow
|
|
|
542 |
output = model.generate(
|
543 |
prompt,
|
544 |
sampling_params=sampling_params,
|
@@ -641,17 +632,6 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
641 |
return output
|
642 |
|
643 |
def function_result_summary(self, input_list, status, enable_summary):
|
644 |
-
"""
|
645 |
-
Processes the input list, extracting information from sequences of 'user', 'tool', 'assistant' roles.
|
646 |
-
Supports 'length' and 'step' modes, and skips the last 'k' groups.
|
647 |
-
Parameters:
|
648 |
-
input_list (list): A list of dictionaries containing role and other information.
|
649 |
-
summary_skip_last_k (int): Number of groups to skip from the end. Defaults to 0.
|
650 |
-
summary_context_length (int): The context length threshold for the 'length' mode.
|
651 |
-
last_processed_index (tuple or int): The last processed index.
|
652 |
-
Returns:
|
653 |
-
list: A list of extracted information from valid sequences.
|
654 |
-
"""
|
655 |
if 'tool_call_step' not in status:
|
656 |
status['tool_call_step'] = 0
|
657 |
|
@@ -748,15 +728,11 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
748 |
|
749 |
return status
|
750 |
|
751 |
-
# Following are Gradio related functions
|
752 |
-
|
753 |
-
# General update method that accepts any new arguments through kwargs
|
754 |
def update_parameters(self, **kwargs):
|
755 |
for key, value in kwargs.items():
|
756 |
if hasattr(self, key):
|
757 |
setattr(self, key, value)
|
758 |
|
759 |
-
# Return the updated attributes
|
760 |
updated_attributes = {key: value for key,
|
761 |
value in kwargs.items() if hasattr(self, key)}
|
762 |
return updated_attributes
|
@@ -795,7 +771,6 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
795 |
return ""
|
796 |
|
797 |
outputs = []
|
798 |
-
outputs_str = ''
|
799 |
last_outputs = []
|
800 |
|
801 |
picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
|
@@ -867,7 +842,7 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
867 |
if self.enable_checker:
|
868 |
good_status, wrong_info = checker.check_conversation()
|
869 |
if not good_status:
|
870 |
-
|
871 |
break
|
872 |
|
873 |
last_outputs = []
|
@@ -884,18 +859,11 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
884 |
logger.debug(f"llm_infer output: {last_outputs_str[:100] if last_outputs_str else None}, token_overflow: {token_overflow}")
|
885 |
|
886 |
if last_outputs_str is None:
|
887 |
-
logger.warning("llm_infer returned None due to token overflow")
|
888 |
-
|
889 |
-
|
890 |
-
|
891 |
-
|
892 |
-
yield history
|
893 |
-
return last_outputs_str
|
894 |
-
else:
|
895 |
-
error_msg = "Token limit exceeded. Please reduce input size or increase max_token."
|
896 |
-
history.append(ChatMessage(role="assistant", content=error_msg))
|
897 |
-
yield history
|
898 |
-
return error_msg
|
899 |
|
900 |
last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
|
901 |
|
@@ -905,14 +873,12 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
905 |
|
906 |
if '[FinalAnswer]' in last_thought:
|
907 |
parts = last_thought.split('[FinalAnswer]', 1)
|
908 |
-
if len(parts) == 2
|
909 |
-
final_thought, final_answer = parts
|
910 |
-
else:
|
911 |
-
final_thought, final_answer = last_thought, ""
|
912 |
history.append(ChatMessage(role="assistant", content=final_thought.strip()))
|
913 |
yield history
|
914 |
history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
|
915 |
yield history
|
|
|
916 |
else:
|
917 |
history.append(ChatMessage(role="assistant", content=last_thought))
|
918 |
yield history
|
@@ -920,15 +886,13 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
920 |
last_outputs.append(last_outputs_str)
|
921 |
|
922 |
if next_round:
|
|
|
923 |
if self.force_finish:
|
924 |
last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
|
925 |
conversation, temperature, max_new_tokens, max_token)
|
926 |
if '[FinalAnswer]' in last_outputs_str:
|
927 |
parts = last_outputs_str.split('[FinalAnswer]', 1)
|
928 |
-
if len(parts) == 2
|
929 |
-
final_thought, final_answer = parts
|
930 |
-
else:
|
931 |
-
final_thought, final_answer = last_outputs_str, ""
|
932 |
history.append(ChatMessage(role="assistant", content=final_thought.strip()))
|
933 |
yield history
|
934 |
history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
|
@@ -937,7 +901,10 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
937 |
history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
|
938 |
yield history
|
939 |
else:
|
940 |
-
|
|
|
|
|
|
|
941 |
|
942 |
except Exception as e:
|
943 |
logger.error(f"Exception in run_gradio_chat: {e}", exc_info=True)
|
@@ -949,10 +916,7 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
949 |
conversation, temperature, max_new_tokens, max_token)
|
950 |
if '[FinalAnswer]' in last_outputs_str:
|
951 |
parts = last_outputs_str.split('[FinalAnswer]', 1)
|
952 |
-
if len(parts) == 2
|
953 |
-
final_thought, final_answer = parts
|
954 |
-
else:
|
955 |
-
final_thought, final_answer = last_outputs_str, ""
|
956 |
history.append(ChatMessage(role="assistant", content=final_thought.strip()))
|
957 |
yield history
|
958 |
history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
|
|
|
12 |
from gradio import ChatMessage
|
13 |
from .toolrag import ToolRAGModel
|
14 |
import torch
|
|
|
15 |
import logging
|
16 |
+
|
17 |
logger = logging.getLogger(__name__)
|
18 |
+
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
19 |
|
20 |
from .utils import NoRepeatSentenceProcessor, ReasoningTraceChecker, tool_result_format
|
21 |
|
|
|
22 |
class TxAgent:
|
23 |
def __init__(self, model_name,
|
24 |
rag_model_name,
|
25 |
+
tool_files_dict=None,
|
26 |
enable_finish=True,
|
27 |
enable_rag=True,
|
28 |
enable_summary=False,
|
|
|
46 |
self.model = None
|
47 |
self.rag_model = ToolRAGModel(rag_model_name)
|
48 |
self.tooluniverse = None
|
49 |
+
self.prompt_multi_step = ("You are a helpful assistant that solves problems through detailed, step-by-step reasoning "
|
50 |
+
"and actions based on your reasoning. Provide comprehensive and clinically precise responses, "
|
51 |
+
"including specific diagnoses, tools, and actionable recommendations when analyzing medical data.")
|
52 |
self.self_prompt = "Strictly follow the instruction."
|
53 |
+
self.chat_prompt = "You are a helpful assistant to chat with the user."
|
54 |
self.enable_finish = enable_finish
|
55 |
self.enable_rag = enable_rag
|
56 |
self.enable_summary = enable_summary
|
|
|
145 |
existing_tools_prompt=[],
|
146 |
rag_num=5,
|
147 |
return_call_result=False):
|
148 |
+
extra_factor = 30
|
149 |
if picked_tool_names is None:
|
150 |
assert picked_tool_names is not None or message is not None
|
151 |
picked_tool_names = self.rag_infer(
|
|
|
270 |
"tool_calls": json.dumps(function_call_json)
|
271 |
}] + call_results
|
272 |
|
|
|
273 |
return revised_messages, existing_tools_prompt, special_tool_call
|
274 |
|
275 |
def run_function_call_stream(self, fcall_str,
|
|
|
363 |
else:
|
364 |
return revised_messages, existing_tools_prompt, special_tool_call
|
365 |
|
|
|
366 |
def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
|
367 |
+
if conversation[-1]['role'] == 'assistant':
|
368 |
conversation.append(
|
369 |
+
{'role': 'tool', 'content': 'Errors happened during the function call, please come up with the final answer with the current information.'})
|
370 |
finish_tools_prompt = self.add_finish_tools([])
|
371 |
|
372 |
last_outputs_str = self.llm_infer(messages=conversation,
|
|
|
385 |
max_round: int = 20,
|
386 |
call_agent=False,
|
387 |
call_agent_level=0) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
print("\033[1;32;40mstart\033[0m")
|
389 |
picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
|
390 |
call_agent, call_agent_level, message)
|
|
|
443 |
if self.enable_checker:
|
444 |
good_status, wrong_info = checker.check_conversation()
|
445 |
if not good_status:
|
|
|
446 |
print(
|
447 |
"Internal error in reasoning: " + wrong_info)
|
448 |
break
|
|
|
477 |
return None
|
478 |
|
479 |
def build_logits_processor(self, messages, llm):
|
|
|
480 |
tokenizer = llm.get_tokenizer()
|
481 |
if self.avoid_repeat and len(messages) > 2:
|
482 |
assistant_messages = []
|
|
|
503 |
sampling_params = SamplingParams(
|
504 |
temperature=temperature,
|
505 |
max_tokens=max_new_tokens,
|
|
|
506 |
seed=seed if seed is not None else self.seed,
|
507 |
)
|
508 |
|
|
|
513 |
|
514 |
if check_token_status and max_token is not None:
|
515 |
token_overflow = False
|
516 |
+
input_tokens = self.tokenizer.encode(prompt, return_tensors="pt")[0]
|
517 |
+
num_input_tokens = len(input_tokens)
|
518 |
+
if num_input_tokens > max_token:
|
519 |
+
logger.info(f"Number of input tokens before inference: {num_input_tokens}")
|
520 |
+
logger.info("The number of tokens exceeds the maximum limit!!!!")
|
521 |
+
max_prompt_tokens = max_token - max_new_tokens - 100
|
522 |
+
if max_prompt_tokens > 0:
|
523 |
+
truncated_input = self.tokenizer.decode(input_tokens[:max_prompt_tokens])
|
524 |
+
prompt = truncated_input
|
525 |
+
logger.info(f"Prompt truncated to {len(self.tokenizer.encode(prompt, return_tensors='pt')[0])} tokens")
|
526 |
+
token_overflow = True
|
527 |
+
else:
|
528 |
+
logger.warning("Max prompt tokens too small, cannot truncate effectively")
|
529 |
torch.cuda.empty_cache()
|
530 |
gc.collect()
|
|
|
|
|
|
|
|
|
|
|
531 |
return None, token_overflow
|
532 |
+
|
533 |
output = model.generate(
|
534 |
prompt,
|
535 |
sampling_params=sampling_params,
|
|
|
632 |
return output
|
633 |
|
634 |
def function_result_summary(self, input_list, status, enable_summary):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
635 |
if 'tool_call_step' not in status:
|
636 |
status['tool_call_step'] = 0
|
637 |
|
|
|
728 |
|
729 |
return status
|
730 |
|
|
|
|
|
|
|
731 |
def update_parameters(self, **kwargs):
|
732 |
for key, value in kwargs.items():
|
733 |
if hasattr(self, key):
|
734 |
setattr(self, key, value)
|
735 |
|
|
|
736 |
updated_attributes = {key: value for key,
|
737 |
value in kwargs.items() if hasattr(self, key)}
|
738 |
return updated_attributes
|
|
|
771 |
return ""
|
772 |
|
773 |
outputs = []
|
|
|
774 |
last_outputs = []
|
775 |
|
776 |
picked_tools_prompt, call_agent_level = self.initialize_tools_prompt(
|
|
|
842 |
if self.enable_checker:
|
843 |
good_status, wrong_info = checker.check_conversation()
|
844 |
if not good_status:
|
845 |
+
logger.warning(f"Checker flagged reasoning error: {wrong_info}")
|
846 |
break
|
847 |
|
848 |
last_outputs = []
|
|
|
859 |
logger.debug(f"llm_infer output: {last_outputs_str[:100] if last_outputs_str else None}, token_overflow: {token_overflow}")
|
860 |
|
861 |
if last_outputs_str is None:
|
862 |
+
logger.warning("llm_infer returned None, likely due to token overflow")
|
863 |
+
error_msg = "Error: Unable to generate response due to token limit. Please reduce input size."
|
864 |
+
history.append(ChatMessage(role="assistant", content=error_msg))
|
865 |
+
yield history
|
866 |
+
return error_msg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
867 |
|
868 |
last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
|
869 |
|
|
|
873 |
|
874 |
if '[FinalAnswer]' in last_thought:
|
875 |
parts = last_thought.split('[FinalAnswer]', 1)
|
876 |
+
final_thought, final_answer = parts if len(parts) == 2 else (last_thought, "")
|
|
|
|
|
|
|
877 |
history.append(ChatMessage(role="assistant", content=final_thought.strip()))
|
878 |
yield history
|
879 |
history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
|
880 |
yield history
|
881 |
+
next_round = False
|
882 |
else:
|
883 |
history.append(ChatMessage(role="assistant", content=last_thought))
|
884 |
yield history
|
|
|
886 |
last_outputs.append(last_outputs_str)
|
887 |
|
888 |
if next_round:
|
889 |
+
logger.info("Max rounds reached, forcing finish")
|
890 |
if self.force_finish:
|
891 |
last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
|
892 |
conversation, temperature, max_new_tokens, max_token)
|
893 |
if '[FinalAnswer]' in last_outputs_str:
|
894 |
parts = last_outputs_str.split('[FinalAnswer]', 1)
|
895 |
+
final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
|
|
|
|
|
|
|
896 |
history.append(ChatMessage(role="assistant", content=final_thought.strip()))
|
897 |
yield history
|
898 |
history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
|
|
|
901 |
history.append(ChatMessage(role="assistant", content=last_outputs_str.strip()))
|
902 |
yield history
|
903 |
else:
|
904 |
+
error_msg = "The number of reasoning rounds exceeded the limit."
|
905 |
+
history.append(ChatMessage(role="assistant", content=error_msg))
|
906 |
+
yield history
|
907 |
+
return error_msg
|
908 |
|
909 |
except Exception as e:
|
910 |
logger.error(f"Exception in run_gradio_chat: {e}", exc_info=True)
|
|
|
916 |
conversation, temperature, max_new_tokens, max_token)
|
917 |
if '[FinalAnswer]' in last_outputs_str:
|
918 |
parts = last_outputs_str.split('[FinalAnswer]', 1)
|
919 |
+
final_thought, final_answer = parts if len(parts) == 2 else (last_outputs_str, "")
|
|
|
|
|
|
|
920 |
history.append(ChatMessage(role="assistant", content=final_thought.strip()))
|
921 |
yield history
|
922 |
history.append(ChatMessage(role="assistant", content="**🧠 Final Analysis:**\n" + final_answer.strip()))
|