Ali2206 commited on
Commit
6ecf798
·
verified ·
1 Parent(s): 095514b

Update src/txagent/utils.py

Browse files
Files changed (1) hide show
  1. src/txagent/utils.py +114 -117
src/txagent/utils.py CHANGED
@@ -1,117 +1,114 @@
1
- import sys
2
- import json
3
- import hashlib
4
- import torch
5
- from typing import List
6
-
7
-
8
- def get_md5(input_str):
9
- # Create an MD5 hash object
10
- md5_hash = hashlib.md5()
11
-
12
- # Encode the string and update the hash object
13
- md5_hash.update(input_str.encode('utf-8'))
14
-
15
- # Return the hexadecimal MD5 digest
16
- return md5_hash.hexdigest()
17
-
18
-
19
- def tool_result_format(function_call_messages):
20
- current_output = "\n\n<details>\n<summary> <strong>Verfied Feedback from Tools</strong>, click to see details:</summary>\n\n"
21
- for each_message in function_call_messages:
22
- if each_message['role'] == 'tool':
23
- current_output += f"{each_message['content']}\n\n"
24
- current_output += "</details>\n\n\n"
25
- return current_output
26
-
27
-
28
- class NoRepeatSentenceProcessor:
29
- def __init__(self, forbidden_sequences: List[List[int]], allowed_prefix_length: int):
30
- """
31
- Args:
32
- forbidden_sequences (List[List[int]]): A list of token ID sequences corresponding to forbidden sentences.
33
- allowed_prefix_length (int): The number k such that if the generated tokens match the first k tokens
34
- of a forbidden sequence, then the candidate token that would extend the match is blocked.
35
- """
36
- self.allowed_prefix_length = allowed_prefix_length
37
- # Build a lookup dictionary: key is a tuple of the first k tokens, value is a set of tokens to block.
38
- self.forbidden_prefix_dict = {}
39
- for seq in forbidden_sequences:
40
- if len(seq) > allowed_prefix_length:
41
- prefix = tuple(seq[:allowed_prefix_length])
42
- next_token = seq[allowed_prefix_length]
43
- self.forbidden_prefix_dict.setdefault(
44
- prefix, set()).add(next_token)
45
-
46
- def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor:
47
- """
48
- Modifies the logits to block tokens that would extend a forbidden sentence.
49
-
50
- Args:
51
- token_ids (List[int]): List of token IDs generated so far.
52
- logits (torch.Tensor): Logits tensor for the next token (shape: [vocab_size]).
53
-
54
- Returns:
55
- torch.Tensor: Modified logits.
56
- """
57
- if len(token_ids) >= self.allowed_prefix_length:
58
- prefix = tuple(token_ids[:self.allowed_prefix_length])
59
- if prefix in self.forbidden_prefix_dict:
60
- for token_id in self.forbidden_prefix_dict[prefix]:
61
- logits[token_id] = -float("inf")
62
- return logits
63
-
64
-
65
- class ReasoningTraceChecker:
66
- def __init__(self, question, conversation, init_index=None):
67
- self.question = question
68
- self.conversation = conversation
69
- self.existing_thoughts = []
70
- self.existing_actions = []
71
- if init_index is not None:
72
- self.index = init_index
73
- else:
74
- self.index = 1
75
- self.question = self.question.lower()
76
- self.new_thoughts = []
77
- self.new_actions = []
78
-
79
- def check_conversation(self):
80
- info = ''
81
- current_index = self.index
82
- for i in range(current_index, len(self.conversation)):
83
- each = self.conversation[i]
84
- self.index = i
85
- if each['role'] == 'assistant':
86
- print(each)
87
- thought = each['content']
88
- actions = each['tool_calls']
89
-
90
- good_status, current_info = self.check_repeat_thought(thought)
91
- info += current_info
92
- if not good_status:
93
- return False, info
94
-
95
- good_status, current_info = self.check_repeat_action(actions)
96
- info += current_info
97
- if not good_status:
98
- return False, info
99
- return True, info
100
-
101
- def check_repeat_thought(self, thought):
102
- if thought in self.existing_thoughts:
103
- return False, "repeat_thought"
104
- self.existing_thoughts.append(thought)
105
- return True, ''
106
-
107
- def check_repeat_action(self, actions):
108
- if type(actions) != list:
109
- actions = json.loads(actions)
110
- for each_action in actions:
111
- if 'call_id' in each_action:
112
- del each_action['call_id']
113
- each_action = json.dumps(each_action)
114
- if each_action in self.existing_actions:
115
- return False, "repeat_action"
116
- self.existing_actions.append(each_action)
117
- return True, ''
 
