haakohu's picture
:)
548d634
raw
history blame
4.45 kB
import io
import torch
import tops
from .utils import collate_fn, jpg_decoder, get_num_workers, png_decoder
def get_dataloader(
dataset, gpu_transform: torch.nn.Module,
num_workers,
batch_size,
infinite: bool,
drop_last: bool,
prefetch_factor: int,
shuffle,
channels_last=False
):
sampler = None
dl_kwargs = dict(
pin_memory=True,
)
if infinite:
sampler = tops.InfiniteSampler(
dataset, rank=tops.rank(),
num_replicas=tops.world_size(),
shuffle=shuffle
)
elif tops.world_size() > 1:
sampler = torch.utils.data.DistributedSampler(
dataset, shuffle=shuffle, num_replicas=tops.world_size(), rank=tops.rank())
dl_kwargs["drop_last"] = drop_last
else:
dl_kwargs["shuffle"] = shuffle
dl_kwargs["drop_last"] = drop_last
dataloader = torch.utils.data.DataLoader(
dataset, sampler=sampler, collate_fn=collate_fn,
batch_size=batch_size,
num_workers=num_workers, prefetch_factor=prefetch_factor,
**dl_kwargs
)
dataloader = tops.DataPrefetcher(dataloader, gpu_transform, channels_last=channels_last)
return dataloader
def get_dataloader_places2_wds(
path,
batch_size: int,
num_workers: int,
transform: torch.nn.Module,
gpu_transform: torch.nn.Module,
infinite: bool,
shuffle: bool,
partial_batches: bool,
sample_shuffle=10_000,
tar_shuffle=100,
channels_last=False,
):
import webdataset as wds
import os
os.environ["RANK"] = str(tops.rank())
os.environ["WORLD_SIZE"] = str(tops.world_size())
if infinite:
pipeline = [wds.ResampledShards(str(path))]
else:
pipeline = [wds.SimpleShardList(str(path))]
if shuffle:
pipeline.append(wds.shuffle(tar_shuffle))
pipeline.extend([
wds.split_by_node,
wds.split_by_worker,
])
if shuffle:
pipeline.append(wds.shuffle(sample_shuffle))
pipeline.extend([
wds.tarfile_to_samples(),
wds.decode("torchrgb8"),
wds.rename_keys(["img", "jpg"], ["__key__", "__key__"]),
])
if transform is not None:
pipeline.append(wds.map(transform))
pipeline.extend([
wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
])
pipeline = wds.DataPipeline(*pipeline)
if infinite:
pipeline = pipeline.repeat(nepochs=1000000)
loader = wds.WebLoader(
pipeline, batch_size=None, shuffle=False,
num_workers=get_num_workers(num_workers),
persistent_workers=True,
)
loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False)
return loader
def get_dataloader_celebAHQ_wds(
path,
batch_size: int,
num_workers: int,
transform: torch.nn.Module,
gpu_transform: torch.nn.Module,
infinite: bool,
shuffle: bool,
partial_batches: bool,
sample_shuffle=10_000,
tar_shuffle=100,
channels_last=False,
):
import webdataset as wds
import os
os.environ["RANK"] = str(tops.rank())
os.environ["WORLD_SIZE"] = str(tops.world_size())
if infinite:
pipeline = [wds.ResampledShards(str(path))]
else:
pipeline = [wds.SimpleShardList(str(path))]
if shuffle:
pipeline.append(wds.shuffle(tar_shuffle))
pipeline.extend([
wds.split_by_node,
wds.split_by_worker,
])
if shuffle:
pipeline.append(wds.shuffle(sample_shuffle))
pipeline.extend([
wds.tarfile_to_samples(),
wds.decode(wds.handle_extension(".png", png_decoder)),
wds.rename_keys(["img", "png"], ["__key__", "__key__"]),
])
if transform is not None:
pipeline.append(wds.map(transform))
pipeline.extend([
wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
])
pipeline = wds.DataPipeline(*pipeline)
if infinite:
pipeline = pipeline.repeat(nepochs=1000000)
loader = wds.WebLoader(
pipeline, batch_size=None, shuffle=False,
num_workers=get_num_workers(num_workers),
persistent_workers=True,
)
loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last)
return loader