Spaces:
Runtime error
Runtime error
import ast | |
import json | |
import logging | |
import math | |
import os | |
import random | |
import sys | |
from dataclasses import dataclass | |
from multiprocessing import Value | |
import braceexpand | |
import numpy as np | |
import pandas as pd | |
import torch | |
import webdataset as wds | |
from PIL import Image | |
from torch.utils.data import DataLoader, IterableDataset, get_worker_info | |
from webdataset.filters import _shuffle | |
from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample | |
# from data_utils import get_normalized_weights_and_num_samples | |
from typing import List, Tuple | |
def get_normalized_weights_and_num_samples( | |
weights: List[float], num_samples: int | |
) -> Tuple[List[float], List[int]]: | |
# Normalize weights | |
weight_sum = sum(weights) | |
assert weight_sum > 0.0 | |
weights = [weight / weight_sum for weight in weights] | |
# Add 0.5% (the 1.005 factor) so in case the blending dataset does | |
# not uniformly distribute the number of samples, we still have | |
# samples left to feed to the network. | |
weighted_num_samples = [] | |
for weight in weights: | |
weighted_num_samples.append(int(math.ceil(num_samples * weight * 1.005))) | |
return weights, weighted_num_samples | |
class SharedEpoch: | |
def __init__(self, epoch: int = 0): | |
self.shared_epoch = Value('i', epoch) | |
def set_value(self, epoch): | |
self.shared_epoch.value = epoch | |
def get_value(self): | |
return self.shared_epoch.value | |
class DataInfo: | |
dataloader: DataLoader | |
shared_epoch: SharedEpoch = None | |
def set_epoch(self, epoch): | |
if self.shared_epoch is not None: | |
self.shared_epoch.set_value(epoch) | |
def expand_urls(urls, weights=None): | |
if weights is None: | |
expanded_urls = wds.shardlists.expand_urls(urls) | |
expanded_urls=[url for url in expanded_urls if os.path.exists(url)] | |
### go save existed url | |
return expanded_urls, None | |
if isinstance(urls, str): | |
urllist = urls.split("::") | |
weights = weights.split('::') | |
assert len(weights) == len(urllist),\ | |
f"Expected the number of data components ({len(urllist)}) and weights({len(weights)}) to match." | |
weights = [float(weight) for weight in weights] | |
all_urls, all_weights = [], [] | |
for url, weight in zip(urllist, weights): | |
expanded_url = list(braceexpand.braceexpand(url)) | |
expanded_urls=[url for url in expanded_urls if os.path.exists(url)] | |
### go save existed url | |
expanded_weights = [weight for _ in expanded_url] | |
all_urls.extend(expanded_url) | |
all_weights.extend(expanded_weights) | |
return all_urls, all_weights | |
else: | |
all_urls = list(urls) | |
return all_urls, weights | |
def get_dataset_size(shards): | |
shards_list = (shards) | |
dir_path = os.path.dirname(shards_list[0]) | |
sizes_filename = os.path.join(dir_path, 'sizes.json') | |
len_filename = os.path.join(dir_path, '__len__') | |
if os.path.exists(sizes_filename): | |
sizes = json.load(open(sizes_filename, 'r')) | |
total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list]) | |
elif os.path.exists(len_filename): | |
# FIXME this used to be eval(open(...)) but that seemed rather unsafe | |
total_size = ast.literal_eval(open(len_filename, 'r').read()) | |
else: | |
total_size = None # num samples undefined | |
# some common dataset sizes (at time of authors last download) | |
# CC3M (train): 2905954 | |
# CC12M: 10968539 | |
# LAION-400M: 407332084 | |
# LAION-2B (english): 2170337258 | |
num_shards = len(shards_list) | |
return total_size, num_shards | |
def count_samples(dataloader): | |
os.environ["WDS_EPOCH"] = "0" | |
n_elements, n_batches = 0, 0 | |
for images, texts in dataloader: | |
n_batches += 1 | |
n_elements += len(images) | |
assert len(images) == len(texts) | |
return n_elements, n_batches | |
def filter_no_caption_or_no_image(sample): | |
has_caption = ('txt' in sample) | |
has_image = ('png' in sample or 'jpg' in sample or 'jpeg' in sample or 'webp' in sample) | |
return has_caption and has_image | |
def log_and_continue(exn): | |
"""Call in an exception handler to ignore any exception, issue a warning, and continue.""" | |
logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.') | |
return True | |
def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): | |
"""Return function over iterator that groups key, value pairs into samples. | |
:param keys: function that splits the key into key and extension (base_plus_ext) | |
:param lcase: convert suffixes to lower case (Default value = True) | |
""" | |
current_sample = None | |
for filesample in data: | |
assert isinstance(filesample, dict) | |
fname, value = filesample["fname"], filesample["data"] | |
prefix, suffix = keys(fname) | |
if prefix is None: | |
continue | |
if lcase: | |
suffix = suffix.lower() | |
# FIXME webdataset version throws if suffix in current_sample, but we have a potential for | |
# this happening in the current LAION400m dataset if a tar ends with same prefix as the next | |
# begins, rare, but can happen since prefix aren't unique across tar files in that dataset | |
if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample: | |
if valid_sample(current_sample): | |
yield current_sample | |
current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) | |
if suffixes is None or suffix in suffixes: | |
current_sample[suffix] = value | |
if valid_sample(current_sample): | |
yield current_sample | |
def tarfile_to_samples_nothrow(src, handler=log_and_continue): | |
# NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw | |
streams = url_opener(src, handler=handler) | |
files = tar_file_expander(streams, handler=handler) | |
samples = group_by_keys_nothrow(files, handler=handler) | |
return samples | |
def pytorch_worker_seed(increment=0): | |
"""get dataloader worker seed from pytorch""" | |
worker_info = get_worker_info() | |
if worker_info is not None: | |
# favour using the seed already created for pytorch dataloader workers if it exists | |
seed = worker_info.seed | |
if increment: | |
# space out seed increments so they can't overlap across workers in different iterations | |
seed += increment * max(1, worker_info.num_workers) | |
return seed | |
# fallback to wds rank based seed | |
return wds.utils.pytorch_worker_seed() | |
_SHARD_SHUFFLE_SIZE = 2000 | |
_SHARD_SHUFFLE_INITIAL = 500 | |
_SAMPLE_SHUFFLE_SIZE = 5000 | |
_SAMPLE_SHUFFLE_INITIAL = 1000 | |
class detshuffle2(wds.PipelineStage): | |
def __init__( | |
self, | |
bufsize=1000, | |
initial=100, | |
seed=0, | |
epoch=-1, | |
): | |
self.bufsize = bufsize | |
self.initial = initial | |
self.seed = seed | |
self.epoch = epoch | |
def run(self, src): | |
if isinstance(self.epoch, SharedEpoch): | |
epoch = self.epoch.get_value() | |
else: | |
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) | |
# situation as different workers may wrap at different times (or not at all). | |
self.epoch += 1 | |
epoch = self.epoch | |
rng = random.Random() | |
if self.seed < 0: | |
# If seed is negative, we use the worker's seed, this will be different across all nodes/workers | |
seed = pytorch_worker_seed(epoch) | |
else: | |
# This seed to be deterministic AND the same across all nodes/workers in each epoch | |
seed = self.seed + epoch | |
rng.seed(seed) | |
return _shuffle(src, self.bufsize, self.initial, rng) | |
class ResampledShards2(IterableDataset): | |
"""An iterable dataset yielding a list of urls.""" | |
def __init__( | |
self, | |
urls, | |
weights=None, | |
nshards=sys.maxsize, | |
worker_seed=None, | |
deterministic=False, | |
epoch=-1, | |
): | |
"""Sample shards from the shard list with replacement. | |
:param urls: a list of URLs as a Python list or brace notation string | |
""" | |
super().__init__() | |
urls, weights = expand_urls(urls, weights) | |
self.urls = urls | |
self.weights = weights | |
if self.weights is not None: | |
assert len(self.urls) == len(self.weights),\ | |
f"Number of urls {len(self.urls)} and weights {len(self.weights)} should match." | |
assert isinstance(self.urls[0], str) | |
self.nshards = nshards | |
self.rng = random.Random() | |
self.worker_seed = worker_seed | |
self.deterministic = deterministic | |
self.epoch = epoch | |
def __iter__(self): | |
"""Return an iterator over the shards.""" | |
if isinstance(self.epoch, SharedEpoch): | |
epoch = self.epoch.get_value() | |
else: | |
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) | |
# situation as different workers may wrap at different times (or not at all). | |
self.epoch += 1 | |
epoch = self.epoch | |
if self.deterministic: | |
# reset seed w/ epoch if deterministic | |
if self.worker_seed is None: | |
# pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id | |
seed = pytorch_worker_seed(epoch) | |
else: | |
seed = self.worker_seed() + epoch | |
self.rng.seed(seed) | |
for _ in range(self.nshards): | |
if self.weights is None: | |
yield dict(url=self.rng.choice(self.urls)) | |
else: | |
yield dict(url=self.rng.choices(self.urls, weights=self.weights, k=1)[0]) | |
def image_text_dict_collation_fn(samples): | |
"""Customize collation_fn to generate dict batch """ | |
assert isinstance(samples[0], (list, tuple)), type(samples[0]) | |
batched = list(zip(*samples)) | |
result = dict() | |
import torch | |
import numpy as np | |
for b in batched: | |
b = torch.stack(list(b)) | |
if b.dim()>=3: # dim means image | |
result['img']=b | |
else: | |
result['text']=b | |
return result | |
def decode_image(png_bytes): | |
return Image.open(BytesIO(png_bytes)) | |
def process_sample(sample): | |
if "png" not in sample: | |
sample["png"] = b'' | |
else: | |
sample["png"] = decode_image(sample["png"]) | |
sample = {"png": sample["png"], "json": sample["json"]} | |
return sample | |
def get_wds_data(args, is_train, epoch=0, floor=False, wds_processor=None): | |
if args.data_path and (args.train_data_weights is None): | |
args.train_data_weights = [1.0] * len(args.data_path) | |
input_shards = args.data_path if is_train else args.valid_data_path | |
input_weights = args.train_data_weights if is_train else args.valid_data_weights | |
assert input_shards is not None | |
resampled = getattr(args, 'dataset_resampled', False) | |
num_shards = None | |
if is_train: | |
if args.train_num_samples is not None: | |
num_samples = args.train_num_samples | |
else: | |
num_samples, num_shards = get_dataset_size(input_shards) | |
if not num_samples: | |
raise RuntimeError( | |
'Currently, the number of dataset samples must be specified for the training dataset. ' | |
'Please specify it via `--train-num-samples` if no dataset length info is present.') | |
else: | |
# Eval will just exhaust the iterator if the size is not specified. | |
num_samples = args.val_num_samples or 0 | |
weights, weighted_num_samples = get_normalized_weights_and_num_samples(input_weights, num_samples) | |
shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc | |
if resampled: | |
complete_url_list = [] | |
complete_weights = [] | |
for i, (urls, weights) in enumerate(zip(input_shards, weights)): | |
current_url_list = expand_urls(urls)[0] | |
complete_url_list.extend(current_url_list) | |
per_url_weight = weights / len(current_url_list) | |
complete_weights.extend([per_url_weight] * len(current_url_list)) | |
# pipeline = [ResampledShards2( | |
# complete_url_list, | |
# weights=complete_weights, | |
# deterministic=True, | |
# epoch=shared_epoch, | |
# )] | |
pipeline = [wds.SimpleShardList(complete_url_list)] | |
else: | |
# assert args.train_data_upsampling_factors is None,\ | |
# "--train_data_upsampling_factors is only supported when sampling with replacement (with --dataset-resampled)." | |
pipeline = [wds.SimpleShardList(input_shards)] | |
# at this point we have an iterator over all the shards | |
if is_train: | |
if not resampled: | |
pipeline.extend([ | |
detshuffle2( | |
bufsize=_SHARD_SHUFFLE_SIZE, | |
initial=_SHARD_SHUFFLE_INITIAL, | |
# seed=args.seed, | |
epoch=shared_epoch, | |
), | |
wds.split_by_node, | |
wds.split_by_worker, | |
]) | |
pipeline.extend([ | |
# at this point, we have an iterator over the shards assigned to each worker at each node | |
# tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), | |
wds.tarfile_to_samples(handler=log_and_continue), | |
wds.shuffle( | |
bufsize=_SAMPLE_SHUFFLE_SIZE, | |
initial=_SAMPLE_SHUFFLE_INITIAL, | |
), | |
]) | |
else: | |
pipeline.extend([ | |
wds.split_by_worker, | |
# at this point, we have an iterator over the shards assigned to each worker | |
wds.tarfile_to_samples(handler=log_and_continue), | |
]) | |
### build preprocess_img and preprocess_text from args | |
# from .transforms import get_clip_transforms | |
# preprocess_img = get_clip_transforms(image_size=data_args.image_processor.crop_size) | |
# assert ( | |
# args.tokenizer.name in ['HFGPT2Tokenizer','HFGPT2TokenizerFast','HFTokenizer'] | |
# ), f"Webdataset only support HFTokenizer, HFGPT2Tokenizer or HFGPT2TokenizerFast" | |
# tokenize = args.tokenizer.tokenize | |
pipeline.extend([ | |
# wds.select(filter_no_caption_or_no_image), | |
wds.decode("pilrgb", handler=log_and_continue), | |
wds.rename(image="jpg;png;jpeg;webp", text="json"), | |
wds.to_tuple("image", "text"), | |
wds.map(wds_processor) | |
# wds.map_dict(image=preprocess_img, text=lambda text: tokenize(text)[0]), | |
# wds.batched(args.batch_size, collation_fn=image_text_dict_collation_fn, partial=not is_train) | |
]) | |
# pipeline.extend([ | |
# wds.map(process_sample), | |
# wds.rename(image="jpg;png;jpeg;webp", text="json"), | |
# wds.to_tuple("image", "text") | |
# wds.map(wds_processor) | |
# ]) | |
dataset = wds.DataPipeline(*pipeline) | |
if is_train: | |
if not resampled: | |
num_shards = num_shards or len(expand_urls(input_shards)[0]) | |
# assert num_shards >= args.num_workers * args.world_size, 'number of shards must be >= total workers' | |
# roll over and repeat a few samples to get same number of full batches on each node | |
round_fn = math.floor if floor else math.ceil | |
global_batch_size = args.batch_size * args.world_size | |
num_batches = round_fn(num_samples / global_batch_size) | |
num_workers = max(1, args.dataloader_num_workers) | |
num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker | |
num_batches = num_worker_batches * num_workers | |
num_samples = num_batches * global_batch_size | |
dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this | |
# dataset = dataset.with_epoch(num_samples) | |
else: | |
# last batches are partial, eval is done on single (master) node | |
num_batches = math.ceil(num_samples / args.batch_size) | |
# dataloader = wds.WebLoader( | |
# dataset, | |
# batch_size=None, | |
# shuffle=False, | |
# num_workers=args.num_workers, | |
# persistent_workers=not (args.num_workers == 0), # set persistent_workers to false if num_workers is 0 | |
# ) | |
# FIXME not clear which approach is better, with_epoch before vs after dataloader? | |
# hoping to resolve via https://github.com/webdataset/webdataset/issues/169 | |
# if is_train: | |
# # roll over and repeat a few samples to get same number of full batches on each node | |
# global_batch_size = args.batch_size * args.world_size | |
# num_batches = math.ceil(num_samples / global_batch_size) | |
# num_workers = max(1, args.num_workers) | |
# num_batches = math.ceil(num_batches / num_workers) * num_workers | |
# num_samples = num_batches * global_batch_size | |
# dataloader = dataloader.with_epoch(num_batches) | |
# else: | |
# # last batches are partial, eval is done on single (master) node | |
# num_batches = math.ceil(num_samples / args.batch_size) | |
# add meta-data to dataloader instance for convenience | |
# dataloader.num_batches = num_batches | |
# dataloader.num_samples = num_samples | |
return dataset | |
# def get_data(args, preprocess_fns, epoch=0, tokenizer=None): | |
# preprocess_train, preprocess_val = preprocess_fns | |
# data = {} | |
# if args.train_data or args.dataset_type == "synthetic": | |
# data["train"] = get_dataset_fn(args.train_data, args.dataset_type)( | |
# args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer) | |
# if args.val_data: | |
# data["val"] = get_dataset_fn(args.val_data, args.dataset_type)( | |
# args, preprocess_val, is_train=False, tokenizer=tokenizer) | |
# return data |