santit96's picture
Create the streamlit app that classifies the trash in an image into classes
fa84113
raw
history blame
3.25 kB
""" Dataset factory
Updated 2021 Wimlds in Detect Waste in Pomerania
"""
from collections import OrderedDict
from pathlib import Path
from .dataset_config import *
from .parsers import *
from .dataset import DetectionDatset
from .parsers import create_parser
# list of detect-waste datasets
waste_datasets_list = ['taco', 'detectwaste', 'binary', 'multi',
'uav', 'mju', 'trashcan', 'wade', 'icra'
'drinkwaste']
def create_dataset(name, root, ann, splits=('train', 'val')):
if isinstance(splits, str):
splits = (splits,)
name = name.lower()
root = Path(root)
dataset_cls = DetectionDatset
datasets = OrderedDict()
if name.startswith('coco'):
if 'coco2014' in name:
dataset_cfg = Coco2014Cfg()
else:
dataset_cfg = Coco2017Cfg()
for s in splits:
if s not in dataset_cfg.splits:
raise RuntimeError(f'{s} split not found in config')
split_cfg = dataset_cfg.splits[s]
ann_file = root / split_cfg['ann_filename']
parser_cfg = CocoParserCfg(
ann_filename=ann_file,
has_labels=split_cfg['has_labels']
)
datasets[s] = dataset_cls(
data_dir=root / Path(split_cfg['img_dir']),
parser=create_parser(dataset_cfg.parser, cfg=parser_cfg),
)
datasets = OrderedDict()
elif name in waste_datasets_list:
if name.startswith('taco'):
dataset_cfg = TACOCfg(root=root, ann=ann)
elif name.startswith('detectwaste'):
dataset_cfg = DetectwasteCfg(root=root, ann=ann)
elif name.startswith('binary'):
dataset_cfg = BinaryCfg(root=root, ann=ann)
elif name.startswith('multi'):
dataset_cfg = BinaryMultiCfg(root=root, ann=ann)
elif name.startswith('uav'):
dataset_cfg = UAVVasteCfg(root=root, ann=ann)
elif name.startswith('trashcan'):
dataset_cfg = TrashCanCfg(root=root, ann=ann)
elif name.startswith('drinkwaste'):
dataset_cfg = DrinkWasteCfg(root=root, ann=ann)
elif name.startswith('mju'):
dataset_cfg = MJU_WasteCfg(root=root, ann=ann)
elif name.startswith('wade'):
dataset_cfg = WadeCfg(root=root, ann=ann)
elif name.startswith('icra'):
dataset_cfg = ICRACfg(root=root, ann=ann)
else:
assert False, f'Unknown dataset parser ({name})'
dataset_cfg.add_split()
for s in splits:
if s not in dataset_cfg.splits:
raise RuntimeError(f'{s} split not found in config')
split_cfg = dataset_cfg.splits[s]
parser_cfg = CocoParserCfg(
ann_filename=split_cfg['ann_filename'],
has_labels=split_cfg['has_labels']
)
datasets[s] = dataset_cls(
data_dir=split_cfg['img_dir'],
parser=create_parser(dataset_cfg.parser, cfg=parser_cfg),
)
else:
assert False, f'Unknown dataset parser ({name})'
datasets = list(datasets.values())
return datasets if len(datasets) > 1 else datasets[0]