|
import io |
|
import ast |
|
import json |
|
import logging |
|
import math |
|
import pickle |
|
import os |
|
import random |
|
import sys |
|
import braceexpand |
|
from dataclasses import dataclass |
|
from multiprocessing import Value |
|
from tqdm import tqdm |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
import torchvision.datasets as datasets |
|
import webdataset as wds |
|
from PIL import Image |
|
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info, ConcatDataset |
|
from torch.utils.data.sampler import Sampler |
|
from torch.utils.data.distributed import DistributedSampler |
|
import wids |
|
from webdataset.filters import _shuffle |
|
from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample |
|
|
|
try: |
|
import horovod.torch as hvd |
|
except ImportError: |
|
hvd = None |
|
|
|
|
|
|
|
def custom_collate_fn(batch, first_dataset, second_dataset): |
|
images, captions = [], [] |
|
for dataset_type, idx in batch: |
|
if dataset_type == 0: |
|
image_tensor, caption_tensor = first_dataset[idx] |
|
else: |
|
image_tensor, caption_tensor = second_dataset[idx] |
|
images.append(image_tensor) |
|
captions.append(caption_tensor) |
|
|
|
|
|
image_batch = torch.stack(images, dim=0) |
|
caption_batch = torch.stack(captions, dim=0) |
|
|
|
return image_batch, caption_batch |
|
|
|
|
|
class UniformCombinedSampler(Sampler): |
|
def __init__(self, first_dataset, second_dataset, batch_size, first_dataset_weight=3.3): |
|
self.first_dataset = first_dataset |
|
self.second_dataset = second_dataset |
|
self.batch_size = batch_size |
|
self.first_dataset_weight = first_dataset_weight |
|
self.first_dataset_indices = list(range(len(first_dataset))) |
|
self.second_dataset_prefixes = list(second_dataset.prefix_dict.keys()) |
|
self._reshuffle_indices() |
|
|
|
def _reshuffle_indices(self): |
|
|
|
random.shuffle(self.first_dataset_indices) |
|
random.shuffle(self.second_dataset_prefixes) |
|
|
|
|
|
self.first_dataset_pointer = 0 |
|
self.second_dataset_pointer = 0 |
|
|
|
def __iter__(self): |
|
self._reshuffle_indices() |
|
while self.first_dataset_pointer < len(self.first_dataset_indices) or self.second_dataset_pointer < len(self.second_dataset_prefixes): |
|
batch = [] |
|
while len(batch) < self.batch_size: |
|
if self.first_dataset_pointer < len(self.first_dataset_indices): |
|
|
|
|
|
if random.random() < self.first_dataset_weight / (self.first_dataset_weight + 1): |
|
|
|
idx = self.first_dataset_indices[self.first_dataset_pointer] |
|
self.first_dataset_pointer += 1 |
|
batch.append(idx) |
|
else: |
|
if self.second_dataset_pointer >= len(self.second_dataset_prefixes): |
|
|
|
continue |
|
|
|
prefix = self.second_dataset_prefixes[self.second_dataset_pointer] |
|
self.second_dataset_pointer += 1 |
|
indices = self.second_dataset.prefix_dict[prefix] |
|
if len(indices) >= 2: |
|
pair_indices = random.sample(indices, 2) |
|
if random.random() < 0.5: |
|
batch.append(pair_indices[0] + len(self.first_dataset)) |
|
else: |
|
batch.append(pair_indices[1] + len(self.first_dataset)) |
|
else: |
|
if self.second_dataset_pointer >= len(self.second_dataset_prefixes): |
|
|
|
|
|
while len(batch) < self.batch_size: |
|
print(f"Repeating a sample for the last batch. Current batch size: {len(batch)}") |
|
batch.append(self.first_dataset_indices[len(batch)]) |
|
else: |
|
|
|
prefix = self.second_dataset_prefixes[self.second_dataset_pointer] |
|
self.second_dataset_pointer += 1 |
|
indices = self.second_dataset.prefix_dict[prefix] |
|
if len(indices) >= 2: |
|
pair_indices = random.sample(indices, 2) |
|
if random.random() < 0.5: |
|
batch.append(pair_indices[0] + len(self.first_dataset)) |
|
else: |
|
batch.append(pair_indices[1] + len(self.first_dataset)) |
|
|
|
|
|
if len(batch) > self.batch_size: |
|
batch = batch[:self.batch_size] |
|
break |
|
|
|
yield batch |
|
|
|
def __len__(self): |
|
return (len(self.first_dataset) + len(self.second_dataset)) // self.batch_size |
|
|
|
|
|
class PartiallySyntheticCSVDataset(Dataset): |
|
r"""A dataset to yield an original sample X or a tuple of (X, X_modified_1, ..., X_modified_n) if |
|
synthetic versions of the sample exist. Currently only supports n == 2. |
|
|
|
Args: |
|
data: Dataset for building sampling logic. |
|
path_to_samples: path to the directory with synthetic samples. |
|
""" |
|
|
|
def __init__(self, wds_dataset, transforms, path_to_samples, fnames_modified, sep="\t", |
|
tokenizer=None): |
|
logging.debug(f'Loading csv data from {path_to_samples}.') |
|
self.df = pd.read_csv(fnames_modified, sep=sep, header=None, dtype=str).set_index(0) |
|
self.synthetic_versions = list() |
|
self.no_synthetic_versions = list() |
|
self.path_to_modified_samples = path_to_samples |
|
self.wds_dataset = wds_dataset |
|
self.n_wds_samples = len(wds_dataset) |
|
self.transforms = transforms |
|
self.tokenize = tokenizer |
|
if not os.path.exists("no_synthetic_versions"): |
|
self.merge_keys() |
|
with open("has_synthetic_versions", "wb") as f: |
|
pickle.dump(self.synthetic_versions, f) |
|
with open("no_synthetic_versions", "wb") as f: |
|
pickle.dump(self.no_synthetic_versions, f) |
|
else: |
|
with open("no_synthetic_versions", "rb") as f: |
|
self.no_synthetic_versions = pickle.load(f) |
|
|
|
def merge_keys(self): |
|
print("Creating the mixed dataset") |
|
for i in tqdm(range(self.n_wds_samples)): |
|
wds_sample = self.wds_dataset[i] |
|
image_name = wds_sample[".json"]["key"] |
|
if image_name in self.df.index: |
|
self.synthetic_versions.append(i) |
|
else: |
|
self.no_synthetic_versions.append(i) |
|
|
|
|
|
def __len__(self): |
|
return self.n_wds_samples + self.df.shape[0] // 2 |
|
|
|
def __getitem__(self, item): |
|
wds_sample = self.wds_dataset[item] |
|
image_name = wds_sample[".json"]["key"] |
|
caption = wds_sample[".json"]["caption"] |
|
text = self.tokenize(caption)[0] |
|
original_image = self.transforms(wds_sample['.jpg']) |
|
|
|
captions = list() |
|
if image_name in self.df.index: |
|
modified_samples = list() |
|
for _, row in self.df.loc[image_name].iterrows(): |
|
version = row[1] |
|
caption = row[2] |
|
modified_samples.append( |
|
self.transforms(Image.open(f"{self.path_to_modified_samples}/{image_name}_{version}.png"))) |
|
captions.append(self.tokenize(caption)[0]) |
|
return torch.stack(modified_samples, dim=0), torch.stack(captions, dim=0) |
|
else: |
|
samples = list() |
|
captions = list() |
|
captions.append(text) |
|
samples.append(original_image) |
|
|
|
key = random.choice(self.no_synthetic_versions[max(item - 500, 0):min(self.n_wds_samples, item + 500)]) |
|
wds_sample = self.wds_dataset[key] |
|
|
|
|
|
caption = wds_sample[".json"]["caption"] |
|
text = self.tokenize(caption)[0] |
|
original_image = self.transforms(wds_sample['.jpg']) |
|
captions.append(text) |
|
samples.append(original_image) |
|
|
|
return torch.stack(samples, dim=0), torch.stack(captions, dim=0) |
|
|
|
|
|
class MixedDataset(Dataset): |
|
r"""A dataset to yield an original sample X or a tuple of (X, X_modified_1, ..., X_modified_n) if |
|
synthetic versions of the sample exist. Currently only supports n == 2. |
|
|
|
Args: |
|
data: Dataset for building sampling logic. |
|
path_to_samples: path to the directory with synthetic samples. |
|
""" |
|
|
|
def __init__(self, wds_dataset, transforms, path_to_samples, fnames_modified, sep="\t", |
|
tokenizer=None): |
|
logging.debug(f'Loading csv data from {path_to_samples}.') |
|
self.df = pd.read_csv(fnames_modified, sep=sep, header=None, dtype=str).set_index(0) |
|
self.synthetic_versions = list() |
|
self.no_synthetic_versions = list() |
|
self.path_to_modified_samples = path_to_samples |
|
self.wds_dataset = wds_dataset |
|
self.n_wds_samples = len(wds_dataset) |
|
self.transforms = transforms |
|
self.tokenize = tokenizer |
|
if not os.path.exists("no_synthetic_versions"): |
|
self.merge_keys() |
|
with open("has_synthetic_versions", "wb") as f: |
|
pickle.dump(self.synthetic_versions, f) |
|
with open("no_synthetic_versions", "wb") as f: |
|
pickle.dump(self.no_synthetic_versions, f) |
|
else: |
|
with open("no_synthetic_versions", "rb") as f: |
|
self.no_synthetic_versions = pickle.load(f) |
|
|
|
def merge_keys(self): |
|
print("Creating the mixed dataset") |
|
for i in tqdm(range(self.n_wds_samples)): |
|
wds_sample = self.wds_dataset[i] |
|
image_name = wds_sample[".json"]["key"] |
|
if image_name in self.df.index: |
|
self.synthetic_versions.append(i) |
|
else: |
|
self.no_synthetic_versions.append(i) |
|
|
|
|
|
def __len__(self): |
|
return self.n_wds_samples + self.df.shape[0] // 2 |
|
|
|
def __getitem__(self, item): |
|
wds_sample = self.wds_dataset[item] |
|
image_name = wds_sample[".json"]["key"] |
|
caption = wds_sample[".json"]["caption"] |
|
text = self.tokenize(caption)[0] |
|
original_image = self.transforms(wds_sample['.jpg']) |
|
|
|
captions = list() |
|
if image_name in self.df.index: |
|
modified_samples = list() |
|
for _, row in self.df.loc[image_name].iterrows(): |
|
version = row[1] |
|
caption = row[2] |
|
modified_samples.append( |
|
self.transforms(Image.open(f"{self.path_to_modified_samples}/{image_name}_{version}.png"))) |
|
captions.append(self.tokenize(caption)[0]) |
|
return torch.stack(modified_samples, dim=0), torch.stack(captions, dim=0) |
|
else: |
|
samples = list() |
|
captions = list() |
|
captions.append(text) |
|
samples.append(original_image) |
|
|
|
key = random.choice(self.no_synthetic_versions[max(item - 500, 0):min(self.n_wds_samples, item + 500)]) |
|
wds_sample = self.wds_dataset[key] |
|
|
|
|
|
caption = wds_sample[".json"]["caption"] |
|
text = self.tokenize(caption)[0] |
|
original_image = self.transforms(wds_sample['.jpg']) |
|
captions.append(text) |
|
samples.append(original_image) |
|
|
|
return torch.stack(samples, dim=0), torch.stack(captions, dim=0) |
|
|
|
|
|
|
|
|
|
class CsvDataset(Dataset): |
|
def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t", tokenizer=None, img_root=None): |
|
logging.debug(f'Loading csv data from {input_filename}.') |
|
df = pd.read_csv(input_filename, sep=sep, header=None) |
|
|
|
|
|
self.root = img_root |
|
self.images = df.iloc[:, 0].tolist() |
|
self.captions = df.iloc[:, 1].tolist() |
|
self.transforms = transforms |
|
logging.debug('Done loading data.') |
|
|
|
self.tokenize = tokenizer |
|
|
|
def __len__(self): |
|
return len(self.captions) |
|
|
|
def __getitem__(self, idx): |
|
if self.root is not None: |
|
images = self.transforms(Image.open(self.root + "/" + str(self.images[idx]).zfill(9)+".jpg")) |
|
else: |
|
images = self.transforms(Image.open(str(self.images[idx]))) |
|
texts = self.tokenize([str(self.captions[idx])])[0] |
|
return images, texts |
|
|
|
class CsvDatasetSyn(Dataset): |
|
def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t", tokenizer=None, img_root=None): |
|
logging.debug(f'Loading csv data from {input_filename}.') |
|
df = pd.read_csv(input_filename, sep=sep) |
|
|
|
self.root = img_root |
|
self.synthetic = True |
|
|
|
if len(df.columns) == 3: |
|
self.images = df.iloc[:, 0].tolist() |
|
self.genders = df.iloc[:, 1].tolist() |
|
self.captions = df.iloc[:, 2].tolist() |
|
elif len(df.columns) == 2: |
|
self.images = df.iloc[:, 0].tolist() |
|
self.captions = df.iloc[:, 1].tolist() |
|
self.synthetic = False |
|
else: |
|
raise NotImplementedError |
|
|
|
self.transforms = transforms |
|
logging.debug('Done loading data.') |
|
|
|
self.tokenize = tokenizer |
|
self.prefix_dict = self._create_prefix_dict() |
|
|
|
|
|
def _create_prefix_dict(self): |
|
prefix_dict = {} |
|
for idx, filename in enumerate(self.images): |
|
prefix = filename |
|
if prefix not in prefix_dict: |
|
prefix_dict[prefix] = [] |
|
prefix_dict[prefix].append(idx) |
|
return prefix_dict |
|
|
|
def __len__(self): |
|
return len(self.captions) |
|
|
|
def __getitem__(self, idx): |
|
if self.root is not None: |
|
if self.synthetic: |
|
images = self.transforms(Image.open(self.root + "/" + str(self.images[idx]).zfill(9)+"_"+self.genders[idx]+".png")) |
|
else: |
|
images = self.transforms(Image.open(self.root + "/" + str(self.images[idx]).zfill(9)+".jpg")) |
|
else: |
|
if self.synthetic: |
|
images = self.transforms(Image.open(str(self.images[idx]).zfill(9)+"_"+self.genders[idx]+".png")) |
|
else: |
|
images = self.transforms(Image.open(str(self.images[idx]).zfill(9)+".jpg")) |
|
texts = self.tokenize([str(self.captions[idx])])[0] |
|
return images, texts |
|
|
|
def get_pair_with_same_prefix(self): |
|
|
|
prefix = random.choice(list(self.prefix_dict.keys())) |
|
indices = self.prefix_dict[prefix] |
|
|
|
if len(indices) < 2: |
|
raise ValueError(f"Not enough samples with the prefix {prefix} to form a pair.") |
|
image_1 = self.transforms(Image.open(self.root + "/" + str(self.images[indices[0]]).zfill(9)+"_man.png")) |
|
image_2 = self.transforms(Image.open(self.root + "/" + str(self.images[indices[1]]).zfill(9)+"_woman.png")) |
|
|
|
text_1 = self.tokenize([str(self.captions[indices[0]])])[0] |
|
text_2 = self.tokenize([str(self.captions[indices[1]])])[0] |
|
|
|
return (image_1, text_1), (image_2, text_2) |
|
|
|
|
|
|
|
class SyntheticCsvDataset(Dataset): |
|
def __init__(self, input_filename, transforms, path_to_samples, img_key, caption_key, sep="\t", tokenizer=None): |
|
logging.debug(f'Loading csv data from {input_filename}.') |
|
self.df = pd.read_csv(input_filename, sep=sep) |
|
self.path_to_modified_samples = path_to_samples |
|
self.transforms = transforms |
|
logging.debug('Done loading data.') |
|
|
|
self.tokenize = tokenizer |
|
|
|
def __len__(self): |
|
return len(self.df)//2 |
|
|
|
def __getitem__(self, idx): |
|
captions = list() |
|
modified_samples = list() |
|
for _, row in self.df.iloc[idx*2,:].iterrows(): |
|
version = row[1] |
|
caption = row[2] |
|
modified_samples.append( |
|
self.transforms(Image.open(f"{self.path_to_modified_samples}/{row[0]}_{version}.png"))) |
|
captions.append(self.tokenize(caption)[0]) |
|
return torch.stack(modified_samples, dim=0), torch.stack(captions, dim=0) |
|
|
|
|
|
class SharedEpoch: |
|
def __init__(self, epoch: int = 0): |
|
self.shared_epoch = Value('i', epoch) |
|
|
|
def set_value(self, epoch): |
|
self.shared_epoch.value = epoch |
|
|
|
def get_value(self): |
|
return self.shared_epoch.value |
|
|
|
|
|
@dataclass |
|
class DataInfo: |
|
dataloader: DataLoader |
|
sampler: DistributedSampler = None |
|
shared_epoch: SharedEpoch = None |
|
|
|
def set_epoch(self, epoch): |
|
if self.shared_epoch is not None: |
|
self.shared_epoch.set_value(epoch) |
|
if self.sampler is not None and isinstance(self.sampler, DistributedSampler): |
|
self.sampler.set_epoch(epoch) |
|
|
|
|
|
def expand_urls(urls, weights=None): |
|
if weights is None: |
|
expanded_urls = wds.shardlists.expand_urls(urls) |
|
return expanded_urls, None |
|
if isinstance(urls, str): |
|
urllist = urls.split("::") |
|
weights = weights.split('::') |
|
assert len(weights) == len(urllist), \ |
|
f"Expected the number of data components ({len(urllist)}) and weights({len(weights)}) to match." |
|
weights = [float(weight) for weight in weights] |
|
all_urls, all_weights = [], [] |
|
for url, weight in zip(urllist, weights): |
|
expanded_url = list(braceexpand.braceexpand(url)) |
|
expanded_weights = [weight for _ in expanded_url] |
|
all_urls.extend(expanded_url) |
|
all_weights.extend(expanded_weights) |
|
return all_urls, all_weights |
|
else: |
|
all_urls = list(urls) |
|
return all_urls, weights |
|
|
|
|
|
def get_dataset_size(shards): |
|
shards_list, _ = expand_urls(shards) |
|
dir_path = os.path.dirname(shards_list[0]) |
|
sizes_filename = os.path.join(dir_path, 'sizes.json') |
|
len_filename = os.path.join(dir_path, '__len__') |
|
if os.path.exists(sizes_filename): |
|
sizes = json.load(open(sizes_filename, 'r')) |
|
total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list]) |
|
elif os.path.exists(len_filename): |
|
|
|
total_size = ast.literal_eval(open(len_filename, 'r').read()) |
|
else: |
|
total_size = None |
|
|
|
|
|
|
|
|
|
|
|
num_shards = len(shards_list) |
|
return total_size, num_shards |
|
|
|
|
|
def get_imagenet(args, preprocess_fns, split): |
|
assert split in ["train", "val", "v2"] |
|
is_train = split == "train" |
|
preprocess_train, preprocess_val = preprocess_fns |
|
|
|
if split == "v2": |
|
from imagenetv2_pytorch import ImageNetV2Dataset |
|
dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val) |
|
else: |
|
if is_train: |
|
data_path = args.imagenet_train |
|
preprocess_fn = preprocess_train |
|
else: |
|
data_path = args.imagenet_val |
|
preprocess_fn = preprocess_val |
|
assert data_path |
|
|
|
dataset = datasets.ImageFolder(data_path, transform=preprocess_fn) |
|
|
|
if is_train: |
|
idxs = np.zeros(len(dataset.targets)) |
|
target_array = np.array(dataset.targets) |
|
k = 50 |
|
for c in range(1000): |
|
m = target_array == c |
|
n = len(idxs[m]) |
|
arr = np.zeros(n) |
|
arr[:k] = 1 |
|
np.random.shuffle(arr) |
|
idxs[m] = arr |
|
|
|
idxs = idxs.astype('int') |
|
sampler = SubsetRandomSampler(np.where(idxs)[0]) |
|
else: |
|
sampler = None |
|
|
|
dataloader = torch.utils.data.DataLoader( |
|
dataset, |
|
batch_size=args.batch_size, |
|
num_workers=args.workers, |
|
sampler=sampler, |
|
) |
|
|
|
return DataInfo(dataloader=dataloader, sampler=sampler) |
|
|
|
|
|
def count_samples(dataloader): |
|
os.environ["WDS_EPOCH"] = "0" |
|
n_elements, n_batches = 0, 0 |
|
for images, texts in dataloader: |
|
n_batches += 1 |
|
n_elements += len(images) |
|
assert len(images) == len(texts) |
|
return n_elements, n_batches |
|
|
|
|
|
def filter_no_caption_or_no_image(sample): |
|
has_caption = ('txt' in sample) |
|
has_image = ('png' in sample or 'jpg' in sample or 'jpeg' in sample or 'webp' in sample) |
|
return has_caption and has_image |
|
|
|
|
|
def log_and_continue(exn): |
|
"""Call in an exception handler to ignore any exception, issue a warning, and continue.""" |
|
logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.') |
|
return True |
|
|
|
|
|
def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): |
|
"""Return function over iterator that groups key, value pairs into samples. |
|
|
|
:param keys: function that splits the key into key and extension (base_plus_ext) |
|
:param lcase: convert suffixes to lower case (Default value = True) |
|
""" |
|
current_sample = None |
|
for filesample in data: |
|
|
|
assert isinstance(filesample, dict) |
|
|
|
fname, value = filesample["fname"], filesample["data"] |
|
prefix, suffix = keys(fname) |
|
if prefix is None: |
|
continue |
|
if lcase: |
|
suffix = suffix.lower() |
|
|
|
|
|
|
|
if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample: |
|
if valid_sample(current_sample): |
|
yield current_sample |
|
current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) |
|
if suffixes is None or suffix in suffixes: |
|
current_sample[suffix] = value |
|
if valid_sample(current_sample): |
|
yield current_sample |
|
|
|
|
|
def tarfile_to_samples_nothrow(src, handler=log_and_continue): |
|
|
|
streams = url_opener(src, handler=handler) |
|
files = tar_file_expander(streams, handler=handler, eof_value=None) |
|
samples = group_by_keys_nothrow(files, handler=handler) |
|
return samples |
|
|
|
|
|
def pytorch_worker_seed(increment=0): |
|
"""get dataloader worker seed from pytorch""" |
|
worker_info = get_worker_info() |
|
if worker_info is not None: |
|
|
|
seed = worker_info.seed |
|
if increment: |
|
|
|
seed += increment * max(1, worker_info.num_workers) |
|
return seed |
|
|
|
return wds.utils.pytorch_worker_seed() |
|
|
|
|
|
_SHARD_SHUFFLE_SIZE = 2000 |
|
_SHARD_SHUFFLE_INITIAL = 5 |
|
_SAMPLE_SHUFFLE_SIZE = 5000 |
|
_SAMPLE_SHUFFLE_INITIAL = 1000 |
|
|
|
|
|
class detshuffle2(wds.PipelineStage): |
|
def __init__( |
|
self, |
|
bufsize=1000, |
|
initial=100, |
|
seed=0, |
|
epoch=-1, |
|
): |
|
self.bufsize = bufsize |
|
self.initial = initial |
|
self.seed = seed |
|
self.epoch = epoch |
|
|
|
def run(self, src): |
|
if isinstance(self.epoch, SharedEpoch): |
|
epoch = self.epoch.get_value() |
|
else: |
|
|
|
|
|
self.epoch += 1 |
|
epoch = self.epoch |
|
rng = random.Random() |
|
if self.seed < 0: |
|
|
|
seed = pytorch_worker_seed(epoch) |
|
else: |
|
|
|
seed = self.seed + epoch |
|
rng.seed(seed) |
|
return _shuffle(src, self.bufsize, self.initial, rng) |
|
|
|
|
|
class ResampledShards2(IterableDataset): |
|
"""An iterable dataset yielding a list of urls.""" |
|
|
|
def __init__( |
|
self, |
|
urls, |
|
weights=None, |
|
nshards=sys.maxsize, |
|
worker_seed=None, |
|
deterministic=False, |
|
epoch=-1, |
|
): |
|
"""Sample shards from the shard list with replacement. |
|
|
|
:param urls: a list of URLs as a Python list or brace notation string |
|
""" |
|
super().__init__() |
|
urls, weights = expand_urls(urls, weights) |
|
self.urls = urls |
|
self.weights = weights |
|
if self.weights is not None: |
|
assert len(self.urls) == len(self.weights), \ |
|
f"Number of urls {len(self.urls)} and weights {len(self.weights)} should match." |
|
assert isinstance(self.urls[0], str) |
|
self.nshards = nshards |
|
self.rng = random.Random() |
|
self.worker_seed = worker_seed |
|
self.deterministic = deterministic |
|
self.epoch = epoch |
|
|
|
def __iter__(self): |
|
"""Return an iterator over the shards.""" |
|
if isinstance(self.epoch, SharedEpoch): |
|
epoch = self.epoch.get_value() |
|
else: |
|
|
|
|
|
self.epoch += 1 |
|
epoch = self.epoch |
|
if self.deterministic: |
|
|
|
if self.worker_seed is None: |
|
|
|
seed = pytorch_worker_seed(epoch) |
|
else: |
|
seed = self.worker_seed() + epoch |
|
self.rng.seed(seed) |
|
for _ in range(self.nshards): |
|
if self.weights is None: |
|
yield dict(url=self.rng.choice(self.urls)) |
|
else: |
|
yield dict(url=self.rng.choices(self.urls, weights=self.weights, k=1)[0]) |
|
|
|
|
|
def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokenizer=None): |
|
input_shards = args.train_data if is_train else args.val_data |
|
assert input_shards is not None |
|
resampled = getattr(args, 'dataset_resampled', False) and is_train |
|
|
|
num_shards = None |
|
if is_train: |
|
if args.train_num_samples is not None: |
|
num_samples = args.train_num_samples |
|
else: |
|
num_samples, num_shards = get_dataset_size(input_shards) |
|
if not num_samples: |
|
raise RuntimeError( |
|
'Currently, the number of dataset samples must be specified for the training dataset. ' |
|
'Please specify it via `--train-num-samples` if no dataset length info is present.') |
|
else: |
|
|
|
num_samples = args.val_num_samples or 0 |
|
|
|
shared_epoch = SharedEpoch(epoch=epoch) |
|
|
|
if is_train and args.train_data_upsampling_factors is not None: |
|
assert resampled, "--train_data_upsampling_factors is only supported when sampling with replacement (with --dataset-resampled)." |
|
|
|
if resampled: |
|
pipeline = [ResampledShards2( |
|
input_shards, |
|
weights=args.train_data_upsampling_factors, |
|
deterministic=True, |
|
epoch=shared_epoch, |
|
)] |
|
else: |
|
pipeline = [wds.SimpleShardList(input_shards)] |
|
|
|
|
|
if is_train: |
|
if not resampled: |
|
pipeline.extend([ |
|
detshuffle2( |
|
bufsize=_SHARD_SHUFFLE_SIZE, |
|
initial=_SHARD_SHUFFLE_INITIAL, |
|
seed=args.seed, |
|
epoch=shared_epoch, |
|
), |
|
wds.split_by_node, |
|
wds.split_by_worker, |
|
]) |
|
pipeline.extend([ |
|
|
|
tarfile_to_samples_nothrow, |
|
wds.shuffle( |
|
bufsize=_SAMPLE_SHUFFLE_SIZE, |
|
initial=_SAMPLE_SHUFFLE_INITIAL, |
|
), |
|
]) |
|
else: |
|
pipeline.extend([ |
|
wds.split_by_worker, |
|
|
|
wds.tarfile_to_samples(handler=log_and_continue), |
|
]) |
|
pipeline.extend([ |
|
wds.select(filter_no_caption_or_no_image), |
|
wds.decode("pilrgb", handler=log_and_continue), |
|
wds.rename(image="jpg;png;jpeg;webp", text="txt"), |
|
wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), |
|
wds.to_tuple("image", "text"), |
|
wds.batched(args.batch_size, partial=not is_train) |
|
]) |
|
|
|
dataset = wds.DataPipeline(*pipeline) |
|
|
|
if is_train: |
|
if not resampled: |
|
num_shards = num_shards or len(expand_urls(input_shards)[0]) |
|
assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' |
|
|
|
round_fn = math.floor if floor else math.ceil |
|
global_batch_size = args.batch_size * args.world_size |
|
num_batches = round_fn(num_samples / global_batch_size) |
|
num_workers = max(1, args.workers) |
|
num_worker_batches = round_fn(num_batches / num_workers) |
|
num_batches = num_worker_batches * num_workers |
|
num_samples = num_batches * global_batch_size |
|
dataset = dataset.with_epoch(num_worker_batches) |
|
else: |
|
|
|
num_batches = math.ceil(num_samples / args.batch_size) |
|
|
|
dataloader = wds.WebLoader( |
|
dataset, |
|
batch_size=None, |
|
shuffle=False, |
|
num_workers=args.workers, |
|
persistent_workers=args.workers > 0, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataloader.num_batches = num_batches |
|
dataloader.num_samples = num_samples |
|
|
|
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) |
|
|
|
|
|
def get_csv_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): |
|
input_filename = args.train_data if is_train else args.val_data |
|
assert input_filename |
|
dataset = CsvDataset( |
|
input_filename, |
|
preprocess_fn, |
|
img_key=args.csv_img_key, |
|
caption_key=args.csv_caption_key, |
|
sep=args.csv_separator, |
|
tokenizer=tokenizer, |
|
img_root=args.train_data_root, |
|
) |
|
num_samples = len(dataset) |
|
sampler = DistributedSampler(dataset) if args.distributed and is_train else None |
|
shuffle = is_train and sampler is None |
|
|
|
dataloader = DataLoader( |
|
dataset, |
|
batch_size=args.batch_size, |
|
shuffle=shuffle, |
|
num_workers=args.workers, |
|
pin_memory=True, |
|
sampler=sampler, |
|
drop_last=is_train, |
|
) |
|
dataloader.num_samples = num_samples |
|
dataloader.num_batches = len(dataloader) |
|
|
|
return DataInfo(dataloader, sampler) |
|
|
|
|
|
class SyntheticDataset(Dataset): |
|
|
|
def __init__( |
|
self, |
|
transform=None, |
|
image_size=(224, 224), |
|
caption="Dummy caption", |
|
dataset_size=100, |
|
tokenizer=None, |
|
): |
|
self.transform = transform |
|
self.image_size = image_size |
|
self.caption = caption |
|
self.image = Image.new('RGB', image_size) |
|
self.dataset_size = dataset_size |
|
|
|
self.preprocess_txt = lambda text: tokenizer(text)[0] |
|
|
|
def __len__(self): |
|
return self.dataset_size |
|
|
|
def __getitem__(self, idx): |
|
if self.transform is not None: |
|
image = self.transform(self.image) |
|
return image, self.preprocess_txt(self.caption) |
|
|
|
|
|
def get_synthetic_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): |
|
image_size = preprocess_fn.transforms[0].size |
|
dataset = SyntheticDataset( |
|
transform=preprocess_fn, image_size=image_size, dataset_size=args.train_num_samples, tokenizer=tokenizer) |
|
num_samples = len(dataset) |
|
sampler = DistributedSampler(dataset) if args.distributed and is_train else None |
|
shuffle = is_train and sampler is None |
|
|
|
dataloader = DataLoader( |
|
dataset, |
|
batch_size=args.batch_size, |
|
shuffle=shuffle, |
|
num_workers=args.workers, |
|
pin_memory=True, |
|
sampler=sampler, |
|
drop_last=is_train, |
|
) |
|
dataloader.num_samples = num_samples |
|
dataloader.num_batches = len(dataloader) |
|
|
|
return DataInfo(dataloader, sampler) |
|
|
|
|
|
def get_dataset_fn(data_path, dataset_type): |
|
if dataset_type == "webdataset": |
|
return get_wds_dataset |
|
elif dataset_type == "csv": |
|
return get_csv_dataset |
|
elif dataset_type == "synthetic": |
|
return get_synthetic_dataset |
|
elif dataset_type == "auto": |
|
ext = data_path.split('.')[-1] |
|
if ext in ['csv', 'tsv']: |
|
return get_csv_dataset |
|
elif ext in ['tar']: |
|
return get_wds_dataset |
|
else: |
|
raise ValueError( |
|
f"Tried to figure out dataset type, but failed for extension {ext}.") |
|
else: |
|
raise ValueError(f"Unsupported dataset type: {dataset_type}") |
|
|
|
|
|
|
|
|
|
def get_dataset_synthetic_counterfactual(args, preprocess_img, is_train, epoch=0, floor=False, tokenizer=None): |
|
input_filename_synthetic = args.synthetic_csv |
|
|
|
assert input_filename_synthetic |
|
|
|
dataset_synthetic = CsvDatasetSyn( |
|
args.synthetic_csv, |
|
preprocess_img, |
|
img_key=args.csv_img_key, |
|
caption_key=args.csv_caption_key, |
|
sep="\t", |
|
tokenizer=tokenizer, |
|
img_root=args.synthetic_path |
|
) |
|
|
|
num_samples = len(dataset_synthetic) |
|
|
|
dataset = dataset_synthetic |
|
sampler = DistributedSampler(dataset) if args.distributed and is_train else None |
|
dataloader = DataLoader( |
|
dataset, |
|
batch_size=args.batch_size, |
|
shuffle=True, |
|
num_workers=args.workers, |
|
pin_memory=True, |
|
drop_last=is_train, |
|
sampler=sampler, |
|
|
|
|
|
) |
|
dataloader.num_samples = num_samples |
|
dataloader.num_batches = len(dataloader) |
|
|
|
return DataInfo(dataloader, sampler) |
|
|
|
|
|
def get_dataset_mixed(args, preprocess_img, is_train, epoch=0, floor=False, tokenizer=None): |
|
|
|
|
|
input_filename = args.train_data if is_train else args.val_data |
|
input_filename_synthetic = args.synthetic_csv |
|
|
|
assert input_filename |
|
assert input_filename_synthetic |
|
|
|
dataset_real = CsvDataset( |
|
input_filename, |
|
preprocess_img, |
|
img_key=args.csv_img_key, |
|
caption_key=args.csv_caption_key, |
|
sep="\t", |
|
tokenizer=tokenizer, |
|
img_root="/home/kis/datasets/cc3m_attempt12" |
|
) |
|
|
|
dataset_synthetic = CsvDatasetSyn( |
|
args.synthetic_csv, |
|
preprocess_img, |
|
img_key=args.csv_img_key, |
|
caption_key=args.csv_caption_key, |
|
sep="\t", |
|
tokenizer=tokenizer, |
|
img_root=args.synthetic_path |
|
) |
|
|
|
num_samples = len(dataset_real) + len(dataset_synthetic) |
|
|
|
dataset = ConcatDataset([dataset_real, dataset_synthetic]) |
|
|
|
uniform_combined_sampler = UniformCombinedSampler(dataset_real, |
|
dataset_synthetic, |
|
args.batch_size, |
|
first_dataset_weight=len(dataset_real)/len(dataset_synthetic)) |
|
dataloader = DataLoader( |
|
dataset, |
|
|
|
|
|
num_workers=args.workers, |
|
pin_memory=True, |
|
|
|
batch_sampler=uniform_combined_sampler, |
|
|
|
) |
|
dataloader.num_samples = num_samples |
|
dataloader.num_batches = len(dataloader) |
|
|
|
return DataInfo(dataloader, uniform_combined_sampler) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_shards = args.train_data if is_train else args.val_data |
|
assert input_shards is not None |
|
wds = wids.ShardListDataset(input_shards) |
|
mixed_dataset = MixedDataset(wds, transforms=preprocess_img, path_to_samples=args.synthetic_path, |
|
fnames_modified=args.synthetic_csv, sep="\t", |
|
tokenizer=tokenizer) |
|
|
|
num_samples = len(mixed_dataset) |
|
|
|
mixed_dataloader = DataLoader( |
|
mixed_dataset, |
|
batch_size=args.batch_size, |
|
shuffle=False, |
|
num_workers=args.workers, |
|
pin_memory=True, |
|
drop_last=is_train, |
|
) |
|
|
|
print(mixed_dataloader) |
|
mixed_dataloader.num_samples = num_samples |
|
mixed_dataloader.num_batches = len(mixed_dataloader) |
|
shared_epoch = SharedEpoch(epoch=epoch) |
|
return DataInfo(mixed_dataloader, shared_epoch=shared_epoch) |
|
|
|
|
|
assert input_shards is not None |
|
resampled = getattr(args, 'dataset_resampled', False) and is_train |
|
|
|
num_shards = None |
|
if is_train: |
|
if args.train_num_samples is not None: |
|
num_samples = args.train_num_samples |
|
else: |
|
num_samples, num_shards = get_dataset_size(input_shards) |
|
if not num_samples: |
|
raise RuntimeError( |
|
'Currently, the number of dataset samples must be specified for the training dataset. ' |
|
'Please specify it via `--train-num-samples` if no dataset length info is present.') |
|
else: |
|
|
|
num_samples = args.val_num_samples or 0 |
|
|
|
shared_epoch = SharedEpoch(epoch=epoch) |
|
|
|
if is_train and args.train_data_upsampling_factors is not None: |
|
assert resampled, "--train_data_upsampling_factors is only supported when sampling with replacement (with --dataset-resampled)." |
|
|
|
if resampled: |
|
pipeline = [ResampledShards2( |
|
input_shards, |
|
weights=args.train_data_upsampling_factors, |
|
deterministic=True, |
|
epoch=shared_epoch, |
|
)] |
|
else: |
|
pipeline = [wds.SimpleShardList(input_shards)] |
|
|
|
|
|
if is_train: |
|
if not resampled: |
|
pipeline.extend([ |
|
detshuffle2( |
|
bufsize=_SHARD_SHUFFLE_SIZE, |
|
initial=_SHARD_SHUFFLE_INITIAL, |
|
seed=args.seed, |
|
epoch=shared_epoch, |
|
), |
|
wds.split_by_node, |
|
wds.split_by_worker, |
|
]) |
|
pipeline.extend([ |
|
|
|
tarfile_to_samples_nothrow, |
|
wds.shuffle( |
|
bufsize=_SAMPLE_SHUFFLE_SIZE, |
|
initial=_SAMPLE_SHUFFLE_INITIAL, |
|
), |
|
]) |
|
else: |
|
pipeline.extend([ |
|
wds.split_by_worker, |
|
|
|
wds.tarfile_to_samples(handler=log_and_continue), |
|
]) |
|
pipeline.extend([ |
|
wds.select(filter_no_caption_or_no_image), |
|
wds.decode("pilrgb", handler=log_and_continue), |
|
wds.rename(image="jpg;png;jpeg;webp", text="txt"), |
|
wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), |
|
wds.to_tuple("image", "text"), |
|
wds.batched(args.batch_size, partial=not is_train) |
|
]) |
|
|
|
dataset = wds.DataPipeline(*pipeline) |
|
|
|
if is_train: |
|
if not resampled: |
|
num_shards = num_shards or len(expand_urls(input_shards)[0]) |
|
assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' |
|
|
|
round_fn = math.floor if floor else math.ceil |
|
global_batch_size = args.batch_size * args.world_size |
|
num_batches = round_fn(num_samples / global_batch_size) |
|
num_workers = max(1, args.workers) |
|
num_worker_batches = round_fn(num_batches / num_workers) |
|
num_batches = num_worker_batches * num_workers |
|
num_samples = num_batches * global_batch_size |
|
dataset = dataset.with_epoch(num_worker_batches) |
|
else: |
|
|
|
num_batches = math.ceil(num_samples / args.batch_size) |
|
|
|
dataset.pipeline = dataset.pipeline[:6] |
|
wds_dataloader = wds.WebLoader( |
|
dataset, |
|
batch_size=None, |
|
shuffle=False, |
|
num_workers=args.workers, |
|
persistent_workers=args.workers > 0, |
|
) |
|
|
|
wds_dataloader.num_batches = num_batches |
|
wds_dataloader.num_samples = num_samples |
|
|
|
mixed_dataset = MixedDataset(wds_dataloader, transforms=preprocess_img, path_to_samples=args.synthetic_path, |
|
fnames_modified=args.synthetic_csv, n_wds_samples=num_samples, sep="\t", |
|
tokenizer=tokenizer) |
|
|
|
num_samples = len(mixed_dataset) |
|
|
|
mixed_dataloader = DataLoader( |
|
mixed_dataset, |
|
batch_size=args.batch_size, |
|
shuffle=False, |
|
num_workers=args.workers, |
|
pin_memory=True, |
|
drop_last=is_train, |
|
) |
|
|
|
print(mixed_dataloader) |
|
mixed_dataloader.num_samples = num_samples |
|
mixed_dataloader.num_batches = len(mixed_dataloader) |
|
|
|
return DataInfo(mixed_dataloader, shared_epoch=shared_epoch) |
|
|
|
|
|
|
|
|
|
def get_data(args, preprocess_fns, epoch=0, tokenizer=None): |
|
preprocess_train, preprocess_val = preprocess_fns |
|
data = {} |
|
|
|
if args.dataset_type == "wds_csv_mixed": |
|
if args.train_data: |
|
data["train"] = get_dataset_mixed(args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer) |
|
elif args.val_data: |
|
data["val"] = get_dataset_mixed(args, preprocess_val, is_train=False, tokenizer=tokenizer) |
|
elif args.dataset_type == "synthetic_counterfactual": |
|
if args.train_data: |
|
data["train"] = get_dataset_synthetic_counterfactual(args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer) |
|
elif args.val_data: |
|
data["val"] = get_dataset_synthetic_counterfactual(args, preprocess_val, is_train=False, tokenizer=tokenizer) |
|
else: |
|
if args.train_data or args.dataset_type == "synthetic": |
|
data["train"] = get_dataset_fn(args.train_data, args.dataset_type)( |
|
args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer) |
|
|
|
if args.val_data: |
|
data["val"] = get_dataset_fn(args.val_data, args.dataset_type)( |
|
args, preprocess_val, is_train=False, tokenizer=tokenizer) |
|
|
|
if args.imagenet_val is not None: |
|
data["imagenet-val"] = get_imagenet(args, preprocess_fns, "val") |
|
|
|
if args.imagenet_v2 is not None: |
|
data["imagenet-v2"] = get_imagenet(args, preprocess_fns, "v2") |
|
|
|
return data |
|
|
|
|
|
if __name__ == '__main__': |
|
pass |