File size: 7,074 Bytes
d72e6ae
632f592
7d8725b
 
632f592
c8c85f9
ab391c2
d72e6ae
b225b76
ab391c2
7d8725b
070377f
 
 
 
 
 
 
 
c1e0552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5720fe4
b225b76
 
 
 
 
 
 
a04f2e1
070377f
a04f2e1
 
 
070377f
 
 
 
 
 
 
 
 
 
c1e0552
411ad3b
c1e0552
411ad3b
d72e6ae
 
062ca1d
c1e0552
070377f
d72e6ae
070377f
 
4aafa13
93fda42
ab391c2
 
93fda42
070377f
d72e6ae
 
ba5c790
070377f
aa518eb
 
 
 
 
 
 
070377f
411ad3b
aa518eb
411ad3b
070377f
ba5c790
070377f
d72e6ae
 
7d8725b
861cd57
c1e0552
 
632f592
 
c1e0552
632f592
d72e6ae
 
 
 
 
 
 
7d8725b
d72e6ae
070377f
 
 
 
 
93fda42
c8c85f9
 
 
 
 
 
070377f
892e2f9
070377f
c1e0552
070377f
 
93fda42
070377f
 
 
 
c1e0552
070377f
c1e0552
 
070377f
 
 
 
c1e0552
070377f
 
c1e0552
d72e6ae
 
f9b4329
070377f
a04f2e1
070377f
f9b4329
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import trl
from transformers import (
    AutoTokenizer, LlamaConfig, LlamaForCausalLM,
    PreTrainedTokenizerFast
)
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset, Dataset
from tokenizers import ByteLevelBPETokenizer
from huggingface_hub import HfApi
from itertools import islice

from logging import getLogger, StreamHandler, INFO

logger = getLogger(__name__)
logger.setLevel(INFO)
handler = StreamHandler()
logger.addHandler(handler)

class Config:
    def __init__(self):
        # Model and training hyperparameters
        self.BATCH_SIZE = 16
        self.EPOCHS = 3
        self.LEARNING_RATE = 2e-4
        self.MAX_SEQ_LENGTH = 512
        self.VOCAB_SIZE = 32000
        self.FP16 = True
        self.WEIGHT_DECAY = 1e-3
        self.GRADIENT_ACCUMULATION_STEPS = self.BATCH_SIZE // 4
    
        # Dataset configurations
        self.INPUT_DATASET = "HuggingFaceTB/smollm-corpus"
        self.INSTRUCT_DATASET = "nroggendorff/elephant"
        self.SHARD_SIZE = int(2e+5)
    
        # Output and repo settings
        self.OUTPUT_REPO = "nroggendorff/smallama"
        self.PUSH_TO_HUB = True
        self.INSTRUCT_FINETUNE_BOOL = False
    
        # Training steps and warmup
        self.FACTOR = 12 ** 3 // 3
        self.TOTAL_STEPS = (self.SHARD_SIZE * self.EPOCHS) // (self.BATCH_SIZE * self.GRADIENT_ACCUMULATION_STEPS)
        self.WARMUP_STEPS = int(self.TOTAL_STEPS * 0.1)
    
        # Initial state for shard offset
        self.INIT = 0

        # ignore
        self.getConfig = lambda: self._args()

    # @staticmethod
    def _args(self):
        return SFTConfig(
            output_dir="model",
            num_train_epochs=self.EPOCHS,
            per_device_train_batch_size=self.BATCH_SIZE,
            learning_rate=self.LEARNING_RATE,
            warmup_steps=self.WARMUP_STEPS,
            weight_decay=self.WEIGHT_DECAY,
            gradient_accumulation_steps=self.GRADIENT_ACCUMULATION_STEPS,
            fp16=self.FP16,
            save_steps=int(self.WARMUP_STEPS * 5),
            logging_steps=int(self.WARMUP_STEPS),
            save_total_limit=2,
            report_to="none",
        )

config = Config().getConfig()

class Space:
    def __init__(self):
        self.api = HfApi()
        self.pause = lambda: self.api.pause_space("nroggendorff/train-llama")

space = Space()

class FineError(Exception):
    def __init__(self, message="Training completed successfully."):
        self.message = message
        super().__init__(self.message)

