AIEM / trainer /utils /cvat_dataset.py
lhhj
initial ppush
463b952
raw
history blame
4.16 kB
import os
import sys
import requests
import shutil
import time
from pathlib import Path
from tqdm.auto import tqdm
class CVATDataset:
def __init__(self, cvat_url, org, task_ids, headers=None, params=None, names=None, dest_folder=None):
"""
Connects to serverless CVAT to download datasets.
Args:
cvat_url (str) : CVAT base URL where the server is loaded.
org (str) : organization we are working with, e.g.: 'bulow'
task_ids (list): list with the task IDs inside CVAT.
params (dict): query parameters.
names (dict): dict where the keys are the task id and values
the names of the local files.
dest_folder (str) : destination folder of the zip files.
Returns:
Content ZIP file containing JSON coco annotations and the images.
"""
self.cvat_url = cvat_url
self.org = org
self.task_ids = task_ids
self.dest_folder = dest_folder
self.names_dict = names
if self.names_dict is not None:
assert all([id_ in self.names_dict.keys() for id_ in self.task_ids]), \
"The keys in names do not match the task IDs."
self.headers = headers
if self.headers is None:
# FIXME: avoid hardcoded authorization.
self.headers = {"Authorization": "Basic ZGphbmdvOlMwbHNraW4xMjM0IQ=="}
self.params = params
if self.params is None:
self.params = {
"format" : "COCO 1.0",
"action" : "download",
"location": "local",
"org" : self.org
}
@staticmethod
def countdown_clock(waiting_time):
t0 = time.monotonic()
while time.monotonic() - t0 < waiting_time:
remaining_time = waiting_time - (time.monotonic() - t0)
mins, secs = divmod(int(remaining_time), 60)
sys.stdout.write("\r")
sys.stdout.write(f"{mins:02d}:{secs:02d}")
sys.stdout.flush()
time.sleep(1)
sys.stdout.write("\n")
def _get_dataset(self, endpoint):
response = requests.get(
endpoint,
headers = self.headers,
params = self.params,
stream = True
)
return response
def _download_task(self, task_id: int, fname: str):
""" Downloads dataset linked to a task. """
endpoint = f"{self.cvat_url}/api/tasks/{task_id}/dataset"
r = self._get_dataset(endpoint)
while r.status_code != 200:
if r.status_code == 202:
print(f" Status code {r.status_code}: server processing request")
self.countdown_clock(10)
else:
print(f" Status code {r.status_code}: connection error")
self.countdown_clock(30)
r = self._get_dataset(endpoint)
print(f" Status code {r.status_code}: request is ready")
total_length = int(r.headers.get("Content-Length"))
with tqdm.wrapattr(r.raw, "read", total=total_length, desc="") as raw:
with open(fname, "wb") as file:
shutil.copyfileobj(raw, file)
def download_tasks(self):
""" Download all the tasks passed as input. """
for task_id in self.task_ids:
name_label = task_id
if self.names_dict is not None:
name_label = self.names_dict[task_id]
fname = f"dataset_{name_label}.zip"
if self.dest_folder is not None:
self.dest_folder = Path(self.dest_folder)
self.dest_folder.mkdir(exist_ok=True, parents=True)
fname = (self.dest_folder / fname).resolve().as_posix()
if os.path.exists(fname):
print(f"File {fname} already exists.")
continue
print(f"\nDownloading task {task_id}, with fname {fname}")
self._download_task(task_id, fname)
# TODO: implement unzip function for the tasks