ariG23498's picture
ariG23498 HF staff
chore: adding lolcats configs scrc and src
ae81e0f
raw
history blame
4.81 kB
"""
Helper functions dataset setup and loading
"""
import os
from os.path import join
import shutil
import numpy as np
from torch.utils.data import Dataset, DataLoader
from datasets import Dataset as HFDataset
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, LlamaTokenizer
from transformers import DataCollatorForSeq2Seq
# from transformers import DefaultDataCollator, DataCollatorWithPadding
def get_seq2seq_loader(dataset: Dataset, tokenizer: AutoTokenizer,
split: str, **loader_kwargs: any):
"""
Get dataloader for seq2seq tasks (evaluation)
"""
tokenizer.padding_side = 'right'
collate_fn = DataCollatorForSeq2Seq(
tokenizer, label_pad_token_id=-100, return_tensors='pt')
return DataLoader(
dataset, shuffle='train' in split, collate_fn=collate_fn, **loader_kwargs)
def get_lm_loader(dataset: Dataset, tokenizer: AutoTokenizer,
split: str, max_length: int = None, **loader_kwargs: any):
"""
Get dataloader for language modeling (training)
-> Currently this ends up being the same as get_seq2seq_loader
"""
# collate_fn = DefaultDataCollator(return_tensors='pt')
# collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, padding=True,
# max_length=max_length, return_tensors='pt')
collate_fn = DataCollatorForSeq2Seq(
tokenizer, label_pad_token_id=-100, return_tensors='pt')
return DataLoader(
dataset, shuffle='train' in split, collate_fn=collate_fn, **loader_kwargs)
def convert_to_hf_dataset(dataset, cache_dir: str):
"""
Convert iterable dataset to HuggingFace HFDataset object
"""
def gen():
for _, sample in enumerate(dataset):
yield sample # dataset[idx]
return HFDataset.from_generator(gen, cache_dir=cache_dir)
def get_tokenizer_from_config(model_config):
"""
Get pretrained tokenizer based on (pretrained) model config
"""
# Get tokenizer
if 'llama' in model_config['pretrained_model_name_or_path']:
try: # if we store locally
model_path = join(model_config['cache_dir'],
model_config['pretrained_model_name_or_path'])
tokenizer = LlamaTokenizer.from_pretrained(model_path)
except Exception as e:
try:
tokenizer = AutoTokenizer.from_pretrained(**model_config)
print("-> Bad LlamaTokenizer.from_pretrained(model_path)", e)
print("-> But resolved with: AutoTokenizer.from_pretrained(**model_config)")
except Exception as e2:
print("-> Error with AutoTokenizer.from_pretrained(**model_config)", e2)
# tokenizer = LlamaTokenizer.from_pretrained(**model_config) # v4.43 errors with `*** TypeError: not a string`
elif 'Mistral-7B-Instruct-v0.3' in model_config['pretrained_model_name_or_path']:
tokenizer = LlamaTokenizer.from_pretrained(**model_config) # hack where AutoTokenizer doesn't recognize
elif 'Mistral-7B' in model_config['pretrained_model_name_or_path']:
tokenizer = AutoTokenizer.from_pretrained(**model_config)
else:
tokenizer = AutoTokenizer.from_pretrained(**model_config)
return tokenizer
def add_special_tokens_to_dataset(dataset, tokenizer):
"""
Add special tokens as attributes to a dataset object
"""
token_map = {k: v for k, v in tokenizer.special_tokens_map.items()}
special_ids = tokenizer.all_special_ids
for idx, k in enumerate(tokenizer.special_tokens_map.keys()):
token_map[f'{k}_id'] = special_ids[idx]
for k, v in token_map.items():
setattr(dataset, k, v)
return dataset
def train_test_split(samples: any, train_size: int, test_size: int, seed: int):
"""
Split samples into train and test sets
"""
try:
assert len(samples) == train_size + test_size
except Exception as e:
print(len(samples), train_size + test_size)
raise e
arange = np.arange(len(samples))
np.random.seed(seed)
test_idx = np.random.choice(arange, size=test_size, replace=False)
train_idx = np.setdiff1d(arange, test_idx)
return samples[train_idx], samples[test_idx]
def download_scrolls_metric():
"""
Download ROUGE, F1, and other accuracy metrics included in the SCROLLS dataset
"""
scrolls_metric_path = hf_hub_download(
repo_id="tau/scrolls", filename="metrics/scrolls.py", repo_type="dataset"
)
updated_scrolls_metric_path = (
os.path.dirname(scrolls_metric_path) +
os.path.basename(scrolls_metric_path).replace(".", "_") + ".py"
)
shutil.copy(scrolls_metric_path, updated_scrolls_metric_path)
return updated_scrolls_metric_path