various fixes
Browse files- src/axolotl/utils/models.py +0 -1
 - src/axolotl/utils/trainer.py +5 -12
 
    	
        src/axolotl/utils/models.py
    CHANGED
    
    | 
         @@ -120,7 +120,6 @@ def load_model( 
     | 
|
| 120 | 
         
             
                            base_model,
         
     | 
| 121 | 
         
             
                            trust_remote_code=True if cfg.trust_remote_code is True else False,
         
     | 
| 122 | 
         
             
                        )
         
     | 
| 123 | 
         
            -
                        config.attn_config['attn_impl'] = 'triton'
         
     | 
| 124 | 
         
             
                        model = AutoModelForCausalLM.from_pretrained(
         
     | 
| 125 | 
         
             
                            base_model,
         
     | 
| 126 | 
         
             
                            config=config,
         
     | 
| 
         | 
|
| 120 | 
         
             
                            base_model,
         
     | 
| 121 | 
         
             
                            trust_remote_code=True if cfg.trust_remote_code is True else False,
         
     | 
| 122 | 
         
             
                        )
         
     | 
| 
         | 
|
| 123 | 
         
             
                        model = AutoModelForCausalLM.from_pretrained(
         
     | 
| 124 | 
         
             
                            base_model,
         
     | 
| 125 | 
         
             
                            config=config,
         
     | 
    	
        src/axolotl/utils/trainer.py
    CHANGED
    
    | 
         @@ -30,16 +30,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): 
     | 
|
| 30 | 
         
             
                    if cfg.logging_steps is not None
         
     | 
| 31 | 
         
             
                    else max(min(int(0.005 * total_num_steps), 10), 1)
         
     | 
| 32 | 
         
             
                )
         
     | 
| 33 | 
         
            -
                save_steps =  
     | 
| 34 | 
         
            -
             
     | 
| 35 | 
         
            -
                    if cfg.save_steps is not None
         
     | 
| 36 | 
         
            -
                    else min(int(0.05 * total_num_steps), 200)
         
     | 
| 37 | 
         
            -
                )
         
     | 
| 38 | 
         
            -
                eval_steps = (
         
     | 
| 39 | 
         
            -
                    cfg.eval_steps
         
     | 
| 40 | 
         
            -
                    if cfg.eval_steps is not None and save_steps % cfg.eval_steps == 0
         
     | 
| 41 | 
         
            -
                    else save_steps
         
     | 
| 42 | 
         
            -
                )
         
     | 
| 43 | 
         | 
| 44 | 
         
             
                training_arguments_kwargs = {}
         
     | 
| 45 | 
         
             
                if cfg.bf16 == "full":
         
     | 
| 
         @@ -92,13 +84,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): 
     | 
|
| 92 | 
         
             
                    num_train_epochs=cfg.num_epochs,
         
     | 
| 93 | 
         
             
                    learning_rate=cfg.learning_rate,
         
     | 
| 94 | 
         
             
                    evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
         
     | 
| 95 | 
         
            -
                    save_strategy="steps",
         
     | 
| 96 | 
         
             
                    eval_steps=eval_steps if cfg.val_set_size > 0 else None,
         
     | 
| 97 | 
         
             
                    save_steps=save_steps,
         
     | 
| 98 | 
         
             
                    output_dir=cfg.output_dir,
         
     | 
| 99 | 
         
             
                    save_total_limit=3,
         
     | 
| 100 | 
         
             
                    load_best_model_at_end=True
         
     | 
| 101 | 
         
            -
                    if cfg.val_set_size > 0 and save_steps % eval_steps == 0 and cfg.load_in_8bit is not True
         
     | 
| 102 | 
         
             
                    else False,
         
     | 
| 103 | 
         
             
                    ddp_find_unused_parameters=False if cfg.ddp else None,
         
     | 
| 104 | 
         
             
                    group_by_length=cfg.group_by_length,
         
     | 
| 
         @@ -158,6 +150,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): 
     | 
|
| 158 | 
         
             
                            cfg.learning_rate,
         
     | 
| 159 | 
         
             
                            total_steps=total_num_steps,
         
     | 
| 160 | 
         
             
                            epochs=cfg.num_epochs,
         
     | 
| 
         | 
|
| 161 | 
         
             
                            **lr_scheduler_kwargs,
         
     | 
| 162 | 
         
             
                        )
         
     | 
| 163 | 
         
             
                    elif cfg.lr_scheduler == "log_sweep":
         
     | 
| 
         | 
|
| 30 | 
         
             
                    if cfg.logging_steps is not None
         
     | 
| 31 | 
         
             
                    else max(min(int(0.005 * total_num_steps), 10), 1)
         
     | 
| 32 | 
         
             
                )
         
     | 
| 33 | 
         
            +
                save_steps = cfg.save_steps
         
     | 
| 34 | 
         
            +
                eval_steps = cfg.eval_steps
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 35 | 
         | 
| 36 | 
         
             
                training_arguments_kwargs = {}
         
     | 
| 37 | 
         
             
                if cfg.bf16 == "full":
         
     | 
| 
         | 
|
| 84 | 
         
             
                    num_train_epochs=cfg.num_epochs,
         
     | 
| 85 | 
         
             
                    learning_rate=cfg.learning_rate,
         
     | 
| 86 | 
         
             
                    evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
         
     | 
| 87 | 
         
            +
                    save_strategy="steps" if save_steps else "epoch",
         
     | 
| 88 | 
         
             
                    eval_steps=eval_steps if cfg.val_set_size > 0 else None,
         
     | 
| 89 | 
         
             
                    save_steps=save_steps,
         
     | 
| 90 | 
         
             
                    output_dir=cfg.output_dir,
         
     | 
| 91 | 
         
             
                    save_total_limit=3,
         
     | 
| 92 | 
         
             
                    load_best_model_at_end=True
         
     | 
| 93 | 
         
            +
                    if cfg.val_set_size > 0 and save_steps is not None and save_steps % eval_steps == 0 and cfg.load_in_8bit is not True
         
     | 
| 94 | 
         
             
                    else False,
         
     | 
| 95 | 
         
             
                    ddp_find_unused_parameters=False if cfg.ddp else None,
         
     | 
| 96 | 
         
             
                    group_by_length=cfg.group_by_length,
         
     | 
| 
         | 
|
| 150 | 
         
             
                            cfg.learning_rate,
         
     | 
| 151 | 
         
             
                            total_steps=total_num_steps,
         
     | 
| 152 | 
         
             
                            epochs=cfg.num_epochs,
         
     | 
| 153 | 
         
            +
                            div_factor=10,
         
     | 
| 154 | 
         
             
                            **lr_scheduler_kwargs,
         
     | 
| 155 | 
         
             
                        )
         
     | 
| 156 | 
         
             
                    elif cfg.lr_scheduler == "log_sweep":
         
     |