Spaces:
Runtime error
Runtime error
import io | |
import torch | |
import tops | |
from .utils import collate_fn, jpg_decoder, get_num_workers, png_decoder | |
def get_dataloader( | |
dataset, gpu_transform: torch.nn.Module, | |
num_workers, | |
batch_size, | |
infinite: bool, | |
drop_last: bool, | |
prefetch_factor: int, | |
shuffle, | |
channels_last=False | |
): | |
sampler = None | |
dl_kwargs = dict( | |
pin_memory=True, | |
) | |
if infinite: | |
sampler = tops.InfiniteSampler( | |
dataset, rank=tops.rank(), | |
num_replicas=tops.world_size(), | |
shuffle=shuffle | |
) | |
elif tops.world_size() > 1: | |
sampler = torch.utils.data.DistributedSampler( | |
dataset, shuffle=shuffle, num_replicas=tops.world_size(), rank=tops.rank()) | |
dl_kwargs["drop_last"] = drop_last | |
else: | |
dl_kwargs["shuffle"] = shuffle | |
dl_kwargs["drop_last"] = drop_last | |
dataloader = torch.utils.data.DataLoader( | |
dataset, sampler=sampler, collate_fn=collate_fn, | |
batch_size=batch_size, | |
num_workers=num_workers, prefetch_factor=prefetch_factor, | |
**dl_kwargs | |
) | |
dataloader = tops.DataPrefetcher(dataloader, gpu_transform, channels_last=channels_last) | |
return dataloader | |
def get_dataloader_places2_wds( | |
path, | |
batch_size: int, | |
num_workers: int, | |
transform: torch.nn.Module, | |
gpu_transform: torch.nn.Module, | |
infinite: bool, | |
shuffle: bool, | |
partial_batches: bool, | |
sample_shuffle=10_000, | |
tar_shuffle=100, | |
channels_last=False, | |
): | |
import webdataset as wds | |
import os | |
os.environ["RANK"] = str(tops.rank()) | |
os.environ["WORLD_SIZE"] = str(tops.world_size()) | |
if infinite: | |
pipeline = [wds.ResampledShards(str(path))] | |
else: | |
pipeline = [wds.SimpleShardList(str(path))] | |
if shuffle: | |
pipeline.append(wds.shuffle(tar_shuffle)) | |
pipeline.extend([ | |
wds.split_by_node, | |
wds.split_by_worker, | |
]) | |
if shuffle: | |
pipeline.append(wds.shuffle(sample_shuffle)) | |
pipeline.extend([ | |
wds.tarfile_to_samples(), | |
wds.decode("torchrgb8"), | |
wds.rename_keys(["img", "jpg"], ["__key__", "__key__"]), | |
]) | |
if transform is not None: | |
pipeline.append(wds.map(transform)) | |
pipeline.extend([ | |
wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches), | |
]) | |
pipeline = wds.DataPipeline(*pipeline) | |
if infinite: | |
pipeline = pipeline.repeat(nepochs=1000000) | |
loader = wds.WebLoader( | |
pipeline, batch_size=None, shuffle=False, | |
num_workers=get_num_workers(num_workers), | |
persistent_workers=True, | |
) | |
loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False) | |
return loader | |
def get_dataloader_celebAHQ_wds( | |
path, | |
batch_size: int, | |
num_workers: int, | |
transform: torch.nn.Module, | |
gpu_transform: torch.nn.Module, | |
infinite: bool, | |
shuffle: bool, | |
partial_batches: bool, | |
sample_shuffle=10_000, | |
tar_shuffle=100, | |
channels_last=False, | |
): | |
import webdataset as wds | |
import os | |
os.environ["RANK"] = str(tops.rank()) | |
os.environ["WORLD_SIZE"] = str(tops.world_size()) | |
if infinite: | |
pipeline = [wds.ResampledShards(str(path))] | |
else: | |
pipeline = [wds.SimpleShardList(str(path))] | |
if shuffle: | |
pipeline.append(wds.shuffle(tar_shuffle)) | |
pipeline.extend([ | |
wds.split_by_node, | |
wds.split_by_worker, | |
]) | |
if shuffle: | |
pipeline.append(wds.shuffle(sample_shuffle)) | |
pipeline.extend([ | |
wds.tarfile_to_samples(), | |
wds.decode(wds.handle_extension(".png", png_decoder)), | |
wds.rename_keys(["img", "png"], ["__key__", "__key__"]), | |
]) | |
if transform is not None: | |
pipeline.append(wds.map(transform)) | |
pipeline.extend([ | |
wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches), | |
]) | |
pipeline = wds.DataPipeline(*pipeline) | |
if infinite: | |
pipeline = pipeline.repeat(nepochs=1000000) | |
loader = wds.WebLoader( | |
pipeline, batch_size=None, shuffle=False, | |
num_workers=get_num_workers(num_workers), | |
persistent_workers=True, | |
) | |
loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last) | |
return loader | |