MotionLCM / mld /data /base.py
wxDai's picture
[Init]
eb339cb
raw
history blame
2.42 kB
import copy
from os.path import join as pjoin
from typing import Any, Callable
from torch.utils.data import DataLoader
class BaseDataModule:
def __init__(self, collate_fn: Callable) -> None:
super(BaseDataModule, self).__init__()
self.collate_fn = collate_fn
self.is_mm = False
def get_sample_set(self, overrides: dict) -> Any:
sample_params = copy.deepcopy(self.hparams)
sample_params.update(overrides)
split_file = pjoin(
eval(f"self.cfg.DATASET.{self.name.upper()}.ROOT"),
self.cfg.TEST.SPLIT + ".txt"
)
return self.Dataset(split_file=split_file, **sample_params)
def __getattr__(self, item: str) -> Any:
if item.endswith("_dataset") and not item.startswith("_"):
subset = item[:-len("_dataset")].upper()
item_c = "_" + item
if item_c not in self.__dict__:
split_file = pjoin(
eval(f"self.cfg.DATASET.{self.name.upper()}.ROOT"),
eval(f"self.cfg.{subset}.SPLIT") + ".txt"
)
self.__dict__[item_c] = self.Dataset(split_file=split_file, **self.hparams)
return getattr(self, item_c)
classname = self.__class__.__name__
raise AttributeError(f"'{classname}' object has no attribute '{item}'")
def get_dataloader_options(self, stage: str) -> dict:
stage_args = eval(f"self.cfg.{stage.upper()}")
dataloader_options = {
"batch_size": stage_args.BATCH_SIZE,
"num_workers": stage_args.NUM_WORKERS,
"collate_fn": self.collate_fn,
"persistent_workers": stage_args.PERSISTENT_WORKERS,
}
return dataloader_options
def train_dataloader(self) -> DataLoader:
dataloader_options = self.get_dataloader_options('TRAIN')
return DataLoader(self.train_dataset, shuffle=True, **dataloader_options)
def val_dataloader(self) -> DataLoader:
dataloader_options = self.get_dataloader_options('VAL')
return DataLoader(self.val_dataset, shuffle=False, **dataloader_options)
def test_dataloader(self) -> DataLoader:
dataloader_options = self.get_dataloader_options('TEST')
dataloader_options["batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE
return DataLoader(self.test_dataset, shuffle=False, **dataloader_options)