tangled-alpha-0.9-core / scripts /prepare_core_datasets.py
mtasic85's picture
prepare datasets
734e414
raw
history blame
2.32 kB
from functools import partial
from transformers import AutoTokenizer
from litgpt.tokenizer import Tokenizer
from litdata import optimize, TokensLoader, StreamingDataset
from utils import tokenize_fn
from core_base_datasets import core_base_datasets
from core_instruct_datasets import core_instruct_datasets
tokenizer_path = '../tokenizer'
seqs = [
(0, 1073741824, 1025, 16000),
(1025, 2049, 2049, 8000),
(2049, 4097, 4097, 4000),
(4097, 8193, 8193, 2000),
(8193, 16385, 16385, 1000),
(16385, 32769, 32769, 500),
(32769, 65537, 65537, 250),
(65537, 131073, 131073, 125),
]
#
# optimize datasets
#
for i, (min_len, max_len, block_size, subchunk_size) in enumerate(seqs):
chunk_size = block_size * subchunk_size
output_dir = f'../core-data-{i}-{min_len}-{max_len}-{block_size}-{subchunk_size}'
outputs = optimize(
fn=partial(
tokenize_fn,
min_len=min_len,
max_len=max_len,
hf_tokenizer=AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True, use_fast=True),
tokenizer=Tokenizer(tokenizer_path),
),
inputs=core_base_datasets + core_instruct_datasets,
output_dir=output_dir,
chunk_size=chunk_size, # Number of tokens to store by chunks. This is roughly 64MB of tokens per chunk.
num_workers=32,
reorder_files=False,
## This is important to inform LitData that we are encoding contiguous 1D array (tokens).
## LitData skips storing metadata for each sample e.g all the tokens are concatenated to form one large tensor.
# item_loader=TokensLoader(block_size=block_size),
)
#
# total number of chunks in datasets
#
for i, (min_len, max_len, block_size, subchunk_size) in enumerate(seqs):
chunk_size = block_size * subchunk_size
input_dir = f'../core-data-{i}-{min_len}-{max_len}-{block_size}-{subchunk_size}'
dataset = StreamingDataset(
input_dir=input_dir,
item_loader=TokensLoader(block_size=block_size),
)
print(f'{i=}, {min_len=}, {max_len=}, {block_size=}, {chunk_size=}, {len(dataset)=}, {len(dataset) * block_size=}')
total_tokens = len(dataset) * block_size
print(f'Total number of tokens in the optimized dataset {input_dir!r} is {total_tokens}')
print()