| # Pretrain Example | |
| > \[!IMPORTANT\] | |
| > Data must be used in conjunction with the corresponding map_fn. | |
| ## Data | |
| `./data.json` | |
| ```json | |
| [{ | |
| "toy_text": "I am an artificial intelligence (AI) assistant named InternLM. I was created by the Shanghai AI Laboratory and my purpose is to assist users with various tasks through natural language processing technology." | |
| }, | |
| { | |
| "toy_text": "I am an artificial intelligence programmed to assist with various types of tasks, including answering questions, providing information, and performing automated processes." | |
| }] | |
| ``` | |
| ## Map Function | |
| `./map_fn.py` | |
| ```python | |
| def pretrain_map_fn(example): | |
| return { | |
| 'conversation': [{ | |
| 'input': '', | |
| 'output': example['toy_text'].strip() | |
| }] | |
| } | |
| ``` | |
| ## Config | |
| Based on [internlm_7b_qlora_json_e3](../../../xtuner/configs/internlm/internlm_7b/internlm_7b_qlora_json_e3.py). | |
| ```diff | |
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| from datasets import load_dataset | |
| + from mmengine.config import read_base | |
| from mmengine.dataset import DefaultSampler | |
| from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, | |
| LoggerHook, ParamSchedulerHook) | |
| from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR | |
| from peft import LoraConfig | |
| from torch.optim import AdamW | |
| from transformers import (AutoModelForCausalLM, AutoTokenizer, | |
| BitsAndBytesConfig) | |
| from xtuner.dataset import process_hf_dataset | |
| from xtuner.dataset.collate_fns import default_collate_fn | |
| -from xtuner.dataset.map_fns import template_map_fn_factory | |
| -from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook | |
| +from xtuner.engine.hooks import DatasetInfoHook | |
| from xtuner.engine.runner import TrainLoop | |
| from xtuner.model import SupervisedFinetune | |
| -from xtuner.utils import PROMPT_TEMPLATE | |
| +with read_base(): | |
| + from .map_fn import single_turn_map_fn as dataset_map_fn | |
| + | |
| ####################################################################### | |
| # PART 1 Settings # | |
| ####################################################################### | |
| # Model | |
| pretrained_model_name_or_path = 'internlm/internlm-7b' | |
| # Data | |
| -data_path = 'path/to/your/json_data' | |
| +data_path = './data.json' | |
| -prompt_template = PROMPT_TEMPLATE.default | |
| max_length = 2048 | |
| pack_to_max_length = True | |
| # Scheduler & Optimizer | |
| batch_size = 1 # per_device | |
| accumulative_counts = 16 | |
| dataloader_num_workers = 0 | |
| max_epochs = 3 | |
| optim_type = AdamW | |
| lr = 2e-4 | |
| betas = (0.9, 0.999) | |
| weight_decay = 0 | |
| max_norm = 1 # grad clip | |
| # Save | |
| save_steps = 500 | |
| save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) | |
| # Evaluate the generation performance during the training | |
| evaluation_freq = 500 | |
| SYSTEM = '' | |
| evaluation_inputs = [ | |
| '请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai' | |
| ] | |
| ####################################################################### | |
| # PART 2 Model & Tokenizer # | |
| ####################################################################### | |
| tokenizer = dict( | |
| type=AutoTokenizer.from_pretrained, | |
| pretrained_model_name_or_path=pretrained_model_name_or_path, | |
| trust_remote_code=True, | |
| padding_side='right') | |
| model = dict( | |
| type=SupervisedFinetune, | |
| llm=dict( | |
| type=AutoModelForCausalLM.from_pretrained, | |
| pretrained_model_name_or_path=pretrained_model_name_or_path, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16, | |
| quantization_config=dict( | |
| type=BitsAndBytesConfig, | |
| load_in_4bit=True, | |
| load_in_8bit=False, | |
| llm_int8_threshold=6.0, | |
| llm_int8_has_fp16_weight=False, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type='nf4')), | |
| lora=dict( | |
| type=LoraConfig, | |
| r=64, | |
| lora_alpha=16, | |
| lora_dropout=0.1, | |
| bias='none', | |
| task_type='CAUSAL_LM')) | |
| ####################################################################### | |
| # PART 3 Dataset & Dataloader # | |
| ####################################################################### | |
| train_dataset = dict( | |
| type=process_hf_dataset, | |
| dataset=dict( | |
| type=load_dataset, path='json', data_files=dict(train=data_path)), | |
| tokenizer=tokenizer, | |
| max_length=max_length, | |
| + dataset_map_fn=dataset_map_fn, | |
| - template_map_fn=dict( | |
| - type=template_map_fn_factory, template=prompt_template), | |
| + template_map_fn=None, | |
| remove_unused_columns=True, | |
| shuffle_before_pack=True, | |
| pack_to_max_length=pack_to_max_length) | |
| train_dataloader = dict( | |
| batch_size=batch_size, | |
| num_workers=dataloader_num_workers, | |
| dataset=train_dataset, | |
| sampler=dict(type=DefaultSampler, shuffle=True), | |
| collate_fn=dict(type=default_collate_fn)) | |
| ####################################################################### | |
| # PART 4 Scheduler & Optimizer # | |
| ####################################################################### | |
| # optimizer | |
| optim_wrapper = dict( | |
| type=AmpOptimWrapper, | |
| optimizer=dict( | |
| type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), | |
| clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), | |
| accumulative_counts=accumulative_counts, | |
| loss_scale='dynamic', | |
| dtype='float16') | |
| # learning policy | |
| # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 | |
| param_scheduler = dict( | |
| type=CosineAnnealingLR, | |
| eta_min=0.0, | |
| by_epoch=True, | |
| end=max_epochs, | |
| convert_to_iter_based=True) | |
| # train, val, test setting | |
| train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) | |
| ####################################################################### | |
| # PART 5 Runtime # | |
| ####################################################################### | |
| # Log the dialogue periodically during the training process, optional | |
| -custom_hooks = [ | |
| - dict(type=DatasetInfoHook, tokenizer=tokenizer), | |
| - dict( | |
| - type=EvaluateChatHook, | |
| - tokenizer=tokenizer, | |
| - every_n_iters=evaluation_freq, | |
| - evaluation_inputs=evaluation_inputs, | |
| - system=SYSTEM, | |
| - prompt_template=prompt_template) | |
| -] | |
| +custom_hooks = [dict(type=DatasetInfoHook, tokenizer=tokenizer)] | |
| # configure default hooks | |
| default_hooks = dict( | |
| # record the time of every iteration. | |
| timer=dict(type=IterTimerHook), | |
| # print log every 10 iterations. | |
| logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), | |
| # enable the parameter scheduler. | |
| param_scheduler=dict(type=ParamSchedulerHook), | |
| # save checkpoint per `save_steps`. | |
| checkpoint=dict( | |
| type=CheckpointHook, | |
| by_epoch=False, | |
| interval=save_steps, | |
| max_keep_ckpts=save_total_limit), | |
| # set sampler seed in distributed evrionment. | |
| sampler_seed=dict(type=DistSamplerSeedHook), | |
| ) | |
| # configure environment | |
| env_cfg = dict( | |
| # whether to enable cudnn benchmark | |
| cudnn_benchmark=False, | |
| # set multi process parameters | |
| mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), | |
| # set distributed parameters | |
| dist_cfg=dict(backend='nccl'), | |
| ) | |
| # set visualizer | |
| visualizer = None | |
| # set log level | |
| log_level = 'INFO' | |
| # load from which checkpoint | |
| load_from = None | |
| # whether to resume training from the loaded checkpoint | |
| resume = False | |
| # Defaults to use random seed and disable `deterministic` | |
| randomness = dict(seed=None, deterministic=False) | |
| # set log processor | |
| log_processor = dict(by_epoch=False) | |
| ``` | |
| ## Quick Start | |
| ```bash | |
| cd ./examples/demo_data/pretrain | |
| xtuner train config.py | |
| ``` | |