Update src/txagent/txagent.py
Browse files- src/txagent/txagent.py +32 -21
src/txagent/txagent.py
CHANGED
@@ -264,7 +264,7 @@ class TxAgent:
|
|
264 |
)
|
265 |
call_result = self.run_multistep_agent(
|
266 |
full_message, temperature=temperature,
|
267 |
-
max_new_tokens=1024, max_token=
|
268 |
call_agent=False, call_agent_level=call_agent_level)
|
269 |
if call_result is None:
|
270 |
call_result = "⚠️ No content returned from sub-agent."
|
@@ -286,13 +286,13 @@ class TxAgent:
|
|
286 |
else:
|
287 |
call_results.append({
|
288 |
"role": "tool",
|
289 |
-
"content": json.dumps({"content": "
|
290 |
})
|
291 |
|
292 |
revised_messages = [{
|
293 |
"role": "assistant",
|
294 |
-
"content": message.strip(),
|
295 |
-
"tool_calls": json.dumps(function_call_json)
|
296 |
}] + call_results
|
297 |
logger.debug("Function call completed, returning %d messages", len(revised_messages))
|
298 |
return revised_messages, existing_tools_prompt, special_tool_call
|
@@ -317,11 +317,11 @@ class TxAgent:
|
|
317 |
logger.warning("No valid function call JSON extracted")
|
318 |
call_results.append({
|
319 |
"role": "tool",
|
320 |
-
"content": json.dumps({"content": "
|
321 |
})
|
322 |
if return_gradio_history:
|
323 |
-
gradio_history.append(
|
324 |
-
yield
|
325 |
return
|
326 |
|
327 |
if isinstance(function_call_json, list):
|
@@ -352,9 +352,9 @@ class TxAgent:
|
|
352 |
str(solution_plan)
|
353 |
)
|
354 |
sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
|
355 |
-
sub_result = yield from self.
|
356 |
full_message, history=[], temperature=temperature,
|
357 |
-
max_new_tokens=1024, max_token=
|
358 |
call_agent=False, call_agent_level=call_agent_level,
|
359 |
conversation=None,
|
360 |
sub_agent_task=sub_agent_task)
|
@@ -375,12 +375,12 @@ class TxAgent:
|
|
375 |
|
376 |
if return_gradio_history and function_call_json[i]["name"] != 'Finish':
|
377 |
metadata = {"title": f"⚒️ {function_call_json[i]['name']}", "log": str(function_call_json[i]['arguments'])}
|
378 |
-
gradio_history.append(
|
379 |
|
380 |
revised_messages = [{
|
381 |
"role": "assistant",
|
382 |
-
"content": message.strip(),
|
383 |
-
"tool_calls": json.dumps(function_call_json)
|
384 |
}] + call_results
|
385 |
|
386 |
if return_gradio_history:
|
@@ -523,7 +523,7 @@ class TxAgent:
|
|
523 |
|
524 |
def llm_infer(self, messages, temperature=0.1, tools=None,
|
525 |
output_begin_string=None, max_new_tokens=2048,
|
526 |
-
max_token=
|
527 |
model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
|
528 |
|
529 |
logger.debug("Running LLM inference with %d messages", len(messages))
|
@@ -565,6 +565,16 @@ class TxAgent:
|
|
565 |
sampling_params=sampling_params,
|
566 |
)
|
567 |
output = output[0].outputs[0].text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
568 |
logger.debug("LLM output: %s", output[:50])
|
569 |
if check_token_status and max_token is not None:
|
570 |
return output, token_overflow
|
@@ -719,7 +729,7 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
719 |
function_response=function_response,
|
720 |
temperature=0.1,
|
721 |
max_new_tokens=1024,
|
722 |
-
max_token=
|
723 |
)
|
724 |
|
725 |
input_list.insert(
|
@@ -748,7 +758,7 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
748 |
function_response=function_response,
|
749 |
temperature=0.1,
|
750 |
max_new_tokens=1024,
|
751 |
-
max_token=
|
752 |
)
|
753 |
|
754 |
tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
|
@@ -842,7 +852,7 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
842 |
|
843 |
if last_outputs:
|
844 |
function_call_result = yield from self.run_function_call_stream(
|
845 |
-
last_outputs, return_message=True,
|
846 |
existing_tools_prompt=picked_tools_prompt,
|
847 |
message_for_call_agent=message,
|
848 |
call_agent=call_agent,
|
@@ -851,17 +861,18 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
851 |
|
852 |
if not function_call_result:
|
853 |
logger.warning("Empty result from run_function_call_stream")
|
854 |
-
history.append({"role": "assistant", "content": "Error:
|
855 |
yield history
|
856 |
-
|
|
|
857 |
|
858 |
function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = function_call_result
|
859 |
|
860 |
-
# Convert
|
861 |
unique_history = []
|
862 |
seen_contents = set()
|
863 |
for msg in current_gradio_history:
|
864 |
-
content = msg.content
|
865 |
if content not in seen_contents:
|
866 |
unique_history.append({"role": "assistant", "content": content})
|
867 |
seen_contents.add(content)
|
@@ -915,7 +926,7 @@ Generate **one summarized sentence** about "function calls' responses" with nece
|
|
915 |
tools=picked_tools_prompt,
|
916 |
skip_special_tokens=False,
|
917 |
max_new_tokens=max_new_tokens,
|
918 |
-
max_token=
|
919 |
seed=seed,
|
920 |
check_token_status=True)
|
921 |
|
|
|
264 |
)
|
265 |
call_result = self.run_multistep_agent(
|
266 |
full_message, temperature=temperature,
|
267 |
+
max_new_tokens=1024, max_token=8192,
|
268 |
call_agent=False, call_agent_level=call_agent_level)
|
269 |
if call_result is None:
|
270 |
call_result = "⚠️ No content returned from sub-agent."
|
|
|
286 |
else:
|
287 |
call_results.append({
|
288 |
"role": "tool",
|
289 |
+
"content": json.dumps({"content": "No valid tool call detected; proceeding with analysis."})
|
290 |
})
|
291 |
|
292 |
revised_messages = [{
|
293 |
"role": "assistant",
|
294 |
+
"content": message.strip() if message else "Processing...",
|
295 |
+
"tool_calls": json.dumps(function_call_json) if function_call_json else None
|
296 |
}] + call_results
|
297 |
logger.debug("Function call completed, returning %d messages", len(revised_messages))
|
298 |
return revised_messages, existing_tools_prompt, special_tool_call
|
|
|
317 |
logger.warning("No valid function call JSON extracted")
|
318 |
call_results.append({
|
319 |
"role": "tool",
|
320 |
+
"content": json.dumps({"content": "No tool call detected; continuing analysis."})
|
321 |
})
|
322 |
if return_gradio_history:
|
323 |
+
gradio_history.append({"role": "assistant", "content": "No specific tool call identified. Proceeding with medical record analysis."})
|
324 |
+
yield [{"role": "assistant", "content": "Processing..."}], existing_tools_prompt or [], special_tool_call, gradio_history
|
325 |
return
|
326 |
|
327 |
if isinstance(function_call_json, list):
|
|
|
352 |
str(solution_plan)
|
353 |
)
|
354 |
sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
|
355 |
+
sub_result = yield from self.run EXEMPLARgradio_chat(
|
356 |
full_message, history=[], temperature=temperature,
|
357 |
+
max_new_tokens=1024, max_token=8192,
|
358 |
call_agent=False, call_agent_level=call_agent_level,
|
359 |
conversation=None,
|
360 |
sub_agent_task=sub_agent_task)
|
|
|
375 |
|
376 |
if return_gradio_history and function_call_json[i]["name"] != 'Finish':
|
377 |
metadata = {"title": f"⚒️ {function_call_json[i]['name']}", "log": str(function_call_json[i]['arguments'])}
|
378 |
+
gradio_history.append({"role": "assistant", "content": str(call_result), "metadata": metadata})
|
379 |
|
380 |
revised_messages = [{
|
381 |
"role": "assistant",
|
382 |
+
"content": message.strip() if message else "Processing...",
|
383 |
+
"tool_calls": json.dumps(function_call_json) if function_call_json else None
|
384 |
}] + call_results
|
385 |
|
386 |
if return_gradio_history:
|
|
|
523 |
|
524 |
def llm_infer(self, messages, temperature=0.1, tools=None,
|
525 |
output_begin_string=None, max_new_tokens=2048,
|
526 |
+
max_token=8192, skip_special_tokens=True,
|
527 |
model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
|
528 |
|
529 |
logger.debug("Running LLM inference with %d messages", len(messages))
|
|
|
565 |
sampling_params=sampling_params,
|
566 |
)
|
567 |
output = output[0].outputs[0].text
|
568 |
+
# Deduplicate repetitive output
|
569 |
+
if output:
|
570 |
+
lines = output.split('\n')
|
571 |
+
seen = set()
|
572 |
+
deduped_lines = []
|
573 |
+
for line in lines:
|
574 |
+
if line.strip() and line not in seen:
|
575 |
+
seen.add(line)
|
576 |
+
deduped_lines.append(line)
|
577 |
+
output = '\n'.join(deduped_lines)
|
578 |
logger.debug("LLM output: %s", output[:50])
|
579 |
if check_token_status and max_token is not None:
|
580 |
return output, token_overflow
|
|
|
729 |
function_response=function_response,
|
730 |
temperature=0.1,
|
731 |
max_new_tokens=1024,
|
732 |
+
max_token=8192
|
733 |
)
|
734 |
|
735 |
input_list.insert(
|
|
|
758 |
function_response=function_response,
|
759 |
temperature=0.1,
|
760 |
max_new_tokens=1024,
|
761 |
+
max_token=8192
|
762 |
)
|
763 |
|
764 |
tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
|
|
|
852 |
|
853 |
if last_outputs:
|
854 |
function_call_result = yield from self.run_function_call_stream(
|
855 |
+
last_outputs[0], return_message=True,
|
856 |
existing_tools_prompt=picked_tools_prompt,
|
857 |
message_for_call_agent=message,
|
858 |
call_agent=call_agent,
|
|
|
861 |
|
862 |
if not function_call_result:
|
863 |
logger.warning("Empty result from run_function_call_stream")
|
864 |
+
history.append({"role": "assistant", "content": "Error: Unable to process tool response. Continuing analysis."})
|
865 |
yield history
|
866 |
+
last_outputs = []
|
867 |
+
continue
|
868 |
|
869 |
function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = function_call_result
|
870 |
|
871 |
+
# Convert history to dicts and deduplicate
|
872 |
unique_history = []
|
873 |
seen_contents = set()
|
874 |
for msg in current_gradio_history:
|
875 |
+
content = msg["content"] if isinstance(msg, dict) else msg.content
|
876 |
if content not in seen_contents:
|
877 |
unique_history.append({"role": "assistant", "content": content})
|
878 |
seen_contents.add(content)
|
|
|
926 |
tools=picked_tools_prompt,
|
927 |
skip_special_tokens=False,
|
928 |
max_new_tokens=max_new_tokens,
|
929 |
+
max_token=8192,
|
930 |
seed=seed,
|
931 |
check_token_status=True)
|
932 |
|