Spaces:
Runtime error
Runtime error
File size: 1,199 Bytes
97a6728 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import torch
import tops
from .utils import collate_fn
def get_dataloader(
dataset, gpu_transform: torch.nn.Module,
num_workers,
batch_size,
infinite: bool,
drop_last: bool,
prefetch_factor: int,
shuffle,
channels_last=False
):
sampler = None
dl_kwargs = dict(
pin_memory=True,
)
if infinite:
sampler = tops.InfiniteSampler(
dataset, rank=tops.rank(),
num_replicas=tops.world_size(),
shuffle=shuffle
)
elif tops.world_size() > 1:
sampler = torch.utils.data.DistributedSampler(
dataset, shuffle=shuffle, num_replicas=tops.world_size(), rank=tops.rank())
dl_kwargs["drop_last"] = drop_last
else:
dl_kwargs["shuffle"] = shuffle
dl_kwargs["drop_last"] = drop_last
dataloader = torch.utils.data.DataLoader(
dataset, sampler=sampler, collate_fn=collate_fn,
batch_size=batch_size,
num_workers=num_workers, prefetch_factor=prefetch_factor,
**dl_kwargs
)
dataloader = tops.DataPrefetcher(dataloader, gpu_transform, channels_last=channels_last)
return dataloader
|