File size: 6,327 Bytes
			
			| 6abb7f6 94f5e41 45ac7c4 ce24f5e 45ac7c4 ce24f5e 553a86b ce24f5e b1f4f7a ce24f5e 6abb7f6 ce24f5e 45ac7c4 ce24f5e 2bc1a5b 77fca25 ce24f5e 2bc1a5b 6abb7f6 ce24f5e 8d959a7 ce24f5e 2db9436 ce24f5e 8d959a7 ce24f5e 8d959a7 ce24f5e 8d959a7 392dfd9 a6028d3 8d959a7 a6028d3 8d959a7 392dfd9 2bc1a5b 94f5e41 553a86b 0f74464 2bc1a5b 37293dc 8d959a7 7b57ed7 80b2ed2 9b8585d 80b2ed2 2bc1a5b 80b2ed2 2db9436 80b2ed2 2bc1a5b 80b2ed2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | """Module containing Dataset functionality"""
import logging
import os
from typing import List
import torch
from datasets import IterableDataset
from .prompt_tokenizers import PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded
# lets use the concept of middlewares to wrap each dataset, for example
# ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)]))
# let's check to ensure we don't truncate an item in the middle, we'll use
# the collators later on to pad the datasets
LOG = logging.getLogger("axolotl")
class TokenizedPromptDataset(IterableDataset):
    """
    Iterable dataset that returns tokenized prompts from a stream of text files.
        Args:
            prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
            dataset (dataset.Dataset): Dataset with text files.
    """
    def __init__(  # pylint: disable=super-init-not-called
        self,
        prompt_tokenizer: PromptTokenizingStrategy,
        dataset: IterableDataset,
    ):
        self.prompt_tokenizer = prompt_tokenizer
        self.dataset = dataset
    def __iter__(self):
        features = self.dataset.features.keys()
        num_proc = os.cpu_count()
        return iter(
            self.dataset.map(
                self.prompt_tokenizer.tokenize_prompt,
                num_proc=num_proc,
                remove_columns=features,
            )
        )
# TODO this isn't the best since it can't interleave datasets
class ConstantLengthDataset(IterableDataset):
    """
    Iterable dataset that returns constant length chunks of tokens from stream of text files.
        Args:
            tokenizer (Tokenizer): The processor used for proccessing the data.
            dataset (dataset.Dataset): Dataset with text files.
            seq_length (int): Length of token sequences to return.
    """
    def __init__(  # pylint: disable=super-init-not-called
        self,
        tokenizer,
        datasets,
        seq_length=2048,
    ):
        self.tokenizer = tokenizer
        self.concat_token_id = tokenizer.eos_token_id
        self.datasets: List[IterableDataset] = datasets
        self.seq_length = seq_length
        vocab_size = len(tokenizer.get_vocab())
        if vocab_size <= torch.iinfo(torch.int16).max:
            self.tokens_dtype = torch.int16
        elif vocab_size <= torch.iinfo(torch.int32).max:
            self.tokens_dtype = torch.int32
        else:
            self.tokens_dtype = torch.int64
    def __iter__(self):
        buffer = {"input_ids": [], "attention_mask": [], "labels": []}
        buffer_len = 0
        for dataset in self.datasets:
            iterator = iter(dataset)
            more_examples = True
            while more_examples:
                try:
                    example = next(iterator)
                except StopIteration:
                    more_examples = False
                    example = None
                add_concat_token = False
                if example:
                    example_len = len(example["input_ids"])
                    add_concat_token = example["input_ids"][-1] != self.concat_token_id
                else:
                    example_len = 0
                if not example_len or (
                    buffer_len + int(add_concat_token) + example_len > self.seq_length
                ):
                    if buffer["input_ids"]:
                        input_ids = torch.cat(buffer["input_ids"], dim=-1)[
                            : self.seq_length
                        ]
                        attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
                            : self.seq_length
                        ]
                        labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
                        if labels.size() == input_ids.size() and (
                            attention_mask.size() == input_ids.size()
                        ):
                            yield {
                                "input_ids": input_ids,
                                "labels": labels,
                                "attention_mask": attention_mask,
                            }
                        else:
                            LOG.warning(
                                f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
                            )
                    buffer = {
                        "input_ids": [],
                        "attention_mask": [],
                        "labels": [],
                    }
                    buffer_len = 0
                if example:
                    # FIXME
                    # just going to drop data points that are too long
                    if len(example["input_ids"]) <= self.seq_length:
                        input_ids = example["input_ids"]
                        attention_mask = example["attention_mask"]
                        labels = example["labels"]
                        if (
                            buffer["input_ids"]
                            and input_ids[0] == self.tokenizer.bos_token_id
                        ):
                            attention_mask[0] = 0
                        if add_concat_token:
                            input_ids.append(self.concat_token_id)
                            attention_mask.append(1)
                            labels.append(self.concat_token_id)
                        input_ids_with_concat = torch.tensor(
                            input_ids, dtype=self.tokens_dtype
                        )
                        attention_mask_with_concat = torch.tensor(
                            attention_mask, dtype=self.tokens_dtype
                        )
                        labels_with_concat = torch.tensor(
                            labels, dtype=self.tokens_dtype
                        )
                        buffer["input_ids"].append(input_ids_with_concat)
                        buffer["attention_mask"].append(attention_mask_with_concat)
                        buffer["labels"].append(labels_with_concat)
                        buffer_len += len(input_ids)
 |