jeffreymeetkai commited on
Commit
38f413c
·
verified ·
1 Parent(s): dea9d91

Update tokenization_functionary.py

Browse files
Files changed (1) hide show
  1. tokenization_functionary.py +57 -369
tokenization_functionary.py CHANGED
@@ -1,6 +1,7 @@
1
  # Copyright (c) 2024, MeetKai Inc. All rights reserved.
2
 
3
  from copy import deepcopy
 
4
  import json
5
  from typing import Any, Dict, List, Literal, Optional, Union
6
 
@@ -14,382 +15,69 @@ from transformers.utils import TensorType, logging
14
 
15
 
16
  logger = logging.get_logger(__name__)
17
- SYSTEM_PROMPT = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
18
- CODE_INTERPRETER_SYSTEM_PROMPT = """When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files."""
19
 
20
- class Function(BaseModel):
21
- name: str
22
- description: Optional[str] = Field(default="")
23
- parameters: Optional[dict] = None
24
-
 
25
 
26
- class Tool(BaseModel):
27
- type: Literal["function", "code_interpreter"]
28
- function: Optional[Function] = None
29
-
30
- @model_validator(mode="after")
31
- def check_type_function_matches(self) -> Self:
32
- if self.type == "function":
33
- assert self.function is not None, '"function" must contain function description when `"type": "function"`'
34
- else:
35
- assert self.function is None, '"function" must not be provided when `"type": "code_interpreter"`'
36
- return self
37
-
38
-
39
- def convert_data_type(param_type: str) -> str:
40
- """convert data_type to typescript data type
41
-
42
- Args:
43
- param_type (str): param_type
44
-
45
- Returns:
46
- str: param type in typescript
47
- """
48
- if param_type == "integer" or param_type == "float":
49
- return "number"
50
- return param_type
51
-
52
-
53
- def get_param_type(param: Dict) -> str:
54
- """get param_type of parameter
55
-
56
- Args:
57
- param (Dict): param dict in properties
58
-
59
- Returns:
60
- str: _description_
61
- """
62
- param_type = "any"
63
- if "type" in param:
64
- raw_param_type = param["type"]
65
- if type(raw_param_type) is list:
66
- param_type = " | ".join(raw_param_type)
67
- else:
68
- param_type = raw_param_type
69
-
70
- else: # in many cases, the json schema contains: oneOf instead of "type"
71
- if "oneOf" in param:
72
- one_of_types = []
73
- for item in param["oneOf"]:
74
- if "type" in item:
75
- one_of_types.append(convert_data_type(item["type"]))
76
- one_of_types = list(set(one_of_types))
77
- param_type = " | ".join(one_of_types)
78
- return convert_data_type(param_type)
79
-
80
-
81
- def get_format_param(param: Dict) -> Optional[str]:
82
- """Get "format" from param. There are cases where format is not directly in param but in oneOf
83
-
84
- Args:
85
- param (Dict): _description_
86
-
87
- Returns:
88
- Optional[str]: _description_
89
- """
90
- if "format" in param:
91
- return param["format"]
92
- if "oneOf" in param:
93
- formats = []
94
- for item in param["oneOf"]:
95
- if "format" in item:
96
- formats.append(item["format"])
97
- if len(formats) > 0:
98
- return " or ".join(formats)
99
- return None
100
-
101
-
102
- def get_param_info(param: Dict) -> Optional[str]:
103
- """get additional information about parameter such as: format, default value, min, max, ...
104
-
105
- Args:
106
- param (Dict): _description_
107
-
108
- Returns:
109
- Optional[str]: _description_
110
- """
111
- param_type = param.get("type", "any")
112
- info_list = []
113
- if "description" in param:
114
- desc = param["description"]
115
- if not desc.endswith("."):
116
- desc += "."
117
- info_list.append(desc)
118
-
119
- if "default" in param:
120
- default_value = param["default"]
121
- if param_type == "string":
122
- default_value = f'"{default_value}"' # if string --> add ""
123
- info_list.append(f"Default={default_value}.")
124
-
125
- format_param = get_format_param(param)
126
- if format_param is not None:
127
- info_list.append("Format=" + format_param)
128
-
129
- for field, field_name in [
130
- ("maximum", "Maximum"),
131
- ("minimum", "Minimum"),
132
- ("maxLength", "Maximum length"),
133
- ("minLength", "Minimum length"),
134
- ]:
135
- if field in param:
136
- info_list.append(f"{field_name}=" + str(param[field]))
137
-
138
- if len(info_list) > 0:
139
- result = "// " + " ".join(info_list)
140
- result = result.replace("\n", " ")
141
- return result
142
- return None
143
-
144
-
145
- def append_new_param_info(
146
- info_list: List[str],
147
- param_declaration: str,
148
- comment_info: Optional[str],
149
- examples_info: List,
150
- depth: int,
151
- ):
152
- """Append a new parameter with comment to the info_list
153
-
154
- Args:
155
- info_lines (List[str]): current info_list
156
- param_declaration (str): param: type
157
- comment_info (Optional[str]): information of comment
158
- examples_info (List): information of examples given
159
- depth (int): level of nested param
160
- """
161
- offset = ""
162
- if depth >= 1:
163
- offset = "".join([" " for _ in range(depth)])
164
- if comment_info is not None:
165
- # if depth == 0: # format: //comment\nparam: type
166
- info_list.append(f"{offset}{comment_info}")
167
- if len(examples_info) > 0:
168
- for example in examples_info:
169
- info_list.append(f"{offset}{example}")
170
- info_list.append(f"{offset}{param_declaration}")
171
- # else: # format: param: type // comment
172
- # info_list.append(f"{offset}{param_declaration} {comment_info}")
173
- else:
174
- info_list.append(f"{offset}{param_declaration}")
175
-
176
-
177
- def get_examples_info(param_name: str, examples: List) -> List:
178
- """get information about examples provided
179
-
180
- Args:
181
- param_name (str): _description_
182
- examples (List): _description_
183
 
184
- Returns:
185
- List: _description_
186
- """
187
- examples_list = [f"// Example {param_name}:"]
188
- for example in examples:
189
- if isinstance(example, dict) or isinstance(example, list):
190
- example_str = json.dumps(example, ensure_ascii=False).replace('\n', '\\n')
191
- else:
192
- example_str = str(example).replace('\n', '\\n')
193
- examples_list.append(f"// {example_str}")
194
-
195
- return examples_list
196
 
