DiTy commited on
Commit
db48e80
·
verified ·
1 Parent(s): bb80064

Upload qwen2_tool_parser.py

Browse files
Files changed (1) hide show
  1. qwen2_tool_parser.py +122 -0
qwen2_tool_parser.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from typing import Dict, List, Sequence, Union
4
+
5
+ import partial_json_parser
6
+ from partial_json_parser.core.options import Allow
7
+
8
+ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
9
+ DeltaFunctionCall, DeltaMessage,
10
+ DeltaToolCall,
11
+ ExtractedToolCallInformation,
12
+ FunctionCall, ToolCall)
13
+ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
14
+ ToolParser, ToolParserManager)
15
+ from vllm.logger import init_logger
16
+ from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
17
+ from vllm.utils import random_uuid
18
+
19
+ logger = init_logger(__name__)
20
+
21
+
22
+ @ToolParserManager.register_module("qwen2")
23
+ class Qwen2ToolParser(ToolParser):
24
+
25
+ def __init__(self, tokenizer: AnyTokenizer):
26
+ super().__init__(tokenizer)
27
+
28
+ if isinstance(self.model_tokenizer, MistralTokenizer):
29
+ logger.error(
30
+ "Detected Mistral tokenizer when using a Qwen2.5 model")
31
+ self.model_tokenizer = self.model_tokenizer.tokenizer
32
+
33
+ self.current_tool_name_sent: bool = False
34
+ self.prev_tool_call_arr: List[Dict] = []
35
+ self.current_tool_id: int = -1
36
+ self.streamed_args_for_tool: List[str] = [
37
+ ] # map what has been streamed for each tool so far to a list
38
+
39
+ self.tool_call_start_token: str = "<tool_call>"
40
+ self.tool_call_end_token: str = "</tool_call>"
41
+
42
+ self.tool_call_regex = re.compile(
43
+ r"<tool_call>(.*?)</tool_call>", re.DOTALL)
44
+ self.scratch_pad_regex = re.compile(
45
+ r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL)
46
+
47
+ if not self.model_tokenizer:
48
+ raise ValueError(
49
+ "The model tokenizer must be passed to the ToolParser "
50
+ "constructor during construction.")
51
+ self.tool_call_start_token_id = self.vocab.get(
52
+ self.tool_call_start_token)
53
+ self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
54
+ if (self.tool_call_start_token_id is None
55
+ or self.tool_call_end_token_id is None):
56
+ raise RuntimeError(
57
+ "Qwen2.5 Tool parser could not locate tool call start/end "
58
+ "tokens in the tokenizer!")
59
+
60
+ def extract_tool_calls(
61
+ self,
62
+ model_output: str,
63
+ request: ChatCompletionRequest,
64
+ ) -> ExtractedToolCallInformation:
65
+
66
+ # sanity check; avoid unnecessary processing
67
+ if self.tool_call_start_token not in model_output:
68
+ return ExtractedToolCallInformation(tools_called=False,
69
+ tool_calls=[],
70
+ content=model_output)
71
+
72
+ else:
73
+
74
+ try:
75
+ # find all tool calls between "<tool_call>" and "</tool_call>"
76
+ # the other is None
77
+ function_call_strs = (
78
+ self.tool_call_regex.findall(model_output))
79
+
80
+ # load the JSON, and then use it to build the Function and
81
+ # Tool Call
82
+ raw_function_calls = json.loads(function_call_strs[0])
83
+
84
+ tool_calls = [
85
+ ToolCall(
86
+ type="function",
87
+ function=FunctionCall(
88
+ name=function_call["tool_name"],
89
+ # function call args are JSON but as a string
90
+ arguments=json.dumps(function_call["parameters"], ensure_ascii=False)
91
+ )
92
+ )
93
+ for function_call in raw_function_calls
94
+ ]
95
+
96
+ content = model_output[:model_output.
97
+ find(self.tool_call_start_token)]
98
+ return ExtractedToolCallInformation(
99
+ tools_called=True,
100
+ tool_calls=tool_calls,
101
+ content=content if content else None)
102
+
103
+ except Exception:
104
+ logger.exception(
105
+ "Error in extracting tool call from response.")
106
+ return ExtractedToolCallInformation(tools_called=False,
107
+ tool_calls=[],
108
+ content=model_output)
109
+
110
+ # for streamed parsing
111
+ def extract_tool_calls_streaming(
112
+ self,
113
+ previous_text: str,
114
+ current_text: str,
115
+ delta_text: str,
116
+ previous_token_ids: Sequence[int],
117
+ current_token_ids: Sequence[int],
118
+ delta_token_ids: Sequence[int],
119
+ request: ChatCompletionRequest,
120
+ ) -> Union[DeltaMessage, None]:
121
+
122
+ pass