Tom Jobbins
Debug tokenization output: Add ability to output text only (no tokens), and/or specify num samples to see (#511)
48434be
unverified
| """Prepare and train a model on a dataset. Can also infer from a model or merge lora""" | |
| import importlib | |
| import logging | |
| import os | |
| import random | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Union | |
| import fire | |
| import torch | |
| import transformers | |
| import yaml | |
| # add src to the pythonpath so we don't need to pip install this | |
| from art import text2art | |
| from transformers import GenerationConfig, TextStreamer | |
| from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer | |
| from axolotl.logging_config import configure_logging | |
| from axolotl.train import TrainDatasetMeta, train | |
| from axolotl.utils.config import normalize_config, validate_config | |
| from axolotl.utils.data import prepare_dataset | |
| from axolotl.utils.dict import DictDefault | |
| from axolotl.utils.distributed import is_main_process | |
| from axolotl.utils.models import load_model_config, load_tokenizer | |
| from axolotl.utils.tokenization import check_dataset_labels | |
| from axolotl.utils.wandb import setup_wandb_env_vars | |
| project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
| src_dir = os.path.join(project_root, "src") | |
| sys.path.insert(0, src_dir) | |
| configure_logging() | |
| LOG = logging.getLogger("axolotl.scripts") | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
| def print_axolotl_text_art(suffix=None): | |
| font = "nancyj" | |
| ascii_text = " axolotl" | |
| if suffix: | |
| ascii_text += f" x {suffix}" | |
| ascii_art = text2art(" axolotl", font=font) | |
| if is_main_process(): | |
| print(ascii_art) | |
| def get_multi_line_input() -> Optional[str]: | |
| print("Give me an instruction (Ctrl + D to finish): ") | |
| instruction = "" | |
| for line in sys.stdin: | |
| instruction += line # pylint: disable=consider-using-join | |
| # instruction = pathlib.Path("/proc/self/fd/0").read_text() | |
| return instruction | |
| def do_merge_lora( | |
| *, | |
| cfg: DictDefault, | |
| cli_args: TrainerCliArgs, | |
| ): | |
| model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) | |
| safe_serialization = cfg.save_safetensors is True | |
| LOG.info("running merge of LoRA with base model") | |
| model = model.merge_and_unload() | |
| model.to(dtype=torch.float16) | |
| if cfg.local_rank == 0: | |
| LOG.info("saving merged model") | |
| model.save_pretrained( | |
| str(Path(cfg.output_dir) / "merged"), | |
| safe_serialization=safe_serialization, | |
| ) | |
| tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) | |
| def shard( | |
| *, | |
| cfg: DictDefault, | |
| cli_args: TrainerCliArgs, | |
| ): | |
| model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) | |
| safe_serialization = cfg.save_safetensors is True | |
| LOG.debug("Re-saving model w/ sharding") | |
| model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) | |
| def do_inference( | |
| *, | |
| cfg: DictDefault, | |
| cli_args: TrainerCliArgs, | |
| ): | |
| model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) | |
| prompter = cli_args.prompter | |
| default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"} | |
| for token, symbol in default_tokens.items(): | |
| # If the token isn't already specified in the config, add it | |
| if not (cfg.special_tokens and token in cfg.special_tokens): | |
| tokenizer.add_special_tokens({token: symbol}) | |
| prompter_module = None | |
| if prompter: | |
| prompter_module = getattr( | |
| importlib.import_module("axolotl.prompters"), prompter | |
| ) | |
| if cfg.landmark_attention: | |
| from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id | |
| set_model_mem_id(model, tokenizer) | |
| model.set_mem_cache_args( | |
| max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None | |
| ) | |
| model = model.to(cfg.device) | |
| while True: | |
| print("=" * 80) | |
| # support for multiline inputs | |
| instruction = get_multi_line_input() | |
| if not instruction: | |
| return | |
| if prompter_module: | |
| prompt: str = next( | |
| prompter_module().build_prompt(instruction=instruction.strip("\n")) | |
| ) | |
| else: | |
| prompt = instruction.strip() | |
| batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) | |
| print("=" * 40) | |
| model.eval() | |
| with torch.no_grad(): | |
| generation_config = GenerationConfig( | |
| repetition_penalty=1.1, | |
| max_new_tokens=1024, | |
| temperature=0.9, | |
| top_p=0.95, | |
| top_k=40, | |
| bos_token_id=tokenizer.bos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.pad_token_id, | |
| do_sample=True, | |
| use_cache=True, | |
| return_dict_in_generate=True, | |
| output_attentions=False, | |
| output_hidden_states=False, | |
| output_scores=False, | |
| ) | |
| streamer = TextStreamer(tokenizer) | |
| generated = model.generate( | |
| inputs=batch["input_ids"].to(cfg.device), | |
| generation_config=generation_config, | |
| streamer=streamer, | |
| ) | |
| print("=" * 40) | |
| print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) | |
| def choose_config(path: Path): | |
| yaml_files = list(path.glob("*.yml")) | |
| if not yaml_files: | |
| raise ValueError( | |
| "No YAML config files found in the specified directory. Are you using a .yml extension?" | |
| ) | |
| if len(yaml_files) == 1: | |
| print(f"Using default YAML file '{yaml_files[0]}'") | |
| return yaml_files[0] | |
| print("Choose a YAML file:") | |
| for idx, file in enumerate(yaml_files): | |
| print(f"{idx + 1}. {file}") | |
| chosen_file = None | |
| while chosen_file is None: | |
| try: | |
| choice = int(input("Enter the number of your choice: ")) | |
| if 1 <= choice <= len(yaml_files): | |
| chosen_file = yaml_files[choice - 1] | |
| else: | |
| print("Invalid choice. Please choose a number from the list.") | |
| except ValueError: | |
| print("Invalid input. Please enter a number.") | |
| return chosen_file | |
| def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool: | |
| return not any(el in list2 for el in list1) | |
| def load_cfg(config: Path = Path("examples/"), **kwargs): | |
| if Path(config).is_dir(): | |
| config = choose_config(config) | |
| # load the config from the yaml file | |
| with open(config, encoding="utf-8") as file: | |
| cfg: DictDefault = DictDefault(yaml.safe_load(file)) | |
| # if there are any options passed in the cli, if it is something that seems valid from the yaml, | |
| # then overwrite the value | |
| cfg_keys = cfg.keys() | |
| for k, _ in kwargs.items(): | |
| # if not strict, allow writing to cfg even if it's not in the yml already | |
| if k in cfg_keys or not cfg.strict: | |
| # handle booleans | |
| if isinstance(cfg[k], bool): | |
| cfg[k] = bool(kwargs[k]) | |
| else: | |
| cfg[k] = kwargs[k] | |
| model_config = load_model_config(cfg) | |
| # figure out if the model is llama | |
| cfg.is_llama_derived_model = ( | |
| (hasattr(model_config, "model_type") and model_config.model_type == "llama") | |
| or cfg.is_llama_derived_model | |
| or "llama" in cfg.base_model | |
| or (cfg.model_type and "llama" in cfg.model_type.lower()) | |
| ) | |
| validate_config(cfg) | |
| normalize_config(cfg) | |
| setup_wandb_env_vars(cfg) | |
| return cfg | |
| def load_datasets( | |
| *, | |
| cfg: DictDefault, | |
| cli_args: TrainerCliArgs, | |
| ) -> TrainDatasetMeta: | |
| tokenizer = load_tokenizer(cfg) | |
| train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer) | |
| if cli_args.debug or cfg.debug: | |
| LOG.info("check_dataset_labels...") | |
| check_dataset_labels( | |
| train_dataset.select( | |
| [ | |
| random.randrange(0, len(train_dataset) - 1) # nosec | |
| for _ in range(cli_args.debug_num_examples) | |
| ] | |
| ), | |
| tokenizer, | |
| num_examples=cli_args.debug_num_examples, | |
| text_only=cli_args.debug_text_only, | |
| ) | |
| return TrainDatasetMeta( | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| total_num_steps=total_num_steps, | |
| ) | |
| def do_cli(config: Path = Path("examples/"), **kwargs): | |
| print_axolotl_text_art() | |
| parsed_cfg = load_cfg(config, **kwargs) | |
| parser = transformers.HfArgumentParser((TrainerCliArgs)) | |
| parsed_cli_args, _ = parser.parse_args_into_dataclasses( | |
| return_remaining_strings=True | |
| ) | |
| if parsed_cli_args.inference: | |
| do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args) | |
| elif parsed_cli_args.merge_lora: | |
| do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) | |
| elif parsed_cli_args.shard: | |
| shard(cfg=parsed_cfg, cli_args=parsed_cli_args) | |
| else: | |
| dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) | |
| if parsed_cli_args.prepare_ds_only: | |
| return | |
| train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) | |
| if __name__ == "__main__": | |
| fire.Fire(do_cli) | |