File size: 2,541 Bytes
			
			| f7a2263 97d3776 f7a2263 97d3776 f7a2263 97d3776 f7a2263 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | """
Basic completion text
"""
from collections import defaultdict
from typing import Any, Dict, Generator, Optional, Tuple
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
    """
    Tokenizing strategy for Completion prompts.
    """
    _field: str = "text"
    def __init__(self, *args, max_length=None, **kwargs):
        super().__init__(*args, **kwargs)
        if max_length is not None:
            self.max_length = max_length
    @property
    def supports_batched(self):
        return True
    @property
    def field(self) -> str:
        return self._field
    @field.setter
    def field(self, new_field: str):
        self._field = new_field
    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
        return (
            prompt[self.field],
            "",
            "",
        )
    def tokenize_prompt(self, prompt):
        res = defaultdict(lambda: [])
        feature_names = list(prompt.keys())
        for row in zip(*prompt.values()):
            prompt_row = dict(zip(feature_names, row))
            (
                instruction,
                _,
                _,
            ) = self.parse_instruction_fields(prompt_row)
            full_prompt = self._build_full_prompt(instruction, None, None)
            tokenized_full_prompt = self._tokenize(full_prompt)
            for key, val in tokenized_full_prompt.items():
                for i in range(0, len(val), self.sequence_len):
                    res[key].append(val[i : i + self.sequence_len])
        return dict(res)
    def _build_full_prompt(
        self, instruction, input, response
    ):  # pylint: disable=redefined-builtin
        return next(iter(self.prompter.build_prompt(instruction, input, response)))
class CompletionPrompter:
    """
    Prompter for completion
    """
    def build_prompt(
        self,
        instruction: str,
        input=None,  # pylint: disable=redefined-builtin, unused-argument
        output=None,  # pylint: disable=unused-argument
    ) -> Generator[str, None, None]:
        yield instruction
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
    strat = CompletionPromptTokenizingStrategy(
        CompletionPrompter(),
        tokenizer,
        cfg.train_on_inputs,
        cfg.sequence_len,
        max_length=cfg.sequence_len * 64,
    )
    if ds_cfg and "field" in ds_cfg:
        strat.field = ds_cfg["field"]
    return strat
 |