197
 
198
- def get_enum_option_str(enum_options: List) -> str:
199
- """get enum option separated by: "|"
 
 
 
200
 
201
- Args:
202
- enum_options (List): list of options
203
 
204
- Returns:
205
- _type_: concatenation of options separated by "|"
206
- """
207
- # if each option is string --> add quote
208
- return " | ".join([f'"{v}"' if type(v) is str else str(v) for v in enum_options])
209
 
 
 
 
210
 
211
- def get_array_typescript(
212
- param_name: Optional[str], param_dic: dict, depth: int = 0
213
- ) -> str:
214
- """recursive implementation for generating type script of array
215
 
216
- Args:
217
- param_name (Optional[str]): name of param, optional
218
- param_dic (dict): param_dic
219
- depth (int, optional): nested level. Defaults to 0.
 
 
220
 
221
- Returns:
222
- _type_: typescript of array
223
- """
224
- offset = ""
225
- if depth >= 1:
226
- offset = "".join([" " for _ in range(depth)])
227
- items_info = param_dic.get("items", {})
228
-
229
- if len(items_info) == 0:
230
- if param_name is not None:
231
- return f"{offset}{param_name}: []"
232
- else:
233
- return "[]"
234
- array_type = get_param_type(items_info)
235
- if array_type == "object":
236
- info_lines = []
237
- child_lines = get_parameter_typescript(
238
- items_info.get("properties", {}), items_info.get("required", []), depth + 1
239
- )
240
- # if comment_info is not None:
241
- # info_lines.append(f"{offset}{comment_info}")
242
- if param_name is not None:
243
- info_lines.append(f"{offset}{param_name}" + ": {")
244
- else:
245
- info_lines.append(f"{offset}" + "{")
246
- info_lines.extend(child_lines)
247
- info_lines.append(f"{offset}" + "}[]")
248
- return "\n".join(info_lines)
249
 
250
- elif array_type == "array":
251
- item_info = get_array_typescript(None, items_info, depth + 1)
252
- if param_name is None:
253
- return f"{item_info}[]"
254
- return f"{offset}{param_name}: {item_info.strip()}[]"
255
 
256
- else:
257
- if "enum" in items_info:
258
- item_type = get_enum_option_str(items_info["enum"])
259
- if param_name is None:
260
- return f"({item_type})[]"
261
- else:
262
- return f"{offset}{param_name}: ({item_type})[]"
263
- else:
264
- if param_name is None:
265
- return f"{array_type}[]"
266
- else:
267
- return f"{offset}{param_name}: {array_type}[],"
268
 
 
 
 
 
 
269
 
