|  | """ | 
					
						
						|  | Test classes for checking functionality of the cfg normalization | 
					
						
						|  | """ | 
					
						
						|  | import unittest | 
					
						
						|  | from unittest.mock import patch | 
					
						
						|  |  | 
					
						
						|  | from axolotl.utils.config import normalize_cfg_datasets, normalize_config | 
					
						
						|  | from axolotl.utils.dict import DictDefault | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class NormalizeConfigTestCase(unittest.TestCase): | 
					
						
						|  | """ | 
					
						
						|  | test class for normalize_config checks | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def _get_base_cfg(self): | 
					
						
						|  | return DictDefault( | 
					
						
						|  | { | 
					
						
						|  | "base_model": "JackFram/llama-68m", | 
					
						
						|  | "base_model_config": "JackFram/llama-68m", | 
					
						
						|  | "tokenizer_type": "LlamaTokenizer", | 
					
						
						|  | "num_epochs": 1, | 
					
						
						|  | "micro_batch_size": 1, | 
					
						
						|  | "gradient_accumulation_steps": 1, | 
					
						
						|  | } | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def test_base_model_config_set_when_empty(self): | 
					
						
						|  | cfg = self._get_base_cfg() | 
					
						
						|  | del cfg.base_model_config | 
					
						
						|  | normalize_config(cfg) | 
					
						
						|  |  | 
					
						
						|  | assert cfg.base_model_config == cfg.base_model | 
					
						
						|  |  | 
					
						
						|  | def test_chat_template_chatml(self): | 
					
						
						|  | cfg = DictDefault( | 
					
						
						|  | { | 
					
						
						|  | "chat_template": "chatml", | 
					
						
						|  | "datasets": [ | 
					
						
						|  | { | 
					
						
						|  | "path": "lorem/ipsum", | 
					
						
						|  | "type": "sharegpt", | 
					
						
						|  | "conversation": "vicuna_v1.1", | 
					
						
						|  | }, | 
					
						
						|  | { | 
					
						
						|  | "path": "sit/amet", | 
					
						
						|  | "type": "sharegpt", | 
					
						
						|  | }, | 
					
						
						|  | ], | 
					
						
						|  | } | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | normalize_cfg_datasets(cfg) | 
					
						
						|  |  | 
					
						
						|  | assert cfg.datasets[0].conversation == "vicuna_v1.1" | 
					
						
						|  | assert cfg.datasets[1].conversation == "chatml" | 
					
						
						|  |  | 
					
						
						|  | @patch("axolotl.utils.config.is_torch_bf16_gpu_available") | 
					
						
						|  | def test_bf16_auto_setter_available(self, mock_bf16_avail): | 
					
						
						|  | cfg = self._get_base_cfg() | 
					
						
						|  | cfg.bf16 = "auto" | 
					
						
						|  | mock_bf16_avail.return_value = True | 
					
						
						|  |  | 
					
						
						|  | normalize_config(cfg) | 
					
						
						|  |  | 
					
						
						|  | self.assertTrue(cfg.bf16) | 
					
						
						|  | self.assertFalse(cfg.fp16) | 
					
						
						|  |  | 
					
						
						|  | @patch("axolotl.utils.config.is_torch_bf16_gpu_available") | 
					
						
						|  | def test_bf16_auto_setter_not_available(self, mock_bf16_avail): | 
					
						
						|  | cfg = self._get_base_cfg() | 
					
						
						|  | cfg.bf16 = "auto" | 
					
						
						|  | cfg.fp16 = None | 
					
						
						|  | mock_bf16_avail.return_value = False | 
					
						
						|  |  | 
					
						
						|  | normalize_config(cfg) | 
					
						
						|  |  | 
					
						
						|  | self.assertFalse(cfg.bf16) | 
					
						
						|  | self.assertTrue(cfg.fp16) | 
					
						
						|  |  | 
					
						
						|  | @patch("axolotl.utils.config.is_torch_bf16_gpu_available") | 
					
						
						|  | def test_bf16_disables_fp16(self, mock_bf16_avail): | 
					
						
						|  | cfg = self._get_base_cfg() | 
					
						
						|  | cfg.bf16 = True | 
					
						
						|  | cfg.fp16 = False | 
					
						
						|  | mock_bf16_avail.return_value = True | 
					
						
						|  |  | 
					
						
						|  | normalize_config(cfg) | 
					
						
						|  |  | 
					
						
						|  | self.assertTrue(cfg.bf16) | 
					
						
						|  | self.assertFalse(cfg.fp16) | 
					
						
						|  |  |