tangled-alpha-0.9-core / scripts /backup /prepare_pretrain_base_datasets.py
mtasic85's picture
prepare datasets
734e414
raw
history blame
2.3 kB
from functools import partial
from litgpt.tokenizer import Tokenizer
from litdata import optimize, TokensLoader, StreamingDataset
from transformers import AutoTokenizer
from utils import tokenize_fn
from pretrain_base_datasets import pretrain_base_datasets
from pretrain_instruct_datasets import pretrain_instruct_datasets
from pretrain_reflection_datasets import pretrain_reflection_datasets
from pretrain_reasoning_datasets import pretrain_reasoning_datasets
#
# optimize datasets
#
for i, (block_size, subchunk_size) in enumerate([(4097, 4000)]):
chunk_size = block_size * subchunk_size
output_dir = f'../pretrain-base-data-{i}-{block_size}-{subchunk_size}'
outputs = optimize(
fn=partial(
tokenize_fn,
hf_tokenizer=AutoTokenizer.from_pretrained('..', trust_remote_code=True, use_fast=True),
tokenizer=Tokenizer('..'),
),
inputs=(
pretrain_base_datasets +
pretrain_instruct_datasets +
pretrain_reflection_datasets +
pretrain_reasoning_datasets
),
output_dir=output_dir,
chunk_size=chunk_size, # Number of tokens to store by chunks. This is roughly 64MB of tokens per chunk.
num_workers=32,
reorder_files=False,
## This is important to inform LitData that we are encoding contiguous 1D array (tokens).
## LitData skips storing metadata for each sample e.g all the tokens are concatenated to form one large tensor.
# item_loader=TokensLoader(block_size=block_size),
)
#
# total number of chunks in datasets
#
for i, (block_size, subchunk_size) in enumerate([(4097, 4000)]):
chunk_size = block_size * subchunk_size
input_dir = f'../pretrain-base-data-{i}-{block_size}-{subchunk_size}'
dataset = StreamingDataset(
input_dir=input_dir,
item_loader=TokensLoader(block_size=block_size),
)
print(f'{i=}, {block_size=}, {chunk_size=}, {len(dataset)=}, {len(dataset) * block_size=}')
# total_tokens = sum(len(data) for data in dataset)
# print(f'Total number of tokens in the optimized dataset {input_dir!r} is {total_tokens}')
total_tokens = len(dataset) * block_size
print(f'Total number of tokens in the optimized dataset {input_dir!r} is {total_tokens}')