lhoestq HF Staff commited on
Commit
f2041d1
·
verified ·
1 Parent(s): ff4534f

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +54 -10
  2. requirements.txt +7 -0
  3. train.py +133 -0
README.md CHANGED
@@ -1,10 +1,54 @@
1
- ---
2
- title: Test Smollm
3
- emoji: 🏆
4
- colorFrom: blue
5
- colorTo: gray
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fine-tuning
2
+
3
+ ## SmolLM2 Instruct
4
+
5
+ We build the SmolLM2 Instruct family by finetuning the base 1.7B on [SmolTalk](https://huggingface.co/datasets/HuggingFaceTB/smoltalk) and the base 360M and 135M models on [Smol-smoltalk](https://huggingface.co/datasets/HuggingFaceTB/smol-smoltalk) using `TRL` and the alignement handbook and then doing DPO on [UltraFeedBack](https://huggingface.co/datasets/openbmb/UltraFeedback). You can find the scipts and instructions for dohere: https://github.com/huggingface/alignment-handbook/tree/main/recipes/smollm2#instructions-to-train-smollm2-17b-instruct
6
+
7
+ ## Custom script
8
+ Here, we provide a simple script for finetuning SmolLM2. In this case, we fine-tune the base 1.7B on python data.
9
+
10
+ ### Setup
11
+
12
+ Install `pytorch` [see documentation](https://pytorch.org/), and then install the requirements
13
+ ```bash
14
+ pip install -r requirements.txt
15
+ ```
16
+
17
+ Before you run any of the scripts make sure you are logged in `wandb` and HuggingFace Hub to push the checkpoints, and you have `accelerate` configured:
18
+ ```bash
19
+ wandb login
20
+ huggingface-cli login
21
+ accelerate config
22
+ ```
23
+ Now that everything is done, you can clone the repository and get into the corresponding directory.
24
+
25
+ ```bash
26
+ git clone https://github.com/huggingface/smollm
27
+ cd smollm/finetune
28
+ ```
29
+
30
+ ### Training
31
+ To fine-tune efficiently with a low cost, we use [PEFT](https://github.com/huggingface/peft) library for Low-Rank Adaptation (LoRA) training. We also use the `SFTTrainer` from [TRL](https://github.com/huggingface/trl).
32
+
33
+ For this example, we will fine-tune SmolLM1-1.7B on the `Python` subset of [the-stack-smol](https://huggingface.co/datasets/bigcode/the-stack-smol). This is just for illustration purposes.
34
+
35
+ To launch the training:
36
+ ```bash
37
+ accelerate launch train.py \
38
+ --model_id "HuggingFaceTB/SmolLM2-1.7B" \
39
+ --dataset_name "bigcode/the-stack-smol" \
40
+ --subset "data/python" \
41
+ --dataset_text_field "content" \
42
+ --split "train" \
43
+ --max_seq_length 2048 \
44
+ --max_steps 5000 \
45
+ --micro_batch_size 1 \
46
+ --gradient_accumulation_steps 8 \
47
+ --learning_rate 3e-4 \
48
+ --warmup_steps 100 \
49
+ --num_proc "$(nproc)"
50
+ ```
51
+
52
+ If you want to fine-tune on other text datasets, you need to change `dataset_text_field` argument to the name of the column containing the code/text you want to train on.
53
+
54
+
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers
2
+ trl>=0.15
3
+ peft
4
+ accelerate
5
+ datasets
6
+ wandb
7
+ bitsandbytes
train.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code adapted from https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py
2
+ # and https://huggingface.co/blog/gemma-peft
3
+ import argparse
4
+ import multiprocessing
5
+ import os
6
+
7
+ import torch
8
+ import transformers
9
+ from accelerate import PartialState
10
+ from datasets import load_dataset
11
+ from peft import AutoPeftModelForCausalLM, LoraConfig
12
+ from transformers import (
13
+ AutoModelForCausalLM,
14
+ AutoTokenizer,
15
+ BitsAndBytesConfig,
16
+ is_torch_npu_available,
17
+ is_torch_xpu_available,
18
+ logging,
19
+ set_seed,
20
+ )
21
+ from trl import SFTConfig, SFTTrainer
22
+
23
+
24
+ def get_args():
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument("--model_id", type=str, default="HuggingFaceTB/SmolLM2-1.7B")
27
+ parser.add_argument("--tokenizer_id", type=str, default="")
28
+ parser.add_argument("--dataset_name", type=str, default="bigcode/the-stack-smol")
29
+ parser.add_argument("--subset", type=str, default="data/python")
30
+ parser.add_argument("--split", type=str, default="train")
31
+ parser.add_argument("--streaming", type=bool, default=False)
32
+ parser.add_argument("--dataset_text_field", type=str, default="content")
33
+
34
+ parser.add_argument("--max_seq_length", type=int, default=2048)
35
+ parser.add_argument("--max_steps", type=int, default=1000)
36
+ parser.add_argument("--micro_batch_size", type=int, default=1)
37
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
38
+ parser.add_argument("--weight_decay", type=float, default=0.01)
39
+ parser.add_argument("--bf16", type=bool, default=True)
40
+
41
+ parser.add_argument("--use_bnb", type=bool, default=False)
42
+ parser.add_argument("--attention_dropout", type=float, default=0.1)
43
+ parser.add_argument("--learning_rate", type=float, default=2e-4)
44
+ parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
45
+ parser.add_argument("--warmup_steps", type=int, default=100)
46
+ parser.add_argument("--seed", type=int, default=0)
47
+ parser.add_argument("--output_dir", type=str, default="finetune_smollm2_python")
48
+ parser.add_argument("--num_proc", type=int, default=None)
49
+ parser.add_argument("--push_to_hub", type=bool, default=True)
50
+ parser.add_argument("--repo_id", type=str, default="SmolLM2-1.7B-finetune")
51
+ return parser.parse_args()
52
+
53
+
54
+ def main(args):
55
+ # config
56
+ lora_config = LoraConfig(
57
+ r=16,
58
+ lora_alpha=32,
59
+ lora_dropout=0.05,
60
+ target_modules=["q_proj", "v_proj"],
61
+ bias="none",
62
+ task_type="CAUSAL_LM",
63
+ )
64
+ bnb_config = None
65
+ if args.use_bnb:
66
+ bnb_config = BitsAndBytesConfig(
67
+ load_in_4bit=True,
68
+ bnb_4bit_quant_type="nf4",
69
+ bnb_4bit_compute_dtype=torch.bfloat16,
70
+ )
71
+ # load model and dataset
72
+ token = os.environ.get("HF_TOKEN", None)
73
+ model = AutoModelForCausalLM.from_pretrained(
74
+ args.model_id,
75
+ quantization_config=bnb_config,
76
+ device_map={"": PartialState().process_index},
77
+ attention_dropout=args.attention_dropout,
78
+ )
79
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id or args.model_id)
80
+
81
+ data = load_dataset(
82
+ args.dataset_name,
83
+ data_dir=args.subset,
84
+ split=args.split,
85
+ token=token,
86
+ num_proc=args.num_proc if args.num_proc or args.streaming else multiprocessing.cpu_count(),
87
+ streaming=args.streaming,
88
+ )
89
+
90
+ # setup the trainer
91
+ trainer = SFTTrainer(
92
+ model=model,
93
+ processing_class=tokenizer,
94
+ train_dataset=data,
95
+ args=SFTConfig(
96
+ dataset_text_field=args.dataset_text_field,
97
+ dataset_num_proc=args.num_proc,
98
+ max_seq_length=args.max_seq_length,
99
+ per_device_train_batch_size=args.micro_batch_size,
100
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
101
+ warmup_steps=args.warmup_steps,
102
+ max_steps=args.max_steps,
103
+ learning_rate=args.learning_rate,
104
+ lr_scheduler_type=args.lr_scheduler_type,
105
+ weight_decay=args.weight_decay,
106
+ bf16=args.bf16,
107
+ logging_strategy="steps",
108
+ logging_steps=10,
109
+ output_dir=args.output_dir,
110
+ optim="paged_adamw_8bit",
111
+ seed=args.seed,
112
+ run_name=f"train-{args.model_id.split('/')[-1]}",
113
+ report_to="wandb",
114
+ push_to_hub=args.push_to_hub,
115
+ hub_model_id=args.repo_id,
116
+ ),
117
+ peft_config=lora_config,
118
+ )
119
+
120
+ # launch
121
+ print("Training...")
122
+ trainer.train()
123
+ print("Training Done! 💥")
124
+
125
+
126
+ if __name__ == "__main__":
127
+ args = get_args()
128
+ set_seed(args.seed)
129
+ os.makedirs(args.output_dir, exist_ok=True)
130
+
131
+ logging.set_verbosity_error()
132
+
133
+ main(args)