Ali2206 commited on
Commit
8e4e12d
·
verified ·
1 Parent(s): f412a81

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +32 -21
src/txagent/txagent.py CHANGED
@@ -264,7 +264,7 @@ class TxAgent:
264
  )
265
  call_result = self.run_multistep_agent(
266
  full_message, temperature=temperature,
267
- max_new_tokens=1024, max_token=99999,
268
  call_agent=False, call_agent_level=call_agent_level)
269
  if call_result is None:
270
  call_result = "⚠️ No content returned from sub-agent."
@@ -286,13 +286,13 @@ class TxAgent:
286
  else:
287
  call_results.append({
288
  "role": "tool",
289
- "content": json.dumps({"content": "Not a valid function call, please check the function call format."})
290
  })
291
 
292
  revised_messages = [{
293
  "role": "assistant",
294
- "content": message.strip(),
295
- "tool_calls": json.dumps(function_call_json)
296
  }] + call_results
297
  logger.debug("Function call completed, returning %d messages", len(revised_messages))
298
  return revised_messages, existing_tools_prompt, special_tool_call
@@ -317,11 +317,11 @@ class TxAgent:
317
  logger.warning("No valid function call JSON extracted")
318
  call_results.append({
319
  "role": "tool",
320
- "content": json.dumps({"content": "Invalid function call format."})
321
  })
322
  if return_gradio_history:
323
- gradio_history.append(ChatMessage(role="assistant", content="Error: Invalid tool call format."))
324
- yield call_results, existing_tools_prompt or [], special_tool_call, gradio_history
325
  return
326
 
327
  if isinstance(function_call_json, list):
@@ -352,9 +352,9 @@ class TxAgent:
352
  str(solution_plan)
353
  )
354
  sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
355
- sub_result = yield from self.run_gradio_chat(
356
  full_message, history=[], temperature=temperature,
357
- max_new_tokens=1024, max_token=99999,
358
  call_agent=False, call_agent_level=call_agent_level,
359
  conversation=None,
360
  sub_agent_task=sub_agent_task)
@@ -375,12 +375,12 @@ class TxAgent:
375
 
376
  if return_gradio_history and function_call_json[i]["name"] != 'Finish':
377
  metadata = {"title": f"⚒️ {function_call_json[i]['name']}", "log": str(function_call_json[i]['arguments'])}
378
- gradio_history.append(ChatMessage(role="assistant", content=str(call_result), metadata=metadata))
379
 
380
  revised_messages = [{
381
  "role": "assistant",
382
- "content": message.strip(),
383
- "tool_calls": json.dumps(function_call_json)
384
  }] + call_results
385
 
386
  if return_gradio_history:
@@ -523,7 +523,7 @@ class TxAgent:
523
 
524
  def llm_infer(self, messages, temperature=0.1, tools=None,
525
  output_begin_string=None, max_new_tokens=2048,
526
- max_token=None, skip_special_tokens=True,
527
  model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
528
 
529
  logger.debug("Running LLM inference with %d messages", len(messages))
@@ -565,6 +565,16 @@ class TxAgent:
565
  sampling_params=sampling_params,
566
  )
567
  output = output[0].outputs[0].text
 
 
 
 
 
 
 
 
 
 
568
  logger.debug("LLM output: %s", output[:50])
569
  if check_token_status and max_token is not None:
570
  return output, token_overflow
@@ -719,7 +729,7 @@ Generate **one summarized sentence** about "function calls' responses" with nece
719
  function_response=function_response,
720
  temperature=0.1,
721
  max_new_tokens=1024,
722
- max_token=99999
723
  )
724
 
725
  input_list.insert(
@@ -748,7 +758,7 @@ Generate **one summarized sentence** about "function calls' responses" with nece
748
  function_response=function_response,
749
  temperature=0.1,
750
  max_new_tokens=1024,
751
- max_token=99999
752
  )
753
 
754
  tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
@@ -842,7 +852,7 @@ Generate **one summarized sentence** about "function calls' responses" with nece
842
 
843
  if last_outputs:
844
  function_call_result = yield from self.run_function_call_stream(
845
- last_outputs, return_message=True,
846
  existing_tools_prompt=picked_tools_prompt,
847
  message_for_call_agent=message,
848
  call_agent=call_agent,
@@ -851,17 +861,18 @@ Generate **one summarized sentence** about "function calls' responses" with nece
851
 
852
  if not function_call_result:
853
  logger.warning("Empty result from run_function_call_stream")
854
- history.append({"role": "assistant", "content": "Error: Tool call processing failed."})
855
  yield history
856
- return "Error: Tool call processing failed."
 
857
 
858
  function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = function_call_result
859
 
860
- # Convert ChatMessage to dicts and deduplicate
861
  unique_history = []
862
  seen_contents = set()
863
  for msg in current_gradio_history:
864
- content = msg.content
865
  if content not in seen_contents:
866
  unique_history.append({"role": "assistant", "content": content})
867
  seen_contents.add(content)
@@ -915,7 +926,7 @@ Generate **one summarized sentence** about "function calls' responses" with nece
915
  tools=picked_tools_prompt,
916
  skip_special_tokens=False,
917
  max_new_tokens=max_new_tokens,
918
- max_token=max_token,
919
  seed=seed,
920
  check_token_status=True)