def load_data(dataset_name: str, split: str, shard_size: int, init_offset: int = 0) -> Dataset:
    dataset = load_dataset(dataset_name, split=split, streaming=True)
    shard_start = init_offset * shard_size
    data_list = list(islice(dataset, shard_start, shard_start + shard_size))
    return Dataset.from_dict({'text': [example.get('text', '') for example in data_list]})

def encode_decode(texts, tokenizer):
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenized_texts = tokenizer(
        texts, padding="max_length", truncation=True, max_length=config.MAX_SEQ_LENGTH, return_tensors="pt"
    ).input_ids
    return tokenizer.batch_decode(tokenized_texts) if tokenized_texts.dim() >= 1 else [tokenizer.pad_token * config.MAX_SEQ_LENGTH]

def create_tokenizer(training_corpus):
    tokenizer = ByteLevelBPETokenizer()
    special_tokens = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
    tokenizer.train_from_iterator(training_corpus, vocab_size=config.VOCAB_SIZE, min_frequency=2, special_tokens=special_tokens)
    return PreTrainedTokenizerFast(tokenizer_object=tokenizer._tokenizer)

def load_tokenizer(repo: str):
    return AutoTokenizer.from_pretrained(repo)

def get_training_corpus(dataset):
    for i in range(0, len(dataset['text']), 1000):
        yield dataset['text'][i : i + 1000]

def format_prompts(examples, tokenizer, is_instructional):
    texts = []
    for text in examples['text']:
        if text and len(text.strip()) > 0:
            if is_instructional:
                conversation = []
                parts = text.split('<|end|>')
                for i in range(0, len(parts) - 1, 2):
                    prompt = parts[i].replace("<|user|>", "").strip()
                    response = parts[i + 1].replace("<|bot|>", "").strip()
                    conversation.append({"role": "user", "content": prompt})
                    conversation.append({"role": "assistant", "content": response})
                coded_text = tokenizer.code(tokenizer.apply_chat_template(conversation, tokenize=False))
                texts.append(coded_text)
            else:
                texts.append(tokenizer.bos_token + tokenizer.code(text) + tokenizer.eos_token)
    if not texts:
        raise ValueError("No valid texts found in examples for formatting.")
    return {'text': tokenizer.code(texts)}

def create_model(tokenizer):
    model_config = LlamaConfig(
        vocab_size=tokenizer.vocab_size,
        hidden_size=config.FACTOR,
        intermediate_size=config.FACTOR * 4,
        num_hidden_layers=12,
        num_attention_heads=12,
        max_position_embeddings=config.MAX_SEQ_LENGTH,
        rms_norm_eps=1e-5,
        initializer_range=0.02,
        use_cache=True,
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        tie_word_embeddings=False,
    )
    return LlamaForCausalLM(model_config)

def train_model(model, tokenizer, dataset, push_to_hub, is_instructional):
    dataset = dataset.map(
        lambda examples: format_prompts(examples, tokenizer, is_instructional), 
        batched=True, 
        remove_columns=dataset.column_names
    )
    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        config=config,
        train_dataset=dataset
    )
    train_result = trainer.train()

    if push_to_hub:
        repo_id = config.OUTPUT_REPO + "-it" if config.INSTRUCT_FINETUNE_BOOL else config.OUTPUT_REPO
        trainer.model.push_to_hub(repo_id, commit_message=f"Training loss: {train_result.training_loss:.4f}", force=True)
        trainer.tokenizer.push_to_hub(repo_id, commit_message=f"Training loss: {train_result.training_loss:.4f}", force=True)
    else:
        trainer.model.save_pretrained("model")
        trainer.tokenizer.save_pretrained("tokenizer")

def main():
    dataset = load_data(config.INPUT_DATASET, "train", config.SHARD_SIZE, config.INIT)
    tokenizer = (
        load_tokenizer(config.OUTPUT_REPO)
        if config.INSTRUCT_FINETUNE_BOOL and config.INIT > 0
        else create_tokenizer(get_training_corpus(dataset))
    )
    model = (
        load_model()
        if config.INSTRUCT_FINETUNE_BOOL or config.INIT > 0
        else create_model(tokenizer)
    )
    train_model(model, tokenizer, dataset, config.PUSH_TO_HUB, config.INSTRUCT_FINETUNE_BOOL)

if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        logger.error(f"{type(e).__name__}: {e}")
        space.pause()