270
- def get_parameter_typescript(properties, required_params, depth=0) -> List[str]:
271
- """Recursion, returning the information about parameters including data type, description and other information
272
- These kinds of information will be put into the prompt
273
 
274
- Args:
275
- properties (_type_): properties in parameters
276
- required_params (_type_): List of required parameters
277
- depth (int, optional): the depth of params (nested level). Defaults to 0.
278
 
279
- Returns:
280
- _type_: list of lines containing information about all parameters
281
- """
282
- tp_lines = []
283
- for param_name, param in properties.items():
284
- # Sometimes properties have "required" field as a list of string.
285
- # Even though its supposed to be not under properties. So we skip it
286
- if not isinstance(param, dict):
287
- continue
288
- # Param Description
289
- comment_info = get_param_info(param)
290
- # Param Examples
291
- examples_info = []
292
- if "examples" in param:
293
- examples_info = get_examples_info(param_name, param["examples"])
294
- # Param Name declaration
295
- param_declaration = f"{param_name}"
296
- if isinstance(required_params, list):
297
- if param_name not in required_params:
298
- param_declaration += "?"
299
- param_type = get_param_type(param)
300
-
301
- offset = ""
302
- if depth >= 1:
303
- offset = "".join([" " for _ in range(depth)])
304
-
305
- if param_type == "object": # param_type is object
306
- child_lines = get_parameter_typescript(
307
- param.get("properties", {}), param.get("required", []), depth + 1
308
- )
309
- if comment_info is not None:
310
- tp_lines.append(f"{offset}{comment_info}")
311
- if len(examples_info) > 0:
312
- for example in examples_info:
313
- tp_lines.append(f"{offset}{example}")
314
-
315
- param_declaration += ": {"
316
- tp_lines.append(f"{offset}{param_declaration}")
317
- tp_lines.extend(child_lines)
318
- tp_lines.append(f"{offset}" + "},")
319
-
320
- elif param_type == "array": # param_type is an array
321
- item_info = param.get("items", {})
322
- if "type" not in item_info: # don't know type of array
323
- param_declaration += ": [],"
324
- append_new_param_info(
325
- tp_lines, param_declaration, comment_info, examples_info, depth
326
- )
327
- else:
328
- array_declaration = get_array_typescript(
329
- param_declaration, param, depth
330
- )
331
- if not array_declaration.endswith(","):
332
- array_declaration += ","
333
- if comment_info is not None:
334
- tp_lines.append(f"{offset}{comment_info}")
335
- if len(examples_info) > 0:
336
- for example in examples_info:
337
- tp_lines.append(f"{offset}{example}")
338
- tp_lines.append(array_declaration)
339
- else:
340
- if "enum" in param:
341
- param_type = get_enum_option_str(param["enum"])
342
- # param_type = " | ".join([f'"{v}"' for v in param["enum"]])
343
- if "nullable" in param and param["nullable"] is True:
344
- param_type += " | null"
345
- param_declaration += f": {param_type},"
346
- append_new_param_info(
347
- tp_lines, param_declaration, comment_info, examples_info, depth
348
- )
349
-
350
- return tp_lines
351
-
352
- def generate_schema_from_functions(
353
- functions: List[Function], namespace="functions"
354
- ) -> str:
355
- """
356
- Convert functions schema to a schema that language models can understand.
357
- """
358
-
359
- schema = "// Supported function definitions that should be called when necessary.\n"
360
- schema += f"namespace {namespace} {{\n\n"
361
-
362
- for function in functions:
363
- # Convert a Function object to dict, if necessary
364
- if not isinstance(function, dict):
365
- function = function.model_dump()
366
- function_name = function.get("name", None)
367
- if function_name is None:
368
- continue
369
-
370
- description = function.get("description", "")
371
- schema += f"// {description}\n"
372
- schema += f"type {function_name}"
373
-
374
- parameters = function.get("parameters", None)
375
- if parameters is not None and parameters.get("properties") is not None:
376
- parameters = deepcopy(jsonref.JsonRef.replace_refs(parameters))
377
- schema += " = (_: {\n"
378
- required_params = parameters.get("required", [])
379
- tp_lines = get_parameter_typescript(
380
- parameters.get("properties"),
381
- required_params,
382
- 0,
383
- )
384
- schema += "\n".join(tp_lines)
385
- schema += "\n}) => any;\n\n"
386
- else:
387
- # Doesn't have any parameters
388
- schema += " = () => any;\n\n"
389
-
390
- schema += f"}} // namespace {namespace}"
391
-
392
- return schema
393
 
394
  class FunctionaryTokenizer(PreTrainedTokenizerFast):
