| """Module for testing models utils file.""" | |
| import unittest | |
| from unittest.mock import patch | |
| import pytest | |
| from axolotl.utils.dict import DictDefault | |
| from axolotl.utils.models import load_model | |
| class ModelsUtilsTest(unittest.TestCase): | |
| """Testing module for models utils.""" | |
| def test_cfg_throws_error_with_s2_attention_and_sample_packing(self): | |
| cfg = DictDefault( | |
| { | |
| "s2_attention": True, | |
| "sample_packing": True, | |
| "base_model": "", | |
| "model_type": "LlamaForCausalLM", | |
| } | |
| ) | |
| # Mock out call to HF hub | |
| with patch( | |
| "axolotl.utils.models.load_model_config" | |
| ) as mocked_load_model_config: | |
| mocked_load_model_config.return_value = {} | |
| with pytest.raises(ValueError) as exc: | |
| # Should error before hitting tokenizer, so we pass in an empty str | |
| load_model(cfg, tokenizer="") | |
| assert ( | |
| "shifted-sparse attention does not currently support sample packing" | |
| in str(exc.value) | |
| ) | |