mtasic85's picture
prepare datasets
734e414
raw
history blame
2.94 kB
from unsloth import FastLanguageModel
import torch
from transformers import AutoTokenizer
max_seq_length = 4096
dtype = torch.bfloat16
load_in_4bit = True
model_name = '../out/pretrain-base'
output_dir = '../out/cpt-base'
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
)
print('Ignore loaded tokenizer by FastLanguageModel.from_pretrained and using AutoTokenizer.from_pretrained')
tokenizer = AutoTokenizer.from_pretrained('..', trust_remote_code=True, use_fast=True)
print(f'{model=}')
print(f'{tokenizer=}')
model = FastLanguageModel.get_peft_model(
model,
r=64, # 128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
"embed_tokens", "lm_head",
], # Add for continual pretraining
lora_alpha=16,
lora_dropout=0, # Supports any, but = 0 is optimized
bias='none', # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing='unsloth', # True or "unsloth" for very long context
random_state=23,
use_rslora=True, # We support rank stabilized LoRA
loftq_config=None, # And LoftQ
)
print(f'{model=}')
from datasets import concatenate_datasets
from cpt_base_datasets import cpt_base_datasets
from cpt_instruct_datasets import cpt_instruct_datasets
from unsloth_utils import load_text_dataset, load_chat_dataset
core_datasets = []
for dataset_config in cpt_base_datasets:
dataset = load_text_dataset(tokenizer, **dataset_config)
print(f'{dataset=}')
core_datasets.append(dataset)
# for dataset_config in cpt_instruct_datasets:
# dataset = load_chat_dataset(tokenizer, **dataset_config)
# print(f'{dataset=}')
# core_datasets.append(dataset)
final_dataset = concatenate_datasets(core_datasets)
print(f'{final_dataset=}')
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
from unsloth import UnslothTrainer, UnslothTrainingArguments
trainer = UnslothTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=final_dataset,
dataset_text_field='text',
max_seq_length=max_seq_length,
dataset_num_proc=32,
args = UnslothTrainingArguments(
per_device_train_batch_size=8,
gradient_accumulation_steps=8,
warmup_ratio=0.1,
num_train_epochs=1,
learning_rate=5e-5,
embedding_learning_rate=5e-6,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=1,
optim='adamw_8bit',
weight_decay=0.01,
lr_scheduler_type='cosine',
seed=23,
output_dir=output_dir,
report_to='wandb',
),
)
trainer_stats = trainer.train()