Attention mask and position id fixes for packing (#285)
Browse files* fix attetion mask with packing
* set position ids and use block diagonal attn mask
* fix expand mask for multiple batch items, make sure we pad position_ids
* don't move masks to cpu
* use multi pack dataloader w random sampler
* add position_ids back
* more fixes for dataloader integration
* est total tokens, fix field loop
* more fixes, position_ids seems broken
* more fixes for sample packing
* use distributed sampler, avoid accelerate prepare
* use accelerator prepare for dataloader
* fix for position_ids w packing
* Update src/axolotl/utils/dataloader.py
* validation for sample packing and doc
* more fixes for 4k and optimizations
* optimized expand mask fn
* better handling of variance in multipack dataloader length and trainer hanging when it runs out of data
* fix rounding of len of batches to int
* better handling so that all devices have the same dataloader len
* fix step calc for packing
* pass sample packing efficiency to training args
* add a test for the mask expansion for sequence packing
* only process eval dataset for packing if not None
* don't split batches when packing
* weighted CE losses
* weighted CEL fixes
* limit packing to sequences of max seq len
* seq_len_multiple for packing
* make sure the chunk size is an int
* sample_packing_seq_len_multiplier config
* use cumulative seq len with var len flash attn v2 w packing
* properly calculate max len
* fix flash-attn, xformers, packing, support chatml
* fix chatml system prompt for openorca, legacy tokenizer opts
* add chatml
* add unit tests for cum seq lens, add ability to build cu_seq_lens from positional ids, fix prompt test
* fix test and pylint checks
* more packing and dataset optimizations and fixes
* filter w multiple cpus
* more fixes and optimizations
* fixes and go back to distributed sampler since batch sampler won't work
* fix counts by accounting for num devices
* fix steps calculation
* previous accelerate is still most performant
* add numba to requirements.
* use custom distributed checks
* fix sampler to prevent overfit w new epochs
* let's not cleanup the cached datasets
* calculate cum seq lens with pos_ids instead of mask, simplify packing params, fix distributed barrier
* speed optimizations and set accelerate fsdp env vars
* optimize dataset concatenation?
* more optimizations for dataset handling
* fix import for annotation
* manual pre-commit fixes
* another sum optimization and bug fix for calc steps
* fix packing estimations
* fix formatting
* pylint problems
* add back flash attention branch for handling unpacked sequences seperately
* Address PR feedback
* add optional sample packing config params to readme
- README.md +9 -1
 - requirements.txt +2 -0
 - scripts/finetune.py +25 -7
 - src/axolotl/datasets.py +32 -21
 - src/axolotl/monkeypatch/llama_attn_hijack_flash.py +21 -1
 - src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +1 -0
 - src/axolotl/monkeypatch/llama_expand_mask.py +52 -0
 - src/axolotl/monkeypatch/utils.py +103 -0
 - src/axolotl/prompt_strategies/alpaca_w_system.py +22 -1
 - src/axolotl/prompters.py +11 -0
 - src/axolotl/utils/collators.py +121 -0
 - src/axolotl/utils/data.py +63 -18
 - src/axolotl/utils/dataloader.py +288 -0
 - src/axolotl/utils/distributed.py +41 -0
 - src/axolotl/utils/models.py +18 -6
 - src/axolotl/utils/trainer.py +273 -11
 - src/axolotl/utils/validation.py +24 -0
 - tests/monkeypatch/test_llama_attn_hijack_flash.py +30 -0
 - tests/test_expand_mask.py +44 -0
 - tests/test_packed_dataset.py +6 -2
 - tests/test_prompt_tokenizers.py +9 -3
 - tests/test_prompters.py +1 -1
 - tests/test_validation.py +24 -0
 
| 
         @@ -375,7 +375,14 @@ dataset_shard_idx: 
     | 
|
| 375 | 
         
             
            sequence_len: 2048
         
     | 
| 376 | 
         
             
            # max sequence length to concatenate training samples together up to
         
     | 
| 377 | 
         
             
            # inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
         
     | 
| 
         | 
|
| 378 | 
         
             
            max_packed_sequence_len: 1024
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 379 | 
         | 
| 380 | 
         
             
            # if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
         
     | 
| 381 | 
         
             
            adapter: lora
         
     | 
| 
         @@ -421,6 +428,7 @@ learning_rate: 0.00003 
     | 
|
| 421 | 
         
             
            logging_steps:
         
     | 
| 422 | 
         
             
            save_steps:
         
     | 
| 423 | 
         
             
            eval_steps:
         
     | 
| 
         | 
|
| 424 | 
         | 
| 425 | 
         
             
            # save model as safetensors (require safetensors package)
         
     | 
| 426 | 
         
             
            save_safetensors:
         
     | 
| 
         @@ -534,7 +542,7 @@ accelerate launch scripts/finetune.py configs/your_config.yml 
     | 
|
| 534 | 
         | 
| 535 | 
         
             
            #### Multi-GPU
         
     | 
| 536 | 
         | 
| 537 | 
         
            -
             
     | 
| 538 | 
         
             
            ```bash
         
     | 
| 539 | 
         
             
            CUDA_VISIBLE_DEVICES="" accelerate ... --prepare_ds_only
         
     | 
| 540 | 
         
             
            ```
         
     | 
| 
         | 
|
| 375 | 
         
             
            sequence_len: 2048
         
     | 
| 376 | 
         
             
            # max sequence length to concatenate training samples together up to
         
     | 
| 377 | 
         
             
            # inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
         
     | 
| 378 | 
         
            +
            # FutureWarning: This will soon be DEPRECATED
         
     | 
| 379 | 
         
             
            max_packed_sequence_len: 1024
         
     | 
| 380 | 
         
            +
            # use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
         
     | 
| 381 | 
         
            +
            sample_packing:
         
     | 
| 382 | 
         
            +
            # you can set these packing optimizations AFTER starting a training at least once.
         
     | 
| 383 | 
         
            +
            # The trainer will provide recommended values for these values.
         
     | 
| 384 | 
         
            +
            sample_packing_eff_est:
         
     | 
| 385 | 
         
            +
            total_num_tokens:
         
     | 
| 386 | 
         | 
| 387 | 
         
             
            # if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
         
     | 
| 388 | 
         
             
            adapter: lora
         
     | 
| 
         | 
|
| 428 | 
         
             
            logging_steps:
         
     | 
| 429 | 
         
             
            save_steps:
         
     | 
| 430 | 
         
             
            eval_steps:
         
     | 
| 431 | 
         
            +
            save_total_limit:
         
     | 
| 432 | 
         | 
| 433 | 
         
             
            # save model as safetensors (require safetensors package)
         
     | 
| 434 | 
         
             
            save_safetensors:
         
     | 
| 
         | 
|
| 542 | 
         | 
| 543 | 
         
             
            #### Multi-GPU
         
     | 
| 544 | 
         | 
| 545 | 
         
            +
            You can optionally pre-tokenize dataset with the following before finetuning:
         
     | 
| 546 | 
         
             
            ```bash
         
     | 
| 547 | 
         
             
            CUDA_VISIBLE_DEVICES="" accelerate ... --prepare_ds_only
         
     | 
| 548 | 
         
             
            ```
         
     | 
| 
         @@ -13,6 +13,8 @@ einops 
     | 
|
| 13 | 
         
             
            xformers
         
     | 
| 14 | 
         
             
            optimum
         
     | 
| 15 | 
         
             
            hf_transfer
         
     | 
| 
         | 
|
| 
         | 
|
| 16 | 
         
             
            # qlora things
         
     | 
| 17 | 
         
             
            bert-score==0.3.13
         
     | 
| 18 | 
         
             
            evaluate==0.4.0
         
     | 
| 
         | 
|
| 13 | 
         
             
            xformers
         
     | 
| 14 | 
         
             
            optimum
         
     | 
| 15 | 
         
             
            hf_transfer
         
     | 
| 16 | 
         
            +
            numba
         
     | 
| 17 | 
         
            +
            numpy==1.24.4
         
     | 
| 18 | 
         
             
            # qlora things
         
     | 
| 19 | 
         
             
            bert-score==0.3.13
         
     | 
| 20 | 
         
             
            evaluate==0.4.0
         
     | 
| 
         @@ -21,9 +21,14 @@ from axolotl.logging_config import configure_logging 
     | 
|
| 21 | 
         
             
            from axolotl.utils.bench import log_gpu_memory_usage
         
     | 
| 22 | 
         
             
            from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
         
     | 
| 23 | 
         
             
            from axolotl.utils.dict import DictDefault
         
     | 
| 
         | 
|
| 24 | 
         
             
            from axolotl.utils.models import load_model, load_tokenizer
         
     | 
| 25 | 
         
             
            from axolotl.utils.tokenization import check_dataset_labels
         
     | 
| 26 | 
         
            -
            from axolotl.utils.trainer import  
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 27 | 
         
             
            from axolotl.utils.validation import validate_config
         
     | 
| 28 | 
         
             
            from axolotl.utils.wandb import setup_wandb_env_vars
         
     | 
| 29 | 
         | 
| 
         @@ -232,12 +237,25 @@ def train( 
     | 
|
| 232 | 
         
             
                            cfg.pretraining_dataset,
         
     | 
| 233 | 
         
             
                            tokenizer,
         
     | 
| 234 | 
         
             
                            max_tokens=cfg.sequence_len,
         
     | 
| 235 | 
         
            -
                            seed=cfg.seed,
         
     | 
| 236 | 
         
             
                        )
         
     | 
| 237 | 
         
             
                        # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
         
     | 
| 238 | 
         
             
                        train_dataset = train_dataset.with_format("torch")
         
     | 
| 239 | 
         
             
                        eval_dataset = None
         
     | 
| 240 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 241 | 
         
             
                if cfg.debug or "debug" in kwargs:
         
     | 
| 242 | 
         
             
                    LOG.info("check_dataset_labels...")
         
     | 
| 243 | 
         
             
                    check_dataset_labels(
         
     | 
| 
         @@ -254,7 +272,7 @@ def train( 
     | 
|
| 254 | 
         
             
                log_gpu_memory_usage(LOG, "baseline", cfg.device)
         
     | 
| 255 | 
         | 
| 256 | 
         
             
                # Load the model and tokenizer
         
     | 
| 257 | 
         
            -
                LOG.info("loading model and peft_config...")
         
     | 
| 258 | 
         
             
                model, peft_config = load_model(cfg, tokenizer)
         
     | 
| 259 | 
         | 
| 260 | 
         
             
                safe_serialization = cfg.save_safetensors is True
         
     | 
| 
         @@ -288,7 +306,9 @@ def train( 
     | 
|
| 288 | 
         
             
                    model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
         
     | 
| 289 | 
         
             
                    return
         
     | 
| 290 | 
         | 
| 291 | 
         
            -
                trainer = setup_trainer( 
     | 
| 
         | 
|
| 
         | 
|
| 292 | 
         | 
| 293 | 
         
             
                model.config.use_cache = False
         
     | 
| 294 | 
         | 
| 
         @@ -347,14 +367,12 @@ def train( 
     | 
|
| 347 | 
         
             
                # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
         
     | 
| 348 | 
         
             
                # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
         
     | 
| 349 | 
         
             
                if cfg.fsdp:
         
     | 
| 350 | 
         
            -
                     
     | 
| 351 | 
         
             
                elif cfg.local_rank == 0:
         
     | 
| 352 | 
         
             
                    if cfg.flash_optimum:
         
     | 
| 353 | 
         
             
                        model = BetterTransformer.reverse(model)
         
     | 
| 354 | 
         
             
                    model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
         
     | 
| 355 | 
         | 
| 356 | 
         
            -
                # trainer.save_model(cfg.output_dir)  # TODO this may be needed for deepspeed to work? need to review another time
         
     | 
| 357 | 
         
            -
             
     | 
| 358 | 
         | 
| 359 | 
         
             
            if __name__ == "__main__":
         
     | 
| 360 | 
         
             
                fire.Fire(train)
         
     | 
| 
         | 
|
| 21 | 
         
             
            from axolotl.utils.bench import log_gpu_memory_usage
         
     | 
| 22 | 
         
             
            from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
         
     | 
| 23 | 
         
             
            from axolotl.utils.dict import DictDefault
         
     | 
| 24 | 
         
            +
            from axolotl.utils.distributed import barrier, is_main_process
         
     | 
| 25 | 
         
             
            from axolotl.utils.models import load_model, load_tokenizer
         
     | 
| 26 | 
         
             
            from axolotl.utils.tokenization import check_dataset_labels
         
     | 
| 27 | 
         
            +
            from axolotl.utils.trainer import (
         
     | 
| 28 | 
         
            +
                calculate_total_num_steps,
         
     | 
| 29 | 
         
            +
                process_datasets_for_packing,
         
     | 
| 30 | 
         
            +
                setup_trainer,
         
     | 
| 31 | 
         
            +
            )
         
     | 
| 32 | 
         
             
            from axolotl.utils.validation import validate_config
         
     | 
| 33 | 
         
             
            from axolotl.utils.wandb import setup_wandb_env_vars
         
     | 
| 34 | 
         | 
| 
         | 
|
| 237 | 
         
             
                            cfg.pretraining_dataset,
         
     | 
| 238 | 
         
             
                            tokenizer,
         
     | 
| 239 | 
         
             
                            max_tokens=cfg.sequence_len,
         
     | 
| 240 | 
         
            +
                            seed=cfg.seed or 42,
         
     | 
| 241 | 
         
             
                        )
         
     | 
| 242 | 
         
             
                        # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
         
     | 
| 243 | 
         
             
                        train_dataset = train_dataset.with_format("torch")
         
     | 
| 244 | 
         
             
                        eval_dataset = None
         
     | 
| 245 | 
         | 
| 246 | 
         
            +
                    if is_main_process():
         
     | 
| 247 | 
         
            +
                        # process on rank 0 first so it gets cached so other ranks load from cache
         
     | 
| 248 | 
         
            +
                        train_dataset, eval_dataset = process_datasets_for_packing(
         
     | 
| 249 | 
         
            +
                            cfg, train_dataset, eval_dataset
         
     | 
| 250 | 
         
            +
                        )
         
     | 
| 251 | 
         
            +
                    barrier()
         
     | 
| 252 | 
         
            +
                    if not is_main_process():
         
     | 
| 253 | 
         
            +
                        train_dataset, eval_dataset = process_datasets_for_packing(
         
     | 
| 254 | 
         
            +
                            cfg, train_dataset, eval_dataset
         
     | 
| 255 | 
         
            +
                        )
         
     | 
| 256 | 
         
            +
                    barrier()
         
     | 
| 257 | 
         
            +
                    total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
             
                if cfg.debug or "debug" in kwargs:
         
     | 
| 260 | 
         
             
                    LOG.info("check_dataset_labels...")
         
     | 
| 261 | 
         
             
                    check_dataset_labels(
         
     | 
| 
         | 
|
| 272 | 
         
             
                log_gpu_memory_usage(LOG, "baseline", cfg.device)
         
     | 
| 273 | 
         | 
| 274 | 
         
             
                # Load the model and tokenizer
         
     | 
| 275 | 
         
            +
                LOG.info("loading model and (optionally) peft_config...")
         
     | 
| 276 | 
         
             
                model, peft_config = load_model(cfg, tokenizer)
         
     | 
| 277 | 
         | 
| 278 | 
         
             
                safe_serialization = cfg.save_safetensors is True
         
     | 
| 
         | 
|
| 306 | 
         
             
                    model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
         
     | 
| 307 | 
         
             
                    return
         
     | 
| 308 | 
         | 
| 309 | 
         
            +
                trainer = setup_trainer(
         
     | 
| 310 | 
         
            +
                    cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
         
     | 
| 311 | 
         
            +
                )
         
     | 
| 312 | 
         | 
| 313 | 
         
             
                model.config.use_cache = False
         
     | 
| 314 | 
         | 
| 
         | 
|
| 367 | 
         
             
                # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
         
     | 
| 368 | 
         
             
                # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
         
     | 
| 369 | 
         
             
                if cfg.fsdp:
         
     | 
| 370 | 
         
            +
                    trainer.save_model(cfg.output_dir)
         
     | 
| 371 | 
         
             
                elif cfg.local_rank == 0:
         
     | 
| 372 | 
         
             
                    if cfg.flash_optimum:
         
     | 
| 373 | 
         
             
                        model = BetterTransformer.reverse(model)
         
     | 
| 374 | 
         
             
                    model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
         
     | 
| 375 | 
         | 
| 
         | 
|
| 
         | 
|
| 376 | 
         | 
| 377 | 
         
             
            if __name__ == "__main__":
         
     | 
| 378 | 
         
             
                fire.Fire(train)
         
     | 
| 
         @@ -5,7 +5,7 @@ import os 
     | 
|
| 5 | 
         
             
            from typing import List
         
     | 
| 6 | 
         | 
| 7 | 
         
             
            import torch
         
     | 
| 8 | 
         
            -
            from datasets import IterableDataset
         
     | 
| 9 | 
         | 
| 10 | 
         
             
            from .prompt_tokenizers import PromptTokenizingStrategy
         
     | 
| 11 | 
         | 
| 
         @@ -18,9 +18,9 @@ from .prompt_tokenizers import PromptTokenizingStrategy 
     | 
|
| 18 | 
         
             
            LOG = logging.getLogger("axolotl")
         
     | 
| 19 | 
         | 
| 20 | 
         | 
| 21 | 
         
            -
            class TokenizedPromptDataset( 
     | 
| 22 | 
         
             
                """
         
     | 
| 23 | 
         
            -
                 
     | 
| 24 | 
         
             
                    Args:
         
     | 
| 25 | 
         
             
                        prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
         
     | 
| 26 | 
         
             
                        dataset (dataset.Dataset): Dataset with text files.
         
     | 
| 
         @@ -30,19 +30,18 @@ class TokenizedPromptDataset(IterableDataset): 
     | 
|
| 30 | 
         
             
                    self,
         
     | 
| 31 | 
         
             
                    prompt_tokenizer: PromptTokenizingStrategy,
         
     | 
| 32 | 
         
             
                    dataset: IterableDataset,
         
     | 
| 
         | 
|
| 33 | 
         
             
                ):
         
     | 
| 34 | 
         
             
                    self.prompt_tokenizer = prompt_tokenizer
         
     | 
| 35 | 
         
            -
                    self.dataset  
     | 
| 36 | 
         
            -
             
     | 
| 37 | 
         
            -
                def  
     | 
| 38 | 
         
            -
                    features =  
     | 
| 39 | 
         
            -
                    num_proc = os.cpu_count()
         
     | 
| 40 | 
         
            -
                    return  
     | 
| 41 | 
         
            -
                        self. 
     | 
| 42 | 
         
            -
             
     | 
| 43 | 
         
            -
             
     | 
| 44 | 
         
            -
                            remove_columns=features,
         
     | 
| 45 | 
         
            -
                        )
         
     | 
| 46 | 
         
             
                    )
         
     | 
| 47 | 
         | 
| 48 | 
         | 
| 
         @@ -77,14 +76,21 @@ class ConstantLengthDataset(IterableDataset): 
     | 
|
| 77 | 
         
             
                        self.tokens_dtype = torch.int64
         
     | 
| 78 | 
         | 
| 79 | 
         
             
                def __iter__(self):
         
     | 
| 80 | 
         
            -
                    buffer = { 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 81 | 
         
             
                    buffer_len = 0
         
     | 
| 82 | 
         
             
                    for dataset in self.datasets:
         
     | 
| 
         | 
|
| 83 | 
         
             
                        iterator = iter(dataset)
         
     | 
| 84 | 
         
             
                        more_examples = True
         
     | 
| 85 | 
         
             
                        while more_examples:
         
     | 
| 86 | 
         
             
                            try:
         
     | 
| 87 | 
         
             
                                example = next(iterator)
         
     | 
| 
         | 
|
| 88 | 
         
             
                            except StopIteration:
         
     | 
| 89 | 
         
             
                                more_examples = False
         
     | 
| 90 | 
         
             
                                example = None
         
     | 
| 
         @@ -106,6 +112,9 @@ class ConstantLengthDataset(IterableDataset): 
     | 
|
| 106 | 
         
             
                                    attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
         
     | 
| 107 | 
         
             
                                        : self.seq_length
         
     | 
| 108 | 
         
             
                                    ]
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 109 | 
         
             
                                    labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
         
     | 
| 110 | 
         
             
                                    if labels.size() == input_ids.size() and (
         
     | 
| 111 | 
         
             
                                        attention_mask.size() == input_ids.size()
         
     | 
| 
         @@ -114,6 +123,7 @@ class ConstantLengthDataset(IterableDataset): 
     | 
|
| 114 | 
         
             
                                            "input_ids": input_ids,
         
     | 
| 115 | 
         
             
                                            "labels": labels,
         
     | 
| 116 | 
         
             
                                            "attention_mask": attention_mask,
         
     | 
| 
         | 
|
| 117 | 
         
             
                                        }
         
     | 
| 118 | 
         
             
                                    else:
         
     | 
| 119 | 
         
             
                                        LOG.warning(
         
     | 
| 
         @@ -123,8 +133,10 @@ class ConstantLengthDataset(IterableDataset): 
     | 
|
| 123 | 
         
             
                                    "input_ids": [],
         
     | 
| 124 | 
         
             
                                    "attention_mask": [],
         
     | 
| 125 | 
         
             
                                    "labels": [],
         
     | 
| 
         | 
|
| 126 | 
         
             
                                }
         
     | 
| 127 | 
         
             
                                buffer_len = 0
         
     | 
| 
         | 
|
| 128 | 
         | 
| 129 | 
         
             
                            if example:
         
     | 
| 130 | 
         
             
                                # FIXME
         
     | 
| 
         @@ -133,11 +145,6 @@ class ConstantLengthDataset(IterableDataset): 
     | 
|
| 133 | 
         
             
                                    input_ids = example["input_ids"]
         
     | 
| 134 | 
         
             
                                    attention_mask = example["attention_mask"]
         
     | 
| 135 | 
         
             
                                    labels = example["labels"]
         
     | 
| 136 | 
         
            -
                                    if (
         
     | 
| 137 | 
         
            -
                                        buffer["input_ids"]
         
     | 
| 138 | 
         
            -
                                        and input_ids[0] == self.tokenizer.bos_token_id
         
     | 
| 139 | 
         
            -
                                    ):
         
     | 
| 140 | 
         
            -
                                        attention_mask[0] = 0
         
     | 
| 141 | 
         | 
| 142 | 
         
             
                                    if add_concat_token:
         
     | 
| 143 | 
         
             
                                        input_ids.append(self.concat_token_id)
         
     | 
| 
         @@ -148,13 +155,17 @@ class ConstantLengthDataset(IterableDataset): 
     | 
|
| 148 | 
         
             
                                        input_ids, dtype=self.tokens_dtype
         
     | 
| 149 | 
         
             
                                    )
         
     | 
| 150 | 
         
             
                                    attention_mask_with_concat = torch.tensor(
         
     | 
| 151 | 
         
            -
                                        attention_mask, dtype= 
     | 
| 152 | 
         
             
                                    )
         
     | 
| 153 | 
         
             
                                    labels_with_concat = torch.tensor(
         
     | 
| 154 | 
         
             
                                        labels, dtype=self.tokens_dtype
         
     | 
| 155 | 
         
             
                                    )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 156 | 
         | 
| 157 | 
         
             
                                    buffer["input_ids"].append(input_ids_with_concat)
         
     | 
| 158 | 
         
             
                                    buffer["attention_mask"].append(attention_mask_with_concat)
         
     | 
| 159 | 
         
             
                                    buffer["labels"].append(labels_with_concat)
         
     | 
| 
         | 
|
| 160 | 
         
             
                                    buffer_len += len(input_ids)
         
     | 
| 
         | 
|
| 5 | 
         
             
            from typing import List
         
     | 
| 6 | 
         | 
| 7 | 
         
             
            import torch
         
     | 
| 8 | 
         
            +
            from datasets import Dataset, IterableDataset
         
     | 
| 9 | 
         | 
| 10 | 
         
             
            from .prompt_tokenizers import PromptTokenizingStrategy
         
     | 
| 11 | 
         | 
| 
         | 
|
| 18 | 
         
             
            LOG = logging.getLogger("axolotl")
         
     | 
| 19 | 
         | 
| 20 | 
         | 
| 21 | 
         
            +
            class TokenizedPromptDataset(Dataset):
         
     | 
| 22 | 
         
             
                """
         
     | 
| 23 | 
         
            +
                Dataset that returns tokenized prompts from a stream of text files.
         
     | 
| 24 | 
         
             
                    Args:
         
     | 
| 25 | 
         
             
                        prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
         
     | 
| 26 | 
         
             
                        dataset (dataset.Dataset): Dataset with text files.
         
     | 
| 
         | 
|
| 30 | 
         
             
                    self,
         
     | 
| 31 | 
         
             
                    prompt_tokenizer: PromptTokenizingStrategy,
         
     | 
| 32 | 
         
             
                    dataset: IterableDataset,
         
     | 
| 33 | 
         
            +
                    **kwargs,
         
     | 
| 34 | 
         
             
                ):
         
     | 
| 35 | 
         
             
                    self.prompt_tokenizer = prompt_tokenizer
         
     | 
| 36 | 
         
            +
                    super().__init__(self.process(dataset).data, **kwargs)
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                def process(self, dataset):
         
     | 
| 39 | 
         
            +
                    features = dataset.features.keys()
         
     | 
| 40 | 
         
            +
                    num_proc = min(64, os.cpu_count())
         
     | 
| 41 | 
         
            +
                    return dataset.map(
         
     | 
| 42 | 
         
            +
                        self.prompt_tokenizer.tokenize_prompt,
         
     | 
| 43 | 
         
            +
                        num_proc=num_proc,
         
     | 
| 44 | 
         
            +
                        remove_columns=features,
         
     | 
| 
         | 
|
| 
         | 
|
| 45 | 
         
             
                    )
         
     | 
| 46 | 
         | 
| 47 | 
         | 
| 
         | 
|
| 76 | 
         
             
                        self.tokens_dtype = torch.int64
         
     | 
| 77 | 
         | 
| 78 | 
         
             
                def __iter__(self):
         
     | 
| 79 | 
         
            +
                    buffer = {
         
     | 
| 80 | 
         
            +
                        "input_ids": [],
         
     | 
| 81 | 
         
            +
                        "attention_mask": [],
         
     | 
| 82 | 
         
            +
                        "labels": [],
         
     | 
| 83 | 
         
            +
                        "position_ids": [],
         
     | 
| 84 | 
         
            +
                    }
         
     | 
| 85 | 
         
             
                    buffer_len = 0
         
     | 
| 86 | 
         
             
                    for dataset in self.datasets:
         
     | 
| 87 | 
         
            +
                        idx = 0
         
     | 
| 88 | 
         
             
                        iterator = iter(dataset)
         
     | 
| 89 | 
         
             
                        more_examples = True
         
     | 
| 90 | 
         
             
                        while more_examples:
         
     | 
| 91 | 
         
             
                            try:
         
     | 
| 92 | 
         
             
                                example = next(iterator)
         
     | 
| 93 | 
         
            +
                                idx += 1
         
     | 
| 94 | 
         
             
                            except StopIteration:
         
     | 
| 95 | 
         
             
                                more_examples = False
         
     | 
| 96 | 
         
             
                                example = None
         
     | 
| 
         | 
|
| 112 | 
         
             
                                    attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
         
     | 
| 113 | 
         
             
                                        : self.seq_length
         
     | 
| 114 | 
         
             
                                    ]
         
     | 
| 115 | 
         
            +
                                    position_ids = torch.cat(buffer["position_ids"], dim=-1)[
         
     | 
| 116 | 
         
            +
                                        : self.seq_length
         
     | 
| 117 | 
         
            +
                                    ]
         
     | 
| 118 | 
         
             
                                    labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
         
     | 
| 119 | 
         
             
                                    if labels.size() == input_ids.size() and (
         
     | 
| 120 | 
         
             
                                        attention_mask.size() == input_ids.size()
         
     | 
| 
         | 
|
| 123 | 
         
             
                                            "input_ids": input_ids,
         
     | 
| 124 | 
         
             
                                            "labels": labels,
         
     | 
| 125 | 
         
             
                                            "attention_mask": attention_mask,
         
     | 
| 126 | 
         
            +
                                            "position_ids": position_ids,
         
     | 
| 127 | 
         
             
                                        }
         
     | 
| 128 | 
         
             
                                    else:
         
     | 
| 129 | 
         
             
                                        LOG.warning(
         
     | 
| 
         | 
|
| 133 | 
         
             
                                    "input_ids": [],
         
     | 
| 134 | 
         
             
                                    "attention_mask": [],
         
     | 
| 135 | 
         
             
                                    "labels": [],
         
     | 
| 136 | 
         
            +
                                    "position_ids": [],
         
     | 
| 137 | 
         
             
                                }
         
     | 
| 138 | 
         
             
                                buffer_len = 0
         
     | 
| 139 | 
         
            +
                                idx = 1
         
     | 
| 140 | 
         | 
| 141 | 
         
             
                            if example:
         
     | 
| 142 | 
         
             
                                # FIXME
         
     | 
| 
         | 
|
| 145 | 
         
             
                                    input_ids = example["input_ids"]
         
     | 
| 146 | 
         
             
                                    attention_mask = example["attention_mask"]
         
     | 
| 147 | 
         
             
                                    labels = example["labels"]
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 148 | 
         | 
| 149 | 
         
             
                                    if add_concat_token:
         
     | 
| 150 | 
         
             
                                        input_ids.append(self.concat_token_id)
         
     | 
| 
         | 
|
| 155 | 
         
             
                                        input_ids, dtype=self.tokens_dtype
         
     | 
| 156 | 
         
             
                                    )
         
     | 
| 157 | 
         
             
                                    attention_mask_with_concat = torch.tensor(
         
     | 
| 158 | 
         
            +
                                        [idx * m for m in attention_mask], dtype=torch.int16
         
     | 
| 159 | 
         
             
                                    )
         
     | 
| 160 | 
         
             
                                    labels_with_concat = torch.tensor(
         
     | 
| 161 | 
         
             
                                        labels, dtype=self.tokens_dtype
         
     | 
| 162 | 
         
             
                                    )
         
     | 
| 163 | 
         
            +
                                    position_ids = torch.arange(
         
     | 
| 164 | 
         
            +
                                        len(input_ids), dtype=self.tokens_dtype
         
     | 
| 165 | 
         
            +
                                    )
         
     | 
| 166 | 
         | 
| 167 | 
         
             
                                    buffer["input_ids"].append(input_ids_with_concat)
         
     | 
| 168 | 
         
             
                                    buffer["attention_mask"].append(attention_mask_with_concat)
         
     | 
| 169 | 
         
             
                                    buffer["labels"].append(labels_with_concat)
         
     | 
| 170 | 
         
            +
                                    buffer["position_ids"].append(position_ids)
         
     | 
| 171 | 
         
             
                                    buffer_len += len(input_ids)
         
     | 
| 
         @@ -8,9 +8,18 @@ import torch 
     | 
|
| 8 | 
         
             
            import transformers
         
     | 
| 9 | 
         
             
            from einops import rearrange
         
     | 
| 10 | 
         
             
            from flash_attn.bert_padding import pad_input, unpad_input
         
     | 
| 11 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 12 | 
         
             
            from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
         
     | 
| 13 | 
         | 
| 
         | 
|
| 
         | 
|
| 14 | 
         | 
| 15 | 
         
             
            def forward(
         
     | 
| 16 | 
         
             
                self,
         
     | 
| 
         @@ -79,6 +88,16 @@ def forward( 
     | 
|
| 79 | 
         
             
                        dtype=torch.int32,
         
     | 
| 80 | 
         
             
                        device=qkv.device,
         
     | 
| 81 | 
         
             
                    )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 82 | 
         
             
                    output = flash_attn_varlen_qkvpacked_func(
         
     | 
| 83 | 
         
             
                        qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
         
     | 
| 84 | 
         
             
                    )
         
     | 
| 
         @@ -113,6 +132,7 @@ def forward( 
     | 
|
| 113 | 
         
             
                        "b s (h d) -> b s h d",
         
     | 
| 114 | 
         
             
                        h=nheads,
         
     | 
| 115 | 
         
             
                    )
         
     | 
| 
         | 
|
| 116 | 
         
             
                return (
         
     | 
| 117 | 
         
             
                    self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
         
     | 
| 118 | 
         
             
                    None,
         
     | 
| 
         | 
|
| 8 | 
         
             
            import transformers
         
     | 
| 9 | 
         
             
            from einops import rearrange
         
     | 
| 10 | 
         
             
            from flash_attn.bert_padding import pad_input, unpad_input
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            try:
         
     | 
| 13 | 
         
            +
                from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
         
     | 
| 14 | 
         
            +
            except ImportError:
         
     | 
| 15 | 
         
            +
                from flash_attn.flash_attn_interface import (
         
     | 
| 16 | 
         
            +
                    flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
         
     | 
| 17 | 
         
            +
                )
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
             
            from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
         
     | 
| 20 | 
         | 
| 21 | 
         
            +
            from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         | 
| 24 | 
         
             
            def forward(
         
     | 
| 25 | 
         
             
                self,
         
     | 
| 
         | 
|
| 88 | 
         
             
                        dtype=torch.int32,
         
     | 
| 89 | 
         
             
                        device=qkv.device,
         
     | 
| 90 | 
         
             
                    )
         
     | 
| 91 | 
         
            +
                    output = flash_attn_varlen_qkvpacked_func(
         
     | 
| 92 | 
         
            +
                        qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
         
     | 
| 93 | 
         
            +
                    )
         
     | 
| 94 | 
         
            +
                    output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
         
     | 
| 95 | 
         
            +
                elif position_ids.shape[0] == 1:
         
     | 
| 96 | 
         
            +
                    # special handling using sample packing
         
     | 
| 97 | 
         
            +
                    qkv = rearrange(qkv, "b s ... -> (b s) ...")
         
     | 
| 98 | 
         
            +
                    cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
         
     | 
| 99 | 
         
            +
                    cu_q_lens = cu_q_lens.squeeze()
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
             
                    output = flash_attn_varlen_qkvpacked_func(
         
     | 
| 102 | 
         
             
                        qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
         
     | 
| 103 | 
         
             
                    )
         
     | 
| 
         | 
|
| 132 | 
         
             
                        "b s (h d) -> b s h d",
         
     | 
| 133 | 
         
             
                        h=nheads,
         
     | 
| 134 | 
         
             
                    )
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
             
                return (
         
     | 
| 137 | 
         
             
                    self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
         
     | 
| 138 | 
         
             
                    None,
         
     | 
| 
         @@ -128,6 +128,7 @@ def xformers_forward( 
     | 
|
| 128 | 
         
             
                            query_states,
         
     | 
| 129 | 
         
             
                            key_states,
         
     | 
| 130 | 
         
             
                            value_states,
         
     | 
| 
         | 
|
| 131 | 
         
             
                            attn_bias=xformers.ops.LowerTriangularMask(),
         
     | 
| 132 | 
         
             
                        )
         
     | 
| 133 | 
         
             
                    attn_weights = None
         
     | 
| 
         | 
|
| 128 | 
         
             
                            query_states,
         
     | 
| 129 | 
         
             
                            key_states,
         
     | 
| 130 | 
         
             
                            value_states,
         
     | 
| 131 | 
         
            +
                            # attn_bias=attention_mask,
         
     | 
| 132 | 
         
             
                            attn_bias=xformers.ops.LowerTriangularMask(),
         
     | 
| 133 | 
         
             
                        )
         
     | 
| 134 | 
         
             
                    attn_weights = None
         
     | 
| 
         @@ -0,0 +1,52 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """
         
     | 
| 2 | 
         
            +
            expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
         
     | 
| 3 | 
         
            +
            """
         
     | 
| 4 | 
         
            +
            from typing import Optional
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
         
     | 
| 10 | 
         
            +
                """
         
     | 
| 11 | 
         
            +
                Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
         
     | 
| 12 | 
         
            +
                This expansion handles packed sequences so that sequences share the same attention mask integer value
         
     | 
| 13 | 
         
            +
                when they attend to each other within that sequence.
         
     | 
| 14 | 
         
            +
                This expansion transforms the mask to lower triangular form to prevent future peeking.
         
     | 
| 15 | 
         
            +
                """
         
     | 
| 16 | 
         
            +
                bsz, src_len = mask.size()
         
     | 
| 17 | 
         
            +
                tgt_len = tgt_len if tgt_len is not None else src_len
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                mask = mask.unsqueeze(1).unsqueeze(2)
         
     | 
| 20 | 
         
            +
                mask = mask.expand(bsz, 1, tgt_len, src_len)
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
         
     | 
| 23 | 
         
            +
                binary_mask = torch.where(
         
     | 
| 24 | 
         
            +
                    mask != 0,
         
     | 
| 25 | 
         
            +
                    torch.tensor(1).to(dtype),
         
     | 
| 26 | 
         
            +
                    torch.tensor(0).to(dtype),
         
     | 
| 27 | 
         
            +
                )
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                # Create a block-diagonal mask.
         
     | 
| 30 | 
         
            +
                # we multiply by the binary mask so that 0's in the original mask are correctly excluded
         
     | 
| 31 | 
         
            +
                zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                # Now let's create a lower triangular mask of ones that will zero out the upper triangular part
         
     | 
| 34 | 
         
            +
                lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
         
     | 
| 35 | 
         
            +
                    mask.device
         
     | 
| 36 | 
         
            +
                )
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                # Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
         
     | 
| 39 | 
         
            +
                masked_zero_one_mask = zero_one_mask * lower_triangular_ones
         
     | 
| 40 | 
         
            +
                inverted_mask = 1.0 - masked_zero_one_mask
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                return inverted_mask.masked_fill(
         
     | 
| 43 | 
         
            +
                    inverted_mask.to(torch.bool), torch.finfo(dtype).min
         
     | 
| 44 | 
         
            +
                )
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            def hijack_expand_mask():
         
     | 
| 48 | 
         
            +
                import transformers
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                transformers.models.llama.modeling_llama._expand_mask = (  # pylint: disable=protected-access
         
     | 
| 51 | 
         
            +
                    _expand_mask
         
     | 
| 52 | 
         
            +
                )
         
     | 
| 
         @@ -0,0 +1,103 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """
         
     | 
| 2 | 
         
            +
            Shared utils for the monkeypatches
         
     | 
| 3 | 
         
            +
            """
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            def get_cu_seqlens(attn_mask):
         
     | 
| 8 | 
         
            +
                """generate a cumulative sequence length mask for flash attention using attn mask"""
         
     | 
| 9 | 
         
            +
                if len(attn_mask.shape) == 1:
         
     | 
| 10 | 
         
            +
                    attn_mask = attn_mask.unsqueeze(0)
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                device = attn_mask.device
         
     | 
| 13 | 
         
            +
                results = []
         
     | 
| 14 | 
         
            +
                max_seq_lens = []
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                for row in attn_mask:
         
     | 
| 17 | 
         
            +
                    # Exclude zeros to avoid adding their positions to the mask
         
     | 
| 18 | 
         
            +
                    t_non_zeros = row[row != 0]
         
     | 
| 19 | 
         
            +
                    # Find where the sequence number changes (including the first position)
         
     | 
| 20 | 
         
            +
                    seq_change = torch.cat(
         
     | 
| 21 | 
         
            +
                        [
         
     | 
| 22 | 
         
            +
                            torch.tensor([1], dtype=torch.int32, device=device),
         
     | 
| 23 | 
         
            +
                            t_non_zeros[1:] != t_non_zeros[:-1],
         
     | 
| 24 | 
         
            +
                        ]
         
     | 
| 25 | 
         
            +
                    )
         
     | 
| 26 | 
         
            +
                    # Get the indices where the sequence changes
         
     | 
| 27 | 
         
            +
                    change_indices = torch.cat(
         
     | 
| 28 | 
         
            +
                        [
         
     | 
| 29 | 
         
            +
                            (seq_change == 1).nonzero(as_tuple=True)[0],
         
     | 
| 30 | 
         
            +
                            torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device),
         
     | 
| 31 | 
         
            +
                        ]
         
     | 
| 32 | 
         
            +
                    )
         
     | 
| 33 | 
         
            +
                    # Calculate the sequence lengths
         
     | 
| 34 | 
         
            +
                    seq_lengths = change_indices[1:] - change_indices[:-1]
         
     | 
| 35 | 
         
            +
                    # Calculate the length of the final sequence or padding
         
     | 
| 36 | 
         
            +
                    final_seq_length = len(row) - change_indices[-1]
         
     | 
| 37 | 
         
            +
                    # Append the length of the final sequence or padding to seq_lengths
         
     | 
| 38 | 
         
            +
                    if final_seq_length.item():
         
     | 
| 39 | 
         
            +
                        seq_lengths = torch.cat(
         
     | 
| 40 | 
         
            +
                            [
         
     | 
| 41 | 
         
            +
                                seq_lengths,
         
     | 
| 42 | 
         
            +
                                torch.tensor(
         
     | 
| 43 | 
         
            +
                                    [final_seq_length.item()], dtype=torch.int32, device=device
         
     | 
| 44 | 
         
            +
                                ),
         
     | 
| 45 | 
         
            +
                            ]
         
     | 
| 46 | 
         
            +
                        )
         
     | 
| 47 | 
         
            +
                    # Calculate the cumulative sequence lengths
         
     | 
| 48 | 
         
            +
                    cu_seqlens = torch.cat(
         
     | 
| 49 | 
         
            +
                        [torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
         
     | 
| 50 | 
         
            +
                    )
         
     | 
| 51 | 
         
            +
                    max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
         
     | 
| 52 | 
         
            +
                    results.append(cu_seqlens)
         
     | 
| 53 | 
         
            +
                    max_seq_lens.append(max_seq_len)
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            def get_cu_seqlens_from_pos_ids(position_ids):
         
     | 
| 59 | 
         
            +
                """generate a cumulative sequence length mask for flash attention using pos ids"""
         
     | 
| 60 | 
         
            +
                if len(position_ids.shape) == 1:
         
     | 
| 61 | 
         
            +
                    position_ids = position_ids.unsqueeze(0)
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                device = position_ids.device
         
     | 
| 64 | 
         
            +
                results = []
         
     | 
| 65 | 
         
            +
                max_seq_lens = []
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                for row in position_ids:
         
     | 
| 68 | 
         
            +
                    # Count the number of consecutive zeros from the right side
         
     | 
| 69 | 
         
            +
                    padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    # Adjust the row to exclude padding
         
     | 
| 72 | 
         
            +
                    adjusted_row = row[:-padding_length] if padding_length else row.clone()
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    # Find where the position resets to 0 (indicating a new sequence)
         
     | 
| 75 | 
         
            +
                    seq_starts = torch.cat(
         
     | 
| 76 | 
         
            +
                        [
         
     | 
| 77 | 
         
            +
                            torch.tensor([True], dtype=torch.bool, device=device),
         
     | 
| 78 | 
         
            +
                            adjusted_row[1:] == 0,
         
     | 
| 79 | 
         
            +
                        ]
         
     | 
| 80 | 
         
            +
                    )
         
     | 
| 81 | 
         
            +
                    # Get the indices where the sequence starts
         
     | 
| 82 | 
         
            +
                    start_indices = torch.cat(
         
     | 
| 83 | 
         
            +
                        [
         
     | 
| 84 | 
         
            +
                            (seq_starts).nonzero(as_tuple=True)[0],
         
     | 
| 85 | 
         
            +
                            torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
         
     | 
| 86 | 
         
            +
                        ]
         
     | 
| 87 | 
         
            +
                    )
         
     | 
| 88 | 
         
            +
                    # Calculate the sequence lengths
         
     | 
| 89 | 
         
            +
                    seq_lengths = start_indices[1:] - start_indices[:-1]
         
     | 
| 90 | 
         
            +
                    # Calculate the cumulative sequence lengths
         
     | 
| 91 | 
         
            +
                    cu_seqlens = torch.cat(
         
     | 
| 92 | 
         
            +
                        [torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
         
     | 
| 93 | 
         
            +
                    )
         
     | 
| 94 | 
         
            +
                    # Append the padding length to the cumulative sequence lengths
         
     | 
| 95 | 
         
            +
                    if padding_length:
         
     | 
| 96 | 
         
            +
                        cu_seqlens = torch.cat(
         
     | 
| 97 | 
         
            +
                            [cu_seqlens, torch.tensor([len(row)], dtype=torch.int32, device=device)]
         
     | 
| 98 | 
         
            +
                        )
         
     | 
| 99 | 
         
            +
                    max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
         
     | 
| 100 | 
         
            +
                    results.append(cu_seqlens)
         
     | 
| 101 | 
         
            +
                    max_seq_lens.append(max_seq_len)
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
         
     | 
| 
         @@ -66,7 +66,11 @@ class SystemDataPrompter(AlpacaPrompter): 
     | 
|
| 66 | 
         
             
                ) -> Generator[str, None, None]:
         
     | 
| 67 | 
         
             
                    # returns the full prompt from instruction and optional input
         
     | 
| 68 | 
         
             
                    # if a label (=response, =output) is provided, it's also appended.
         
     | 
| 69 | 
         
            -
                    formatted_sys_prompt =  
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 70 | 
         
             
                    if input:
         
     | 
| 71 | 
         
             
                        res = formatted_sys_prompt + self.turn_format.format(
         
     | 
| 72 | 
         
             
                            instruction=instruction, input=input
         
     | 
| 
         @@ -86,12 +90,20 @@ class OpenOrcaSystemDataPrompter(SystemDataPrompter): 
     | 
|
| 86 | 
         
             
                """
         
     | 
| 87 | 
         | 
| 88 | 
         
             
                def match_prompt_style(self):
         
     | 
| 
         | 
|
| 89 | 
         
             
                    if self.prompt_style == PromptStyle.INSTRUCT.value:
         
     | 
| 90 | 
         
             
                        self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
         
     | 
| 91 | 
         
             
                        self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
         
     | 
| 92 | 
         
             
                    if self.prompt_style == PromptStyle.CHAT.value:
         
     | 
| 93 | 
         
             
                        self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
         
     | 
| 94 | 
         
             
                        self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 95 | 
         | 
| 96 | 
         | 
| 97 | 
         
             
            class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
         
     | 
| 
         @@ -137,3 +149,12 @@ def load_open_orca(tokenizer, cfg): 
     | 
|
| 137 | 
         
             
                    cfg.train_on_inputs,
         
     | 
| 138 | 
         
             
                    cfg.sequence_len,
         
     | 
| 139 | 
         
             
                )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 66 | 
         
             
                ) -> Generator[str, None, None]:
         
     | 
| 67 | 
         
             
                    # returns the full prompt from instruction and optional input
         
     | 
| 68 | 
         
             
                    # if a label (=response, =output) is provided, it's also appended.
         
     | 
| 69 | 
         
            +
                    formatted_sys_prompt = (
         
     | 
| 70 | 
         
            +
                        self.system_format.format(system=system)
         
     | 
| 71 | 
         
            +
                        if system and self.system_format
         
     | 
| 72 | 
         
            +
                        else ""
         
     | 
| 73 | 
         
            +
                    )
         
     | 
| 74 | 
         
             
                    if input:
         
     | 
| 75 | 
         
             
                        res = formatted_sys_prompt + self.turn_format.format(
         
     | 
| 76 | 
         
             
                            instruction=instruction, input=input
         
     | 
| 
         | 
|
| 90 | 
         
             
                """
         
     | 
| 91 | 
         | 
| 92 | 
         
             
                def match_prompt_style(self):
         
     | 
| 93 | 
         
            +
                    # pylint: disable=duplicate-code
         
     | 
| 94 | 
         
             
                    if self.prompt_style == PromptStyle.INSTRUCT.value:
         
     | 
| 95 | 
         
             
                        self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
         
     | 
| 96 | 
         
             
                        self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
         
     | 
| 97 | 
         
             
                    if self.prompt_style == PromptStyle.CHAT.value:
         
     | 
| 98 | 
         
             
                        self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
         
     | 
| 99 | 
         
             
                        self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
         
     | 
| 100 | 
         
            +
                        self.system_format = "SYSTEM: {system}\n"
         
     | 
| 101 | 
         
            +
                    if self.prompt_style == PromptStyle.CHATML.value:
         
     | 
| 102 | 
         
            +
                        self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
         
     | 
| 103 | 
         
            +
                        self.turn_no_input_format = (
         
     | 
| 104 | 
         
            +
                            "<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
         
     | 
| 105 | 
         
            +
                        )
         
     | 
| 106 | 
         
            +
                        self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
         
     | 
| 107 | 
         | 
| 108 | 
         | 
| 109 | 
         
             
            class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
         
     | 
| 
         | 
|
| 149 | 
         
             
                    cfg.train_on_inputs,
         
     | 
| 150 | 
         
             
                    cfg.sequence_len,
         
     | 
| 151 | 
         
             
                )
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
            def load_open_orca_chatml(tokenizer, cfg):
         
     | 
| 155 | 
         
            +
                return OpenOrcaPromptTokenizingStrategy(
         
     | 
| 156 | 
         
            +
                    OpenOrcaSystemDataPrompter(PromptStyle.CHATML.value),
         
     | 
| 157 | 
         
            +
                    tokenizer,
         
     | 
| 158 | 
         
            +
                    cfg.train_on_inputs,
         
     | 
| 159 | 
         
            +
                    cfg.sequence_len,
         
     | 
| 160 | 
         
            +
                )
         
     | 
| 
         @@ -16,6 +16,7 @@ class PromptStyle(Enum): 
     | 
|
| 16 | 
         | 
| 17 | 
         
             
                INSTRUCT = "instruct"
         
     | 
| 18 | 
         
             
                CHAT = "chat"
         
     | 
| 
         | 
|
| 19 | 
         | 
| 20 | 
         | 
| 21 | 
         
             
            class AlpacaPrompter:
         
     | 
| 
         @@ -25,6 +26,7 @@ class AlpacaPrompter: 
     | 
|
| 25 | 
         | 
| 26 | 
         
             
                system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
         
     | 
| 27 | 
         
             
                system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
         
     | 
| 
         | 
|
| 28 | 
         
             
                turn_format: str
         
     | 
| 29 | 
         
             
                turn_no_input_format: str
         
     | 
| 30 | 
         
             
                prompt_style: Optional[PromptStyle] = None
         
     | 
| 
         @@ -34,14 +36,23 @@ class AlpacaPrompter: 
     | 
|
| 34 | 
         
             
                    self.match_prompt_style()
         
     | 
| 35 | 
         | 
| 36 | 
         
             
                def match_prompt_style(self):
         
     | 
| 
         | 
|
| 37 | 
         
             
                    if self.prompt_style == PromptStyle.INSTRUCT.value:
         
     | 
| 38 | 
         
             
                        self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
         
     | 
| 39 | 
         
             
                        self.turn_no_input_format = (
         
     | 
| 40 | 
         
             
                            "### Instruction:\n{instruction}\n\n### Response:\n"
         
     | 
| 41 | 
         
             
                        )
         
     | 
| 
         | 
|
| 42 | 
         
             
                    if self.prompt_style == PromptStyle.CHAT.value:
         
     | 
| 43 | 
         
             
                        self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
         
     | 
| 44 | 
         
             
                        self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 45 | 
         | 
| 46 | 
         
             
                def build_prompt(
         
     | 
| 47 | 
         
             
                    self,
         
     | 
| 
         | 
|
| 16 | 
         | 
| 17 | 
         
             
                INSTRUCT = "instruct"
         
     | 
| 18 | 
         
             
                CHAT = "chat"
         
     | 
| 19 | 
         
            +
                CHATML = "chatml"
         
     | 
| 20 | 
         | 
| 21 | 
         | 
| 22 | 
         
             
            class AlpacaPrompter:
         
     | 
| 
         | 
|
| 26 | 
         | 
| 27 | 
         
             
                system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
         
     | 
| 28 | 
         
             
                system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
         
     | 
| 29 | 
         
            +
                system_format: str
         
     | 
| 30 | 
         
             
                turn_format: str
         
     | 
| 31 | 
         
             
                turn_no_input_format: str
         
     | 
| 32 | 
         
             
                prompt_style: Optional[PromptStyle] = None
         
     | 
| 
         | 
|
| 36 | 
         
             
                    self.match_prompt_style()
         
     | 
| 37 | 
         | 
| 38 | 
         
             
                def match_prompt_style(self):
         
     | 
| 39 | 
         
            +
                    # pylint: disable=duplicate-code
         
     | 
| 40 | 
         
             
                    if self.prompt_style == PromptStyle.INSTRUCT.value:
         
     | 
| 41 | 
         
             
                        self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
         
     | 
| 42 | 
         
             
                        self.turn_no_input_format = (
         
     | 
| 43 | 
         
             
                            "### Instruction:\n{instruction}\n\n### Response:\n"
         
     | 
| 44 | 
         
             
                        )
         
     | 
| 45 | 
         
            +
                        self.system_format = "### System:\n{system}\n\n"
         
     | 
| 46 | 
         
             
                    if self.prompt_style == PromptStyle.CHAT.value:
         
     | 
| 47 | 
         
             
                        self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
         
     | 
| 48 | 
         
             
                        self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
         
     | 
| 49 | 
         
            +
                        self.system_format = "SYSTEM: {system}\n"
         
     | 
| 50 | 
         
            +
                    if self.prompt_style == PromptStyle.CHATML.value:
         
     | 
| 51 | 
         
            +
                        self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
         
     | 
| 52 | 
         
            +
                        self.turn_no_input_format = (
         
     | 
| 53 | 
         
            +
                            "<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
         
     | 
| 54 | 
         
            +
                        )
         
     | 
| 55 | 
         
            +
                        self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
         
     | 
| 56 | 
         | 
| 57 | 
         
             
                def build_prompt(
         
     | 
| 58 | 
         
             
                    self,
         
     | 
| 
         @@ -0,0 +1,121 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """
         
     | 
| 2 | 
         
            +
            DataCollator for axolotl to pad labels and position_ids for packed sequences
         
     | 
| 3 | 
         
            +
            """
         
     | 
| 4 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 5 | 
         
            +
            from typing import Any, Optional, Union
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import numpy as np
         
     | 
| 8 | 
         
            +
            from transformers import PreTrainedTokenizerBase
         
     | 
| 9 | 
         
            +
            from transformers.utils import PaddingStrategy
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            @dataclass
         
     | 
| 13 | 
         
            +
            class DataCollatorForSeq2Seq:
         
     | 
| 14 | 
         
            +
                """
         
     | 
| 15 | 
         
            +
                Data collator that will dynamically pad the inputs received, as well as the labels and position_ids
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                Args:
         
     | 
| 18 | 
         
            +
                    tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
         
     | 
| 19 | 
         
            +
                        The tokenizer used for encoding the data.
         
     | 
| 20 | 
         
            +
                    model ([`PreTrainedModel`]):
         
     | 
| 21 | 
         
            +
                        The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
         
     | 
| 22 | 
         
            +
                        prepare the *decoder_input_ids*
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                        This is useful when using *label_smoothing* to avoid calculating loss twice.
         
     | 
| 25 | 
         
            +
                    padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
         
     | 
| 26 | 
         
            +
                        Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
         
     | 
| 27 | 
         
            +
                        among:
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                        - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
         
     | 
| 30 | 
         
            +
                          sequence is provided).
         
     | 
| 31 | 
         
            +
                        - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
         
     | 
| 32 | 
         
            +
                          acceptable input length for the model if that argument is not provided.
         
     | 
| 33 | 
         
            +
                        - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
         
     | 
| 34 | 
         
            +
                    max_length (`int`, *optional*):
         
     | 
| 35 | 
         
            +
                        Maximum length of the returned list and optionally padding length (see above).
         
     | 
| 36 | 
         
            +
                    pad_to_multiple_of (`int`, *optional*):
         
     | 
| 37 | 
         
            +
                        If set will pad the sequence to a multiple of the provided value.
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                        This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
         
     | 
| 40 | 
         
            +
                        7.5 (Volta).
         
     | 
| 41 | 
         
            +
                    label_pad_token_id (`int`, *optional*, defaults to -100):
         
     | 
| 42 | 
         
            +
                        The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
         
     | 
| 43 | 
         
            +
                    return_tensors (`str`):
         
     | 
| 44 | 
         
            +
                        The type of Tensor to return. Allowable values are "np", "pt" and "tf".
         
     | 
| 45 | 
         
            +
                """
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                tokenizer: PreTrainedTokenizerBase
         
     | 
| 48 | 
         
            +
                model: Optional[Any] = None
         
     | 
| 49 | 
         
            +
                padding: Union[bool, str, PaddingStrategy] = True
         
     | 
| 50 | 
         
            +
                max_length: Optional[int] = None
         
     | 
| 51 | 
         
            +
                pad_to_multiple_of: Optional[int] = None
         
     | 
| 52 | 
         
            +
                label_pad_token_id: int = -100
         
     | 
| 53 | 
         
            +
                position_pad_token_id: int = 0
         
     | 
| 54 | 
         
            +
                return_tensors: str = "pt"
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                def __call__(self, features, return_tensors=None):
         
     | 
| 57 | 
         
            +
                    labels = None
         
     | 
| 58 | 
         
            +
                    if return_tensors is None:
         
     | 
| 59 | 
         
            +
                        return_tensors = self.return_tensors
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    for feature_name, pad_token_id in [
         
     | 
| 62 | 
         
            +
                        ("labels", self.label_pad_token_id),
         
     | 
| 63 | 
         
            +
                        ("position_ids", self.position_pad_token_id),
         
     | 
| 64 | 
         
            +
                    ]:
         
     | 
| 65 | 
         
            +
                        feat = (
         
     | 
| 66 | 
         
            +
                            [feature[feature_name] for feature in features]
         
     | 
| 67 | 
         
            +
                            if feature_name in features[0].keys()
         
     | 
| 68 | 
         
            +
                            else None
         
     | 
| 69 | 
         
            +
                        )
         
     | 
| 70 | 
         
            +
                        labels = feat if feat and feature_name == "labels" else labels
         
     | 
| 71 | 
         
            +
                        # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
         
     | 
| 72 | 
         
            +
                        # same length to return tensors.
         
     | 
| 73 | 
         
            +
                        if feat is not None:
         
     | 
| 74 | 
         
            +
                            max_feature_length = max(len(l) for l in feat)  # noqa: E741
         
     | 
| 75 | 
         
            +
                            if self.pad_to_multiple_of is not None:
         
     | 
| 76 | 
         
            +
                                max_feature_length = (
         
     | 
| 77 | 
         
            +
                                    (max_feature_length + self.pad_to_multiple_of - 1)
         
     | 
| 78 | 
         
            +
                                    // self.pad_to_multiple_of
         
     | 
| 79 | 
         
            +
                                    * self.pad_to_multiple_of
         
     | 
| 80 | 
         
            +
                                )
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                            padding_side = self.tokenizer.padding_side
         
     | 
| 83 | 
         
            +
                            for feature in features:
         
     | 
| 84 | 
         
            +
                                remainder = [pad_token_id] * (
         
     | 
| 85 | 
         
            +
                                    max_feature_length - len(feature[feature_name])
         
     | 
| 86 | 
         
            +
                                )
         
     | 
| 87 | 
         
            +
                                if isinstance(feature[feature_name], list):
         
     | 
| 88 | 
         
            +
                                    feature[feature_name] = (
         
     | 
| 89 | 
         
            +
                                        feature[feature_name] + remainder
         
     | 
| 90 | 
         
            +
                                        if padding_side == "right"
         
     | 
| 91 | 
         
            +
                                        else remainder + feature[feature_name]
         
     | 
| 92 | 
         
            +
                                    )
         
     | 
| 93 | 
         
            +
                                elif padding_side == "right":
         
     | 
| 94 | 
         
            +
                                    feature[feature_name] = np.concatenate(
         
     | 
| 95 | 
         
            +
                                        [feature[feature_name], remainder]
         
     | 
| 96 | 
         
            +
                                    ).astype(np.int64)
         
     | 
| 97 | 
         
            +
                                else:
         
     | 
| 98 | 
         
            +
                                    feature[feature_name] = np.concatenate(
         
     | 
| 99 | 
         
            +
                                        [remainder, feature[feature_name]]
         
     | 
| 100 | 
         
            +
                                    ).astype(np.int64)
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                    features = self.tokenizer.pad(
         
     | 
| 103 | 
         
            +
                        features,
         
     | 
| 104 | 
         
            +
                        padding=self.padding,
         
     | 
| 105 | 
         
            +
                        max_length=self.max_length,
         
     | 
| 106 | 
         
            +
                        pad_to_multiple_of=self.pad_to_multiple_of,
         
     | 
| 107 | 
         
            +
                        return_tensors=return_tensors,
         
     | 
| 108 | 
         
            +
                    )
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                    # prepare decoder_input_ids
         
     | 
| 111 | 
         
            +
                    if (
         
     | 
| 112 | 
         
            +
                        labels is not None
         
     | 
| 113 | 
         
            +
                        and self.model is not None
         
     | 
| 114 | 
         
            +
                        and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
         
     | 
| 115 | 
         
            +
                    ):
         
     | 
| 116 | 
         
            +
                        decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(
         
     | 
| 117 | 
         
            +
                            labels=features["labels"]
         
     | 
| 118 | 
         
            +
                        )
         
     | 
| 119 | 
         
            +
                        features["decoder_input_ids"] = decoder_input_ids
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                    return features
         
     | 
| 
         @@ -1,13 +1,19 @@ 
     | 
|
| 1 | 
         
             
            """Module containing data utilities"""
         
     | 
| 2 | 
         
             
            import functools
         
     | 
| 3 | 
         
            -
            import  
     | 
| 4 | 
         
             
            import logging
         
     | 
| 5 | 
         
             
            from hashlib import md5
         
     | 
| 6 | 
         
             
            from pathlib import Path
         
     | 
| 7 | 
         
            -
            from typing import  
     | 
| 8 | 
         | 
| 9 | 
         
             
            import torch
         
     | 
| 10 | 
         
            -
            from datasets import  
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 11 | 
         
             
            from huggingface_hub import hf_hub_download
         
     | 
| 12 | 
         
             
            from transformers import PreTrainedTokenizerBase
         
     | 
| 13 | 
         | 
| 
         @@ -35,6 +41,7 @@ from axolotl.prompters import ( 
     | 
|
| 35 | 
         
             
                ShareGPTPrompter,
         
     | 
| 36 | 
         
             
                SummarizeTLDRPrompter,
         
     | 
| 37 | 
         
             
            )
         
     | 
| 
         | 
|
| 38 | 
         | 
| 39 | 
         
             
            LOG = logging.getLogger("axolotl")
         
     | 
| 40 | 
         | 
| 
         @@ -109,6 +116,7 @@ def load_tokenized_prepared_datasets( 
     | 
|
| 109 | 
         
             
                        local_path = Path(d.path)
         
     | 
| 110 | 
         
             
                        if local_path.exists():
         
     | 
| 111 | 
         
             
                            if local_path.is_dir():
         
     | 
| 
         | 
|
| 112 | 
         
             
                                ds = load_dataset(
         
     | 
| 113 | 
         
             
                                    d.path,
         
     | 
| 114 | 
         
             
                                    name=d.name,
         
     | 
| 
         @@ -262,20 +270,12 @@ def load_tokenized_prepared_datasets( 
     | 
|
| 262 | 
         
             
                            raise ValueError(
         
     | 
| 263 | 
         
             
                                f"unhandled prompt tokenization strategy: {d.type} {suffix}"
         
     | 
| 264 | 
         
             
                            )
         
     | 
| 265 | 
         
            -
                    LOG.info(" 
     | 
| 266 | 
         
            -
             
     | 
| 267 | 
         
            -
             
     | 
| 268 | 
         
            -
                     
     | 
| 269 | 
         
            -
             
     | 
| 270 | 
         
            -
                         
     | 
| 271 | 
         
            -
                        while True:
         
     | 
| 272 | 
         
            -
                            chunk = list(itertools.islice(d_iter, chunk_size))
         
     | 
| 273 | 
         
            -
                            if not chunk:
         
     | 
| 274 | 
         
            -
                                break
         
     | 
| 275 | 
         
            -
                            samples.extend(chunk)
         
     | 
| 276 | 
         
            -
             
     | 
| 277 | 
         
            -
                    LOG.info("shuffle")
         
     | 
| 278 | 
         
            -
                    dataset = Dataset.from_list(samples).shuffle(seed=seed)
         
     | 
| 279 | 
         
             
                    if cfg.local_rank == 0:
         
     | 
| 280 | 
         
             
                        LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
         
     | 
| 281 | 
         
             
                        dataset.save_to_disk(prepared_ds_path)
         
     | 
| 
         @@ -374,6 +374,7 @@ def load_prepare_datasets( 
     | 
|
| 374 | 
         
             
                        dataset = Dataset.from_list(list(constant_len_dataset))
         
     | 
| 375 | 
         | 
| 376 | 
         
             
                        # filter out bad data
         
     | 
| 
         | 
|
| 377 | 
         
             
                        dataset = Dataset.from_list(
         
     | 
| 378 | 
         
             
                            [
         
     | 
| 379 | 
         
             
                                d
         
     | 
| 
         @@ -413,7 +414,51 @@ def load_prepare_datasets( 
     | 
|
| 413 | 
         
             
                    )
         
     | 
| 414 | 
         | 
| 415 | 
         
             
                if cfg.val_set_size:
         
     | 
| 416 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 417 | 
         
             
                    train_dataset = dataset["train"]
         
     | 
| 418 | 
         
             
                    eval_dataset = dataset["test"]
         
     | 
| 419 | 
         
             
                else:
         
     | 
| 
         | 
|
| 1 | 
         
             
            """Module containing data utilities"""
         
     | 
| 2 | 
         
             
            import functools
         
     | 
| 3 | 
         
            +
            import hashlib
         
     | 
| 4 | 
         
             
            import logging
         
     | 
| 5 | 
         
             
            from hashlib import md5
         
     | 
| 6 | 
         
             
            from pathlib import Path
         
     | 
| 7 | 
         
            +
            from typing import Tuple, Union
         
     | 
| 8 | 
         | 
| 9 | 
         
             
            import torch
         
     | 
| 10 | 
         
            +
            from datasets import (
         
     | 
| 11 | 
         
            +
                Dataset,
         
     | 
| 12 | 
         
            +
                DatasetDict,
         
     | 
| 13 | 
         
            +
                concatenate_datasets,
         
     | 
| 14 | 
         
            +
                load_dataset,
         
     | 
| 15 | 
         
            +
                load_from_disk,
         
     | 
| 16 | 
         
            +
            )
         
     | 
| 17 | 
         
             
            from huggingface_hub import hf_hub_download
         
     | 
| 18 | 
         
             
            from transformers import PreTrainedTokenizerBase
         
     | 
| 19 | 
         | 
| 
         | 
|
| 41 | 
         
             
                ShareGPTPrompter,
         
     | 
| 42 | 
         
             
                SummarizeTLDRPrompter,
         
     | 
| 43 | 
         
             
            )
         
     | 
| 44 | 
         
            +
            from axolotl.utils.distributed import barrier, is_main_process
         
     | 
| 45 | 
         | 
| 46 | 
         
             
            LOG = logging.getLogger("axolotl")
         
     | 
| 47 | 
         | 
| 
         | 
|
| 116 | 
         
             
                        local_path = Path(d.path)
         
     | 
| 117 | 
         
             
                        if local_path.exists():
         
     | 
| 118 | 
         
             
                            if local_path.is_dir():
         
     | 
| 119 | 
         
            +
                                # TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
         
     | 
| 120 | 
         
             
                                ds = load_dataset(
         
     | 
| 121 | 
         
             
                                    d.path,
         
     | 
| 122 | 
         
             
                                    name=d.name,
         
     | 
| 
         | 
|
| 270 | 
         
             
                            raise ValueError(
         
     | 
| 271 | 
         
             
                                f"unhandled prompt tokenization strategy: {d.type} {suffix}"
         
     | 
| 272 | 
         
             
                            )
         
     | 
| 273 | 
         
            +
                    LOG.info("merging datasets")
         
     | 
| 274 | 
         
            +
                    dataset = concatenate_datasets(datasets)
         
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
                    if len(datasets) > 1:
         
     | 
| 277 | 
         
            +
                        LOG.info("shuffle merged datasets")
         
     | 
| 278 | 
         
            +
                        dataset = dataset.shuffle(seed=seed)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 279 | 
         
             
                    if cfg.local_rank == 0:
         
     | 
| 280 | 
         
             
                        LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
         
     | 
| 281 | 
         
             
                        dataset.save_to_disk(prepared_ds_path)
         
     | 
| 
         | 
|
| 374 | 
         
             
                        dataset = Dataset.from_list(list(constant_len_dataset))
         
     | 
| 375 | 
         | 
| 376 | 
         
             
                        # filter out bad data
         
     | 
| 377 | 
         
            +
                        # TODO convert to dataset.filter(...)
         
     | 
| 378 | 
         
             
                        dataset = Dataset.from_list(
         
     | 
| 379 | 
         
             
                            [
         
     | 
| 380 | 
         
             
                                d
         
     | 
| 
         | 
|
| 414 | 
         
             
                    )
         
     | 
| 415 | 
         | 
| 416 | 
         
             
                if cfg.val_set_size:
         
     | 
| 417 | 
         
            +
                    # ensure we end up with the same fingerprint by doing rank0 first and being able to cache
         
     | 
| 418 | 
         
            +
                    to_hash_train = (
         
     | 
| 419 | 
         
            +
                        dataset._fingerprint  # pylint: disable=protected-access
         
     | 
| 420 | 
         
            +
                        + "|"
         
     | 
| 421 | 
         
            +
                        + str(cfg.val_set_size)
         
     | 
| 422 | 
         
            +
                        + "|"
         
     | 
| 423 | 
         
            +
                        + "train"
         
     | 
| 424 | 
         
            +
                        + "|"
         
     | 
| 425 | 
         
            +
                        + str(cfg.seed or 42)
         
     | 
| 426 | 
         
            +
                    )
         
     | 
| 427 | 
         
            +
                    to_hash_test = (
         
     | 
| 428 | 
         
            +
                        dataset._fingerprint  # pylint: disable=protected-access
         
     | 
| 429 | 
         
            +
                        + "|"
         
     | 
| 430 | 
         
            +
                        + str(cfg.val_set_size)
         
     | 
| 431 | 
         
            +
                        + "|"
         
     | 
| 432 | 
         
            +
                        + "test"
         
     | 
| 433 | 
         
            +
                        + "|"
         
     | 
| 434 | 
         
            +
                        + str(cfg.seed or 42)
         
     | 
| 435 | 
         
            +
                    )
         
     | 
| 436 | 
         
            +
                    train_fingerprint = hashlib.md5(
         
     | 
| 437 | 
         
            +
                        to_hash_train.encode(), usedforsecurity=False
         
     | 
| 438 | 
         
            +
                    ).hexdigest()
         
     | 
| 439 | 
         
            +
                    test_fingerprint = hashlib.md5(
         
     | 
| 440 | 
         
            +
                        to_hash_test.encode(), usedforsecurity=False
         
     | 
| 441 | 
         
            +
                    ).hexdigest()
         
     | 
| 442 | 
         
            +
             
     | 
| 443 | 
         
            +
                    if is_main_process():
         
     | 
| 444 | 
         
            +
                        dataset = dataset.train_test_split(
         
     | 
| 445 | 
         
            +
                            test_size=cfg.val_set_size,
         
     | 
| 446 | 
         
            +
                            shuffle=False,
         
     | 
| 447 | 
         
            +
                            seed=cfg.seed or 42,
         
     | 
| 448 | 
         
            +
                            train_new_fingerprint=train_fingerprint,
         
     | 
| 449 | 
         
            +
                            test_new_fingerprint=test_fingerprint,
         
     | 
| 450 | 
         
            +
                        )
         
     | 
| 451 | 
         
            +
                    barrier()
         
     | 
| 452 | 
         
            +
                    if not is_main_process():
         
     | 
| 453 | 
         
            +
                        dataset = dataset.train_test_split(
         
     | 
| 454 | 
         
            +
                            test_size=cfg.val_set_size,
         
     | 
| 455 | 
         
            +
                            shuffle=False,
         
     | 
| 456 | 
         
            +
                            seed=cfg.seed or 42,
         
     | 
| 457 | 
         
            +
                            train_new_fingerprint=train_fingerprint,
         
     | 
| 458 | 
         
            +
                            test_new_fingerprint=test_fingerprint,
         
     | 
| 459 | 
         
            +
                        )
         
     | 
| 460 | 
         
            +
                    barrier()
         
     | 
| 461 | 
         
            +
             
     | 
| 462 | 
         
             
                    train_dataset = dataset["train"]
         
     | 
| 463 | 
         
             
                    eval_dataset = dataset["test"]
         
     | 
| 464 | 
         
             
                else:
         
     | 
| 
         @@ -0,0 +1,288 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # pylint: skip-file
         
     | 
| 2 | 
         
            +
            import hashlib
         
     | 
| 3 | 
         
            +
            import itertools
         
     | 
| 4 | 
         
            +
            import logging
         
     | 
| 5 | 
         
            +
            import math
         
     | 
| 6 | 
         
            +
            from typing import Any, Callable, List, Union
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import numba
         
     | 
| 9 | 
         
            +
            import numpy as np
         
     | 
| 10 | 
         
            +
            from torch.utils.data import DistributedSampler, Sampler
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            LOG = logging.getLogger("axolotl.utils.dataloader")
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            @numba.njit
         
     | 
| 16 | 
         
            +
            def ffd_check(a: np.ndarray, c: int, n: int):
         
     | 
| 17 | 
         
            +
                # First-fit-decreasing bin packing
         
     | 
| 18 | 
         
            +
                # Check if a[] could fit in n bins with capacity c
         
     | 
| 19 | 
         
            +
                # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                a = np.sort(a)[::-1]
         
     | 
| 22 | 
         
            +
                bins = np.full((n,), c, dtype=a.dtype)
         
     | 
| 23 | 
         
            +
                for size in a:
         
     | 
| 24 | 
         
            +
                    not_found = True
         
     | 
| 25 | 
         
            +
                    for idx in range(n):
         
     | 
| 26 | 
         
            +
                        if bins[idx] >= size:
         
     | 
| 27 | 
         
            +
                            bins[idx] -= size
         
     | 
| 28 | 
         
            +
                            not_found = False
         
     | 
| 29 | 
         
            +
                            break
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                    if not_found:
         
     | 
| 32 | 
         
            +
                        return False
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                return True
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            @numba.njit
         
     | 
| 38 | 
         
            +
            def ffd_with_result(a: np.ndarray, c: int, start_index: int):
         
     | 
| 39 | 
         
            +
                # First-fit-decreasing bin packing (with result return)
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                indices = np.argsort(a)[::-1]
         
     | 
| 42 | 
         
            +
                a = a[indices]
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                bins: List[Any] = []
         
     | 
| 45 | 
         
            +
                bins_result: List[Any] = []
         
     | 
| 46 | 
         
            +
                for a_id, size in enumerate(a):
         
     | 
| 47 | 
         
            +
                    add_new = True
         
     | 
| 48 | 
         
            +
                    for idx in range(len(bins)):
         
     | 
| 49 | 
         
            +
                        if bins[idx] >= size:
         
     | 
| 50 | 
         
            +
                            bins[idx] -= size
         
     | 
| 51 | 
         
            +
                            bins_result[idx].append(indices[a_id] + start_index)
         
     | 
| 52 | 
         
            +
                            add_new = False
         
     | 
| 53 | 
         
            +
                            break
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                    if add_new:
         
     | 
| 56 | 
         
            +
                        bins.append(c - size)
         
     | 
| 57 | 
         
            +
                        bins_result.append([indices[a_id] + start_index])
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                return bins_result, len(a)
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
            @numba.njit
         
     | 
| 63 | 
         
            +
            def allocate(
         
     | 
| 64 | 
         
            +
                lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
         
     | 
| 65 | 
         
            +
            ):
         
     | 
| 66 | 
         
            +
                """
         
     | 
| 67 | 
         
            +
                :param lengths: array of lengths of each sample
         
     | 
| 68 | 
         
            +
                :param lengths_cumsum: cumulative sum of consecutive lengths
         
     | 
| 69 | 
         
            +
                :param rank: rank for this process
         
     | 
| 70 | 
         
            +
                :param c: length of tokens per batch
         
     | 
| 71 | 
         
            +
                :param n: number of ranks
         
     | 
| 72 | 
         
            +
                :return:
         
     | 
| 73 | 
         
            +
                """
         
     | 
| 74 | 
         
            +
                # Dynamic batch allocator, similar to Multifit
         
     | 
| 75 | 
         
            +
                # https://en.wikipedia.org/wiki/Multifit_algorithm
         
     | 
| 76 | 
         
            +
                # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                s = 0
         
     | 
| 79 | 
         
            +
                start_index = 0
         
     | 
| 80 | 
         
            +
                result = []
         
     | 
| 81 | 
         
            +
                result_totseqs = []
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                while True:
         
     | 
| 84 | 
         
            +
                    # binary search [left, right)
         
     | 
| 85 | 
         
            +
                    left = 1
         
     | 
| 86 | 
         
            +
                    right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    while right - left > 1:
         
     | 
| 89 | 
         
            +
                        mid = (left + right) // 2
         
     | 
| 90 | 
         
            +
                        if ffd_check(lengths[start_index : start_index + mid], c, n):
         
     | 
| 91 | 
         
            +
                            left = mid
         
     | 
| 92 | 
         
            +
                        else:
         
     | 
| 93 | 
         
            +
                            right = mid
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                    # use length left
         
     | 
| 96 | 
         
            +
                    batch, tot_seqs = ffd_with_result(
         
     | 
| 97 | 
         
            +
                        lengths[start_index : start_index + left], c, start_index
         
     | 
| 98 | 
         
            +
                    )
         
     | 
| 99 | 
         
            +
                    if len(batch) < n:
         
     | 
| 100 | 
         
            +
                        break
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                    start_index += left
         
     | 
| 103 | 
         
            +
                    s = lengths_cumsum[start_index - 1]
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                    # add local rank
         
     | 
| 106 | 
         
            +
                    result.append(batch[rank])
         
     | 
| 107 | 
         
            +
                    # add total seqs for all ranks
         
     | 
| 108 | 
         
            +
                    result_totseqs.append(tot_seqs)
         
     | 
| 109 | 
         
            +
                    # yield batch[rank], tot_seqs, s, len(result) * c * n
         
     | 
| 110 | 
         
            +
                return result, result_totseqs, s, len(result) * c * n
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
            def chunk(iterable, n):
         
     | 
| 114 | 
         
            +
                """
         
     | 
| 115 | 
         
            +
                Chunk data into tuples of length n
         
     | 
| 116 | 
         
            +
                """
         
     | 
| 117 | 
         
            +
                # batched('ABCDEFG', 3) --> ABC DEF G
         
     | 
| 118 | 
         
            +
                if n < 1:
         
     | 
| 119 | 
         
            +
                    raise ValueError("n must be at least one")
         
     | 
| 120 | 
         
            +
                it = iter(iterable)
         
     | 
| 121 | 
         
            +
                while batch := tuple(itertools.islice(it, n)):
         
     | 
| 122 | 
         
            +
                    yield batch
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
            def hash_indices(lst: List[int]) -> str:
         
     | 
| 126 | 
         
            +
                # Convert the list of integers to a string representation
         
     | 
| 127 | 
         
            +
                concatenated = ",".join(map(str, lst))
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                # Generate the hash
         
     | 
| 130 | 
         
            +
                sha256 = hashlib.sha256()
         
     | 
| 131 | 
         
            +
                sha256.update(concatenated.encode())
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                return sha256.hexdigest()
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
            class MultipackDistributedDataloader:
         
     | 
| 137 | 
         
            +
                """Unpadded data loading using Multipack.
         
     | 
| 138 | 
         
            +
                Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
         
     | 
| 139 | 
         
            +
                Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.
         
     | 
| 140 | 
         
            +
                """
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                def __init__(
         
     | 
| 143 | 
         
            +
                    self,
         
     | 
| 144 | 
         
            +
                    dataset: Any,
         
     | 
| 145 | 
         
            +
                    collate_fn: Callable,
         
     | 
| 146 | 
         
            +
                    seq_max_length: int = 2048,
         
     | 
| 147 | 
         
            +
                    batch_size: int = 1,
         
     | 
| 148 | 
         
            +
                    sampler: Union[Sampler, DistributedSampler] = None,
         
     | 
| 149 | 
         
            +
                    packing_efficiency_estimate: float = 1.0,
         
     | 
| 150 | 
         
            +
                    sample_packing_seq_len_multiplier: int = 1,
         
     | 
| 151 | 
         
            +
                    device_count: int = 1,
         
     | 
| 152 | 
         
            +
                ):
         
     | 
| 153 | 
         
            +
                    # Dataset
         
     | 
| 154 | 
         
            +
                    self.dataset = dataset
         
     | 
| 155 | 
         
            +
                    self.lengths = (
         
     | 
| 156 | 
         
            +
                        dataset.data.column("position_ids")
         
     | 
| 157 | 
         
            +
                        .to_pandas()
         
     | 
| 158 | 
         
            +
                        .apply(lambda x: x[-1] + 1)
         
     | 
| 159 | 
         
            +
                        .values
         
     | 
| 160 | 
         
            +
                    )
         
     | 
| 161 | 
         
            +
                    assert isinstance(self.lengths, np.ndarray)
         
     | 
| 162 | 
         
            +
                    assert batch_size % sample_packing_seq_len_multiplier == 0
         
     | 
| 163 | 
         
            +
                    assert batch_size >= sample_packing_seq_len_multiplier
         
     | 
| 164 | 
         
            +
                    self.sampler = sampler
         
     | 
| 165 | 
         
            +
                    self.batch_size = batch_size
         
     | 
| 166 | 
         
            +
                    self.sample_packing_seq_len_multiplier = sample_packing_seq_len_multiplier
         
     | 
| 167 | 
         
            +
                    self.seq_max_length = seq_max_length
         
     | 
| 168 | 
         
            +
                    self.batch_max_length = batch_size * seq_max_length
         
     | 
| 169 | 
         
            +
                    self.collate_fn = collate_fn
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                    self.num_replicas = 1
         
     | 
| 172 | 
         
            +
                    self.rank = 0
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                    # statistics
         
     | 
| 175 | 
         
            +
                    self.eff_total_used = 0
         
     | 
| 176 | 
         
            +
                    self.eff_total_slots = 0
         
     | 
| 177 | 
         
            +
                    self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
         
     | 
| 178 | 
         
            +
                    self.device_count = device_count
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                def generate_batches(self, set_stats=False):
         
     | 
| 181 | 
         
            +
                    LOG.info("generating packed batches")
         
     | 
| 182 | 
         
            +
                    if self.sampler:
         
     | 
| 183 | 
         
            +
                        indices = [idx for idx in self.sampler]
         
     | 
| 184 | 
         
            +
                    else:
         
     | 
| 185 | 
         
            +
                        indices = range(0, len(self.dataset))
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                    LOG.info(hash_indices(indices))
         
     | 
| 188 | 
         
            +
                    lengths = self.lengths[indices]
         
     | 
| 189 | 
         
            +
                    lengths_cumsum = np.cumsum(lengths)
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                    batches, totseqs, total_used, total_slots = allocate(
         
     | 
| 192 | 
         
            +
                        lengths=lengths,
         
     | 
| 193 | 
         
            +
                        lengths_cumsum=lengths_cumsum,
         
     | 
| 194 | 
         
            +
                        rank=self.rank,
         
     | 
| 195 | 
         
            +
                        # c=self.batch_max_length,
         
     | 
| 196 | 
         
            +
                        c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
         
     | 
| 197 | 
         
            +
                        n=self.num_replicas,
         
     | 
| 198 | 
         
            +
                    )
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                    batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                    # statistics
         
     | 
| 203 | 
         
            +
                    if set_stats:
         
     | 
| 204 | 
         
            +
                        self.eff_total_used += total_used
         
     | 
| 205 | 
         
            +
                        self.eff_total_slots += total_slots
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                    return batches, totseqs
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
                def __iter__(self):
         
     | 
| 210 | 
         
            +
                    if hasattr(self.sampler, "set_epoch"):
         
     | 
| 211 | 
         
            +
                        new_epoch = self.sampler.epoch + 1
         
     | 
| 212 | 
         
            +
                        self.sampler.set_epoch(new_epoch)
         
     | 
| 213 | 
         
            +
                        LOG.info(f"calling sampler.set_epoch({new_epoch})")
         
     | 
| 214 | 
         
            +
                    all_batches, _ = self.generate_batches(set_stats=True)
         
     | 
| 215 | 
         
            +
                    features = self.dataset.features.keys()
         
     | 
| 216 | 
         
            +
                    len_remaining = self._len_est()
         
     | 
| 217 | 
         
            +
                    for batches in chunk(
         
     | 
| 218 | 
         
            +
                        all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
         
     | 
| 219 | 
         
            +
                    ):
         
     | 
| 220 | 
         
            +
                        chunked_data = []
         
     | 
| 221 | 
         
            +
                        attn_mask_cum_idx = 0
         
     | 
| 222 | 
         
            +
                        for batch in batches:
         
     | 
| 223 | 
         
            +
                            concatenated = {}
         
     | 
| 224 | 
         
            +
                            batched_data = [self.dataset[batch_idx] for batch_idx in batch]
         
     | 
| 225 | 
         
            +
                            for feature in features:
         
     | 
| 226 | 
         
            +
                                if feature == "attention_mask":
         
     | 
| 227 | 
         
            +
                                    arrays = [
         
     | 
| 228 | 
         
            +
                                        (attn_mask_cum_idx + idx + 1) * np.array(item[feature])
         
     | 
| 229 | 
         
            +
                                        for idx, item in enumerate(batched_data)
         
     | 
| 230 | 
         
            +
                                        if feature in item
         
     | 
| 231 | 
         
            +
                                    ]
         
     | 
| 232 | 
         
            +
                                    attn_mask_cum_idx += len(batched_data)
         
     | 
| 233 | 
         
            +
                                    concatenated[feature] = np.concatenate(arrays)
         
     | 
| 234 | 
         
            +
                                else:
         
     | 
| 235 | 
         
            +
                                    arrays = [
         
     | 
| 236 | 
         
            +
                                        np.array(item[feature])
         
     | 
| 237 | 
         
            +
                                        for item in batched_data
         
     | 
| 238 | 
         
            +
                                        if feature in item
         
     | 
| 239 | 
         
            +
                                    ]
         
     | 
| 240 | 
         
            +
                                    concatenated[feature] = np.concatenate(arrays)
         
     | 
| 241 | 
         
            +
                            chunked_data.append(concatenated)
         
     | 
| 242 | 
         
            +
                        yield self.collate_fn(chunked_data)
         
     | 
| 243 | 
         
            +
                        len_remaining -= 1
         
     | 
| 244 | 
         
            +
                        if not len_remaining:
         
     | 
| 245 | 
         
            +
                            return
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                def _len_est(self):
         
     | 
| 248 | 
         
            +
                    lengths_sum = np.sum(self.lengths)
         
     | 
| 249 | 
         
            +
                    lengths_sum_per_device = lengths_sum // self.device_count
         
     | 
| 250 | 
         
            +
                    LOG.info(
         
     | 
| 251 | 
         
            +
                        f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
         
     | 
| 252 | 
         
            +
                        f"total_num_tokens per device: {lengths_sum_per_device}"
         
     | 
| 253 | 
         
            +
                    )
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
                    # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
         
     | 
| 256 | 
         
            +
                    return (
         
     | 
| 257 | 
         
            +
                        math.floor(
         
     | 
| 258 | 
         
            +
                            0.99
         
     | 
| 259 | 
         
            +
                            * lengths_sum_per_device
         
     | 
| 260 | 
         
            +
                            / self.packing_efficiency_estimate
         
     | 
| 261 | 
         
            +
                            // self.seq_max_length
         
     | 
| 262 | 
         
            +
                            // self.batch_size
         
     | 
| 263 | 
         
            +
                        )
         
     | 
| 264 | 
         
            +
                        - 1
         
     | 
| 265 | 
         
            +
                    )
         
     | 
| 266 | 
         
            +
             
     | 
| 267 | 
         
            +
                def __len__(self):
         
     | 
| 268 | 
         
            +
                    # this doesn't return the actual length b/c with distributed samplers, not all dataloaders get
         
     | 
| 269 | 
         
            +
                    # the same share of total tokens
         
     | 
| 270 | 
         
            +
                    # if not self.eff_total_used:
         
     | 
| 271 | 
         
            +
                    #     batches, _ = self.generate_batches(set_stats=True)
         
     | 
| 272 | 
         
            +
                    # LOG.info(
         
     | 
| 273 | 
         
            +
                    #     f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
         
     | 
| 274 | 
         
            +
                    #     f"actual packing efficiency: {self.efficiency()}"
         
     | 
| 275 | 
         
            +
                    # )
         
     | 
| 276 | 
         
            +
                    return max(1, self._len_est())
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                def len_w_stats(self):
         
     | 
| 279 | 
         
            +
                    if not self.eff_total_used:
         
     | 
| 280 | 
         
            +
                        batches, _ = self.generate_batches(set_stats=True)
         
     | 
| 281 | 
         
            +
                    LOG.info(
         
     | 
| 282 | 
         
            +
                        f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
         
     | 
| 283 | 
         
            +
                        f"actual packing efficiency: {self.efficiency()}"
         
     | 
| 284 | 
         
            +
                    )
         
     | 
| 285 | 
         
            +
                    return max(1, self._len_est())
         
     | 
| 286 | 
         
            +
             
     | 
| 287 | 
         
            +
                def efficiency(self):
         
     | 
| 288 | 
         
            +
                    return self.eff_total_used / self.eff_total_slots
         
     | 
| 
         @@ -0,0 +1,41 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """
         
     | 
| 2 | 
         
            +
            utility helpers for distributed checks
         
     | 
| 3 | 
         
            +
            """
         
     | 
| 4 | 
         
            +
            import torch.distributed as dist
         
     | 
| 5 | 
         
            +
            from accelerate import Accelerator
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            accelerate = None  # pylint: disable=invalid-name
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            def load_accelerate():
         
     | 
| 11 | 
         
            +
                global accelerate  # pylint: disable=global-statement
         
     | 
| 12 | 
         
            +
                accelerate = Accelerator()
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            def is_distributed():
         
     | 
| 16 | 
         
            +
                """
         
     | 
| 17 | 
         
            +
                Check if distributed training is initialized.
         
     | 
| 18 | 
         
            +
                """
         
     | 
| 19 | 
         
            +
                global accelerate  # pylint: disable=global-statement
         
     | 
| 20 | 
         
            +
                if not accelerate:
         
     | 
| 21 | 
         
            +
                    accelerate = Accelerator()
         
     | 
| 22 | 
         
            +
                return dist.is_available() and dist.is_initialized()
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            def barrier():
         
     | 
| 26 | 
         
            +
                """
         
     | 
| 27 | 
         
            +
                Acts as a barrier to wait for all processes. This ensures that all processes
         
     | 
| 28 | 
         
            +
                reach the barrier before proceeding further.
         
     | 
| 29 | 
         
            +
                """
         
     | 
| 30 | 
         
            +
                if is_distributed():
         
     | 
| 31 | 
         
            +
                    dist.barrier()
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            def is_main_process():
         
     | 
| 35 | 
         
            +
                """
         
     | 
| 36 | 
         
            +
                Check if the current process is the main process.
         
     | 
| 37 | 
         
            +
                If not in distributed mode, always return True.
         
     | 
| 38 | 
         
            +
                """
         
     | 
| 39 | 
         
            +
                if not is_distributed():
         
     | 
| 40 | 
         
            +
                    return True
         
     | 
| 41 | 
         
            +
                return dist.get_rank() == 0
         
     | 
| 
         @@ -37,20 +37,26 @@ def load_tokenizer( 
     | 
|
| 37 | 
         
             
                tokenizer_type,
         
     | 
| 38 | 
         
             
                cfg,
         
     | 
| 39 | 
         
             
            ):
         
     | 
| 
         | 
|
| 40 | 
         
             
                use_fast = True  # this is the default
         
     | 
| 41 | 
         
             
                if cfg.tokenizer_use_fast is not None:
         
     | 
| 42 | 
         
             
                    use_fast = cfg.tokenizer_use_fast
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 43 | 
         
             
                if tokenizer_type:
         
     | 
| 44 | 
         
             
                    tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
         
     | 
| 45 | 
         
             
                        tokenizer_config,
         
     | 
| 46 | 
         
             
                        trust_remote_code=cfg.trust_remote_code or False,
         
     | 
| 47 | 
         
             
                        use_fast=use_fast,
         
     | 
| 
         | 
|
| 48 | 
         
             
                    )
         
     | 
| 49 | 
         
             
                else:
         
     | 
| 50 | 
         
             
                    tokenizer = AutoTokenizer.from_pretrained(
         
     | 
| 51 | 
         
             
                        tokenizer_config,
         
     | 
| 52 | 
         
             
                        trust_remote_code=cfg.trust_remote_code or False,
         
     | 
| 53 | 
         
             
                        use_fast=use_fast,
         
     | 
| 
         | 
|
| 54 | 
         
             
                    )
         
     | 
| 55 | 
         | 
| 56 | 
         
             
                LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
         
     | 
| 
         @@ -90,8 +96,10 @@ def load_model( 
     | 
|
| 90 | 
         | 
| 91 | 
         
             
                # TODO refactor as a kwarg
         
     | 
| 92 | 
         
             
                load_in_8bit = cfg.load_in_8bit
         
     | 
| 93 | 
         
            -
                cfg.is_llama_derived_model =  
     | 
| 94 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 95 | 
         
             
                )
         
     | 
| 96 | 
         | 
| 97 | 
         
             
                if cfg.is_llama_derived_model and cfg.flash_attention:
         
     | 
| 
         @@ -136,6 +144,14 @@ def load_model( 
     | 
|
| 136 | 
         
             
                    LOG.info("patching with xpos rope")
         
     | 
| 137 | 
         
             
                    replace_llama_rope_with_xpos_rope()
         
     | 
| 138 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 139 | 
         
             
                if cfg.bf16 or cfg.bfloat16:
         
     | 
| 140 | 
         
             
                    torch_dtype = torch.bfloat16
         
     | 
| 141 | 
         
             
                elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
         
     | 
| 
         @@ -228,7 +244,6 @@ def load_model( 
     | 
|
| 228 | 
         
             
                            load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
         
     | 
| 229 | 
         
             
                            load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
         
     | 
| 230 | 
         
             
                            torch_dtype=torch_dtype,
         
     | 
| 231 | 
         
            -
                            device_map="auto" if cfg.world_size == 1 else cfg.device_map,
         
     | 
| 232 | 
         
             
                            **model_kwargs,
         
     | 
| 233 | 
         
             
                        )
         
     | 
| 234 | 
         
             
                    # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
         
     | 
| 
         @@ -263,7 +278,6 @@ def load_model( 
     | 
|
| 263 | 
         
             
                            load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
         
     | 
| 264 | 
         
             
                            load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
         
     | 
| 265 | 
         
             
                            torch_dtype=torch_dtype,
         
     | 
| 266 | 
         
            -
                            device_map=cfg.device_map,
         
     | 
| 267 | 
         
             
                            trust_remote_code=cfg.trust_remote_code or False,
         
     | 
| 268 | 
         
             
                            **model_kwargs,
         
     | 
| 269 | 
         
             
                        )
         
     | 
| 
         @@ -294,7 +308,6 @@ def load_model( 
     | 
|
| 294 | 
         
             
                            load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
         
     | 
| 295 | 
         
             
                            load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
         
     | 
| 296 | 
         
             
                            torch_dtype=torch_dtype,
         
     | 
| 297 | 
         
            -
                            device_map=cfg.device_map,
         
     | 
| 298 | 
         
             
                            trust_remote_code=cfg.trust_remote_code or False,
         
     | 
| 299 | 
         
             
                            **model_kwargs,
         
     | 
| 300 | 
         
             
                        )
         
     | 
| 
         @@ -308,7 +321,6 @@ def load_model( 
     | 
|
| 308 | 
         
             
                        load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
         
     | 
| 309 | 
         
             
                        load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
         
     | 
| 310 | 
         
             
                        torch_dtype=torch_dtype,
         
     | 
| 311 | 
         
            -
                        device_map=cfg.device_map,
         
     | 
| 312 | 
         
             
                        trust_remote_code=cfg.trust_remote_code or False,
         
     | 
| 313 | 
         
             
                        **model_kwargs,
         
     | 
| 314 | 
         
             
                    )
         
     | 
| 
         | 
|
| 37 | 
         
             
                tokenizer_type,
         
     | 
| 38 | 
         
             
                cfg,
         
     | 
| 39 | 
         
             
            ):
         
     | 
| 40 | 
         
            +
                tokenizer_kwargs = {}
         
     | 
| 41 | 
         
             
                use_fast = True  # this is the default
         
     | 
| 42 | 
         
             
                if cfg.tokenizer_use_fast is not None:
         
     | 
| 43 | 
         
             
                    use_fast = cfg.tokenizer_use_fast
         
     | 
| 44 | 
         
            +
                if cfg.tokenizer_legacy is not None:
         
     | 
| 45 | 
         
            +
                    # True is the default w/ https://github.com/huggingface/transformers/pull/25224
         
     | 
| 46 | 
         
            +
                    tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
         
     | 
| 47 | 
         
             
                if tokenizer_type:
         
     | 
| 48 | 
         
             
                    tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
         
     | 
| 49 | 
         
             
                        tokenizer_config,
         
     | 
| 50 | 
         
             
                        trust_remote_code=cfg.trust_remote_code or False,
         
     | 
| 51 | 
         
             
                        use_fast=use_fast,
         
     | 
| 52 | 
         
            +
                        **tokenizer_kwargs,
         
     | 
| 53 | 
         
             
                    )
         
     | 
| 54 | 
         
             
                else:
         
     | 
| 55 | 
         
             
                    tokenizer = AutoTokenizer.from_pretrained(
         
     | 
| 56 | 
         
             
                        tokenizer_config,
         
     | 
| 57 | 
         
             
                        trust_remote_code=cfg.trust_remote_code or False,
         
     | 
| 58 | 
         
             
                        use_fast=use_fast,
         
     | 
| 59 | 
         
            +
                        **tokenizer_kwargs,
         
     | 
| 60 | 
         
             
                    )
         
     | 
| 61 | 
         | 
| 62 | 
         
             
                LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
         
     | 
| 
         | 
|
| 96 | 
         | 
| 97 | 
         
             
                # TODO refactor as a kwarg
         
     | 
| 98 | 
         
             
                load_in_8bit = cfg.load_in_8bit
         
     | 
| 99 | 
         
            +
                cfg.is_llama_derived_model = (
         
     | 
| 100 | 
         
            +
                    "llama" in base_model
         
     | 
| 101 | 
         
            +
                    or (cfg.model_type and "llama" in cfg.model_type.lower())
         
     | 
| 102 | 
         
            +
                    or cfg.is_llama_derived_model
         
     | 
| 103 | 
         
             
                )
         
     | 
| 104 | 
         | 
| 105 | 
         
             
                if cfg.is_llama_derived_model and cfg.flash_attention:
         
     | 
| 
         | 
|
| 144 | 
         
             
                    LOG.info("patching with xpos rope")
         
     | 
| 145 | 
         
             
                    replace_llama_rope_with_xpos_rope()
         
     | 
| 146 | 
         | 
| 147 | 
         
            +
                if cfg.is_llama_derived_model and (
         
     | 
| 148 | 
         
            +
                    cfg.max_packed_sequence_len or cfg.sample_packing
         
     | 
| 149 | 
         
            +
                ):
         
     | 
| 150 | 
         
            +
                    from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                    LOG.info("patching _expand_mask")
         
     | 
| 153 | 
         
            +
                    hijack_expand_mask()
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
             
                if cfg.bf16 or cfg.bfloat16:
         
     | 
| 156 | 
         
             
                    torch_dtype = torch.bfloat16
         
     | 
| 157 | 
         
             
                elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
         
     | 
| 
         | 
|
| 244 | 
         
             
                            load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
         
     | 
| 245 | 
         
             
                            load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
         
     | 
| 246 | 
         
             
                            torch_dtype=torch_dtype,
         
     | 
| 
         | 
|
| 247 | 
         
             
                            **model_kwargs,
         
     | 
| 248 | 
         
             
                        )
         
     | 
| 249 | 
         
             
                    # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
         
     | 
| 
         | 
|
| 278 | 
         
             
                            load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
         
     | 
| 279 | 
         
             
                            load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
         
     | 
| 280 | 
         
             
                            torch_dtype=torch_dtype,
         
     | 
| 
         | 
|
| 281 | 
         
             
                            trust_remote_code=cfg.trust_remote_code or False,
         
     | 
| 282 | 
         
             
                            **model_kwargs,
         
     | 
| 283 | 
         
             
                        )
         
     | 
| 
         | 
|
| 308 | 
         
             
                            load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
         
     | 
| 309 | 
         
             
                            load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
         
     | 
| 310 | 
         
             
                            torch_dtype=torch_dtype,
         
     | 
| 
         | 
|
| 311 | 
         
             
                            trust_remote_code=cfg.trust_remote_code or False,
         
     | 
| 312 | 
         
             
                            **model_kwargs,
         
     | 
| 313 | 
         
             
                        )
         
     | 
| 
         | 
|
| 321 | 
         
             
                        load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
         
     | 
| 322 | 
         
             
                        load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
         
     | 
| 323 | 
         
             
                        torch_dtype=torch_dtype,
         
     | 
| 
         | 
|
| 324 | 
         
             
                        trust_remote_code=cfg.trust_remote_code or False,
         
     | 
| 325 | 
         
             
                        **model_kwargs,
         
     | 
| 326 | 
         
             
                    )
         
     | 
| 
         @@ -1,19 +1,23 @@ 
     | 
|
| 1 | 
         
             
            """Module containing the Trainer class and related functions"""
         
     | 
| 2 | 
         
            -
             
     | 
| 3 | 
         
             
            import importlib
         
     | 
| 4 | 
         
             
            import logging
         
     | 
| 5 | 
         
             
            import math
         
     | 
| 6 | 
         
             
            import os
         
     | 
| 7 | 
         
             
            import sys
         
     | 
| 
         | 
|
| 8 | 
         
             
            from dataclasses import dataclass, field
         
     | 
| 
         | 
|
| 9 | 
         
             
            from pathlib import Path
         
     | 
| 10 | 
         
            -
            from typing import Optional
         
     | 
| 11 | 
         | 
| 12 | 
         
             
            import bitsandbytes as bnb
         
     | 
| 
         | 
|
| 13 | 
         
             
            import torch.cuda
         
     | 
| 14 | 
         
             
            import transformers
         
     | 
| 
         | 
|
| 15 | 
         
             
            from torch import nn
         
     | 
| 16 | 
         
             
            from torch.optim.lr_scheduler import OneCycleLR
         
     | 
| 
         | 
|
| 17 | 
         
             
            from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
         
     | 
| 18 | 
         
             
            from transformers.trainer_pt_utils import get_parameter_names
         
     | 
| 19 | 
         | 
| 
         @@ -22,6 +26,8 @@ from axolotl.utils.callbacks import ( 
     | 
|
| 22 | 
         
             
                SaveBetterTransformerModelCallback,
         
     | 
| 23 | 
         
             
                SavePeftModelCallback,
         
     | 
| 24 | 
         
             
            )
         
     | 
| 
         | 
|
| 
         | 
|
| 25 | 
         
             
            from axolotl.utils.schedulers import (
         
     | 
| 26 | 
         
             
                InterpolatingLogScheduler,
         
     | 
| 27 | 
         
             
                get_cosine_schedule_with_quadratic_warmup,
         
     | 
| 
         @@ -30,6 +36,68 @@ from axolotl.utils.schedulers import ( 
     | 
|
| 30 | 
         
             
            LOG = logging.getLogger("axolotl")
         
     | 
| 31 | 
         | 
| 32 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 33 | 
         
             
            @dataclass
         
     | 
| 34 | 
         
             
            class AxolotlTrainingArguments(TrainingArguments):
         
     | 
| 35 | 
         
             
                """
         
     | 
| 
         @@ -40,6 +108,22 @@ class AxolotlTrainingArguments(TrainingArguments): 
     | 
|
| 40 | 
         
             
                    default=False,
         
     | 
| 41 | 
         
             
                    metadata={"help": "Use quadratic warmup for cosine scheduling."},
         
     | 
| 42 | 
         
             
                )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 43 | 
         | 
| 44 | 
         | 
| 45 | 
         
             
            class AxolotlTrainer(Trainer):
         
     | 
| 
         @@ -77,6 +161,64 @@ class AxolotlTrainer(Trainer): 
     | 
|
| 77 | 
         
             
                            return super().create_scheduler(num_training_steps, optimizer)
         
     | 
| 78 | 
         
             
                    return self.lr_scheduler
         
     | 
| 79 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 80 | 
         | 
| 81 | 
         
             
            class OneCycleLRSchedulerTrainer(AxolotlTrainer):
         
     | 
| 82 | 
         
             
                """
         
     | 
| 
         @@ -107,10 +249,121 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer): 
     | 
|
| 107 | 
         
             
                    return self.lr_scheduler
         
     | 
| 108 | 
         | 
| 109 | 
         | 
| 110 | 
         
            -
            def  
     | 
| 111 | 
         
            -
                 
     | 
| 112 | 
         
            -
             
     | 
| 113 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 114 | 
         
             
                warmup_steps = (
         
     | 
| 115 | 
         
             
                    cfg.warmup_steps
         
     | 
| 116 | 
         
             
                    if cfg.warmup_steps is not None
         
     | 
| 
         @@ -190,7 +443,14 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): 
     | 
|
| 190 | 
         
             
                if cfg.save_safetensors:
         
     | 
| 191 | 
         
             
                    training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
         
     | 
| 192 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 193 | 
         
             
                training_args = AxolotlTrainingArguments(  # pylint: disable=unexpected-keyword-arg
         
     | 
| 
         | 
|
| 
         | 
|
| 194 | 
         
             
                    per_device_train_batch_size=cfg.micro_batch_size,
         
     | 
| 195 | 
         
             
                    per_device_eval_batch_size=cfg.eval_batch_size
         
     | 
| 196 | 
         
             
                    if cfg.eval_batch_size is not None
         
     | 
| 
         @@ -204,7 +464,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): 
     | 
|
| 204 | 
         
             
                    eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
         
     | 
| 205 | 
         
             
                    save_steps=cfg.save_steps,
         
     | 
| 206 | 
         
             
                    output_dir=cfg.output_dir,
         
     | 
| 207 | 
         
            -
                    save_total_limit= 
     | 
| 208 | 
         
             
                    load_best_model_at_end=(
         
     | 
| 209 | 
         
             
                        cfg.load_best_model_at_end is not False
         
     | 
| 210 | 
         
             
                        and cfg.val_set_size > 0
         
     | 
| 
         @@ -222,6 +482,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): 
     | 
|
| 222 | 
         
             
                    if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep")
         
     | 
| 223 | 
         
             
                    else "cosine",
         
     | 
| 224 | 
         
             
                    weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
         
     | 
| 
         | 
|
| 
         | 
|
| 225 | 
         
             
                    **training_arguments_kwargs,
         
     | 
| 226 | 
         
             
                )
         
     | 
| 227 | 
         | 
| 
         @@ -316,11 +578,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): 
     | 
|
| 316 | 
         
             
                if cfg.collator_pad_to_longest:
         
     | 
| 317 | 
         
             
                    data_collator_kwargs["padding"] = "longest"
         
     | 
| 318 | 
         
             
                else:
         
     | 
| 319 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 320 | 
         | 
| 321 | 
         
             
                if cfg.is_llama_derived_model and cfg.landmark_attention:
         
     | 
| 322 | 
         
            -
                    from functools import partial
         
     | 
| 323 | 
         
            -
             
     | 
| 324 | 
         
             
                    from axolotl.monkeypatch.llama_landmark_attn import (
         
     | 
| 325 | 
         
             
                        add_mem_tokens,
         
     | 
| 326 | 
         
             
                        get_mem_id,
         
     | 
| 
         @@ -348,7 +610,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): 
     | 
|
| 348 | 
         
             
                    train_dataset=train_dataset,
         
     | 
| 349 | 
         
             
                    eval_dataset=eval_dataset,
         
     | 
| 350 | 
         
             
                    args=training_args,
         
     | 
| 351 | 
         
            -
                    data_collator= 
     | 
| 352 | 
         
             
                        tokenizer,
         
     | 
| 353 | 
         
             
                        return_tensors="pt",
         
     | 
| 354 | 
         
             
                        **data_collator_kwargs,
         
     | 
| 
         | 
|
| 1 | 
         
             
            """Module containing the Trainer class and related functions"""
         
     | 
| 
         | 
|
| 2 | 
         
             
            import importlib
         
     | 
| 3 | 
         
             
            import logging
         
     | 
| 4 | 
         
             
            import math
         
     | 
| 5 | 
         
             
            import os
         
     | 
| 6 | 
         
             
            import sys
         
     | 
| 7 | 
         
            +
            from contextlib import contextmanager
         
     | 
| 8 | 
         
             
            from dataclasses import dataclass, field
         
     | 
| 9 | 
         
            +
            from functools import partial
         
     | 
| 10 | 
         
             
            from pathlib import Path
         
     | 
| 11 | 
         
            +
            from typing import Optional, Union
         
     | 
| 12 | 
         | 
| 13 | 
         
             
            import bitsandbytes as bnb
         
     | 
| 14 | 
         
            +
            import numpy as np
         
     | 
| 15 | 
         
             
            import torch.cuda
         
     | 
| 16 | 
         
             
            import transformers
         
     | 
| 17 | 
         
            +
            from datasets import Dataset, set_caching_enabled
         
     | 
| 18 | 
         
             
            from torch import nn
         
     | 
| 19 | 
         
             
            from torch.optim.lr_scheduler import OneCycleLR
         
     | 
| 20 | 
         
            +
            from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
         
     | 
| 21 | 
         
             
            from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
         
     | 
| 22 | 
         
             
            from transformers.trainer_pt_utils import get_parameter_names
         
     | 
| 23 | 
         | 
| 
         | 
|
| 26 | 
         
             
                SaveBetterTransformerModelCallback,
         
     | 
| 27 | 
         
             
                SavePeftModelCallback,
         
     | 
| 28 | 
         
             
            )
         
     | 
| 29 | 
         
            +
            from axolotl.utils.collators import DataCollatorForSeq2Seq
         
     | 
| 30 | 
         
            +
            from axolotl.utils.dataloader import MultipackDistributedDataloader
         
     | 
| 31 | 
         
             
            from axolotl.utils.schedulers import (
         
     | 
| 32 | 
         
             
                InterpolatingLogScheduler,
         
     | 
| 33 | 
         
             
                get_cosine_schedule_with_quadratic_warmup,
         
     | 
| 
         | 
|
| 36 | 
         
             
            LOG = logging.getLogger("axolotl")
         
     | 
| 37 | 
         | 
| 38 | 
         | 
| 39 | 
         
            +
            @torch.jit.script
         
     | 
| 40 | 
         
            +
            def weighted_cross_entropy(
         
     | 
| 41 | 
         
            +
                logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
         
     | 
| 42 | 
         
            +
            ):
         
     | 
| 43 | 
         
            +
                # Flatten the logits, labels, and weights tensors
         
     | 
| 44 | 
         
            +
                logits = logits.view(
         
     | 
| 45 | 
         
            +
                    -1, logits.size(-1)
         
     | 
| 46 | 
         
            +
                )  # logits becomes of shape [batch_size*sequence_length, vocab_size]
         
     | 
| 47 | 
         
            +
                labels = labels.view(-1)  # labels becomes of shape [batch_size*sequence_length]
         
     | 
| 48 | 
         
            +
                weights = weights.view(-1)  # weights becomes of shape [batch_size*sequence_length]
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                # Compute the unweighted cross entropy loss
         
     | 
| 51 | 
         
            +
                losses = torch.nn.functional.cross_entropy(logits, labels, reduction="none")
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                # Apply the weights to the losses and compute their sum
         
     | 
| 54 | 
         
            +
                return (weights * losses).sum()
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            @torch.jit.script
         
     | 
| 58 | 
         
            +
            def create_weighted_mask(labels: torch.Tensor):
         
     | 
| 59 | 
         
            +
                # Check if the tensor is 2D. If not, unsqueeze it to make it 2D
         
     | 
| 60 | 
         
            +
                if len(labels.shape) == 1:
         
     | 
| 61 | 
         
            +
                    labels = labels.unsqueeze(0)
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                weights = torch.zeros_like(labels).float()
         
     | 
| 64 | 
         
            +
                for i in range(labels.shape[0]):
         
     | 
| 65 | 
         
            +
                    mask = labels[i] != -100
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                    # Create a tensor to track group ids
         
     | 
| 68 | 
         
            +
                    group_ids = torch.zeros_like(labels[i]).int()
         
     | 
| 69 | 
         
            +
                    curr_group_id = 0
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    for j in range(1, len(labels[i])):
         
     | 
| 72 | 
         
            +
                        if mask[j] and not mask[j - 1]:  # switch from masked to unmasked label
         
     | 
| 73 | 
         
            +
                            curr_group_id += 1  # start new group
         
     | 
| 74 | 
         
            +
                        group_ids[j] = (
         
     | 
| 75 | 
         
            +
                            curr_group_id if mask[j] else 0
         
     | 
| 76 | 
         
            +
                        )  # assign group id if unmasked label
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                    # Count only unmasked labels in each group
         
     | 
| 79 | 
         
            +
                    group_counts = torch.bincount(group_ids[mask])
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    mask_weights = torch.zeros_like(labels[i]).float()
         
     | 
| 82 | 
         
            +
                    mask_weights[mask] = 1.0 / group_counts[group_ids[mask]]
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                    weights[i] = mask_weights
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                return weights.squeeze()  # squeeze the output to match the input dimension
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
            def trainer_weighted_loss(model_output, labels, shift_labels=True):
         
     | 
| 90 | 
         
            +
                logits = (
         
     | 
| 91 | 
         
            +
                    model_output["logits"] if isinstance(model_output, dict) else model_output[0]
         
     | 
| 92 | 
         
            +
                )
         
     | 
| 93 | 
         
            +
                if shift_labels:
         
     | 
| 94 | 
         
            +
                    logits = logits[..., :-1, :].contiguous()
         
     | 
| 95 | 
         
            +
                    labels = labels[..., 1:].contiguous()
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                weights = create_weighted_mask(labels)
         
     | 
| 98 | 
         
            +
                return weighted_cross_entropy(logits, labels, weights)
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
             
            @dataclass
         
     | 
| 102 | 
         
             
            class AxolotlTrainingArguments(TrainingArguments):
         
     | 
| 103 | 
         
             
                """
         
     | 
| 
         | 
|
| 108 | 
         
             
                    default=False,
         
     | 
| 109 | 
         
             
                    metadata={"help": "Use quadratic warmup for cosine scheduling."},
         
     | 
| 110 | 
         
             
                )
         
     | 
| 111 | 
         
            +
                sample_packing: bool = field(
         
     | 
| 112 | 
         
            +
                    default=False,
         
     | 
| 113 | 
         
            +
                    metadata={"help": "Use sample packing for efficient training."},
         
     | 
| 114 | 
         
            +
                )
         
     | 
| 115 | 
         
            +
                sample_packing_efficiency: float = field(
         
     | 
| 116 | 
         
            +
                    default=1.0,
         
     | 
| 117 | 
         
            +
                    metadata={"help": "Sample packing efficiency for calculating batch length."},
         
     | 
| 118 | 
         
            +
                )
         
     | 
| 119 | 
         
            +
                max_seq_length: int = field(
         
     | 
| 120 | 
         
            +
                    default=2048,
         
     | 
| 121 | 
         
            +
                    metadata={"help": "The maximum sequence length the model can handle"},
         
     | 
| 122 | 
         
            +
                )
         
     | 
| 123 | 
         
            +
                sample_packing_seq_len_multiplier: int = field(
         
     | 
| 124 | 
         
            +
                    default=1,
         
     | 
| 125 | 
         
            +
                    metadata={"help": "the multiplier for the max len for packed sequences"},
         
     | 
| 126 | 
         
            +
                )
         
     | 
| 127 | 
         | 
| 128 | 
         | 
| 129 | 
         
             
            class AxolotlTrainer(Trainer):
         
     | 
| 
         | 
|
| 161 | 
         
             
                            return super().create_scheduler(num_training_steps, optimizer)
         
     | 
| 162 | 
         
             
                    return self.lr_scheduler
         
     | 
| 163 | 
         | 
| 164 | 
         
            +
                def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
         
     | 
| 165 | 
         
            +
                    if self.args.world_size > 1 and self.args.sample_packing:
         
     | 
| 166 | 
         
            +
                        return DistributedSampler(
         
     | 
| 167 | 
         
            +
                            self.train_dataset,
         
     | 
| 168 | 
         
            +
                            num_replicas=self.args.world_size,
         
     | 
| 169 | 
         
            +
                            rank=self.args.process_index,
         
     | 
| 170 | 
         
            +
                            seed=self.args.seed,
         
     | 
| 171 | 
         
            +
                        )
         
     | 
| 172 | 
         
            +
                    return super()._get_train_sampler()
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
         
     | 
| 175 | 
         
            +
                    if self.args.sample_packing:
         
     | 
| 176 | 
         
            +
                        train_sampler = self._get_train_sampler()
         
     | 
| 177 | 
         
            +
                        return self.accelerator.prepare(
         
     | 
| 178 | 
         
            +
                            MultipackDistributedDataloader(
         
     | 
| 179 | 
         
            +
                                self.train_dataset,
         
     | 
| 180 | 
         
            +
                                batch_size=self._train_batch_size,
         
     | 
| 181 | 
         
            +
                                seq_max_length=self.args.max_seq_length,
         
     | 
| 182 | 
         
            +
                                collate_fn=self.data_collator,
         
     | 
| 183 | 
         
            +
                                sampler=train_sampler,
         
     | 
| 184 | 
         
            +
                                packing_efficiency_estimate=self.args.sample_packing_efficiency,
         
     | 
| 185 | 
         
            +
                                sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
         
     | 
| 186 | 
         
            +
                                device_count=int(os.environ.get("WORLD_SIZE", 1)),
         
     | 
| 187 | 
         
            +
                            )
         
     | 
| 188 | 
         
            +
                        )
         
     | 
| 189 | 
         
            +
                    return super().get_train_dataloader()
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                def get_eval_dataloader(
         
     | 
| 192 | 
         
            +
                    self, eval_dataset: Optional[Dataset] = None
         
     | 
| 193 | 
         
            +
                ) -> Union[DataLoader, MultipackDistributedDataloader]:
         
     | 
| 194 | 
         
            +
                    if self.args.sample_packing:
         
     | 
| 195 | 
         
            +
                        eval_dataset = (
         
     | 
| 196 | 
         
            +
                            eval_dataset if eval_dataset is not None else self.eval_dataset
         
     | 
| 197 | 
         
            +
                        )
         
     | 
| 198 | 
         
            +
                        eval_sampler = self._get_eval_sampler(eval_dataset)
         
     | 
| 199 | 
         
            +
                        return self.accelerator.prepare(
         
     | 
| 200 | 
         
            +
                            MultipackDistributedDataloader(
         
     | 
| 201 | 
         
            +
                                eval_dataset,
         
     | 
| 202 | 
         
            +
                                batch_size=self.args.eval_batch_size,
         
     | 
| 203 | 
         
            +
                                seq_max_length=self.args.max_seq_length,
         
     | 
| 204 | 
         
            +
                                collate_fn=self.data_collator,
         
     | 
| 205 | 
         
            +
                                sampler=eval_sampler,
         
     | 
| 206 | 
         
            +
                                packing_efficiency_estimate=self.args.sample_packing_efficiency,
         
     | 
| 207 | 
         
            +
                                sample_packing_seq_len_multiplier=self.args.eval_batch_size,
         
     | 
| 208 | 
         
            +
                                device_count=int(os.environ.get("WORLD_SIZE", 1)),
         
     | 
| 209 | 
         
            +
                            )
         
     | 
| 210 | 
         
            +
                        )
         
     | 
| 211 | 
         
            +
                    return super().get_eval_dataloader(eval_dataset)
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
                def compute_loss(self, model, inputs, return_outputs=False):
         
     | 
| 214 | 
         
            +
                    # use one's weighted cross entropy loss calc
         
     | 
| 215 | 
         
            +
                    # if self.args.sample_packing:
         
     | 
| 216 | 
         
            +
                    #     labels = inputs.pop("labels")
         
     | 
| 217 | 
         
            +
                    #     outputs = model(**inputs)
         
     | 
| 218 | 
         
            +
                    #     loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
         
     | 
| 219 | 
         
            +
                    #     return (loss, outputs) if return_outputs else loss
         
     | 
| 220 | 
         
            +
                    return super().compute_loss(model, inputs, return_outputs=return_outputs)
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         | 
| 223 | 
         
             
            class OneCycleLRSchedulerTrainer(AxolotlTrainer):
         
     | 
| 224 | 
         
             
                """
         
     | 
| 
         | 
|
| 249 | 
         
             
                    return self.lr_scheduler
         
     | 
| 250 | 
         | 
| 251 | 
         | 
| 252 | 
         
            +
            def add_position_ids(sample):
         
     | 
| 253 | 
         
            +
                sample["position_ids"] = torch.arange(len(sample["input_ids"]))
         
     | 
| 254 | 
         
            +
                return sample
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
            def drop_long_seq(sample, sequence_len=2048):
         
     | 
| 258 | 
         
            +
                return len(sample["input_ids"]) <= sequence_len
         
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
            @contextmanager
         
     | 
| 262 | 
         
            +
            def disable_datasets_caching():
         
     | 
| 263 | 
         
            +
                try:
         
     | 
| 264 | 
         
            +
                    set_caching_enabled(False)
         
     | 
| 265 | 
         
            +
                    yield
         
     | 
| 266 | 
         
            +
                finally:
         
     | 
| 267 | 
         
            +
                    set_caching_enabled(True)
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
            def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
         
     | 
| 271 | 
         
            +
                if cfg.sample_packing:
         
     | 
| 272 | 
         
            +
                    drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
         
     | 
| 273 | 
         
            +
                    train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count()).map(
         
     | 
| 274 | 
         
            +
                        add_position_ids, num_proc=os.cpu_count()
         
     | 
| 275 | 
         
            +
                    )
         
     | 
| 276 | 
         
            +
                    if eval_dataset:
         
     | 
| 277 | 
         
            +
                        eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count()).map(
         
     | 
| 278 | 
         
            +
                            add_position_ids, num_proc=os.cpu_count()
         
     | 
| 279 | 
         
            +
                        )
         
     | 
| 280 | 
         
            +
                return train_dataset, eval_dataset
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
            def calculate_total_num_steps(cfg, train_dataset, tokenizer):
         
     | 
| 284 | 
         
            +
                if cfg.sample_packing:
         
     | 
| 285 | 
         
            +
                    # we have to drop anything longer then sequence len otherwise
         
     | 
| 286 | 
         
            +
                    # flash attention with position ids fails
         
     | 
| 287 | 
         
            +
                    if not cfg.total_num_tokens:
         
     | 
| 288 | 
         
            +
                        LOG.info("calculating total_num_tokens")
         
     | 
| 289 | 
         
            +
                        total_num_tokens = np.sum(
         
     | 
| 290 | 
         
            +
                            train_dataset.data.column("input_ids")
         
     | 
| 291 | 
         
            +
                            .to_pandas()
         
     | 
| 292 | 
         
            +
                            .apply(lambda x: len(x))  # pylint: disable=unnecessary-lambda
         
     | 
| 293 | 
         
            +
                            .values
         
     | 
| 294 | 
         
            +
                        )
         
     | 
| 295 | 
         
            +
                        LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
         
     | 
| 296 | 
         
            +
                        cfg.total_num_tokens = total_num_tokens
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
                    if cfg.sample_packing_eff_est:
         
     | 
| 299 | 
         
            +
                        total_num_steps = (
         
     | 
| 300 | 
         
            +
                            # match count to len est in dataloader
         
     | 
| 301 | 
         
            +
                            (
         
     | 
| 302 | 
         
            +
                                math.floor(
         
     | 
| 303 | 
         
            +
                                    0.99
         
     | 
| 304 | 
         
            +
                                    * cfg.total_num_tokens
         
     | 
| 305 | 
         
            +
                                    / cfg.sample_packing_eff_est
         
     | 
| 306 | 
         
            +
                                    / cfg.sequence_len
         
     | 
| 307 | 
         
            +
                                    // cfg.batch_size
         
     | 
| 308 | 
         
            +
                                    // int(os.environ.get("WORLD_SIZE", 1))
         
     | 
| 309 | 
         
            +
                                )
         
     | 
| 310 | 
         
            +
                                - 1
         
     | 
| 311 | 
         
            +
                            )
         
     | 
| 312 | 
         
            +
                            * cfg.num_epochs
         
     | 
| 313 | 
         
            +
                        )
         
     | 
| 314 | 
         
            +
                        LOG.info(
         
     | 
| 315 | 
         
            +
                            f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
         
     | 
| 316 | 
         
            +
                        )
         
     | 
| 317 | 
         
            +
                    else:
         
     | 
| 318 | 
         
            +
                        sampler = RandomSampler(train_dataset)
         
     | 
| 319 | 
         
            +
                        data_loader = MultipackDistributedDataloader(
         
     | 
| 320 | 
         
            +
                            train_dataset,
         
     | 
| 321 | 
         
            +
                            batch_size=cfg.micro_batch_size,
         
     | 
| 322 | 
         
            +
                            seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len,
         
     | 
| 323 | 
         
            +
                            collate_fn=DataCollatorForSeq2Seq(
         
     | 
| 324 | 
         
            +
                                tokenizer,
         
     | 
| 325 | 
         
            +
                                return_tensors="pt",
         
     | 
| 326 | 
         
            +
                                padding="longest",
         
     | 
| 327 | 
         
            +
                            ),
         
     | 
| 328 | 
         
            +
                            sampler=sampler,
         
     | 
| 329 | 
         
            +
                            packing_efficiency_estimate=cfg.sample_packing_eff_est,
         
     | 
| 330 | 
         
            +
                            sample_packing_seq_len_multiplier=cfg.micro_batch_size,
         
     | 
| 331 | 
         
            +
                            device_count=int(os.environ.get("WORLD_SIZE", 1)),
         
     | 
| 332 | 
         
            +
                        )
         
     | 
| 333 | 
         
            +
                        data_loader_len = data_loader.len_w_stats()
         
     | 
| 334 | 
         
            +
                        actual_eff = data_loader.efficiency()
         
     | 
| 335 | 
         
            +
                        LOG.info(f"data_loader_len: {data_loader_len}")
         
     | 
| 336 | 
         
            +
                        total_num_steps = int(
         
     | 
| 337 | 
         
            +
                            math.floor(
         
     | 
| 338 | 
         
            +
                                data_loader_len
         
     | 
| 339 | 
         
            +
                                * cfg.micro_batch_size
         
     | 
| 340 | 
         
            +
                                * cfg.num_epochs
         
     | 
| 341 | 
         
            +
                                // cfg.batch_size
         
     | 
| 342 | 
         
            +
                            )
         
     | 
| 343 | 
         
            +
                        )
         
     | 
| 344 | 
         
            +
                        LOG.info(
         
     | 
| 345 | 
         
            +
                            f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`"
         
     | 
| 346 | 
         
            +
                        )
         
     | 
| 347 | 
         
            +
                        cfg.sample_packing_eff_est = math.ceil(actual_eff * 100.0) / 100.0
         
     | 
| 348 | 
         
            +
                else:
         
     | 
| 349 | 
         
            +
                    total_num_steps = int(
         
     | 
| 350 | 
         
            +
                        math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
         
     | 
| 351 | 
         
            +
                    )
         
     | 
| 352 | 
         
            +
                LOG.info(f"total_num_steps: {total_num_steps}")
         
     | 
| 353 | 
         
            +
                return total_num_steps
         
     | 
| 354 | 
         
            +
             
     | 
| 355 | 
         
            +
             
     | 
| 356 | 
         
            +
            def setup_fsdp_envs(cfg):
         
     | 
| 357 | 
         
            +
                os.environ["ACCELERATE_USE_FSDP"] = "true"
         
     | 
| 358 | 
         
            +
                if cfg.fsdp_config.fsdp_sync_module_states:
         
     | 
| 359 | 
         
            +
                    os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
         
     | 
| 360 | 
         
            +
                if cfg.fsdp_config.fsdp_state_dict_type:
         
     | 
| 361 | 
         
            +
                    os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
         
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
             
     | 
| 364 | 
         
            +
            def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
         
     | 
| 365 | 
         
            +
                if cfg.fsdp:
         
     | 
| 366 | 
         
            +
                    setup_fsdp_envs(cfg)
         
     | 
| 367 | 
         
             
                warmup_steps = (
         
     | 
| 368 | 
         
             
                    cfg.warmup_steps
         
     | 
| 369 | 
         
             
                    if cfg.warmup_steps is not None
         
     | 
| 
         | 
|
| 443 | 
         
             
                if cfg.save_safetensors:
         
     | 
| 444 | 
         
             
                    training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
         
     | 
| 445 | 
         | 
| 446 | 
         
            +
                if cfg.sample_packing_eff_est:
         
     | 
| 447 | 
         
            +
                    training_arguments_kwargs[
         
     | 
| 448 | 
         
            +
                        "sample_packing_efficiency"
         
     | 
| 449 | 
         
            +
                    ] = cfg.sample_packing_eff_est
         
     | 
| 450 | 
         
            +
             
     | 
| 451 | 
         
             
                training_args = AxolotlTrainingArguments(  # pylint: disable=unexpected-keyword-arg
         
     | 
| 452 | 
         
            +
                    # max_steps=total_num_steps,  # this is helpful in case we don't actually know total # of steps
         
     | 
| 453 | 
         
            +
                    max_seq_length=cfg.sequence_len,
         
     | 
| 454 | 
         
             
                    per_device_train_batch_size=cfg.micro_batch_size,
         
     | 
| 455 | 
         
             
                    per_device_eval_batch_size=cfg.eval_batch_size
         
     | 
| 456 | 
         
             
                    if cfg.eval_batch_size is not None
         
     | 
| 
         | 
|
| 464 | 
         
             
                    eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
         
     | 
| 465 | 
         
             
                    save_steps=cfg.save_steps,
         
     | 
| 466 | 
         
             
                    output_dir=cfg.output_dir,
         
     | 
| 467 | 
         
            +
                    save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
         
     | 
| 468 | 
         
             
                    load_best_model_at_end=(
         
     | 
| 469 | 
         
             
                        cfg.load_best_model_at_end is not False
         
     | 
| 470 | 
         
             
                        and cfg.val_set_size > 0
         
     | 
| 
         | 
|
| 482 | 
         
             
                    if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep")
         
     | 
| 483 | 
         
             
                    else "cosine",
         
     | 
| 484 | 
         
             
                    weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
         
     | 
| 485 | 
         
            +
                    sample_packing=cfg.sample_packing if cfg.sample_packing else False,
         
     | 
| 486 | 
         
            +
                    sample_packing_seq_len_multiplier=cfg.micro_batch_size,
         
     | 
| 487 | 
         
             
                    **training_arguments_kwargs,
         
     | 
| 488 | 
         
             
                )
         
     | 
| 489 | 
         | 
| 
         | 
|
| 578 | 
         
             
                if cfg.collator_pad_to_longest:
         
     | 
| 579 | 
         
             
                    data_collator_kwargs["padding"] = "longest"
         
     | 
| 580 | 
         
             
                else:
         
     | 
| 581 | 
         
            +
                    # A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
         
     | 
| 582 | 
         
            +
                    # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
         
     | 
| 583 | 
         
            +
                    data_collator_kwargs["pad_to_multiple_of"] = 64
         
     | 
| 584 | 
         | 
| 585 | 
         
             
                if cfg.is_llama_derived_model and cfg.landmark_attention:
         
     | 
| 
         | 
|
| 
         | 
|
| 586 | 
         
             
                    from axolotl.monkeypatch.llama_landmark_attn import (
         
     | 
| 587 | 
         
             
                        add_mem_tokens,
         
     | 
| 588 | 
         
             
                        get_mem_id,
         
     | 
| 
         | 
|
| 610 | 
         
             
                    train_dataset=train_dataset,
         
     | 
| 611 | 
         
             
                    eval_dataset=eval_dataset,
         
     | 
| 612 | 
         
             
                    args=training_args,
         
     | 
| 613 | 
         
            +
                    data_collator=DataCollatorForSeq2Seq(
         
     | 
| 614 | 
         
             
                        tokenizer,
         
     | 
| 615 | 
         
             
                        return_tensors="pt",
         
     | 
| 616 | 
         
             
                        **data_collator_kwargs,
         
     | 
| 
         @@ -8,6 +8,19 @@ LOG = logging.getLogger("axolotl") 
     | 
|
| 8 | 
         | 
| 9 | 
         | 
| 10 | 
         
             
            def validate_config(cfg):
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 11 | 
         
             
                if cfg.gradient_accumulation_steps and cfg.batch_size:
         
     | 
| 12 | 
         
             
                    raise ValueError(
         
     | 
| 13 | 
         
             
                        "please set only one of gradient_accumulation_steps or batch_size"
         
     | 
| 
         @@ -104,6 +117,17 @@ def validate_config(cfg): 
     | 
|
| 104 | 
         
             
                        + "point to its path, and remove model_revision from the config."
         
     | 
| 105 | 
         
             
                    )
         
     | 
| 106 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 107 | 
         
             
                # TODO
         
     | 
| 108 | 
         
             
                # MPT 7b
         
     | 
| 109 | 
         
             
                # https://github.com/facebookresearch/bitsandbytes/issues/25
         
     | 
| 
         | 
|
| 8 | 
         | 
| 9 | 
         | 
| 10 | 
         
             
            def validate_config(cfg):
         
     | 
| 11 | 
         
            +
                if cfg.max_packed_sequence_len and cfg.sample_packing:
         
     | 
| 12 | 
         
            +
                    raise ValueError(
         
     | 
| 13 | 
         
            +
                        "please set only one of max_packed_sequence_len (deprecated soon) or sample_packing"
         
     | 
| 14 | 
         
            +
                    )
         
     | 
| 15 | 
         
            +
                if cfg.max_packed_sequence_len:
         
     | 
| 16 | 
         
            +
                    LOG.warning(
         
     | 
| 17 | 
         
            +
                        str(
         
     | 
| 18 | 
         
            +
                            PendingDeprecationWarning(
         
     | 
| 19 | 
         
            +
                                "max_packed_sequence_len will be deprecated in favor of sample_packing"
         
     | 
| 20 | 
         
            +
                            )
         
     | 
| 21 | 
         
            +
                        )
         
     | 
| 22 | 
         
            +
                    )
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
             
                if cfg.gradient_accumulation_steps and cfg.batch_size:
         
     | 
| 25 | 
         
             
                    raise ValueError(
         
     | 
| 26 | 
         
             
                        "please set only one of gradient_accumulation_steps or batch_size"
         
     | 
| 
         | 
|
| 117 | 
         
             
                        + "point to its path, and remove model_revision from the config."
         
     | 
| 118 | 
         
             
                    )
         
     | 
| 119 | 
         | 
| 120 | 
         
            +
                if cfg.sample_packing and cfg.sdp_attention:
         
     | 
| 121 | 
         
            +
                    # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
         
     | 
| 122 | 
         
            +
                    raise ValueError(
         
     | 
| 123 | 
         
            +
                        "sample_packing not compatible with sdp_attention. Use flash_attention"
         
     | 
| 124 | 
         
            +
                    )
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                if cfg.sample_packing and cfg.xformers_attention:
         
     | 
| 127 | 
         
            +
                    raise ValueError(
         
     | 
| 128 | 
         
            +
                        "sample_packing not compatible with xformers_attention. Use flash_attention"
         
     | 
| 129 | 
         
            +
                    )
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
             
                # TODO
         
     | 
| 132 | 
         
             
                # MPT 7b
         
     | 
| 133 | 
         
             
                # https://github.com/facebookresearch/bitsandbytes/issues/25
         
     | 
| 
         @@ -0,0 +1,30 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """
         
     | 
| 2 | 
         
            +
            Unit tests for the monkeypatch utils
         
     | 
| 3 | 
         
            +
            """
         
     | 
| 4 | 
         
            +
            import unittest
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from axolotl.monkeypatch.utils import get_cu_seqlens, get_cu_seqlens_from_pos_ids
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            class TestMonkeyPatchUtils(unittest.TestCase):
         
     | 
| 12 | 
         
            +
                """
         
     | 
| 13 | 
         
            +
                Unit test class for monkeypatch utils
         
     | 
| 14 | 
         
            +
                """
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                def test_get_cu_seqlens_1d(self):
         
     | 
| 17 | 
         
            +
                    attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
         
     | 
| 18 | 
         
            +
                    target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)
         
     | 
| 19 | 
         
            +
                    self.assertTrue(torch.allclose(get_cu_seqlens(attn_mask)[0], target_res))
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                def test_get_cu_seqlens_from_pos_ids_1d(self):
         
     | 
| 22 | 
         
            +
                    position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0]])
         
     | 
| 23 | 
         
            +
                    target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)
         
     | 
| 24 | 
         
            +
                    self.assertTrue(
         
     | 
| 25 | 
         
            +
                        torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
         
     | 
| 26 | 
         
            +
                    )
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 30 | 
         
            +
                unittest.main()
         
     | 
| 
         @@ -0,0 +1,44 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """
         
     | 
| 2 | 
         
            +
            Unit tests for the monkey patch for expand mask to handle packed sequences
         
     | 
| 3 | 
         
            +
            """
         
     | 
| 4 | 
         
            +
            import unittest
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from axolotl.monkeypatch.llama_expand_mask import _expand_mask
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            class TestExpandMask(unittest.TestCase):
         
     | 
| 12 | 
         
            +
                """
         
     | 
| 13 | 
         
            +
                Test class for attention mask expansion for packed sequences
         
     | 
| 14 | 
         
            +
                """
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                def test_output(self):
         
     | 
| 17 | 
         
            +
                    mask = torch.tensor([[1, 1, 1, 2], [2, 3, 3, 0]])
         
     | 
| 18 | 
         
            +
                    dtype = torch.float32
         
     | 
| 19 | 
         
            +
                    expected_output = torch.tensor(
         
     | 
| 20 | 
         
            +
                        [
         
     | 
| 21 | 
         
            +
                            [
         
     | 
| 22 | 
         
            +
                                [
         
     | 
| 23 | 
         
            +
                                    [0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
         
     | 
| 24 | 
         
            +
                                    [0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
         
     | 
| 25 | 
         
            +
                                    [0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
         
     | 
| 26 | 
         
            +
                                    [-3.4028e38, -3.4028e38, -3.4028e38, 0.0000e00],
         
     | 
| 27 | 
         
            +
                                ]
         
     | 
| 28 | 
         
            +
                            ],
         
     | 
| 29 | 
         
            +
                            [
         
     | 
| 30 | 
         
            +
                                [
         
     | 
| 31 | 
         
            +
                                    [0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
         
     | 
| 32 | 
         
            +
                                    [-3.4028e38, 0.0000e00, -3.4028e38, -3.4028e38],
         
     | 
| 33 | 
         
            +
                                    [-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38],
         
     | 
| 34 | 
         
            +
                                    [-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38],
         
     | 
| 35 | 
         
            +
                                ]
         
     | 
| 36 | 
         
            +
                            ],
         
     | 
| 37 | 
         
            +
                        ]
         
     | 
| 38 | 
         
            +
                    )
         
     | 
| 39 | 
         
            +
                    # Check that the output matches the expected output
         
     | 
| 40 | 
         
            +
                    self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 44 | 
         
            +
                unittest.main()
         
     | 
| 
         @@ -27,7 +27,7 @@ class TestPacking(unittest.TestCase): 
     | 
|
| 27 | 
         
             
                        }
         
     | 
| 28 | 
         
             
                    )
         
     | 
| 29 | 
         | 
| 30 | 
         
            -
                def  
     | 
| 31 | 
         
             
                    prompter = AlpacaPrompter("chat")
         
     | 
| 32 | 
         
             
                    strat = AlpacaPromptTokenizingStrategy(
         
     | 
| 33 | 
         
             
                        prompter,
         
     | 
| 
         @@ -55,10 +55,14 @@ class TestPacking(unittest.TestCase): 
     | 
|
| 55 | 
         
             
                    # first example doesn't have mask reset
         
     | 
| 56 | 
         
             
                    assert example["input_ids"][0] == self.tokenizer.bos_token_id
         
     | 
| 57 | 
         
             
                    assert example["attention_mask"][0] == 1
         
     | 
| 
         | 
|
| 
         | 
|
| 58 | 
         | 
| 59 | 
         
             
                    # but subsequent one does
         
     | 
| 60 | 
         
             
                    assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
         
     | 
| 61 | 
         
            -
                    assert example["attention_mask"][next_bos_index] ==  
     | 
| 
         | 
|
| 
         | 
|
| 62 | 
         | 
| 63 | 
         | 
| 64 | 
         
             
            if __name__ == "__main__":
         
     | 
| 
         | 
|
| 27 | 
         
             
                        }
         
     | 
| 28 | 
         
             
                    )
         
     | 
| 29 | 
         | 
| 30 | 
         
            +
                def test_increments_attention(self):
         
     | 
| 31 | 
         
             
                    prompter = AlpacaPrompter("chat")
         
     | 
| 32 | 
         
             
                    strat = AlpacaPromptTokenizingStrategy(
         
     | 
| 33 | 
         
             
                        prompter,
         
     | 
| 
         | 
|
| 55 | 
         
             
                    # first example doesn't have mask reset
         
     | 
| 56 | 
         
             
                    assert example["input_ids"][0] == self.tokenizer.bos_token_id
         
     | 
| 57 | 
         
             
                    assert example["attention_mask"][0] == 1
         
     | 
| 58 | 
         
            +
                    assert example["position_ids"][0] == 0
         
     | 
| 59 | 
         
            +
                    assert example["position_ids"][1] == 1
         
     | 
| 60 | 
         | 
| 61 | 
         
             
                    # but subsequent one does
         
     | 
| 62 | 
         
             
                    assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
         
     | 
| 63 | 
         
            +
                    assert example["attention_mask"][next_bos_index] == 2
         
     | 
| 64 | 
         
            +
                    assert example["position_ids"][next_bos_index] == 0
         
     | 
| 65 | 
         
            +
                    assert example["position_ids"][next_bos_index + 1] == 1
         
     | 
| 66 | 
         | 
| 67 | 
         | 
| 68 | 
         
             
            if __name__ == "__main__":
         
     | 
| 
         @@ -134,9 +134,15 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase): 
     | 
|
| 134 | 
         
             
                        "output": "Hi! How can I help?",
         
     | 
| 135 | 
         
             
                    }
         
     | 
| 136 | 
         
             
                    example = strat.tokenize_prompt(sample)
         
     | 
| 137 | 
         
            -
                    assert example["input_ids"][0: 
     | 
| 138 | 
         
            -
             
     | 
| 139 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 140 | 
         | 
| 141 | 
         | 
| 142 | 
         
             
            class Llama2ChatTokenizationTest(unittest.TestCase):
         
     | 
| 
         | 
|
| 134 | 
         
             
                        "output": "Hi! How can I help?",
         
     | 
| 135 | 
         
             
                    }
         
     | 
| 136 | 
         
             
                    example = strat.tokenize_prompt(sample)
         
     | 
| 137 | 
         
            +
                    assert example["input_ids"][0:5] == [
         
     | 
| 138 | 
         
            +
                        1,
         
     | 
| 139 | 
         
            +
                        28962,
         
     | 
| 140 | 
         
            +
                        1254,
         
     | 
| 141 | 
         
            +
                        12665,
         
     | 
| 142 | 
         
            +
                        29901,
         
     | 
| 143 | 
         
            +
                    ]  # "<s>SYSTEM:"
         
     | 
| 144 | 
         
            +
                    assert example["input_ids"][5:7] == [671, 20118]  # " use cot"
         
     | 
| 145 | 
         
            +
                    assert example["input_ids"][8] == 11889  # USER
         
     | 
| 146 | 
         | 
| 147 | 
         | 
| 148 | 
         
             
            class Llama2ChatTokenizationTest(unittest.TestCase):
         
     | 
| 
         @@ -70,7 +70,7 @@ class AlpacaPrompterTest(unittest.TestCase): 
     | 
|
| 70 | 
         
             
                        )
         
     | 
| 71 | 
         
             
                    )
         
     | 
| 72 | 
         
             
                    assert "use cot" in res
         
     | 
| 73 | 
         
            -
                    assert res.startswith(" 
     | 
| 74 | 
         
             
                    assert "### Instruction:" not in res
         
     | 
| 75 | 
         
             
                    assert "### Input:" not in res
         
     | 
| 76 | 
         
             
                    assert "alpacas" in res
         
     | 
| 
         | 
|
| 70 | 
         
             
                        )
         
     | 
| 71 | 
         
             
                    )
         
     | 
| 72 | 
         
             
                    assert "use cot" in res
         
     | 
| 73 | 
         
            +
                    assert res.startswith("SYSTEM:")
         
     | 
| 74 | 
         
             
                    assert "### Instruction:" not in res
         
     | 
| 75 | 
         
             
                    assert "### Input:" not in res
         
     | 
| 76 | 
         
             
                    assert "alpacas" in res
         
     | 
| 
         @@ -313,3 +313,27 @@ class ValidationTest(unittest.TestCase): 
     | 
|
| 313 | 
         
             
                    )
         
     | 
| 314 | 
         | 
| 315 | 
         
             
                    validate_config(cfg)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 313 | 
         
             
                    )
         
     | 
| 314 | 
         | 
| 315 | 
         
             
                    validate_config(cfg)
         
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
                def test_packing(self):
         
     | 
| 318 | 
         
            +
                    cfg = DictDefault(
         
     | 
| 319 | 
         
            +
                        {
         
     | 
| 320 | 
         
            +
                            "max_packed_sequence_len": 2048,
         
     | 
| 321 | 
         
            +
                        }
         
     | 
| 322 | 
         
            +
                    )
         
     | 
| 323 | 
         
            +
                    with self._caplog.at_level(logging.WARNING):
         
     | 
| 324 | 
         
            +
                        validate_config(cfg)
         
     | 
| 325 | 
         
            +
                        assert any(
         
     | 
| 326 | 
         
            +
                            "max_packed_sequence_len will be deprecated in favor of sample_packing"
         
     | 
| 327 | 
         
            +
                            in record.message
         
     | 
| 328 | 
         
            +
                            for record in self._caplog.records
         
     | 
| 329 | 
         
            +
                        )
         
     | 
| 330 | 
         
            +
             
     | 
| 331 | 
         
            +
                    cfg = DictDefault(
         
     | 
| 332 | 
         
            +
                        {
         
     | 
| 333 | 
         
            +
                            "max_packed_sequence_len": 2048,
         
     | 
| 334 | 
         
            +
                            "sample_packing": True,
         
     | 
| 335 | 
         
            +
                        }
         
     | 
| 336 | 
         
            +
                    )
         
     | 
| 337 | 
         
            +
                    regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
         
     | 
| 338 | 
         
            +
                    with pytest.raises(ValueError, match=regex_exp):
         
     | 
| 339 | 
         
            +
                        validate_config(cfg)
         
     |