File size: 4,158 Bytes
463b952
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import sys
import requests
import shutil
import time 
from pathlib import Path
from tqdm.auto import tqdm

class CVATDataset:
    def __init__(self, cvat_url, org, task_ids, headers=None, params=None, names=None, dest_folder=None):
        """
        Connects to serverless CVAT to download datasets. 

        Args:
            cvat_url    (str) : CVAT base URL where the server is loaded.
            org         (str) : organization we are working with, e.g.: 'bulow'
            task_ids    (list): list with the task IDs inside CVAT. 
            params      (dict): query parameters. 
            names       (dict): dict where the keys are the task id and values 
                                the names of the local files.
            dest_folder (str) : destination folder of the zip files.
        
        Returns:
            Content ZIP file containing JSON coco annotations and the images. 
        """
        self.cvat_url = cvat_url
        self.org = org
        self.task_ids = task_ids
        self.dest_folder = dest_folder
        self.names_dict = names
        if self.names_dict is not None:
            assert all([id_ in self.names_dict.keys() for id_ in self.task_ids]), \
                "The keys in names do not match the task IDs."

        self.headers = headers
        if self.headers is None:
            # FIXME: avoid hardcoded authorization. 
            self.headers = {"Authorization": "Basic ZGphbmdvOlMwbHNraW4xMjM0IQ=="}
        
        self.params = params
        if self.params is None:
            self.params = {
                "format"  : "COCO 1.0",
                "action"  : "download",
                "location": "local",
                "org"     : self.org
            }

    @staticmethod
    def countdown_clock(waiting_time):
        t0 = time.monotonic()
        while time.monotonic() - t0 < waiting_time:
            remaining_time = waiting_time - (time.monotonic() - t0)
            mins, secs = divmod(int(remaining_time), 60)
            sys.stdout.write("\r")
            sys.stdout.write(f"{mins:02d}:{secs:02d}")
            sys.stdout.flush()
            time.sleep(1)
        sys.stdout.write("\n")
    
    def _get_dataset(self, endpoint):
        response = requests.get(
            endpoint,
            headers = self.headers,
            params = self.params,
            stream = True
        )
        return response
    
    def _download_task(self, task_id: int, fname: str):
        """ Downloads dataset linked to a task. """
        endpoint = f"{self.cvat_url}/api/tasks/{task_id}/dataset"
        r = self._get_dataset(endpoint)
        while r.status_code != 200:
            if r.status_code == 202:
                print(f"  Status code {r.status_code}: server processing request")
                self.countdown_clock(10)
            else:
                print(f"  Status code {r.status_code}: connection error")
                self.countdown_clock(30)
            r = self._get_dataset(endpoint)
        
        print(f"  Status code {r.status_code}: request is ready")
        total_length = int(r.headers.get("Content-Length"))
        with tqdm.wrapattr(r.raw, "read", total=total_length, desc="") as raw:
            with open(fname, "wb") as file:
                shutil.copyfileobj(raw, file)

    def download_tasks(self):
        """ Download all the tasks passed as input. """
        for task_id in self.task_ids:
            name_label = task_id
            if self.names_dict is not None:
                name_label = self.names_dict[task_id]
            fname = f"dataset_{name_label}.zip"
            if self.dest_folder is not None:
                self.dest_folder = Path(self.dest_folder)
                self.dest_folder.mkdir(exist_ok=True, parents=True)
            fname = (self.dest_folder / fname).resolve().as_posix()

            if os.path.exists(fname):
                print(f"File {fname} already exists.")
                continue
            
            print(f"\nDownloading task {task_id}, with fname {fname}")
            self._download_task(task_id, fname)

    # TODO: implement unzip function for the tasks