File size: 2,411 Bytes
			
			f474650  | 
								1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77  | 
								"""Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class"""
import logging
from typing import Tuple
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter
LOG = logging.getLogger("axolotl")
IGNORE_TOKEN_ID = -100
# pylint: disable=duplicate-code
class MetharmePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
    """
    Tokenizing strategy for the Metharme models
    """
    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
        return (prompt["prompt"], "", prompt["generation"])
    def _tokenize(
        self,
        prompt: str,
        add_eos_token: bool = True,
        strip_bos_token: bool = False,
        num_eos_tokens: int = 3,
    ):
        result = self.tokenizer(
            prompt,
            truncation=True,
            max_length=self.sequence_len,
            padding=False,
            return_tensors=None,
        )
        if len(result["input_ids"]) == 0:
            LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
        # If there's already an EOS token there, subtract from the number added
        if result["input_ids"][-1] == self.tokenizer.eos_token_id:
            num_eos_tokens -= 1
        if num_eos_tokens > 0 and add_eos_token and len(result["input_ids"]) > 0:
            for _ in range(num_eos_tokens):
                if len(result["input_ids"]) < self.sequence_len:
                    result["input_ids"].append(self.tokenizer.eos_token_id)
                    result["attention_mask"].append(1)
        if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
            result["input_ids"] = result["input_ids"][1:]
            result["attention_mask"] = result["attention_mask"][1:]
        result["labels"] = result["input_ids"].copy()
        return result
class MetharmePrompter(AlpacaPrompter):
    """
    Prompter for the Metharme models.
    """
    system_prompt = ""
    system_no_input_prompt = ""
    system_format = ""
    turn_format = "{instruction}"
    turn_no_input_format = "{instruction}"
    def __init__(self, *args, **kwargs):  # pylint: disable=super-init-not-called
        pass
def load(tokenizer, cfg):
    return MetharmePromptTokenizingStrategy(
        MetharmePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
    )
 |