Update src/txagent/txagent.py
Browse files- src/txagent/txagent.py +32 -21
src/txagent/txagent.py
CHANGED
|
@@ -264,7 +264,7 @@ class TxAgent:
|
|
| 264 |
)
|
| 265 |
call_result = self.run_multistep_agent(
|
| 266 |
full_message, temperature=temperature,
|
| 267 |
-
max_new_tokens=1024, max_token=
|
| 268 |
call_agent=False, call_agent_level=call_agent_level)
|
| 269 |
if call_result is None:
|
| 270 |
call_result = "⚠️ No content returned from sub-agent."
|
|
@@ -286,13 +286,13 @@ class TxAgent:
|
|
| 286 |
else:
|
| 287 |
call_results.append({
|
| 288 |
"role": "tool",
|
| 289 |
-
"content": json.dumps({"content": "
|
| 290 |
})
|
| 291 |
|
| 292 |
revised_messages = [{
|
| 293 |
"role": "assistant",
|
| 294 |
-
"content": message.strip(),
|
| 295 |
-
"tool_calls": json.dumps(function_call_json)
|
| 296 |
}] + call_results
|
| 297 |
logger.debug("Function call completed, returning %d messages", len(revised_messages))
|
| 298 |
return revised_messages, existing_tools_prompt, special_tool_call
|
|
@@ -317,11 +317,11 @@ class TxAgent:
|
|
| 317 |
logger.warning("No valid function call JSON extracted")
|
| 318 |
call_results.append({
|
| 319 |
"role": "tool",
|
| 320 |
-
"content": json.dumps({"content": "
|
| 321 |
})
|
| 322 |
if return_gradio_history:
|
| 323 |
-
gradio_history.append(
|
| 324 |
-
yield
|
| 325 |
return
|
| 326 |
|
| 327 |
if isinstance(function_call_json, list):
|
|
@@ -352,9 +352,9 @@ class TxAgent:
|
|
| 352 |
str(solution_plan)
|
| 353 |
)
|
| 354 |
sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
|
| 355 |
-
sub_result = yield from self.
|
| 356 |
full_message, history=[], temperature=temperature,
|
| 357 |
-
max_new_tokens=1024, max_token=
|
| 358 |
call_agent=False, call_agent_level=call_agent_level,
|
| 359 |
conversation=None,
|
| 360 |
sub_agent_task=sub_agent_task)
|
|
@@ -375,12 +375,12 @@ class TxAgent:
|
|
| 375 |
|
| 376 |
if return_gradio_history and function_call_json[i]["name"] != 'Finish':
|
| 377 |
metadata = {"title": f"⚒️ {function_call_json[i]['name']}", "log": str(function_call_json[i]['arguments'])}
|
| 378 |
-
gradio_history.append(
|
| 379 |
|
| 380 |
revised_messages = [{
|
| 381 |
"role": "assistant",
|
| 382 |
-
"content": message.strip(),
|
| 383 |
-
"tool_calls": json.dumps(function_call_json)
|
| 384 |
}] + call_results
|
| 385 |
|
| 386 |
if return_gradio_history:
|
|
@@ -523,7 +523,7 @@ class TxAgent:
|
|
| 523 |
|
| 524 |
def llm_infer(self, messages, temperature=0.1, tools=None,
|
| 525 |
output_begin_string=None, max_new_tokens=2048,
|
| 526 |
-
max_token=
|
| 527 |
model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
|
| 528 |
|
| 529 |
logger.debug("Running LLM inference with %d messages", len(messages))
|
|
@@ -565,6 +565,16 @@ class TxAgent:
|
|
| 565 |
sampling_params=sampling_params,
|
| 566 |
)
|
| 567 |
output = output[0].outputs[0].text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 568 |
logger.debug("LLM output: %s", output[:50])
|
| 569 |
if check_token_status and max_token is not None:
|
| 570 |
return output, token_overflow
|
|
@@ -719,7 +729,7 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
| 719 |
function_response=function_response,
|
| 720 |
temperature=0.1,
|
| 721 |
max_new_tokens=1024,
|
| 722 |
-
max_token=
|
| 723 |
)
|
| 724 |
|
| 725 |
input_list.insert(
|
|
@@ -748,7 +758,7 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
| 748 |
function_response=function_response,
|
| 749 |
temperature=0.1,
|
| 750 |
max_new_tokens=1024,
|
| 751 |
-
max_token=
|
| 752 |
)
|
| 753 |
|
| 754 |
tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
|
|
@@ -842,7 +852,7 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
| 842 |
|
| 843 |
if last_outputs:
|
| 844 |
function_call_result = yield from self.run_function_call_stream(
|
| 845 |
-
last_outputs, return_message=True,
|
| 846 |
existing_tools_prompt=picked_tools_prompt,
|
| 847 |
message_for_call_agent=message,
|
| 848 |
call_agent=call_agent,
|
|
@@ -851,17 +861,18 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
| 851 |
|
| 852 |
if not function_call_result:
|
| 853 |
logger.warning("Empty result from run_function_call_stream")
|
| 854 |
-
history.append({"role": "assistant", "content": "Error:
|
| 855 |
yield history
|
| 856 |
-
|
|
|
|
| 857 |
|
| 858 |
function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = function_call_result
|
| 859 |
|
| 860 |
-
# Convert
|
| 861 |
unique_history = []
|
| 862 |
seen_contents = set()
|
| 863 |
for msg in current_gradio_history:
|
| 864 |
-
content = msg.content
|
| 865 |
if content not in seen_contents:
|
| 866 |
unique_history.append({"role": "assistant", "content": content})
|
| 867 |
seen_contents.add(content)
|
|
@@ -915,7 +926,7 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
| 915 |
tools=picked_tools_prompt,
|
| 916 |
skip_special_tokens=False,
|
| 917 |
max_new_tokens=max_new_tokens,
|
| 918 |
-
max_token=
|
| 919 |
seed=seed,
|
| 920 |
check_token_status=True)
|
| 921 |
|
|
|
|
| 264 |
)
|
| 265 |
call_result = self.run_multistep_agent(
|
| 266 |
full_message, temperature=temperature,
|
| 267 |
+
max_new_tokens=1024, max_token=8192,
|
| 268 |
call_agent=False, call_agent_level=call_agent_level)
|
| 269 |
if call_result is None:
|
| 270 |
call_result = "⚠️ No content returned from sub-agent."
|
|
|
|
| 286 |
else:
|
| 287 |
call_results.append({
|
| 288 |
"role": "tool",
|
| 289 |
+
"content": json.dumps({"content": "No valid tool call detected; proceeding with analysis."})
|
| 290 |
})
|
| 291 |
|
| 292 |
revised_messages = [{
|
| 293 |
"role": "assistant",
|
| 294 |
+
"content": message.strip() if message else "Processing...",
|
| 295 |
+
"tool_calls": json.dumps(function_call_json) if function_call_json else None
|
| 296 |
}] + call_results
|
| 297 |
logger.debug("Function call completed, returning %d messages", len(revised_messages))
|
| 298 |
return revised_messages, existing_tools_prompt, special_tool_call
|
|
|
|
| 317 |
logger.warning("No valid function call JSON extracted")
|
| 318 |
call_results.append({
|
| 319 |
"role": "tool",
|
| 320 |
+
"content": json.dumps({"content": "No tool call detected; continuing analysis."})
|
| 321 |
})
|
| 322 |
if return_gradio_history:
|
| 323 |
+
gradio_history.append({"role": "assistant", "content": "No specific tool call identified. Proceeding with medical record analysis."})
|
| 324 |
+
yield [{"role": "assistant", "content": "Processing..."}], existing_tools_prompt or [], special_tool_call, gradio_history
|
| 325 |
return
|
| 326 |
|
| 327 |
if isinstance(function_call_json, list):
|
|
|
|
| 352 |
str(solution_plan)
|
| 353 |
)
|
| 354 |
sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
|
| 355 |
+
sub_result = yield from self.run EXEMPLARgradio_chat(
|
| 356 |
full_message, history=[], temperature=temperature,
|
| 357 |
+
max_new_tokens=1024, max_token=8192,
|
| 358 |
call_agent=False, call_agent_level=call_agent_level,
|
| 359 |
conversation=None,
|
| 360 |
sub_agent_task=sub_agent_task)
|
|
|
|
| 375 |
|
| 376 |
if return_gradio_history and function_call_json[i]["name"] != 'Finish':
|
| 377 |
metadata = {"title": f"⚒️ {function_call_json[i]['name']}", "log": str(function_call_json[i]['arguments'])}
|
| 378 |
+
gradio_history.append({"role": "assistant", "content": str(call_result), "metadata": metadata})
|
| 379 |
|
| 380 |
revised_messages = [{
|
| 381 |
"role": "assistant",
|
| 382 |
+
"content": message.strip() if message else "Processing...",
|
| 383 |
+
"tool_calls": json.dumps(function_call_json) if function_call_json else None
|
| 384 |
}] + call_results
|
| 385 |
|
| 386 |
if return_gradio_history:
|
|
|
|
| 523 |
|
| 524 |
def llm_infer(self, messages, temperature=0.1, tools=None,
|
| 525 |
output_begin_string=None, max_new_tokens=2048,
|
| 526 |
+
max_token=8192, skip_special_tokens=True,
|
| 527 |
model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
|
| 528 |
|
| 529 |
logger.debug("Running LLM inference with %d messages", len(messages))
|
|
|
|
| 565 |
sampling_params=sampling_params,
|
| 566 |
)
|
| 567 |
output = output[0].outputs[0].text
|
| 568 |
+
# Deduplicate repetitive output
|
| 569 |
+
if output:
|
| 570 |
+
lines = output.split('\n')
|
| 571 |
+
seen = set()
|
| 572 |
+
deduped_lines = []
|
| 573 |
+
for line in lines:
|
| 574 |
+
if line.strip() and line not in seen:
|
| 575 |
+
seen.add(line)
|
| 576 |
+
deduped_lines.append(line)
|
| 577 |
+
output = '\n'.join(deduped_lines)
|
| 578 |
logger.debug("LLM output: %s", output[:50])
|
| 579 |
if check_token_status and max_token is not None:
|
| 580 |
return output, token_overflow
|
|
|
|
| 729 |
function_response=function_response,
|
| 730 |
temperature=0.1,
|
| 731 |
max_new_tokens=1024,
|
| 732 |
+
max_token=8192
|
| 733 |
)
|
| 734 |
|
| 735 |
input_list.insert(
|
|
|
|
| 758 |
function_response=function_response,
|
| 759 |
temperature=0.1,
|
| 760 |
max_new_tokens=1024,
|
| 761 |
+
max_token=8192
|
| 762 |
)
|
| 763 |
|
| 764 |
tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
|
|
|
|
| 852 |
|
| 853 |
if last_outputs:
|
| 854 |
function_call_result = yield from self.run_function_call_stream(
|
| 855 |
+
last_outputs[0], return_message=True,
|
| 856 |
existing_tools_prompt=picked_tools_prompt,
|
| 857 |
message_for_call_agent=message,
|
| 858 |
call_agent=call_agent,
|
|
|
|
| 861 |
|
| 862 |
if not function_call_result:
|
| 863 |
logger.warning("Empty result from run_function_call_stream")
|
| 864 |
+
history.append({"role": "assistant", "content": "Error: Unable to process tool response. Continuing analysis."})
|
| 865 |
yield history
|
| 866 |
+
last_outputs = []
|
| 867 |
+
continue
|
| 868 |
|
| 869 |
function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = function_call_result
|
| 870 |
|
| 871 |
+
# Convert history to dicts and deduplicate
|
| 872 |
unique_history = []
|
| 873 |
seen_contents = set()
|
| 874 |
for msg in current_gradio_history:
|
| 875 |
+
content = msg["content"] if isinstance(msg, dict) else msg.content
|
| 876 |
if content not in seen_contents:
|
| 877 |
unique_history.append({"role": "assistant", "content": content})
|
| 878 |
seen_contents.add(content)
|
|
|
|
| 926 |
tools=picked_tools_prompt,
|
| 927 |
skip_special_tokens=False,
|
| 928 |
max_new_tokens=max_new_tokens,
|
| 929 |
+
max_token=8192,
|
| 930 |
seed=seed,
|
| 931 |
check_token_status=True)
|
| 932 |
|