Spaces:
Running
Running
import os | |
import sys | |
import argparse | |
from datetime import datetime | |
from functools import partial | |
import datasets | |
import torch | |
from torch.utils.tensorboard import SummaryWriter | |
import wandb | |
from transformers import ( | |
AutoTokenizer, | |
AutoModel, | |
AutoModelForCausalLM, | |
TrainingArguments, | |
Trainer, | |
DataCollatorForSeq2Seq | |
) | |
from transformers.trainer import TRAINING_ARGS_NAME | |
from transformers.integrations import TensorBoardCallback | |
# Importing LoRA specific modules | |
from peft import ( | |
TaskType, | |
LoraConfig, | |
get_peft_model, | |
get_peft_model_state_dict, | |
prepare_model_for_int8_training, | |
set_peft_model_state_dict | |
) | |
from utils import * | |
# Replace with your own api_key and project name | |
os.environ['WANDB_API_KEY'] = 'ecf1e5e4f47441d46822d38a3249d62e8fc94db4' | |
os.environ['WANDB_PROJECT'] = 'fingpt-benchmark' | |
def main(args): | |
""" | |
Main function to execute the training script. | |
:param args: Command line arguments | |
""" | |
# Parse the model name and determine if it should be fetched from a remote source | |
model_name = parse_model_name(args.base_model, args.from_remote) | |
# Load the pre-trained causal language model | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
# load_in_8bit=True, | |
# device_map="auto", | |
trust_remote_code=True | |
) | |
# Print model architecture for the first process in distributed training | |
if args.local_rank == 0: | |
print(model) | |
# Load tokenizer associated with the pre-trained model | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
# Apply model specific tokenization settings | |
if args.base_model != 'mpt': | |
tokenizer.padding_side = "left" | |
if args.base_model == 'qwen': | |
tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids('<|endoftext|>') | |
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids('<|extra_0|>') | |
# Ensure padding token is set correctly | |
if not tokenizer.pad_token or tokenizer.pad_token_id == tokenizer.eos_token_id: | |
tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
model.resize_token_embeddings(len(tokenizer)) | |
# Load training and testing datasets | |
dataset_list = load_dataset(args.dataset, args.from_remote) | |
dataset_train = datasets.concatenate_datasets([d['train'] for d in dataset_list]).shuffle(seed=42) | |
if args.test_dataset: | |
dataset_list = load_dataset(args.test_dataset, args.from_remote) | |
dataset_test = datasets.concatenate_datasets([d['test'] for d in dataset_list]) | |
dataset = datasets.DatasetDict({'train': dataset_train, 'test': dataset_test}) | |
# Display first sample from the training dataset | |
print(dataset['train'][0]) | |
# Filter out samples that exceed the maximum token length and remove unused columns | |
dataset = dataset.map(partial(tokenize, args, tokenizer)) | |
print('original dataset length: ', len(dataset['train'])) | |
dataset = dataset.filter(lambda x: not x['exceed_max_length']) | |
print('filtered dataset length: ', len(dataset['train'])) | |
dataset = dataset.remove_columns(['instruction', 'input', 'output', 'exceed_max_length']) | |
print(dataset['train'][0]) | |
# Create a timestamp for model saving | |
current_time = datetime.now() | |
formatted_time = current_time.strftime('%Y%m%d%H%M') | |
# Set up training arguments | |
training_args = TrainingArguments( | |
output_dir=f'finetuned_models/{args.run_name}_{formatted_time}', # 保存位置 | |
logging_steps=args.log_interval, | |
num_train_epochs=args.num_epochs, | |
per_device_train_batch_size=args.batch_size, | |
per_device_eval_batch_size=args.batch_size, | |
gradient_accumulation_steps=args.gradient_steps, | |
dataloader_num_workers=args.num_workers, | |
learning_rate=args.learning_rate, | |
warmup_ratio=args.warmup_ratio, | |
lr_scheduler_type=args.scheduler, | |
save_steps=args.eval_steps, | |
eval_steps=args.eval_steps, | |
fp16=True, | |
# fp16_full_eval=True, | |
deepspeed=args.ds_config, | |
evaluation_strategy=args.evaluation_strategy, | |
load_best_model_at_end=args.load_best_model, | |
remove_unused_columns=False, | |
report_to='wandb', | |
run_name=args.run_name | |
) | |
if not args.base_model == 'mpt': | |
model.gradient_checkpointing_enable() | |
model.enable_input_require_grads() | |
model.is_parallelizable = True | |
model.model_parallel = True | |
model.config.use_cache = ( | |
False | |
) | |
# model = prepare_model_for_int8_training(model | |
# setup peft for lora | |
peft_config = LoraConfig( | |
task_type=TaskType.CAUSAL_LM, | |
inference_mode=False, | |
r=8, | |
lora_alpha=32, | |
lora_dropout=0.1, | |
target_modules=lora_module_dict[args.base_model], | |
bias='none', | |
) | |
model = get_peft_model(model, peft_config) | |
# Initialize TensorBoard for logging | |
writer = SummaryWriter() | |
# Initialize the trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=dataset["train"], | |
eval_dataset=dataset["test"], | |
data_collator=DataCollatorForSeq2Seq( | |
tokenizer, padding=True, | |
return_tensors="pt" | |
), | |
callbacks=[TensorBoardCallback(writer)], | |
) | |
# if torch.__version__ >= "2" and sys.platform != "win32": | |
# model = torch.compile(model) | |
# Clear CUDA cache and start training | |
torch.cuda.empty_cache() | |
trainer.train() | |
writer.close() | |
# Save the fine-tuned model | |
model.save_pretrained(training_args.output_dir) | |
if __name__ == "__main__": | |
# Argument parser for command line arguments | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--local_rank", default=0, type=int) | |
parser.add_argument("--run_name", default='local-test', type=str) | |
parser.add_argument("--dataset", required=True, type=str) | |
parser.add_argument("--test_dataset", type=str) | |
parser.add_argument("--base_model", required=True, type=str, choices=['chatglm2', 'llama2', 'llama2-13b', 'llama2-13b-nr', 'baichuan', 'falcon', 'internlm', 'qwen', 'mpt', 'bloom']) | |
parser.add_argument("--max_length", default=512, type=int) | |
parser.add_argument("--batch_size", default=4, type=int, help="The train batch size per device") | |
parser.add_argument("--learning_rate", default=1e-4, type=float, help="The learning rate") | |
parser.add_argument("--num_epochs", default=8, type=float, help="The training epochs") | |
parser.add_argument("--gradient_steps", default=8, type=float, help="The gradient accumulation steps") | |
parser.add_argument("--num_workers", default=8, type=int, help="dataloader workers") | |
parser.add_argument("--log_interval", default=20, type=int) | |
parser.add_argument("--warmup_ratio", default=0.05, type=float) | |
parser.add_argument("--ds_config", default='./config_new.json', type=str) | |
parser.add_argument("--scheduler", default='linear', type=str) | |
parser.add_argument("--instruct_template", default='default') | |
parser.add_argument("--evaluation_strategy", default='steps', type=str) | |
parser.add_argument("--load_best_model", default='False', type=bool) | |
parser.add_argument("--eval_steps", default=0.1, type=float) | |
parser.add_argument("--from_remote", default=False, type=bool) | |
args = parser.parse_args() | |
# Login to Weights and Biases | |
wandb.login() | |
# Run the main function | |
main(args) | |