import os import sys import requests import shutil import time from pathlib import Path from tqdm.auto import tqdm class CVATDataset: def __init__(self, cvat_url, org, task_ids, headers=None, params=None, names=None, dest_folder=None): """ Connects to serverless CVAT to download datasets. Args: cvat_url (str) : CVAT base URL where the server is loaded. org (str) : organization we are working with, e.g.: 'bulow' task_ids (list): list with the task IDs inside CVAT. params (dict): query parameters. names (dict): dict where the keys are the task id and values the names of the local files. dest_folder (str) : destination folder of the zip files. Returns: Content ZIP file containing JSON coco annotations and the images. """ self.cvat_url = cvat_url self.org = org self.task_ids = task_ids self.dest_folder = dest_folder self.names_dict = names if self.names_dict is not None: assert all([id_ in self.names_dict.keys() for id_ in self.task_ids]), \ "The keys in names do not match the task IDs." self.headers = headers if self.headers is None: # FIXME: avoid hardcoded authorization. self.headers = {"Authorization": "Basic ZGphbmdvOlMwbHNraW4xMjM0IQ=="} self.params = params if self.params is None: self.params = { "format" : "COCO 1.0", "action" : "download", "location": "local", "org" : self.org } @staticmethod def countdown_clock(waiting_time): t0 = time.monotonic() while time.monotonic() - t0 < waiting_time: remaining_time = waiting_time - (time.monotonic() - t0) mins, secs = divmod(int(remaining_time), 60) sys.stdout.write("\r") sys.stdout.write(f"{mins:02d}:{secs:02d}") sys.stdout.flush() time.sleep(1) sys.stdout.write("\n") def _get_dataset(self, endpoint): response = requests.get( endpoint, headers = self.headers, params = self.params, stream = True ) return response def _download_task(self, task_id: int, fname: str): """ Downloads dataset linked to a task. """ endpoint = f"{self.cvat_url}/api/tasks/{task_id}/dataset" r = self._get_dataset(endpoint) while r.status_code != 200: if r.status_code == 202: print(f" Status code {r.status_code}: server processing request") self.countdown_clock(10) else: print(f" Status code {r.status_code}: connection error") self.countdown_clock(30) r = self._get_dataset(endpoint) print(f" Status code {r.status_code}: request is ready") total_length = int(r.headers.get("Content-Length")) with tqdm.wrapattr(r.raw, "read", total=total_length, desc="") as raw: with open(fname, "wb") as file: shutil.copyfileobj(raw, file) def download_tasks(self): """ Download all the tasks passed as input. """ for task_id in self.task_ids: name_label = task_id if self.names_dict is not None: name_label = self.names_dict[task_id] fname = f"dataset_{name_label}.zip" if self.dest_folder is not None: self.dest_folder = Path(self.dest_folder) self.dest_folder.mkdir(exist_ok=True, parents=True) fname = (self.dest_folder / fname).resolve().as_posix() if os.path.exists(fname): print(f"File {fname} already exists.") continue print(f"\nDownloading task {task_id}, with fname {fname}") self._download_task(task_id, fname) # TODO: implement unzip function for the tasks