Spaces:
Build error
Build error
import os | |
import sys | |
import requests | |
import shutil | |
import time | |
from pathlib import Path | |
from tqdm.auto import tqdm | |
class CVATDataset: | |
def __init__(self, cvat_url, org, task_ids, headers=None, params=None, names=None, dest_folder=None): | |
""" | |
Connects to serverless CVAT to download datasets. | |
Args: | |
cvat_url (str) : CVAT base URL where the server is loaded. | |
org (str) : organization we are working with, e.g.: 'bulow' | |
task_ids (list): list with the task IDs inside CVAT. | |
params (dict): query parameters. | |
names (dict): dict where the keys are the task id and values | |
the names of the local files. | |
dest_folder (str) : destination folder of the zip files. | |
Returns: | |
Content ZIP file containing JSON coco annotations and the images. | |
""" | |
self.cvat_url = cvat_url | |
self.org = org | |
self.task_ids = task_ids | |
self.dest_folder = dest_folder | |
self.names_dict = names | |
if self.names_dict is not None: | |
assert all([id_ in self.names_dict.keys() for id_ in self.task_ids]), \ | |
"The keys in names do not match the task IDs." | |
self.headers = headers | |
if self.headers is None: | |
# FIXME: avoid hardcoded authorization. | |
self.headers = {"Authorization": "Basic ZGphbmdvOlMwbHNraW4xMjM0IQ=="} | |
self.params = params | |
if self.params is None: | |
self.params = { | |
"format" : "COCO 1.0", | |
"action" : "download", | |
"location": "local", | |
"org" : self.org | |
} | |
def countdown_clock(waiting_time): | |
t0 = time.monotonic() | |
while time.monotonic() - t0 < waiting_time: | |
remaining_time = waiting_time - (time.monotonic() - t0) | |
mins, secs = divmod(int(remaining_time), 60) | |
sys.stdout.write("\r") | |
sys.stdout.write(f"{mins:02d}:{secs:02d}") | |
sys.stdout.flush() | |
time.sleep(1) | |
sys.stdout.write("\n") | |
def _get_dataset(self, endpoint): | |
response = requests.get( | |
endpoint, | |
headers = self.headers, | |
params = self.params, | |
stream = True | |
) | |
return response | |
def _download_task(self, task_id: int, fname: str): | |
""" Downloads dataset linked to a task. """ | |
endpoint = f"{self.cvat_url}/api/tasks/{task_id}/dataset" | |
r = self._get_dataset(endpoint) | |
while r.status_code != 200: | |
if r.status_code == 202: | |
print(f" Status code {r.status_code}: server processing request") | |
self.countdown_clock(10) | |
else: | |
print(f" Status code {r.status_code}: connection error") | |
self.countdown_clock(30) | |
r = self._get_dataset(endpoint) | |
print(f" Status code {r.status_code}: request is ready") | |
total_length = int(r.headers.get("Content-Length")) | |
with tqdm.wrapattr(r.raw, "read", total=total_length, desc="") as raw: | |
with open(fname, "wb") as file: | |
shutil.copyfileobj(raw, file) | |
def download_tasks(self): | |
""" Download all the tasks passed as input. """ | |
for task_id in self.task_ids: | |
name_label = task_id | |
if self.names_dict is not None: | |
name_label = self.names_dict[task_id] | |
fname = f"dataset_{name_label}.zip" | |
if self.dest_folder is not None: | |
self.dest_folder = Path(self.dest_folder) | |
self.dest_folder.mkdir(exist_ok=True, parents=True) | |
fname = (self.dest_folder / fname).resolve().as_posix() | |
if os.path.exists(fname): | |
print(f"File {fname} already exists.") | |
continue | |
print(f"\nDownloading task {task_id}, with fname {fname}") | |
self._download_task(task_id, fname) | |
# TODO: implement unzip function for the tasks | |