|
|
|
|
|
from copy import deepcopy |
|
import datetime |
|
import json |
|
from typing import Any, Dict, List, Literal, Optional, Union |
|
|
|
import jsonref |
|
from pydantic import BaseModel, Field, model_validator |
|
from typing_extensions import Self |
|
|
|
from transformers.tokenization_utils_base import BatchEncoding |
|
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast |
|
from transformers.utils import TensorType, logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
def get_instruction_string(custom_tool_definition) -> str: |
|
name, description = ( |
|
custom_tool_definition["name"], |
|
custom_tool_definition["description"], |
|
) |
|
return f"Use the function '{name}' to '{description}'" |
|
|
|
|
|
def get_parameters_string(custom_tool_definition) -> str: |
|
return json.dumps(custom_tool_definition) |
|
|
|
|
|
def get_system_prompt_for_custom_tools(custom_tools: List) -> str: |
|
custom_tool_params = "" |
|
for t in custom_tools: |
|
custom_tool_params += get_instruction_string(t) + "\n" |
|
custom_tool_params += get_parameters_string(t) + "\n\n" |
|
|
|
content = f""" |
|
You have access to the following functions: |
|
|
|
{custom_tool_params} |
|
Think very carefully before calling functions. |
|
If a you choose to call a function ONLY reply in the following format: |
|
<{{start_tag}}={{function_name}}>{{parameters}}{{end_tag}} |
|
where |
|
|
|
start_tag => `<function` |
|
parameters => a JSON dict with the function argument name as key and function argument value as value. |
|
end_tag => `</function>` |
|
|
|
Here is an example, |
|
<function=example_function_name>{{"example_name": "example_value"}}</function> |
|
|
|
Reminder: |
|
- If looking for real time information use relevant functions before falling back to brave_search |
|
- Function calls MUST follow the specified format, start with <function= and end with </function> |
|
- Required parameters MUST be specified |
|
- Only call one function at a time |
|
- Put the entire function call reply on one line |
|
|
|
""" |
|
return content |
|
|
|
|
|
def get_system_message_for_tools(tools: List[Dict], use_code_interpreter) -> List[Dict]: |
|
content = "" |
|
if use_code_interpreter: |
|
content += "Environment: ipython\n" |
|
|
|
current_date = datetime.datetime.now() |
|
formatted_date = current_date.strftime("%d %B %Y") |
|
date_str = f""" |
|
Cutting Knowledge Date: December 2023\n\n""" |
|
content += date_str |
|
|
|
if tools: |
|
custom_message = get_system_prompt_for_custom_tools(tools) |
|
content += custom_message |
|
|
|
return {"role": "system", "content": content} |
|
|
|
|
|
class FunctionaryTokenizer(PreTrainedTokenizerFast): |
|
def apply_chat_template( |
|
self, |
|
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]], str], |
|
tools: Optional[List[Dict[str, Any]]], |
|
chat_template: Optional[str] = None, |
|
add_generation_prompt: bool = False, |
|
tokenize: bool = True, |
|
padding: bool = False, |
|
truncation: bool = False, |
|
max_length: Optional[int] = None, |
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
return_dict: bool = False, |
|
tokenizer_kwargs: Optional[Dict[str, Any]] = None, |
|
**kwargs, |
|
) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]: |
|
|
|
if return_dict and not tokenize: |
|
raise ValueError( |
|
"`return_dict=True` is incompatible with `tokenize=False`, because there is no dict " |
|
"of tokenizer outputs to return." |
|
) |
|
|
|
if tokenizer_kwargs is None: |
|
tokenizer_kwargs = {} |
|
|
|
using_default_template = False |
|
|
|
|
|
if isinstance(self.chat_template, dict) or ( |
|
self.chat_template is None and isinstance(self.default_chat_template, dict) |
|
): |
|
if self.chat_template is not None: |
|
template_dict = self.chat_template |
|
using_default_dict = False |
|
else: |
|
template_dict = self.default_chat_template |
|
using_default_dict = True |
|
if chat_template is not None and chat_template in template_dict: |
|
|
|
chat_template = template_dict[chat_template] |
|
if using_default_dict: |
|
using_default_template = True |
|
elif chat_template is None and "default" in template_dict: |
|
chat_template = template_dict["default"] |
|
if using_default_dict: |
|
using_default_template = True |
|
elif chat_template is None: |
|
raise ValueError( |
|
"This model has multiple chat templates with no default specified! Please either pass a chat " |
|
"template or the name of the template you wish to use to the `chat_template` argument. Available " |
|
f"template names are {sorted(template_dict.keys())}." |
|
) |
|
elif chat_template is None: |
|
|
|
|
|
if self.chat_template is not None: |
|
chat_template = self.chat_template |
|
else: |
|
chat_template = self.default_chat_template |
|
using_default_template = True |
|
|
|
if using_default_template: |
|
logger.warning_once( |
|
"No chat template is set for this tokenizer, falling back to a default class-level template. This is " |
|
"very error-prone, because models are often trained with templates different from the class default! " |
|
"Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which " |
|
"point any code depending on them will stop working. We recommend setting a valid chat template before " |
|
"then to ensure that this model continues working without issues." |
|
) |
|
|
|
|
|
functions_pydantic_to_render = [] |
|
has_code_interpreter = False |
|
if tools is not None: |
|
for item in tools: |
|
if "function" in item and item["function"] is not None: |
|
functions_pydantic_to_render.append(item["function"]) |
|
elif "type" in item and item["type"] == "code_interpreter": |
|
has_code_interpreter = True |
|
else: |
|
functions_pydantic_to_render.append(item) |
|
tools_system_message = get_system_message_for_tools(functions_pydantic_to_render, has_code_interpreter) |
|
conversation.insert(0, tools_system_message) |
|
|
|
|
|
compiled_template = self._compile_jinja_template(chat_template) |
|
|
|
if isinstance(conversation, (list, tuple)) and ( |
|
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages") |
|
): |
|
conversations = conversation |
|
is_batched = True |
|
else: |
|
conversations = [conversation] |
|
is_batched = False |
|
|
|
rendered = [] |
|
template_kwargs = {**self.special_tokens_map, **kwargs} |
|
for chat in conversations: |
|
if hasattr(chat, "messages"): |
|
|
|
chat = chat.messages |
|
rendered_chat = compiled_template.render( |
|
messages=chat, add_generation_prompt=add_generation_prompt, **template_kwargs |
|
) |
|
rendered.append(rendered_chat) |
|
|
|
if not is_batched: |
|
rendered = rendered[0] |
|
|
|
if tokenize: |
|
out = self( |
|
rendered, |
|
padding=padding, |
|
truncation=truncation, |
|
max_length=max_length, |
|
add_special_tokens=False, |
|
return_tensors=return_tensors, |
|
**tokenizer_kwargs, |
|
) |
|
if return_dict: |
|
return out |
|
else: |
|
return out["input_ids"] |
|
else: |
|
return rendered |