File size: 7,130 Bytes
			
			| 6abb7f6 94f5e41 45ac7c4 cdc71f7 ce24f5e 2bb0b78 ce24f5e 45ac7c4 ce24f5e 553a86b ce24f5e b1f4f7a 2bb0b78 6abb7f6 2bb0b78 6abb7f6 d1236f2 6abb7f6 ce24f5e cdc71f7 2bb0b78 ce24f5e cdc71f7 2bb0b78 cdc71f7 97d3776 2bb0b78 97d3776 45ac7c4 ce24f5e 2bc1a5b 77fca25 ce24f5e d1236f2 ce24f5e 2bc1a5b 6abb7f6 ce24f5e 8d959a7 ce24f5e 2db9436 ce24f5e 2bb0b78 8d959a7 2bb0b78 8d959a7 ce24f5e 8d959a7 2bb0b78 ce24f5e 8d959a7 392dfd9 a6028d3 8d959a7 a6028d3 2bb0b78 8d959a7 392dfd9 2bc1a5b 94f5e41 2bb0b78 94f5e41 553a86b 0f74464 2bc1a5b 37293dc 2bb0b78 37293dc 8d959a7 2bb0b78 8d959a7 7b57ed7 80b2ed2 2bc1a5b 80b2ed2 2bb0b78 80b2ed2 2bc1a5b 2bb0b78 80b2ed2 2bb0b78 80b2ed2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | """Module containing Dataset functionality"""
import logging
import os
from typing import List, Optional
import torch
from datasets import Dataset, IterableDataset
from .prompt_tokenizers import PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded
# lets use the concept of middlewares to wrap each dataset, for example
# ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)]))
# let's check to ensure we don't truncate an item in the middle, we'll use
# the collators later on to pad the datasets
LOG = logging.getLogger("axolotl")
class TokenizedPromptDataset(Dataset):
    """
    Dataset that returns tokenized prompts from a stream of text files.
        Args:
            prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
            dataset (dataset.Dataset): Dataset with text files.
    """
    def __init__(  # pylint: disable=super-init-not-called
        self,
        prompt_tokenizer: PromptTokenizingStrategy,
        dataset: IterableDataset,
        process_count: Optional[int] = None,
        **kwargs,
    ):
        self.prompt_tokenizer = prompt_tokenizer
        self.process_count = process_count
        super().__init__(self.process(dataset).data, **kwargs)
    def process(self, dataset):
        features = dataset.features.keys()
        num_proc = (
            min(64, self.process_count)
            if self.process_count
            else min(64, os.cpu_count())
        )
        map_kwargs = {}
        if self.prompt_tokenizer.supports_batched:
            map_kwargs["batched"] = True
            map_kwargs["batch_size"] = 100
        return dataset.map(
            self.prompt_tokenizer.tokenize_prompt,
            num_proc=num_proc,
            remove_columns=features,
            **map_kwargs,
        )
# TODO this isn't the best since it can't interleave datasets
class ConstantLengthDataset(IterableDataset):
    """
    Iterable dataset that returns constant length chunks of tokens from stream of text files.
        Args:
            tokenizer (Tokenizer): The processor used for processing the data.
            dataset (dataset.Dataset): Dataset with text files.
            seq_length (int): Length of token sequences to return.
    """
    def __init__(  # pylint: disable=super-init-not-called
        self,
        tokenizer,
        datasets,
        seq_length=2048,
    ):
        self.tokenizer = tokenizer
        self.concat_token_id = tokenizer.eos_token_id
        self.datasets: List[IterableDataset] = datasets
        self.seq_length = seq_length
        vocab_size = len(tokenizer.get_vocab())
        if vocab_size <= torch.iinfo(torch.int16).max:
            self.tokens_dtype = torch.int16
        elif vocab_size <= torch.iinfo(torch.int32).max:
            self.tokens_dtype = torch.int32
        else:
            self.tokens_dtype = torch.int64
    def __iter__(self):
        buffer = {
            "input_ids": [],
            "attention_mask": [],
            "labels": [],
            "position_ids": [],
        }
        buffer_len = 0
        for dataset in self.datasets:
            idx = 0
            iterator = iter(dataset)
            more_examples = True
            while more_examples:
                try:
                    example = next(iterator)
                    idx += 1
                except StopIteration:
                    more_examples = False
                    example = None
                add_concat_token = False
                if example:
                    example_len = len(example["input_ids"])
                    add_concat_token = example["input_ids"][-1] != self.concat_token_id
                else:
                    example_len = 0
                if not example_len or (
                    buffer_len + int(add_concat_token) + example_len > self.seq_length
                ):
                    if buffer["input_ids"]:
                        input_ids = torch.cat(buffer["input_ids"], dim=-1)[
                            : self.seq_length
                        ]
                        attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
                            : self.seq_length
                        ]
                        position_ids = torch.cat(buffer["position_ids"], dim=-1)[
                            : self.seq_length
                        ]
                        labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
                        if labels.size() == input_ids.size() and (
                            attention_mask.size() == input_ids.size()
                        ):
                            yield {
                                "input_ids": input_ids,
                                "labels": labels,
                                "attention_mask": attention_mask,
                                "position_ids": position_ids,
                            }
                        else:
                            LOG.warning(
                                f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
                            )
                    buffer = {
                        "input_ids": [],
                        "attention_mask": [],
                        "labels": [],
                        "position_ids": [],
                    }
                    buffer_len = 0
                    idx = 1
                if example:
                    # FIXME
                    # just going to drop data points that are too long
                    if len(example["input_ids"]) <= self.seq_length:
                        input_ids = example["input_ids"]
                        attention_mask = example["attention_mask"]
                        labels = example["labels"]
                        if add_concat_token:
                            input_ids.append(self.concat_token_id)
                            attention_mask.append(1)
                            labels.append(self.concat_token_id)
                        input_ids_with_concat = torch.tensor(
                            input_ids, dtype=self.tokens_dtype
                        )
                        attention_mask_with_concat = torch.tensor(
                            [idx * m for m in attention_mask], dtype=torch.int16
                        )
                        labels_with_concat = torch.tensor(
                            labels, dtype=self.tokens_dtype
                        )
                        position_ids = torch.arange(
                            len(input_ids), dtype=self.tokens_dtype
                        )
                        buffer["input_ids"].append(input_ids_with_concat)
                        buffer["attention_mask"].append(attention_mask_with_concat)
                        buffer["labels"].append(labels_with_concat)
                        buffer["position_ids"].append(position_ids)
                        buffer_len += len(input_ids)
 |