Spaces:
Paused
Paused
| # Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import logging | |
| from typing import Callable, Literal, Optional | |
| import datasets | |
| from datasets import Dataset, Value | |
| from packaging import version | |
| from transformers import AutoTokenizer | |
| if version.parse(datasets.__version__) >= version.parse("4.0.0"): | |
| from datasets import List | |
| FORMAT_MAPPING = { | |
| "chatml": List({"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}), | |
| "instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)}, | |
| } | |
| else: | |
| FORMAT_MAPPING = { | |
| "chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}], | |
| "instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)}, | |
| } | |
| def conversations_formatting_function( | |
| tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"], tools: Optional[list] = None | |
| ): | |
| r""" | |
| return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the | |
| tokenizer apply chat template to the dataset along with the schema of the list of functions in the tools list. | |
| """ | |
| def format_dataset(examples): | |
| if isinstance(examples[messages_field][0], list): | |
| output_texts = [] | |
| for i in range(len(examples[messages_field])): | |
| output_texts.append( | |
| tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False, tools=tools) | |
| ) | |
| return output_texts | |
| else: | |
| return tokenizer.apply_chat_template(examples[messages_field], tokenize=False, tools=tools) | |
| return format_dataset | |
| def instructions_formatting_function(tokenizer: AutoTokenizer): | |
| r""" | |
| return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the | |
| tokenizer apply chat template to the dataset | |
| """ | |
| def format_dataset(examples): | |
| if isinstance(examples["prompt"], list): | |
| output_texts = [] | |
| for i in range(len(examples["prompt"])): | |
| converted_sample = [ | |
| {"role": "user", "content": examples["prompt"][i]}, | |
| {"role": "assistant", "content": examples["completion"][i]}, | |
| ] | |
| output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False)) | |
| return output_texts | |
| else: | |
| converted_sample = [ | |
| {"role": "user", "content": examples["prompt"]}, | |
| {"role": "assistant", "content": examples["completion"]}, | |
| ] | |
| return tokenizer.apply_chat_template(converted_sample, tokenize=False) | |
| return format_dataset | |
| def get_formatting_func_from_dataset( | |
| dataset: Dataset, tokenizer: AutoTokenizer, tools: Optional[list] = None | |
| ) -> Optional[Callable]: | |
| r""" | |
| Finds the correct formatting function based on the dataset structure. Currently supported datasets are: | |
| - `ChatML` with [{"role": str, "content": str}] | |
| - `instruction` with [{"prompt": str, "completion": str}] | |
| Args: | |
| dataset (Dataset): User dataset | |
| tokenizer (AutoTokenizer): Tokenizer used for formatting | |
| Returns: | |
| Callable: Formatting function if the dataset format is supported else None | |
| """ | |
| if isinstance(dataset, Dataset): | |
| if "messages" in dataset.features: | |
| if dataset.features["messages"] == FORMAT_MAPPING["chatml"]: | |
| logging.info("Formatting dataset with chatml format") | |
| return conversations_formatting_function(tokenizer, "messages", tools) | |
| if "conversations" in dataset.features: | |
| if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]: | |
| logging.info("Formatting dataset with chatml format") | |
| return conversations_formatting_function(tokenizer, "conversations", tools) | |
| elif dataset.features == FORMAT_MAPPING["instruction"]: | |
| logging.info("Formatting dataset with instruction format") | |
| return instructions_formatting_function(tokenizer) | |
| return None | |