| """Module containing the InstructShareGPTPromptTokenizingStrategy class""" | |
| from typing import Any, Dict, Optional | |
| from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy | |
| from axolotl.prompters import ShareGPTPrompterV2 | |
| def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): | |
| conversation = ( | |
| ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None | |
| ) | |
| strategy = InstructShareGPTPromptTokenizingStrategy( | |
| # pylint: disable=duplicate-code | |
| ShareGPTPrompterV2( | |
| conversation=conversation, | |
| ), | |
| tokenizer, | |
| cfg.train_on_inputs, | |
| cfg.sequence_len, | |
| ) | |
| return strategy | |
| class InstructShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): | |
| """ | |
| basic sharegpt strategy to grab conversations from the sample row | |
| """ | |
| def get_conversation_thread(self, prompt): | |
| return [ | |
| {"from": "human", "value": prompt["instruction"]}, | |
| {"from": "gpt", "value": prompt["output"]}, | |
| ] | |