""" Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import json import logging import os import time from multiprocessing import Pool import numpy as np import requests import tqdm from lavis.common.utils import cleanup_dir, get_abs_path, get_cache_path from omegaconf import OmegaConf header_mzl = { "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36", # "User-Agent": "Googlebot-Image/1.0", # Pretend to be googlebot # "X-Forwarded-For": "64.18.15.200", } header_gbot = { "User-Agent": "Googlebot-Image/1.0", # Pretend to be googlebot } headers = [header_mzl, header_gbot] # Setup logging.basicConfig(filename="download_nocaps.log", filemode="w", level=logging.INFO) requests.packages.urllib3.disable_warnings( requests.packages.urllib3.exceptions.InsecureRequestWarning ) def download_file(url, filename): max_retries = 20 cur_retries = 0 header = headers[0] while cur_retries < max_retries: try: r = requests.get(url, headers=header, timeout=10) with open(filename, "wb") as f: f.write(r.content) break except Exception as e: logging.info(" ".join(repr(e).splitlines())) logging.error(url) cur_retries += 1 # random sample a header from headers header = headers[np.random.randint(0, len(headers))] time.sleep(3 + cur_retries * 2) def download_image_from_url_val(url): basename = os.path.basename(url) filename = os.path.join(storage_dir, "val", basename) download_file(url, filename) def download_image_from_url_test(url): basename = os.path.basename(url) filename = os.path.join(storage_dir, "test", basename) download_file(url, filename) if __name__ == "__main__": os.makedirs("tmp", exist_ok=True) # storage dir config_path = get_abs_path("configs/datasets/nocaps/defaults.yaml") storage_dir = OmegaConf.load(config_path).datasets.nocaps.build_info.images.storage storage_dir = get_cache_path(storage_dir) # make sure the storage dir exists os.makedirs(storage_dir, exist_ok=True) print("Storage dir:", storage_dir) # make sure the storage dir for val and test exists os.makedirs(os.path.join(storage_dir, "val"), exist_ok=True) os.makedirs(os.path.join(storage_dir, "test"), exist_ok=True) # download annotations val_url = "https://nocaps.s3.amazonaws.com/nocaps_val_4500_captions.json" tst_url = "https://s3.amazonaws.com/nocaps/nocaps_test_image_info.json" print("Downloading validation annotations from %s" % val_url) download_file(val_url, "tmp/nocaps_val_ann.json") print("Downloading testing annotations from %s" % tst_url) download_file(tst_url, "tmp/nocaps_tst_ann.json") # open annotations val_ann = json.load(open("tmp/nocaps_val_ann.json")) tst_ann = json.load(open("tmp/nocaps_tst_ann.json")) # collect image urls val_info = val_ann["images"] tst_info = tst_ann["images"] val_urls = [info["coco_url"] for info in val_info] tst_urls = [info["coco_url"] for info in tst_info] # setup multiprocessing # large n_procs possibly causes server to reject requests n_procs = 16 with Pool(n_procs) as pool: print("Downloading validation images...") list( tqdm.tqdm( pool.imap(download_image_from_url_val, val_urls), total=len(val_urls) ) ) with Pool(n_procs) as pool: print("Downloading test images...") list( tqdm.tqdm( pool.imap(download_image_from_url_test, tst_urls), total=len(tst_urls) ) ) # clean tmp cleanup_dir("tmp")