File size: 694 Bytes
e8861c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import torch
import torch.utils.data as data
from src.core import register
__all__ = ['DataLoader']
@register
class DataLoader(data.DataLoader):
__inject__ = ['dataset', 'collate_fn']
def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
for n in ['dataset', 'batch_size', 'num_workers', 'drop_last', 'collate_fn']:
format_string += "\n"
format_string += " {0}: {1}".format(n, getattr(self, n))
format_string += "\n)"
return format_string
@register
def default_collate_fn(items):
'''default collate_fn
'''
return torch.cat([x[0][None] for x in items], dim=0), [x[1] for x in items]
|