Spaces:
Runtime error
Runtime error
# Code adapted from https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py | |
# and https://huggingface.co/blog/gemma-peft | |
import argparse | |
import multiprocessing | |
import os | |
import torch | |
import transformers | |
from accelerate import PartialState | |
from datasets import load_dataset | |
from peft import AutoPeftModelForCausalLM, LoraConfig | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
BitsAndBytesConfig, | |
is_torch_npu_available, | |
is_torch_xpu_available, | |
logging, | |
set_seed, | |
) | |
from trl import SFTConfig, SFTTrainer | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model_id", type=str, default="HuggingFaceTB/SmolLM2-1.7B") | |
parser.add_argument("--tokenizer_id", type=str, default="") | |
parser.add_argument("--dataset_name", type=str, default="bigcode/the-stack-smol") | |
parser.add_argument("--subset", type=str, default="data/python") | |
parser.add_argument("--split", type=str, default="train") | |
parser.add_argument("--streaming", type=bool, default=False) | |
parser.add_argument("--dataset_text_field", type=str, default="content") | |
parser.add_argument("--max_seq_length", type=int, default=2048) | |
parser.add_argument("--max_steps", type=int, default=1000) | |
parser.add_argument("--micro_batch_size", type=int, default=1) | |
parser.add_argument("--gradient_accumulation_steps", type=int, default=4) | |
parser.add_argument("--weight_decay", type=float, default=0.01) | |
parser.add_argument("--bf16", type=bool, default=True) | |
parser.add_argument("--use_bnb", type=bool, default=False) | |
parser.add_argument("--attention_dropout", type=float, default=0.1) | |
parser.add_argument("--learning_rate", type=float, default=2e-4) | |
parser.add_argument("--lr_scheduler_type", type=str, default="cosine") | |
parser.add_argument("--warmup_steps", type=int, default=100) | |
parser.add_argument("--seed", type=int, default=0) | |
parser.add_argument("--output_dir", type=str, default="finetune_smollm2_python") | |
parser.add_argument("--num_proc", type=int, default=None) | |
parser.add_argument("--push_to_hub", type=bool, default=True) | |
parser.add_argument("--repo_id", type=str, default="SmolLM2-1.7B-finetune") | |
return parser.parse_args() | |
def main(args): | |
# config | |
lora_config = LoraConfig( | |
r=16, | |
lora_alpha=32, | |
lora_dropout=0.05, | |
target_modules=["q_proj", "v_proj"], | |
bias="none", | |
task_type="CAUSAL_LM", | |
) | |
bnb_config = None | |
if args.use_bnb: | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
) | |
# load model and dataset | |
token = os.environ.get("HF_TOKEN", None) | |
model = AutoModelForCausalLM.from_pretrained( | |
args.model_id, | |
quantization_config=bnb_config, | |
device_map={"": PartialState().process_index}, | |
attention_dropout=args.attention_dropout, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id or args.model_id) | |
data = load_dataset( | |
args.dataset_name, | |
data_dir=args.subset, | |
split=args.split, | |
token=token, | |
num_proc=args.num_proc if args.num_proc or args.streaming else multiprocessing.cpu_count(), | |
streaming=args.streaming, | |
) | |
# setup the trainer | |
trainer = SFTTrainer( | |
model=model, | |
processing_class=tokenizer, | |
train_dataset=data, | |
args=SFTConfig( | |
dataset_text_field=args.dataset_text_field, | |
dataset_num_proc=args.num_proc, | |
max_seq_length=args.max_seq_length, | |
per_device_train_batch_size=args.micro_batch_size, | |
gradient_accumulation_steps=args.gradient_accumulation_steps, | |
warmup_steps=args.warmup_steps, | |
max_steps=args.max_steps, | |
learning_rate=args.learning_rate, | |
lr_scheduler_type=args.lr_scheduler_type, | |
weight_decay=args.weight_decay, | |
bf16=args.bf16, | |
logging_strategy="steps", | |
logging_steps=10, | |
output_dir=args.output_dir, | |
optim="paged_adamw_8bit", | |
seed=args.seed, | |
run_name=f"train-{args.model_id.split('/')[-1]}", | |
report_to="wandb", | |
push_to_hub=args.push_to_hub, | |
hub_model_id=args.repo_id, | |
), | |
peft_config=lora_config, | |
) | |
# launch | |
print("Training...") | |
trainer.train() | |
print("Training Done! π₯") | |
if __name__ == "__main__": | |
args = get_args() | |
set_seed(args.seed) | |
os.makedirs(args.output_dir, exist_ok=True) | |
logging.set_verbosity_error() | |
main(args) | |