Spaces:
Sleeping
Sleeping
""" Dataset factory | |
Updated 2021 Wimlds in Detect Waste in Pomerania | |
""" | |
from collections import OrderedDict | |
from pathlib import Path | |
from .dataset_config import * | |
from .parsers import * | |
from .dataset import DetectionDatset | |
from .parsers import create_parser | |
# list of detect-waste datasets | |
waste_datasets_list = ['taco', 'detectwaste', 'binary', 'multi', | |
'uav', 'mju', 'trashcan', 'wade', 'icra' | |
'drinkwaste'] | |
def create_dataset(name, root, ann, splits=('train', 'val')): | |
if isinstance(splits, str): | |
splits = (splits,) | |
name = name.lower() | |
root = Path(root) | |
dataset_cls = DetectionDatset | |
datasets = OrderedDict() | |
if name.startswith('coco'): | |
if 'coco2014' in name: | |
dataset_cfg = Coco2014Cfg() | |
else: | |
dataset_cfg = Coco2017Cfg() | |
for s in splits: | |
if s not in dataset_cfg.splits: | |
raise RuntimeError(f'{s} split not found in config') | |
split_cfg = dataset_cfg.splits[s] | |
ann_file = root / split_cfg['ann_filename'] | |
parser_cfg = CocoParserCfg( | |
ann_filename=ann_file, | |
has_labels=split_cfg['has_labels'] | |
) | |
datasets[s] = dataset_cls( | |
data_dir=root / Path(split_cfg['img_dir']), | |
parser=create_parser(dataset_cfg.parser, cfg=parser_cfg), | |
) | |
datasets = OrderedDict() | |
elif name in waste_datasets_list: | |
if name.startswith('taco'): | |
dataset_cfg = TACOCfg(root=root, ann=ann) | |
elif name.startswith('detectwaste'): | |
dataset_cfg = DetectwasteCfg(root=root, ann=ann) | |
elif name.startswith('binary'): | |
dataset_cfg = BinaryCfg(root=root, ann=ann) | |
elif name.startswith('multi'): | |
dataset_cfg = BinaryMultiCfg(root=root, ann=ann) | |
elif name.startswith('uav'): | |
dataset_cfg = UAVVasteCfg(root=root, ann=ann) | |
elif name.startswith('trashcan'): | |
dataset_cfg = TrashCanCfg(root=root, ann=ann) | |
elif name.startswith('drinkwaste'): | |
dataset_cfg = DrinkWasteCfg(root=root, ann=ann) | |
elif name.startswith('mju'): | |
dataset_cfg = MJU_WasteCfg(root=root, ann=ann) | |
elif name.startswith('wade'): | |
dataset_cfg = WadeCfg(root=root, ann=ann) | |
elif name.startswith('icra'): | |
dataset_cfg = ICRACfg(root=root, ann=ann) | |
else: | |
assert False, f'Unknown dataset parser ({name})' | |
dataset_cfg.add_split() | |
for s in splits: | |
if s not in dataset_cfg.splits: | |
raise RuntimeError(f'{s} split not found in config') | |
split_cfg = dataset_cfg.splits[s] | |
parser_cfg = CocoParserCfg( | |
ann_filename=split_cfg['ann_filename'], | |
has_labels=split_cfg['has_labels'] | |
) | |
datasets[s] = dataset_cls( | |
data_dir=split_cfg['img_dir'], | |
parser=create_parser(dataset_cfg.parser, cfg=parser_cfg), | |
) | |
else: | |
assert False, f'Unknown dataset parser ({name})' | |
datasets = list(datasets.values()) | |
return datasets if len(datasets) > 1 else datasets[0] | |