File size: 2,938 Bytes
734e414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from unsloth import FastLanguageModel
import torch
from transformers import AutoTokenizer

max_seq_length = 4096
dtype = torch.bfloat16
load_in_4bit = True
model_name = '../out/pretrain-base'
output_dir = '../out/cpt-base'

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_name,
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
)

print('Ignore loaded tokenizer by FastLanguageModel.from_pretrained and using AutoTokenizer.from_pretrained')
tokenizer = AutoTokenizer.from_pretrained('..', trust_remote_code=True, use_fast=True)

print(f'{model=}')
print(f'{tokenizer=}')

model = FastLanguageModel.get_peft_model(
    model,
    r=64, # 128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
        "embed_tokens", "lm_head",
    ], # Add for continual pretraining
    lora_alpha=16,
    lora_dropout=0, # Supports any, but = 0 is optimized
    bias='none',    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing='unsloth', # True or "unsloth" for very long context
    random_state=23,
    use_rslora=True,  # We support rank stabilized LoRA
    loftq_config=None, # And LoftQ
)

print(f'{model=}')

from datasets import concatenate_datasets
from cpt_base_datasets import cpt_base_datasets
from cpt_instruct_datasets import cpt_instruct_datasets
from unsloth_utils import load_text_dataset, load_chat_dataset

core_datasets = []

for dataset_config in cpt_base_datasets:
    dataset = load_text_dataset(tokenizer, **dataset_config)
    print(f'{dataset=}')
    core_datasets.append(dataset)

# for dataset_config in cpt_instruct_datasets:
#     dataset = load_chat_dataset(tokenizer, **dataset_config)
#     print(f'{dataset=}')
#     core_datasets.append(dataset)

final_dataset = concatenate_datasets(core_datasets)
print(f'{final_dataset=}')


from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
from unsloth import UnslothTrainer, UnslothTrainingArguments


trainer = UnslothTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=final_dataset,
    dataset_text_field='text',
    max_seq_length=max_seq_length,
    dataset_num_proc=32,

    args = UnslothTrainingArguments(
        per_device_train_batch_size=8,
        gradient_accumulation_steps=8,

        warmup_ratio=0.1,
        num_train_epochs=1,

        learning_rate=5e-5,
        embedding_learning_rate=5e-6,

        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=1,
        optim='adamw_8bit',
        weight_decay=0.01,
        lr_scheduler_type='cosine',
        seed=23,
        output_dir=output_dir,
        report_to='wandb',
    ),
)

trainer_stats = trainer.train()