File size: 5,340 Bytes
734e414 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import gc
from typing import Optional, Iterator, Callable
import torch
from datasets import load_dataset
from litgpt.tokenizer import Tokenizer
from transformers import AutoTokenizer
def batch_text_iterator(kind: str,
path: str,
name: Optional[str]=None,
data_dir: Optional[str]=None,
data_files: Optional[str]=None,
keep_in_memory: bool=False,
revision: Optional[str]=None,
split: str='train',
num_proc: Optional[int]=None,
format: Optional[Callable|str]=None) -> Iterator[str]:
assert isinstance(format, str) or callable(format), f'{path=} {format=}'
assert kind == 'base'
dataset = load_dataset(path=path,
name=name,
data_dir=data_dir,
data_files=data_files,
keep_in_memory=keep_in_memory,
revision=revision,
split=split,
trust_remote_code=True,
num_proc=num_proc)
if callable(format):
for row in dataset:
text = format(row)
if not text:
continue
yield text
else:
for row in dataset:
text = format.format(**row)
if not text:
continue
yield text
del dataset
gc.collect()
def batch_chat_iterator(kind: str,
path: str,
name: Optional[str]=None,
data_dir: Optional[str]=None,
data_files: Optional[str]=None,
keep_in_memory: bool=False,
revision: Optional[str]=None,
split: str='train',
num_proc: Optional[int]=None,
field: Optional[str]=None,
transform: Optional[Callable]=None) -> Iterator[list[dict[str, str]]]:
assert kind == 'instruct'
dataset = load_dataset(path=path,
name=name,
data_dir=data_dir,
data_files=data_files,
keep_in_memory=keep_in_memory,
revision=revision,
split=split,
trust_remote_code=True,
num_proc=num_proc)
if callable(transform):
for row in dataset:
if field:
messages = transform(row[field])
else:
messages = transform(row)
if not messages:
continue
yield messages
else:
for row in dataset:
if field:
messages = row[field]
else:
raise ValueError(field)
if not messages:
continue
yield messages
del dataset
gc.collect()
# NOTE: used only by tokenizer trainer
def batch_dataset_iterator(dataset_config: dict) -> Iterator[str]:
if dataset_config['kind'] == 'base':
for text in batch_text_iterator(**dataset_config):
yield text
elif dataset_config['kind'] == 'instruct':
for messages in batch_chat_iterator(**dataset_config):
text = '\n'.join(n['content'] for n in messages)
yield text
def tokenize_text_fn(dataset_config: dict, hf_tokenizer: AutoTokenizer, tokenizer: Tokenizer) -> Iterator[torch.Tensor]:
for text in batch_text_iterator(**dataset_config):
text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=True)
yield text_ids
def tokenize_chat_fn(dataset_config: dict, hf_tokenizer: AutoTokenizer, tokenizer: Tokenizer) -> Iterator[torch.Tensor]:
for messages in batch_chat_iterator(**dataset_config):
text: str = hf_tokenizer.apply_chat_template(messages, tokenize=False)
text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=False)
yield text_ids
def tokenize_fn(dataset_config: dict, min_len: int, max_len: int, hf_tokenizer: AutoTokenizer, tokenizer: Tokenizer) -> Iterator[torch.Tensor]:
if dataset_config['kind'] == 'base':
for text in batch_text_iterator(**dataset_config):
try:
text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=True)
except Exception as e:
print(f'Skip base raw: {e=} {type(text)=} {text=}')
continue
if min_len <= len(text_ids) <= max_len:
yield text_ids
elif dataset_config['kind'] == 'instruct':
for messages in batch_chat_iterator(**dataset_config):
try:
text: str = hf_tokenizer.apply_chat_template(messages, tokenize=False)
text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=False)
except Exception as e:
print(f'Skip instruct row: {e=} {type(messages)=} {messages=}')
continue
if min_len <= len(text_ids) <= max_len:
yield text_ids
else:
raise ValueError(dataset_config['kind'])
|