Ali2206 commited on
Commit
9737311
·
verified ·
1 Parent(s): c0b2cb7

Update src/txagent/txagent.py

Browse files
Files changed (1) hide show
  1. src/txagent/txagent.py +31 -12
src/txagent/txagent.py CHANGED
@@ -73,7 +73,7 @@ class TxAgent:
73
  return f"The model {model_name} is already loaded."
74
  self.model_name = model_name
75
 
76
- self.model = LLM(model=self.model_name, dtype="float16", max_model_len=1024, gpu_memory_utilization=0.8)
77
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
78
  self.tokenizer = self.model.get_tokenizer()
79
  logger.info("Model %s loaded successfully", self.model_name)
@@ -176,8 +176,14 @@ class TxAgent:
176
  call_agent=False,
177
  call_agent_level=None,
178
  temperature=None):
179
- function_call_json, message = self.tooluniverse.extract_function_call_json(
180
- fcall_str, return_message=return_message, verbose=False)
 
 
 
 
 
 
181
  call_results = []
182
  special_tool_call = ''
183
  if function_call_json:
@@ -197,7 +203,7 @@ class TxAgent:
197
  )
198
  call_result = self.run_multistep_agent(
199
  full_message, temperature=temperature,
200
- max_new_tokens=512, max_token=1024,
201
  call_agent=False, call_agent_level=call_agent_level)
202
  if call_result is None:
203
  call_result = "⚠️ No content returned from sub-agent."
@@ -217,7 +223,7 @@ class TxAgent:
217
  else:
218
  call_results.append({
219
  "role": "tool",
220
- "content": json.dumps({"content": "Invalid function call format."})
221
  })
222
 
223
  revised_messages = [{
@@ -235,8 +241,14 @@ class TxAgent:
235
  call_agent_level=None,
236
  temperature=None,
237
  return_gradio_history=True):
238
- function_call_json, message = self.tooluniverse.extract_function_call_json(
239
- fcall_str, return_message=return_message, verbose=False)
 
 
 
 
 
 
240
  call_results = []
241
  special_tool_call = ''
242
  if return_gradio_history:
@@ -264,7 +276,7 @@ class TxAgent:
264
  sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
265
  call_result = yield from self.run_gradio_chat(
266
  full_message, history=[], temperature=temperature,
267
- max_new_tokens=512, max_token=1024,
268
  call_agent=False, call_agent_level=call_agent_level,
269
  conversation=None, sub_agent_task=sub_agent_task)
270
  if call_result is not None and isinstance(call_result, str):
@@ -287,7 +299,7 @@ class TxAgent:
287
  else:
288
  call_results.append({
289
  "role": "tool",
290
- "content": json.dumps({"content": "Invalid function call format."})
291
  })
292
 
293
  revised_messages = [{
@@ -300,6 +312,13 @@ class TxAgent:
300
  return revised_messages, existing_tools_prompt, special_tool_call
301
 
302
  def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
 
 
 
 
 
 
 
303
  if conversation[-1]['role'] == 'assistant':
304
  conversation.append(
305
  {'role': 'tool', 'content': 'Errors occurred during function call; provide final answer with current information.'})
@@ -401,7 +420,7 @@ class TxAgent:
401
 
402
  def llm_infer(self, messages, temperature=0.1, tools=None,
403
  output_begin_string=None, max_new_tokens=512,
404
- max_token=1024, skip_special_tokens=True,
405
  model=None, tokenizer=None, terminators=None,
406
  seed=None, check_token_status=False):
407
  if model is None:
@@ -550,7 +569,7 @@ Summarize the function calls' responses in one sentence with all necessary infor
550
  function_response=function_response,
551
  temperature=0.1,
552
  max_new_tokens=512,
553
- max_token=1024)
554
  input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary})
555
  status['summarized_index'] = last_call_idx + 2
556
  idx += 1
@@ -572,7 +591,7 @@ Summarize the function calls' responses in one sentence with all necessary infor
572
  function_response=function_response,
573
  temperature=0.1,
574
  max_new_tokens=512,
575
- max_token=1024)
576
  tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
577
  for tool_call in tool_calls:
578
  del tool_call['call_id']
 
73
  return f"The model {model_name} is already loaded."
74
  self.model_name = model_name
75
 
