feat: Add LLaMA-3 instruct prompt strategies for fine-tuning (#1553)
Browse files* Add prompt strategies
* Update modified URL
* Update modified URL
* Update fastchat_conversation_turns.py
* Update register function
* Remove extra /n for system prompt
* Fix return
* Fix BOS
* Update requirements, pylint
* Linting
* Linting
* fix tuples, make sure to set system message in template
* tests for llama3 tokenization
* fix conditionals for loading chat template
---------
Co-authored-by: Ram <[email protected]>
Co-authored-by: Wing Lian <[email protected]>
- src/axolotl/cli/preprocess.py +20 -8
- src/axolotl/cli/train.py +12 -1
- src/axolotl/prompt_strategies/sharegpt.py +18 -2
- src/axolotl/prompters.py +1 -0
- src/axolotl/utils/chat_templates.py +1 -0
- src/axolotl/utils/config/models/input/v0_4_1/__init__.py +1 -0
- tests/prompt_strategies/test_sharegpt.py +49 -1
src/axolotl/cli/preprocess.py
CHANGED
|
@@ -19,7 +19,10 @@ from axolotl.cli import (
|
|
| 19 |
)
|
| 20 |
from axolotl.common.cli import PreprocessCliArgs
|
| 21 |
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
| 22 |
-
from axolotl.prompt_strategies.sharegpt import
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
LOG = logging.getLogger("axolotl.cli.preprocess")
|
| 25 |
|
|
@@ -36,13 +39,22 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|
| 36 |
return_remaining_strings=True
|
| 37 |
)
|
| 38 |
|
| 39 |
-
if parsed_cfg.chat_template == "chatml"
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
if not parsed_cfg.dataset_prepared_path:
|
| 48 |
msg = (
|
|
|
|
| 19 |
)
|
| 20 |
from axolotl.common.cli import PreprocessCliArgs
|
| 21 |
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
| 22 |
+
from axolotl.prompt_strategies.sharegpt import (
|
| 23 |
+
register_chatml_template,
|
| 24 |
+
register_llama3_template,
|
| 25 |
+
)
|
| 26 |
|
| 27 |
LOG = logging.getLogger("axolotl.cli.preprocess")
|
| 28 |
|
|
|
|
| 39 |
return_remaining_strings=True
|
| 40 |
)
|
| 41 |
|
| 42 |
+
if parsed_cfg.chat_template == "chatml":
|
| 43 |
+
if parsed_cfg.default_system_message:
|
| 44 |
+
LOG.info(
|
| 45 |
+
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
|
| 46 |
+
)
|
| 47 |
+
register_chatml_template(parsed_cfg.default_system_message)
|
| 48 |
+
else:
|
| 49 |
+
register_chatml_template()
|
| 50 |
+
elif parsed_cfg.chat_template == "llama3":
|
| 51 |
+
if parsed_cfg.default_system_message:
|
| 52 |
+
LOG.info(
|
| 53 |
+
f"LLaMA-3 set. Adding default system message: {parsed_cfg.default_system_message}"
|
| 54 |
+
)
|
| 55 |
+
register_llama3_template(parsed_cfg.default_system_message)
|
| 56 |
+
else:
|
| 57 |
+
register_llama3_template()
|
| 58 |
|
| 59 |
if not parsed_cfg.dataset_prepared_path:
|
| 60 |
msg = (
|
src/axolotl/cli/train.py
CHANGED
|
@@ -19,7 +19,10 @@ from axolotl.cli import (
|
|
| 19 |
print_axolotl_text_art,
|
| 20 |
)
|
| 21 |
from axolotl.common.cli import TrainerCliArgs
|
| 22 |
-
from axolotl.prompt_strategies.sharegpt import
|
|
|
|
|
|
|
|
|
|
| 23 |
from axolotl.train import train
|
| 24 |
|
| 25 |
LOG = logging.getLogger("axolotl.cli.train")
|
|
@@ -47,6 +50,14 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
|
| 47 |
else:
|
| 48 |
register_chatml_template()
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
if cfg.rl: # and cfg.rl != "orpo":
|
| 51 |
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
| 52 |
else:
|
|
|
|
| 19 |
print_axolotl_text_art,
|
| 20 |
)
|
| 21 |
from axolotl.common.cli import TrainerCliArgs
|
| 22 |
+
from axolotl.prompt_strategies.sharegpt import (
|
| 23 |
+
register_chatml_template,
|
| 24 |
+
register_llama3_template,
|
| 25 |
+
)
|
| 26 |
from axolotl.train import train
|
| 27 |
|
| 28 |
LOG = logging.getLogger("axolotl.cli.train")
|
|
|
|
| 50 |
else:
|
| 51 |
register_chatml_template()
|
| 52 |
|
| 53 |
+
if cfg.chat_template == "llama3" and cfg.default_system_message:
|
| 54 |
+
LOG.info(
|
| 55 |
+
f"LLaMA-3 set. Adding default system message: {cfg.default_system_message}"
|
| 56 |
+
)
|
| 57 |
+
register_llama3_template(cfg.default_system_message)
|
| 58 |
+
else:
|
| 59 |
+
register_llama3_template()
|
| 60 |
+
|
| 61 |
if cfg.rl: # and cfg.rl != "orpo":
|
| 62 |
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
| 63 |
else:
|
src/axolotl/prompt_strategies/sharegpt.py
CHANGED
|
@@ -22,7 +22,7 @@ def register_chatml_template(system_message=None):
|
|
| 22 |
name="chatml",
|
| 23 |
system_template="<|im_start|>system\n{system_message}",
|
| 24 |
system_message=system_message,
|
| 25 |
-
roles=
|
| 26 |
sep_style=SeparatorStyle.CHATML,
|
| 27 |
sep="<|im_end|>",
|
| 28 |
)
|
|
@@ -32,13 +32,29 @@ def register_chatml_template(system_message=None):
|
|
| 32 |
name="chatml_glaive",
|
| 33 |
system_template="<|im_start|>system\n{system_message}",
|
| 34 |
system_message=system_message,
|
| 35 |
-
roles=
|
| 36 |
sep_style=SeparatorStyle.CHATML,
|
| 37 |
sep="<|im_end|>",
|
| 38 |
)
|
| 39 |
)
|
| 40 |
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
def build_loader(
|
| 43 |
tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
|
| 44 |
prompter_cls: Type["ShareGPTPrompterV2"],
|
|
|
|
| 22 |
name="chatml",
|
| 23 |
system_template="<|im_start|>system\n{system_message}",
|
| 24 |
system_message=system_message,
|
| 25 |
+
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
| 26 |
sep_style=SeparatorStyle.CHATML,
|
| 27 |
sep="<|im_end|>",
|
| 28 |
)
|
|
|
|
| 32 |
name="chatml_glaive",
|
| 33 |
system_template="<|im_start|>system\n{system_message}",
|
| 34 |
system_message=system_message,
|
| 35 |
+
roles=("<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"),
|
| 36 |
sep_style=SeparatorStyle.CHATML,
|
| 37 |
sep="<|im_end|>",
|
| 38 |
)
|
| 39 |
)
|
| 40 |
|
| 41 |
|
| 42 |
+
def register_llama3_template(system_message=None):
|
| 43 |
+
system_message = system_message or "You are a helpful assistant."
|
| 44 |
+
register_conv_template(
|
| 45 |
+
Conversation(
|
| 46 |
+
name="llama3",
|
| 47 |
+
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
|
| 48 |
+
system_message=system_message,
|
| 49 |
+
roles=("user", "assistant"),
|
| 50 |
+
sep_style=SeparatorStyle.LLAMA3,
|
| 51 |
+
sep="",
|
| 52 |
+
stop_str="<|eot_id|>",
|
| 53 |
+
stop_token_ids=[128001, 128009],
|
| 54 |
+
)
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
def build_loader(
|
| 59 |
tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
|
| 60 |
prompter_cls: Type["ShareGPTPrompterV2"],
|
src/axolotl/prompters.py
CHANGED
|
@@ -263,6 +263,7 @@ CONVERSATION_ROLE_FORMAT = {
|
|
| 263 |
"chatml": "<|im_start|>{ROLE}",
|
| 264 |
"zephyr": "<|{ROLE}|>",
|
| 265 |
"vicuna_v1.1": "{ROLE}",
|
|
|
|
| 266 |
}
|
| 267 |
|
| 268 |
|
|
|
|
| 263 |
"chatml": "<|im_start|>{ROLE}",
|
| 264 |
"zephyr": "<|{ROLE}|>",
|
| 265 |
"vicuna_v1.1": "{ROLE}",
|
| 266 |
+
"llama3": "<|start_header_id|>{ROLE}<|end_header_id|>",
|
| 267 |
}
|
| 268 |
|
| 269 |
|
src/axolotl/utils/chat_templates.py
CHANGED
|
@@ -24,6 +24,7 @@ def chat_templates(user_choice: str):
|
|
| 24 |
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
| 25 |
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
| 26 |
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
|
|
|
| 27 |
}
|
| 28 |
|
| 29 |
if user_choice in templates:
|
|
|
|
| 24 |
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
| 25 |
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
| 26 |
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
| 27 |
+
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% else %}{{ eos_token }}{% endif %}",
|
| 28 |
}
|
| 29 |
|
| 30 |
if user_choice in templates:
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
|
@@ -143,6 +143,7 @@ class ChatTemplate(str, Enum):
|
|
| 143 |
inst = "inst" # pylint: disable=invalid-name
|
| 144 |
gemma = "gemma" # pylint: disable=invalid-name
|
| 145 |
cohere = "cohere" # pylint: disable=invalid-name
|
|
|
|
| 146 |
|
| 147 |
|
| 148 |
class LoftQConfig(BaseModel):
|
|
|
|
| 143 |
inst = "inst" # pylint: disable=invalid-name
|
| 144 |
gemma = "gemma" # pylint: disable=invalid-name
|
| 145 |
cohere = "cohere" # pylint: disable=invalid-name
|
| 146 |
+
llama3 = "llama3" # pylint: disable=invalid-name
|
| 147 |
|
| 148 |
|
| 149 |
class LoftQConfig(BaseModel):
|
tests/prompt_strategies/test_sharegpt.py
CHANGED
|
@@ -12,10 +12,12 @@ from axolotl.prompt_strategies.sharegpt import (
|
|
| 12 |
GlaiveShareGPTPromptTokenizingStrategy,
|
| 13 |
SimpleShareGPTPromptTokenizingStrategy,
|
| 14 |
register_chatml_template,
|
|
|
|
| 15 |
)
|
| 16 |
from axolotl.prompters import ShareGPTPrompterV2
|
| 17 |
|
| 18 |
register_chatml_template()
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
@pytest.fixture(name="sharegpt_dataset")
|
|
@@ -115,7 +117,53 @@ def fixture_tokenizer():
|
|
| 115 |
return tokenizer
|
| 116 |
|
| 117 |
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
"""
|
| 120 |
Test class for sharegpt prompter
|
| 121 |
"""
|
|
|
|
| 12 |
GlaiveShareGPTPromptTokenizingStrategy,
|
| 13 |
SimpleShareGPTPromptTokenizingStrategy,
|
| 14 |
register_chatml_template,
|
| 15 |
+
register_llama3_template,
|
| 16 |
)
|
| 17 |
from axolotl.prompters import ShareGPTPrompterV2
|
| 18 |
|
| 19 |
register_chatml_template()
|
| 20 |
+
register_llama3_template()
|
| 21 |
|
| 22 |
|
| 23 |
@pytest.fixture(name="sharegpt_dataset")
|
|
|
|
| 117 |
return tokenizer
|
| 118 |
|
| 119 |
|
| 120 |
+
@pytest.fixture(name="llama3_tokenizer")
|
| 121 |
+
def fixture_llama3_tokenizer():
|
| 122 |
+
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
| 123 |
+
tokenizer.eos_token = "<|eot_id|>"
|
| 124 |
+
|
| 125 |
+
return tokenizer
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class TestSharegptLlama3:
|
| 129 |
+
"""Test class for ShareGPT style datasets with llama-3 prompts"""
|
| 130 |
+
|
| 131 |
+
def test_tokenization(self, sharegpt_dataset, llama3_tokenizer):
|
| 132 |
+
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
| 133 |
+
ShareGPTPrompterV2(
|
| 134 |
+
conversation="llama3",
|
| 135 |
+
role_key_model=None,
|
| 136 |
+
role_key_human=None,
|
| 137 |
+
),
|
| 138 |
+
llama3_tokenizer,
|
| 139 |
+
False, # train_on_inputs
|
| 140 |
+
2048, # sequence_len
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
dataset_wrapper = TokenizedPromptDataset(
|
| 144 |
+
strategy, sharegpt_dataset, process_count=1
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
input_ids = dataset_wrapper[0]["input_ids"]
|
| 148 |
+
|
| 149 |
+
# fmt: off
|
| 150 |
+
assert input_ids == [
|
| 151 |
+
128000, # bos
|
| 152 |
+
128006, 9125, 128007, # system header
|
| 153 |
+
271, 31724, 128009, # sys prompt, eot
|
| 154 |
+
128006, 882, 128007, # user header
|
| 155 |
+
271, 15339, 128009, # user prompt eot
|
| 156 |
+
128006, 78191, 128007, # assistant header
|
| 157 |
+
271, 15339, 128009, # assistant response eot
|
| 158 |
+
128006, 882, 128007,
|
| 159 |
+
271, 19045, 29474, 128009,
|
| 160 |
+
128006, 78191, 128007,
|
| 161 |
+
271, 19045, 29474, 128009,
|
| 162 |
+
]
|
| 163 |
+
# fmt: on
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class TestSharegptChatML:
|
| 167 |
"""
|
| 168 |
Test class for sharegpt prompter
|
| 169 |
"""
|