tangled-alpha-0.9-core / scripts /cpt_core_model_4.py
mtasic85's picture
cpt core 4
6ffe1e7
raw
history blame
3.1 kB
from unsloth import FastLanguageModel
import torch
from transformers import AutoTokenizer
max_seq_length = 16385
dtype = torch.bfloat16
load_in_4bit = True
model_name = '../out/pretrain-core-3/hf'
output_dir = '../out/cpt-core-4'
dataset_input_dir = '../core-data-4-8193-16385-16385-1000/'
dataset_block_size = 16385
#
# model
#
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
)
print(f'{model=}')
# print('Ignore loaded tokenizer by FastLanguageModel.from_pretrained and using AutoTokenizer.from_pretrained')
# tokenizer = AutoTokenizer.from_pretrained('..', trust_remote_code=True, use_fast=True)
# print(f'{tokenizer=}')
model = FastLanguageModel.get_peft_model(
model,
r = 256, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj",
"up_proj", "down_proj",
"embed_tokens", "lm_head",
],
lora_alpha = 32,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
use_rslora = True, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)
print(f'{model=}')
from datasets import Dataset
from litdata import TokensLoader, StreamingDataset
litgpt_streaming_dataset = StreamingDataset(
input_dir=dataset_input_dir,
item_loader=TokensLoader(block_size=dataset_block_size),
)
def unlsoth_generator():
global litgpt_streaming_dataset
for batch in litgpt_streaming_dataset:
yield {'input_ids': batch}
# train_dataset = Dataset.from_generator(unlsoth_generator, streaming=True)
train_dataset = Dataset.from_generator(unlsoth_generator)
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
from unsloth import UnslothTrainer, UnslothTrainingArguments
trainer = UnslothTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
# dataset_text_field='text',
max_seq_length=max_seq_length,
dataset_num_proc=32,
max_steps=len(litgpt_streaming_dataset),
packing=False, # Can make training 5x faster for short sequences.
args = UnslothTrainingArguments(
per_device_train_batch_size=16,
gradient_accumulation_steps=64,
warmup_ratio=0,
num_train_epochs=1,
# learning_rate=5e-5,
# embedding_learning_rate=5e-6,
learning_rate = 5e-5 * 2,
embedding_learning_rate = 5e-5 / 2,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=1,
optim='adamw_8bit',
weight_decay=0.01,
lr_scheduler_type='cosine',
seed=23,
output_dir=output_dir,
report_to='wandb',
),
)
trainer_stats = trainer.train()