|
""" |
|
Helper functions dataset setup and loading |
|
""" |
|
import os |
|
from os.path import join |
|
import shutil |
|
import numpy as np |
|
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
from datasets import Dataset as HFDataset |
|
from huggingface_hub import hf_hub_download |
|
from transformers import AutoTokenizer, LlamaTokenizer |
|
from transformers import DataCollatorForSeq2Seq |
|
|
|
|
|
|
|
def get_seq2seq_loader(dataset: Dataset, tokenizer: AutoTokenizer, |
|
split: str, **loader_kwargs: any): |
|
""" |
|
Get dataloader for seq2seq tasks (evaluation) |
|
""" |
|
tokenizer.padding_side = 'right' |
|
collate_fn = DataCollatorForSeq2Seq( |
|
tokenizer, label_pad_token_id=-100, return_tensors='pt') |
|
return DataLoader( |
|
dataset, shuffle='train' in split, collate_fn=collate_fn, **loader_kwargs) |
|
|
|
|
|
def get_lm_loader(dataset: Dataset, tokenizer: AutoTokenizer, |
|
split: str, max_length: int = None, **loader_kwargs: any): |
|
""" |
|
Get dataloader for language modeling (training) |
|
-> Currently this ends up being the same as get_seq2seq_loader |
|
""" |
|
|
|
|
|
|
|
collate_fn = DataCollatorForSeq2Seq( |
|
tokenizer, label_pad_token_id=-100, return_tensors='pt') |
|
return DataLoader( |
|
dataset, shuffle='train' in split, collate_fn=collate_fn, **loader_kwargs) |
|
|
|
|
|
def convert_to_hf_dataset(dataset, cache_dir: str): |
|
""" |
|
Convert iterable dataset to HuggingFace HFDataset object |
|
""" |
|
def gen(): |
|
for _, sample in enumerate(dataset): |
|
yield sample |
|
return HFDataset.from_generator(gen, cache_dir=cache_dir) |
|
|
|
|
|
def get_tokenizer_from_config(model_config): |
|
""" |
|
Get pretrained tokenizer based on (pretrained) model config |
|
""" |
|
|
|
if 'llama' in model_config['pretrained_model_name_or_path']: |
|
try: |
|
model_path = join(model_config['cache_dir'], |
|
model_config['pretrained_model_name_or_path']) |
|
tokenizer = LlamaTokenizer.from_pretrained(model_path) |
|
except Exception as e: |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(**model_config) |
|
print("-> Bad LlamaTokenizer.from_pretrained(model_path)", e) |
|
print("-> But resolved with: AutoTokenizer.from_pretrained(**model_config)") |
|
except Exception as e2: |
|
print("-> Error with AutoTokenizer.from_pretrained(**model_config)", e2) |
|
|
|
elif 'Mistral-7B-Instruct-v0.3' in model_config['pretrained_model_name_or_path']: |
|
tokenizer = LlamaTokenizer.from_pretrained(**model_config) |
|
elif 'Mistral-7B' in model_config['pretrained_model_name_or_path']: |
|
tokenizer = AutoTokenizer.from_pretrained(**model_config) |
|
else: |
|
tokenizer = AutoTokenizer.from_pretrained(**model_config) |
|
return tokenizer |
|
|
|
|
|
def add_special_tokens_to_dataset(dataset, tokenizer): |
|
""" |
|
Add special tokens as attributes to a dataset object |
|
""" |
|
token_map = {k: v for k, v in tokenizer.special_tokens_map.items()} |
|
special_ids = tokenizer.all_special_ids |
|
for idx, k in enumerate(tokenizer.special_tokens_map.keys()): |
|
token_map[f'{k}_id'] = special_ids[idx] |
|
for k, v in token_map.items(): |
|
setattr(dataset, k, v) |
|
return dataset |
|
|
|
|
|
def train_test_split(samples: any, train_size: int, test_size: int, seed: int): |
|
""" |
|
Split samples into train and test sets |
|
""" |
|
try: |
|
assert len(samples) == train_size + test_size |
|
except Exception as e: |
|
print(len(samples), train_size + test_size) |
|
raise e |
|
arange = np.arange(len(samples)) |
|
np.random.seed(seed) |
|
test_idx = np.random.choice(arange, size=test_size, replace=False) |
|
train_idx = np.setdiff1d(arange, test_idx) |
|
return samples[train_idx], samples[test_idx] |
|
|
|
|
|
def download_scrolls_metric(): |
|
""" |
|
Download ROUGE, F1, and other accuracy metrics included in the SCROLLS dataset |
|
""" |
|
scrolls_metric_path = hf_hub_download( |
|
repo_id="tau/scrolls", filename="metrics/scrolls.py", repo_type="dataset" |
|
) |
|
updated_scrolls_metric_path = ( |
|
os.path.dirname(scrolls_metric_path) + |
|
os.path.basename(scrolls_metric_path).replace(".", "_") + ".py" |
|
) |
|
shutil.copy(scrolls_metric_path, updated_scrolls_metric_path) |
|
return updated_scrolls_metric_path |
|
|