import torch import torch.utils.data as data from src.core import register __all__ = ['DataLoader'] @register class DataLoader(data.DataLoader): __inject__ = ['dataset', 'collate_fn'] def __repr__(self) -> str: format_string = self.__class__.__name__ + "(" for n in ['dataset', 'batch_size', 'num_workers', 'drop_last', 'collate_fn']: format_string += "\n" format_string += " {0}: {1}".format(n, getattr(self, n)) format_string += "\n)" return format_string @register def default_collate_fn(items): '''default collate_fn ''' return torch.cat([x[0][None] for x in items], dim=0), [x[1] for x in items]