1
+ import sys
2
+ import json
3
+ import hashlib
4
+ import torch
5
+ from typing import List
6
+ from gradio import ChatMessage # Ensure this is present
7
+
8
+ def get_md5(input_str):
9
+ # Create an MD5 hash object
10
+ md5_hash = hashlib.md5()
11
+ md5_hash.update(input_str.encode('utf-8'))
12
+ return md5_hash.hexdigest()
13
+
14
+
15
+ def tool_result_format(function_call_messages):
16
+ """
17
+ Format tool outputs as a list of ChatMessage objects with metadata
18
+ so the UI can display tool names and details cleanly.
19
+ """
20
+ formatted_messages = []
21
+
22
+ for each_message in function_call_messages:
23
+ if each_message['role'] == 'tool':
24
+ try:
25
+ data = json.loads(each_message['content'])
26
+ tool_name = data.get("tool_name", "Tool Result")
27
+ tool_output = data.get("content", "")
28
+ log = data if isinstance(data, dict) else {}
29
+ except Exception as e:
30
+ # Handle malformed JSON
31
+ tool_name = "Tool Result"
32
+ tool_output = str(each_message['content'])
33
+ log = {"error": "Malformed tool output", "raw": tool_output}
34
+
35
+ formatted_messages.append(ChatMessage(
36
+ role="assistant",
37
+ content=tool_output,
38
+ metadata={
39
+ "title": f"⚒️ {tool_name}",
40
+ "log": json.dumps(log, indent=2)
41
+ }
42
+ ))
43
+
44
+ return formatted_messages
45
+
46
+
47
+ class NoRepeatSentenceProcessor:
48
+ def __init__(self, forbidden_sequences: List[List[int]], allowed_prefix_length: int):
49
+ self.allowed_prefix_length = allowed_prefix_length
50
+ self.forbidden_prefix_dict = {}
51
+ for seq in forbidden_sequences:
52
+ if len(seq) > allowed_prefix_length:
53
+ prefix = tuple(seq[:allowed_prefix_length])
54
+ next_token = seq[allowed_prefix_length]
55
+ self.forbidden_prefix_dict.setdefault(prefix, set()).add(next_token)
56
+
57
+ def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor:
58
+ if len(token_ids) >= self.allowed_prefix_length:
59
+ prefix = tuple(token_ids[:self.allowed_prefix_length])
60
+ if prefix in self.forbidden_prefix_dict:
61
+ for token_id in self.forbidden_prefix_dict[prefix]:
62
+ logits[token_id] = -float("inf")
63
+ return logits
64
+
65
+
66
+ class ReasoningTraceChecker:
67
+ def __init__(self, question, conversation, init_index=None):
68
+ self.question = question.lower()
69
+ self.conversation = conversation
70
+ self.existing_thoughts = []
71
+ self.existing_actions = []
72
+ self.new_thoughts = []
73
+ self.new_actions = []
74
+ self.index = init_index if init_index is not None else 1
75
+
76
+ def check_conversation(self):
77
+ info = ''
78
+ current_index = self.index
79
+ for i in range(current_index, len(self.conversation)):
80
+ each = self.conversation[i]
81
+ self.index = i
82
+ if each['role'] == 'assistant':
83
+ print(each)
84
+ thought = each['content']
85
+ actions = each['tool_calls']
86
+
87
+ good_status, current_info = self.check_repeat_thought(thought)
88
+ info += current_info
89
+ if not good_status:
90
+ return False, info
91
+
92
+ good_status, current_info = self.check_repeat_action(actions)
93
+ info += current_info
94
+ if not good_status:
95
+ return False, info
96
+ return True, info
97
+
98
+ def check_repeat_thought(self, thought):
99
+ if thought in self.existing_thoughts:
100
+ return False, "repeat_thought"
101
+ self.existing_thoughts.append(thought)
102
+ return True, ''
103
+
104
+ def check_repeat_action(self, actions):
105
+ if type(actions) != list:
106
+ actions = json.loads(actions)
107
+ for each_action in actions:
108
+ if 'call_id' in each_action:
109
+ del each_action['call_id']
110
+ each_action_str = json.dumps(each_action)
111
+ if each_action_str in self.existing_actions:
112
+ return False, "repeat_action"
113
+ self.existing_actions.append(each_action_str)
114
+ return True, ''