395
  def apply_chat_template(
@@ -465,16 +153,16 @@ class FunctionaryTokenizer(PreTrainedTokenizerFast):
465
  # Prepare tools/functions into schema
466
  functions_pydantic_to_render = []
467
  has_code_interpreter = False
468
- for i in range(len(tools)):
469
- tool_pydantic = Tool.model_validate(tools[i])
470
- if tool_pydantic.type == "function":
471
- functions_pydantic_to_render.append(tool_pydantic.function)
472
- else:
473
- has_code_interpreter = True
474
- conversation.insert(0, {"role": "system", "content": generate_schema_from_functions(functions_pydantic_to_render)})
475
- # Insert system prompt
476
- system_prompt_to_use = SYSTEM_PROMPT if not has_code_interpreter else CODE_INTERPRETER_SYSTEM_PROMPT
477
- conversation.insert(1, {"role": "system", "content": system_prompt_to_use})
478
 
479
  # Compilation function uses a cache to avoid recompiling the same template
480
  compiled_template = self._compile_jinja_template(chat_template)
 
1
  # Copyright (c) 2024, MeetKai Inc. All rights reserved.
2
 
3
  from copy import deepcopy
4
+ import datetime
5
  import json
6
  from typing import Any, Dict, List, Literal, Optional, Union
7
 
 
15
 
16
 
17
  logger = logging.get_logger(__name__)
 
 
18
 
19
+ def get_instruction_string(custom_tool_definition) -> str:
20
+ name, description = (
21
+ custom_tool_definition["name"],
22
+ custom_tool_definition["description"],
23
+ )
24
+ return f"Use the function '{name}' to '{description}'"
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ def get_parameters_string(custom_tool_definition) -> str:
28
+ return json.dumps(custom_tool_definition)
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
+ def get_system_prompt_for_custom_tools(custom_tools: List) -> str:
32
+ custom_tool_params = ""
33
+ for t in custom_tools:
34
+ custom_tool_params += get_instruction_string(t) + "\n"
35
+ custom_tool_params += get_parameters_string(t) + "\n\n"
36
 
37
+ content = f"""
38
+ You have access to the following functions:
39
 
40
+ {custom_tool_params}
41
+ Think very carefully before calling functions.
42
+ If a you choose to call a function ONLY reply in the following format:
43
+ <{{start_tag}}={{function_name}}>{{parameters}}{{end_tag}}
44
+ where
45
 
46
+ start_tag => `<function`
47
+ parameters => a JSON dict with the function argument name as key and function argument value as value.
48
+ end_tag => `</function>`
49
 
50
+ Here is an example,
51
+ <function=example_function_name>{{"example_name": "example_value"}}</function>
 
 
52
 
53
+ Reminder:
54
+ - If looking for real time information use relevant functions before falling back to brave_search
55
+ - Function calls MUST follow the specified format, start with <function= and end with </function>
56
+ - Required parameters MUST be specified
57
+ - Only call one function at a time
58
+ - Put the entire function call reply on one line
59
 
60
+ """
61
+ return content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
 
 
 
 
 
63
 
64
+ def get_system_message_for_tools(tools: List[Dict], use_code_interpreter) -> List[Dict]:
65
+ content = ""
66
+ if use_code_interpreter:
67
+ content += "Environment: ipython\n"
 
 
 
 
 
 
 
 
68
 
69
+ current_date = datetime.datetime.now()
70
+ formatted_date = current_date.strftime("%d %B %Y")
71
+ date_str = f"""
72
+ Cutting Knowledge Date: December 2023\n\n"""
73
+ content += date_str
74
 
75
+ if tools:
76
+ custom_message = get_system_prompt_for_custom_tools(tools)
77
+ content += custom_message
78
 
79
+ return {"role": "system", "content": content}
 
 
 
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  class FunctionaryTokenizer(PreTrainedTokenizerFast):
83
  def apply_chat_template(
 
153
  # Prepare tools/functions into schema
154
  functions_pydantic_to_render = []
155
  has_code_interpreter = False
156
+ if tools is not None:
157
+ for item in tools:
158
+ if "function" in item and item["function"] is not None:
159
+ functions_pydantic_to_render.append(item["function"])
160
+ elif "type" in item and item["type"] == "code_interpreter":
161
+ has_code_interpreter = True
162
+ else:
163
+ functions_pydantic_to_render.append(item)
164
+ tools_system_message = get_system_message_for_tools(functions_pydantic_to_render, has_code_interpreter)
165
+ conversation.insert(0, tools_system_message)
166
 
167
  # Compilation function uses a cache to avoid recompiling the same template
168
  compiled_template = self._compile_jinja_template(chat_template)