File size: 2,633 Bytes
			
			d2e7f27  | 
								1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99  | 
								"""
User Defined prompts with configuration from the YML config
"""
from dataclasses import dataclass
from functools import partial
from typing import Optional, Tuple
from axolotl.prompt_strategies.alpaca_w_system import (
    InstructionWSystemPromptTokenizingStrategy,
    SystemDataPrompter,
)
@dataclass
class UserDefinedDatasetConfig:
    """
    dataclass configuration representing a userdefined dataset type
    """
    system_prompt: str = ""
    field_system: str = "system"
    field_instruction: str = "instruction"
    field_input: str = "input"
    field_output: str = "output"
    format: str = "{instruction} {input} "
    no_input_format: str = "{instruction} "
    system_format: str = "{system}"
    def __getitem__(self, item):
        return getattr(self, item)
class UserDefinedPromptTokenizationStrategy(InstructionWSystemPromptTokenizingStrategy):
    """
    Prompt Tokenization Strategy for user defined prompts
    """
def load(tokenizer, cfg, ds_cfg: Optional[UserDefinedDatasetConfig] = None):
    if not ds_cfg:
        raise ValueError("Missing dataset prompt configuration")
    system_prompt = ""
    if ds_cfg.system_prompt:
        system_prompt = ds_cfg.system_prompt
    def parse_instruction_fields(
        field_instruction,
        field_input,
        field_output,
        field_system,
        system_prompt,
        prompt,
    ) -> Tuple[str, str, str, str]:
        return (
            prompt[field_instruction],
            prompt[field_input] if field_input in prompt else "",
            prompt[field_output] if field_output in prompt else "",
            prompt[field_system] if field_system in prompt else system_prompt,
        )
    turn_format = ds_cfg.format
    turn_no_input_format = ds_cfg.no_input_format
    system_format = ds_cfg.system_format
    class UserDefinedPrompter(SystemDataPrompter):
        """
        Prompter for user defined prompts
        """
        def match_prompt_style(self):
            self.turn_format = turn_format
            self.turn_no_input_format = turn_no_input_format
            self.system_format = system_format
    prompter = UserDefinedPrompter()
    strat = UserDefinedPromptTokenizationStrategy(
        prompter,
        tokenizer,
        cfg.train_on_inputs,
        cfg.sequence_len,
    )
    setattr(
        strat,
        "parse_instruction_fields",
        partial(
            parse_instruction_fields,
            ds_cfg.field_instruction,
            ds_cfg.field_input,
            ds_cfg.field_output,
            ds_cfg.field_system,
            system_prompt,
        ),
    )
    return strat
 |