|
from unsloth import FastLanguageModel |
|
import torch |
|
from transformers import AutoTokenizer |
|
|
|
max_seq_length = 4096 |
|
dtype = torch.bfloat16 |
|
load_in_4bit = True |
|
model_name = '../out/pretrain-base' |
|
output_dir = '../out/cpt-base' |
|
|
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
model_name=model_name, |
|
max_seq_length=max_seq_length, |
|
dtype=dtype, |
|
load_in_4bit=load_in_4bit, |
|
) |
|
|
|
print('Ignore loaded tokenizer by FastLanguageModel.from_pretrained and using AutoTokenizer.from_pretrained') |
|
tokenizer = AutoTokenizer.from_pretrained('..', trust_remote_code=True, use_fast=True) |
|
|
|
print(f'{model=}') |
|
print(f'{tokenizer=}') |
|
|
|
model = FastLanguageModel.get_peft_model( |
|
model, |
|
r=64, |
|
target_modules=[ |
|
"q_proj", "k_proj", "v_proj", "o_proj", |
|
"gate_proj", "up_proj", "down_proj", |
|
"embed_tokens", "lm_head", |
|
], |
|
lora_alpha=16, |
|
lora_dropout=0, |
|
bias='none', |
|
|
|
use_gradient_checkpointing='unsloth', |
|
random_state=23, |
|
use_rslora=True, |
|
loftq_config=None, |
|
) |
|
|
|
print(f'{model=}') |
|
|
|
from datasets import concatenate_datasets |
|
from cpt_base_datasets import cpt_base_datasets |
|
from cpt_instruct_datasets import cpt_instruct_datasets |
|
from unsloth_utils import load_text_dataset, load_chat_dataset |
|
|
|
core_datasets = [] |
|
|
|
for dataset_config in cpt_base_datasets: |
|
dataset = load_text_dataset(tokenizer, **dataset_config) |
|
print(f'{dataset=}') |
|
core_datasets.append(dataset) |
|
|
|
|
|
|
|
|
|
|
|
|
|
final_dataset = concatenate_datasets(core_datasets) |
|
print(f'{final_dataset=}') |
|
|
|
|
|
from trl import SFTTrainer |
|
from transformers import TrainingArguments |
|
from unsloth import is_bfloat16_supported |
|
from unsloth import UnslothTrainer, UnslothTrainingArguments |
|
|
|
|
|
trainer = UnslothTrainer( |
|
model=model, |
|
tokenizer=tokenizer, |
|
train_dataset=final_dataset, |
|
dataset_text_field='text', |
|
max_seq_length=max_seq_length, |
|
dataset_num_proc=32, |
|
|
|
args = UnslothTrainingArguments( |
|
per_device_train_batch_size=8, |
|
gradient_accumulation_steps=8, |
|
|
|
warmup_ratio=0.1, |
|
num_train_epochs=1, |
|
|
|
learning_rate=5e-5, |
|
embedding_learning_rate=5e-6, |
|
|
|
fp16=not is_bfloat16_supported(), |
|
bf16=is_bfloat16_supported(), |
|
logging_steps=1, |
|
optim='adamw_8bit', |
|
weight_decay=0.01, |
|
lr_scheduler_type='cosine', |
|
seed=23, |
|
output_dir=output_dir, |
|
report_to='wandb', |
|
), |
|
) |
|
|
|
trainer_stats = trainer.train() |
|
|