cheng-hust's picture
Upload 91 files
e8861c0 verified
raw
history blame
694 Bytes
import torch
import torch.utils.data as data
from src.core import register
__all__ = ['DataLoader']
@register
class DataLoader(data.DataLoader):
__inject__ = ['dataset', 'collate_fn']
def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
for n in ['dataset', 'batch_size', 'num_workers', 'drop_last', 'collate_fn']:
format_string += "\n"
format_string += " {0}: {1}".format(n, getattr(self, n))
format_string += "\n)"
return format_string
@register
def default_collate_fn(items):
'''default collate_fn
'''
return torch.cat([x[0][None] for x in items], dim=0), [x[1] for x in items]