Spaces:
Runtime error
Runtime error
File size: 2,487 Bytes
2df809d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
from .utils.transforms import *
from .base.batched_sampler import BatchedRandomSampler # noqa
from .arkitscenes import ARKitScenes_Multi # noqa
from .arkitscenes_highres import ARKitScenesHighRes_Multi
from .bedlam import BEDLAM_Multi
from .blendedmvs import BlendedMVS_Multi # noqa
from .co3d import Co3d_Multi # noqa
from .cop3d import Cop3D_Multi
from .dl3dv import DL3DV_Multi
from .dynamic_replica import DynamicReplica
from .eden import EDEN_Multi
from .hypersim import HyperSim_Multi
from .hoi4d import HOI4D_Multi
from .irs import IRS
from .mapfree import MapFree_Multi
from .megadepth import MegaDepth_Multi # noqa
from .mp3d import MP3D_Multi
from .mvimgnet import MVImgNet_Multi
from .mvs_synth import MVS_Synth_Multi
from .omniobject3d import OmniObject3D_Multi
from .pointodyssey import PointOdyssey_Multi
from .realestate10k import RE10K_Multi
from .scannet import ScanNet_Multi
from .scannetpp import ScanNetpp_Multi # noqa
from .smartportraits import SmartPortraits_Multi
from .spring import Spring
from .synscapes import SynScapes
from .tartanair import TartanAir_Multi
from .threedkb import ThreeDKenBurns
from .uasol import UASOL_Multi
from .urbansyn import UrbanSyn
from .unreal4k import UnReal4K_Multi
from .vkitti2 import VirtualKITTI2_Multi # noqa
from .waymo import Waymo_Multi # noqa
from .wildrgbd import WildRGBD_Multi # noqa
from accelerate import Accelerator
def get_data_loader(
dataset,
batch_size,
num_workers=8,
shuffle=True,
drop_last=True,
pin_mem=True,
accelerator: Accelerator = None,
fixed_length=False,
):
import torch
# pytorch dataset
if isinstance(dataset, str):
dataset = eval(dataset)
try:
sampler = dataset.make_sampler(
batch_size,
shuffle=shuffle,
drop_last=drop_last,
world_size=accelerator.num_processes,
fixed_length=fixed_length
)
shuffle = False
data_loader = torch.utils.data.DataLoader(
dataset,
batch_sampler=sampler,
num_workers=num_workers,
pin_memory=pin_mem,
)
except (AttributeError, NotImplementedError):
sampler = None
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_mem,
drop_last=drop_last,
)
return data_loader
|