|
import os, shutil, json
|
|
import pickle, argparse
|
|
|
|
"""Unzip the data and and save it as a pickle file."""
|
|
|
|
def make_pkl(data_dir, dataset_json, train_flag=False):
|
|
coco_data_list = []
|
|
for i, data in enumerate(dataset_json['annotations']):
|
|
image_id = data['image_id']
|
|
caption = data['caption']
|
|
for img in dataset_json['images']:
|
|
if img['id'] == image_id:
|
|
image_url = img['coco_url']
|
|
file_name = img['file_name']
|
|
break
|
|
coco_data_list.append({'image_id': image_id,'image_url': image_url, 'file_name': file_name, 'caption': caption})
|
|
if train_flag:
|
|
with open(os.path.join(data_dir, f'coco_train.pkl'), 'wb') as f:
|
|
pickle.dump(coco_data_list, f)
|
|
else:
|
|
with open(os.path.join(data_dir, f'coco_val.pkl'), 'wb') as f:
|
|
pickle.dump(coco_data_list, f)
|
|
|
|
|
|
def main(coco_path, data_dir):
|
|
coco_dir = os.path.dirname(coco_path)
|
|
|
|
with open(os.path.join(coco_dir, 'annotations/captions_train2017.json')) as f:
|
|
coco_train_dataset = json.load(f)
|
|
with open(os.path.join(coco_dir, 'annotations/captions_val2017.json')) as f:
|
|
coco_val_dataset = json.load(f)
|
|
make_pkl(data_dir, coco_train_dataset, train_flag=True)
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--coco_path', type=str, default='coco.zip')
|
|
parser.add_argument('--data_dir', type=str, default='data')
|
|
args = parser.parse_args()
|
|
main(args.coco_path, args.data_dir)
|
|
|