Ali2206 commited on
Commit
e669311
·
verified ·
1 Parent(s): 206cae1

Create utils.py

Browse files
Files changed (1) hide show
  1. src/txagent/utils.py +94 -0
src/txagent/utils.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import json
3
+ import hashlib
4
+ import torch
5
+ from typing import List
6
+
7
+
8
+ def get_md5(input_str):
9
+ # Create an MD5 hash object
10
+ md5_hash = hashlib.md5()
11
+ md5_hash.update(input_str.encode('utf-8'))
12
+ return md5_hash.hexdigest()
13
+
14
+
15
+ def tool_result_format(function_call_messages):
16
+ current_output = "\n\n<details>\n<summary> <strong>Verified Feedback from Tools</strong>, click to see details:</summary>\n\n"
17
+ for each_message in function_call_messages:
18
+ if each_message['role'] == 'tool':
19
+ try:
20
+ parsed = json.loads(each_message['content'])
21
+ tool_name = parsed.get("tool_name", "Unknown Tool")
22
+ tool_output = parsed.get("content", each_message['content'])
23
+ current_output += f"**🔧 Tool: {tool_name}**\n\n{tool_output}\n\n"
24
+ except Exception:
25
+ current_output += f"{each_message['content']}\n\n"
26
+ current_output += "</details>\n\n\n"
27
+ return current_output
28
+
29
+
30
+ class NoRepeatSentenceProcessor:
31
+ def __init__(self, forbidden_sequences: List[List[int]], allowed_prefix_length: int):
32
+ self.allowed_prefix_length = allowed_prefix_length
33
+ self.forbidden_prefix_dict = {}
34
+ for seq in forbidden_sequences:
35
+ if len(seq) > allowed_prefix_length:
36
+ prefix = tuple(seq[:allowed_prefix_length])
37
+ next_token = seq[allowed_prefix_length]
38
+ self.forbidden_prefix_dict.setdefault(prefix, set()).add(next_token)
39
+
40
+ def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor:
41
+ if len(token_ids) >= self.allowed_prefix_length:
42
+ prefix = tuple(token_ids[:self.allowed_prefix_length])
43
+ if prefix in self.forbidden_prefix_dict:
44
+ for token_id in self.forbidden_prefix_dict[prefix]:
45
+ logits[token_id] = -float("inf")
46
+ return logits
47
+
48
+
49
+ class ReasoningTraceChecker:
50
+ def __init__(self, question, conversation, init_index=None):
51
+ self.question = question.lower()
52
+ self.conversation = conversation
53
+ self.existing_thoughts = []
54
+ self.existing_actions = []
55
+ self.new_thoughts = []
56
+ self.new_actions = []
57
+ self.index = init_index if init_index is not None else 1
58
+
59
+ def check_conversation(self):
60
+ info = ''
61
+ current_index = self.index
62
+ for i in range(current_index, len(self.conversation)):
63
+ each = self.conversation[i]
64
+ self.index = i
65
+ if each['role'] == 'assistant':
66
+ thought = each['content']
67
+ actions = each['tool_calls']
68
+ good_status, current_info = self.check_repeat_thought(thought)
69
+ info += current_info
70
+ if not good_status:
71
+ return False, info
72
+ good_status, current_info = self.check_repeat_action(actions)
73
+ info += current_info
74
+ if not good_status:
75
+ return False, info
76
+ return True, info
77
+
78
+ def check_repeat_thought(self, thought):
79
+ if thought in self.existing_thoughts:
80
+ return False, "repeat_thought"
81
+ self.existing_thoughts.append(thought)
82
+ return True, ''
83
+
84
+ def check_repeat_action(self, actions):
85
+ if type(actions) != list:
86
+ actions = json.loads(actions)
87
+ for each_action in actions:
88
+ if 'call_id' in each_action:
89
+ del each_action['call_id']
90
+ each_action = json.dumps(each_action)
91
+ if each_action in self.existing_actions:
92
+ return False, "repeat_action"
93
+ self.existing_actions.append(each_action)
94
+ return True, ''