921
 
 
264
  )
265
  call_result = self.run_multistep_agent(
266
  full_message, temperature=temperature,
267
+ max_new_tokens=1024, max_token=8192,
268
  call_agent=False, call_agent_level=call_agent_level)
269
  if call_result is None:
270
  call_result = "⚠️ No content returned from sub-agent."
 
286
  else:
287
  call_results.append({
288
  "role": "tool",
289
+ "content": json.dumps({"content": "No valid tool call detected; proceeding with analysis."})
290
  })
291
 
292
  revised_messages = [{
293
  "role": "assistant",
294
+ "content": message.strip() if message else "Processing...",
295
+ "tool_calls": json.dumps(function_call_json) if function_call_json else None
296
  }] + call_results
297
  logger.debug("Function call completed, returning %d messages", len(revised_messages))
298
  return revised_messages, existing_tools_prompt, special_tool_call
 
317
  logger.warning("No valid function call JSON extracted")
318
  call_results.append({
319
  "role": "tool",
320
+ "content": json.dumps({"content": "No tool call detected; continuing analysis."})
321
  })
322
  if return_gradio_history:
323
+ gradio_history.append({"role": "assistant", "content": "No specific tool call identified. Proceeding with medical record analysis."})
324
+ yield [{"role": "assistant", "content": "Processing..."}], existing_tools_prompt or [], special_tool_call, gradio_history
325
  return
326
 
327
  if isinstance(function_call_json, list):
 
352
  str(solution_plan)
353
  )
354
  sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
355
+ sub_result = yield from self.run EXEMPLARgradio_chat(
356
  full_message, history=[], temperature=temperature,
357
+ max_new_tokens=1024, max_token=8192,
358
  call_agent=False, call_agent_level=call_agent_level,
359
  conversation=None,
360
  sub_agent_task=sub_agent_task)
 
375
 
376
  if return_gradio_history and function_call_json[i]["name"] != 'Finish':
377
  metadata = {"title": f"⚒️ {function_call_json[i]['name']}", "log": str(function_call_json[i]['arguments'])}
378
+ gradio_history.append({"role": "assistant", "content": str(call_result), "metadata": metadata})
379
 
380
  revised_messages = [{
381
  "role": "assistant",
382
+ "content": message.strip() if message else "Processing...",
383
+ "tool_calls": json.dumps(function_call_json) if function_call_json else None
384
  }] + call_results
385
 
386
  if return_gradio_history:
 
523
 
524
  def llm_infer(self, messages, temperature=0.1, tools=None,
525
  output_begin_string=None, max_new_tokens=2048,
526
+ max_token=8192, skip_special_tokens=True,
527
  model=None, tokenizer=None, terminators=None, seed=None, check_token_status=False):
528
 
529
  logger.debug("Running LLM inference with %d messages", len(messages))
 
565
  sampling_params=sampling_params,
566
  )
567
  output = output[0].outputs[0].text
568
+ # Deduplicate repetitive output
569
+ if output:
570
+ lines = output.split('\n')
571
+ seen = set()
572
+ deduped_lines = []
573
+ for line in lines:
574
+ if line.strip() and line not in seen:
575
+ seen.add(line)
576
+ deduped_lines.append(line)
577
+ output = '\n'.join(deduped_lines)
578
  logger.debug("LLM output: %s", output[:50])
579
  if check_token_status and max_token is not None:
580
  return output, token_overflow
 
729
  function_response=function_response,
730
  temperature=0.1,
731
  max_new_tokens=1024,
732
+ max_token=8192
733
  )
734
 
735
  input_list.insert(
 
758
  function_response=function_response,
759
  temperature=0.1,
760
  max_new_tokens=1024,
761
+ max_token=8192
762
  )
763
 
764
  tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
 
852
 
853
  if last_outputs:
854
  function_call_result = yield from self.run_function_call_stream(
855
+ last_outputs[0], return_message=True,
856
  existing_tools_prompt=picked_tools_prompt,
857
  message_for_call_agent=message,
858
  call_agent=call_agent,
 
861
 
862
  if not function_call_result:
863
  logger.warning("Empty result from run_function_call_stream")
864
+ history.append({"role": "assistant", "content": "Error: Unable to process tool response. Continuing analysis."})
865
  yield history
866
+ last_outputs = []
867
+ continue
868
 
869
  function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = function_call_result
870
 
871
+ # Convert history to dicts and deduplicate
872
  unique_history = []
873
  seen_contents = set()
874
  for msg in current_gradio_history:
875
+ content = msg["content"] if isinstance(msg, dict) else msg.content
876
  if content not in seen_contents:
877
  unique_history.append({"role": "assistant", "content": content})
878
  seen_contents.add(content)
 
926
  tools=picked_tools_prompt,
927
  skip_special_tokens=False,
928
  max_new_tokens=max_new_tokens,
929
+ max_token=8192,
930
  seed=seed,
931
  check_token_status=True)
932