76
+ self.model = LLM(model=self.model_name, dtype="float16", max_model_len=2048, gpu_memory_utilization=0.8)
77
  self.chat_template = Template(self.model.get_tokenizer().chat_template)
78
  self.tokenizer = self.model.get_tokenizer()
79
  logger.info("Model %s loaded successfully", self.model_name)
 
176
  call_agent=False,
177
  call_agent_level=None,
178
  temperature=None):
179
+ try:
180
+ function_call_json, message = self.tooluniverse.extract_function_call_json(
181
+ fcall_str, return_message=return_message, verbose=False)
182
+ except Exception as e:
183
+ logger.error("Tool call parsing failed: %s", e)
184
+ function_call_json = []
185
+ message = fcall_str
186
+
187
  call_results = []
188
  special_tool_call = ''
189
  if function_call_json:
 
203
  )
204
  call_result = self.run_multistep_agent(
205
  full_message, temperature=temperature,
206
+ max_new_tokens=512, max_token=2048,
207
  call_agent=False, call_agent_level=call_agent_level)
208
  if call_result is None:
209
  call_result = "⚠️ No content returned from sub-agent."
 
223
  else:
224
  call_results.append({
225
  "role": "tool",
226
+ "content": json.dumps({"content": "Invalid or no function call detected."})
227
  })
228
 
229
  revised_messages = [{
 
241
  call_agent_level=None,
242
  temperature=None,
243
  return_gradio_history=True):
244
+ try:
245
+ function_call_json, message = self.tooluniverse.extract_function_call_json(
246
+ fcall_str, return_message=return_message, verbose=False)
247
+ except Exception as e:
248
+ logger.error("Tool call parsing failed: %s", e)
249
+ function_call_json = []
250
+ message = fcall_str
251
+
252
  call_results = []
253
  special_tool_call = ''
254
  if return_gradio_history:
 
276
  sub_agent_task = "Sub TxAgent plan: " + str(solution_plan)
277
  call_result = yield from self.run_gradio_chat(
278
  full_message, history=[], temperature=temperature,
279
+ max_new_tokens=512, max_token=2048,
280
  call_agent=False, call_agent_level=call_agent_level,
281
  conversation=None, sub_agent_task=sub_agent_task)
282
  if call_result is not None and isinstance(call_result, str):
 
299
  else:
300
  call_results.append({
301
  "role": "tool",
302
+ "content": json.dumps({"content": "Invalid or no function call detected."})
303
  })
304
 
305
  revised_messages = [{
 
312
  return revised_messages, existing_tools_prompt, special_tool_call
313
 
314
  def get_answer_based_on_unfinished_reasoning(self, conversation, temperature, max_new_tokens, max_token, outputs=None):
315
+ # Truncate conversation to fit within max_token
316
+ tokenized = self.tokenizer.encode(json.dumps(conversation), add_special_tokens=False)
317
+ if len(tokenized) > max_token - 100:
318
+ logger.warning("Truncating conversation to fit max_token=%d", max_token)
319
+ while len(tokenized) > max_token - 100 and len(conversation) > 1:
320
+ conversation.pop(1) # Keep system prompt and latest message
321
+ tokenized = self.tokenizer.encode(json.dumps(conversation), add_special_tokens=False)
322
  if conversation[-1]['role'] == 'assistant':
323
  conversation.append(
324
  {'role': 'tool', 'content': 'Errors occurred during function call; provide final answer with current information.'})
 
420
 
421
  def llm_infer(self, messages, temperature=0.1, tools=None,
422
  output_begin_string=None, max_new_tokens=512,
423
+ max_token=2048, skip_special_tokens=True,
424
  model=None, tokenizer=None, terminators=None,
425
  seed=None, check_token_status=False):
426
  if model is None:
 
569
  function_response=function_response,
570
  temperature=0.1,
571
  max_new_tokens=512,
572
+ max_token=2048)
573
  input_list.insert(last_call_idx + 1, {'role': 'tool', 'content': result_summary})
574
  status['summarized_index'] = last_call_idx + 2
575
  idx += 1
 
591
  function_response=function_response,
592
  temperature=0.1,
593
  max_new_tokens=512,
594
+ max_token=2048)
595
  tool_calls = json.loads(input_list[last_call_idx]['tool_calls'])
596
  for tool_call in tool_calls:
597
  del tool_call['call_id']