Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
import os.path as osp | |
import numpy as np | |
from mmengine.fileio import dump, load | |
from mmengine.utils import mkdir_or_exist, track_parallel_progress | |
prog_description = '''K-Fold coco split. | |
To split coco data for semi-supervised object detection: | |
python tools/misc/split_coco.py | |
''' | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--data-root', | |
type=str, | |
help='The data root of coco dataset.', | |
default='./data/coco/') | |
parser.add_argument( | |
'--out-dir', | |
type=str, | |
help='The output directory of coco semi-supervised annotations.', | |
default='./data/coco/semi_anns/') | |
parser.add_argument( | |
'--labeled-percent', | |
type=float, | |
nargs='+', | |
help='The percentage of labeled data in the training set.', | |
default=[1, 2, 5, 10]) | |
parser.add_argument( | |
'--fold', | |
type=int, | |
help='K-fold cross validation for semi-supervised object detection.', | |
default=5) | |
args = parser.parse_args() | |
return args | |
def split_coco(data_root, out_dir, percent, fold): | |
"""Split COCO data for Semi-supervised object detection. | |
Args: | |
data_root (str): The data root of coco dataset. | |
out_dir (str): The output directory of coco semi-supervised | |
annotations. | |
percent (float): The percentage of labeled data in the training set. | |
fold (int): The fold of dataset and set as random seed for data split. | |
""" | |
def save_anns(name, images, annotations): | |
sub_anns = dict() | |
sub_anns['images'] = images | |
sub_anns['annotations'] = annotations | |
sub_anns['licenses'] = anns['licenses'] | |
sub_anns['categories'] = anns['categories'] | |
sub_anns['info'] = anns['info'] | |
mkdir_or_exist(out_dir) | |
dump(sub_anns, f'{out_dir}/{name}.json') | |
# set random seed with the fold | |
np.random.seed(fold) | |
ann_file = osp.join(data_root, 'annotations/instances_train2017.json') | |
anns = load(ann_file) | |
image_list = anns['images'] | |
labeled_total = int(percent / 100. * len(image_list)) | |
labeled_inds = set( | |
np.random.choice(range(len(image_list)), size=labeled_total)) | |
labeled_ids, labeled_images, unlabeled_images = [], [], [] | |
for i in range(len(image_list)): | |
if i in labeled_inds: | |
labeled_images.append(image_list[i]) | |
labeled_ids.append(image_list[i]['id']) | |
else: | |
unlabeled_images.append(image_list[i]) | |
# get all annotations of labeled images | |
labeled_ids = set(labeled_ids) | |
labeled_annotations, unlabeled_annotations = [], [] | |
for ann in anns['annotations']: | |
if ann['image_id'] in labeled_ids: | |
labeled_annotations.append(ann) | |
else: | |
unlabeled_annotations.append(ann) | |
# save labeled and unlabeled | |
labeled_name = f'instances_train2017.{fold}@{percent}' | |
unlabeled_name = f'instances_train2017.{fold}@{percent}-unlabeled' | |
save_anns(labeled_name, labeled_images, labeled_annotations) | |
save_anns(unlabeled_name, unlabeled_images, unlabeled_annotations) | |
def multi_wrapper(args): | |
return split_coco(*args) | |
if __name__ == '__main__': | |
args = parse_args() | |
arguments_list = [(args.data_root, args.out_dir, p, f) | |
for f in range(1, args.fold + 1) | |
for p in args.labeled_percent] | |
track_parallel_progress(multi_wrapper, arguments_list, args.fold) | |