File size: 6,807 Bytes
86b5e8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
# flake8: noqa
"""
pip install -U transformers accelerate trl wandb wheel packaging peft bitsandbytes liger-kernel flash_attn
python sft.py \
--run_name="llama3.1-8b-continued2" \
--model_name_or_path="meta-llama/Meta-Llama-3.1-8B" \
--dataset_name="mlfoundations/dclm-baseline-1.0-parquet,mlabonne/FineTome-100k" \
--report_to="wandb" \
--optim="adamw_torch_fused" \
--lr_scheduler_type="cosine" \
--max_steps=10000000 \
--max_seq_length=64000 \
--learning_rate=0.0001 \
--attn_implementation="flash_attention_2" \
--save_strategy="steps" \
--save_steps 50 \
--save_total_limit=10 \
--per_device_train_batch_size=1 \
--gradient_accumulation_steps=8 \
--logging_steps=1 \
--num_train_epochs=1 \
--load_in_4bit \
--push_to_hub \
--hub_model_id="ericflo/Llama-3.1-8B-ContinuedTraining2-LoRA" \
--hub_strategy="all_checkpoints" \
--gradient_checkpointing \
--use_peft \
--lora_r=128 \
--lora_alpha=256 \
--lora_dropout=0.05 \
--use_liger=true \
--packing=true \
--torch_dtype="bfloat16" \
--output_dir="continuedtraining2_output"
"""
import logging
import os
import random
from contextlib import nullcontext
from trl.commands.cli_utils import init_zero_verbose, SFTScriptArguments, TrlParser
from trl.env_utils import strtobool
TRL_USE_RICH = strtobool(os.getenv("TRL_USE_RICH", "0"))
if TRL_USE_RICH:
init_zero_verbose()
FORMAT = "%(message)s"
from rich.console import Console
from rich.logging import RichHandler
import torch
from datasets import load_dataset, interleave_datasets
from tqdm.rich import tqdm
from transformers import AutoTokenizer
from trl import (
ModelConfig,
RichProgressCallback,
SFTConfig,
SFTTrainer,
get_peft_config,
get_quantization_config,
get_kbit_device_map,
)
tqdm.pandas()
if TRL_USE_RICH:
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO)
print("Loading tokenizers...")
METAML_TOK = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
CHATML_TOK = AutoTokenizer.from_pretrained("NousResearch/Hermes-3-Llama-3.1-8B")
print("Tokenizers loaded.")
def formatting_prompts_func(example):
try:
language = example.get('language')
url = example.get('url')
text = example.get('text')
title = example.get('title')
conversations = example.get('conversations')
source = example.get('source')
repo_name = example.get('max_stars_repo_name')
repo_path = example.get('max_stars_repo_path')
star_count = example.get('max_stars_count')
content = example.get('content')
# mlfoundations/dclm-baseline-1.0-parquet
if language and url and text:
return f'{language} {url} {text}'
elif title and url and text: # wikimedia/wikipedia
return f'{title} {url} {text}'
elif conversations: # mlabonne/FineTome-100k
rows = [{
"role": {"system": "system", "gpt": "assistant", "human": "user"}[row["from"]],
"content": row["value"],
} for row in conversations]
tok = random.choice([METAML_TOK, CHATML_TOK])
return f'{source} {tok.apply_chat_template(rows, tokenize=False)}'
elif "max_stars_repo_name" in example: # bigcode/starcoderdata
return f'{example["max_stars_repo_name"]} {example["max_stars_repo_path"]} {example["max_stars_count"]} {example["content"]}'
print(f"Unknown example: {example}")
raise ValueError(f"Unknown example: {example}")
except Exception as e:
print(e)
raise e
if __name__ == "__main__":
parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
args, training_args, model_config = parser.parse_args_and_config()
# Force use our print callback
if TRL_USE_RICH:
training_args.disable_tqdm = True
console = Console()
################
# Model init kwargs & Tokenizer
################
model_config.lora_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=model_config.torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
training_args.model_init_kwargs = model_kwargs
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True
)
tokenizer.pad_token = tokenizer.eos_token
################
# Dataset
################
dataset_names = args.dataset_name.split(',')
train_datasets = [load_dataset(name, split="train", streaming=True) for name in dataset_names]
train_datasets.append(load_dataset("bigcode/starcoderdata", data_dir="python", split="train", streaming=True))
train_datasets.append(load_dataset("wikimedia/wikipedia", "20231101.en", split="train", streaming=True))
train_datasets.append(load_dataset("wikimedia/wikipedia", "20231101.es", split="train", streaming=True))
train_datasets.append(load_dataset("wikimedia/wikipedia", "20231101.fr", split="train", streaming=True))
interleaved_dataset = interleave_datasets(train_datasets)
eval_dataset = interleaved_dataset.take(100)
train_dataset = interleaved_dataset.skip(100)
print(train_dataset)
print(eval_dataset)
################
# Optional rich context managers
###############
init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the SFTTrainer...")
save_context = (
nullcontext()
if not TRL_USE_RICH
else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}")
)
################
# Training
################
with init_context:
trainer = SFTTrainer(
model=model_config.model_name_or_path,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=get_peft_config(model_config),
callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
formatting_func=formatting_prompts_func,
)
trainer.train()
with save_context:
trainer.save_model(training_args.output_dir) |