File size: 6,021 Bytes
d4e7f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import torch
from typing import Any, Optional, Union, List, Dict
import math
import os
from urllib.parse import urlparse
import errno
import sys
import validators
import requests
import json


def hwc2bchw(images: torch.Tensor) -> torch.Tensor:
    return images.unsqueeze(0).permute(0, 3, 1, 2)


def bchw2hwc(images: torch.Tensor, nrows: Optional[int] = None, border: int = 2,
             background_value: float = 0) -> torch.Tensor:
    """ make a grid image from an image batch.

    Args:
        images (torch.Tensor): input image batch.
        nrows: rows of grid.
        border: border size in pixel.
        background_value: color value of background.
    """
    assert images.ndim == 4  # n x c x h x w
    images = images.permute(0, 2, 3, 1)  # n x h x w x c
    n, h, w, c = images.shape
    if nrows is None:
        nrows = max(int(math.sqrt(n)), 1)
    ncols = (n + nrows - 1) // nrows
    result = torch.full([(h + border) * nrows - border,
                         (w + border) * ncols - border, c], background_value,
                        device=images.device,
                        dtype=images.dtype)

    for i, single_image in enumerate(images):
        row = i // ncols
        col = i % ncols
        yy = (h + border) * row
        xx = (w + border) * col
        result[yy:(yy + h), xx:(xx + w), :] = single_image
    return result


def bchw2bhwc(images: torch.Tensor) -> torch.Tensor:
    return images.permute(0, 2, 3, 1)


def bhwc2bchw(images: torch.Tensor) -> torch.Tensor:
    return images.permute(0, 3, 1, 2)


def bhwc2hwc(images: torch.Tensor, *kargs, **kwargs) -> torch.Tensor:
    return bchw2hwc(bhwc2bchw(images), *kargs, **kwargs)


def select_data(selection, data):
    if isinstance(data, dict):
        return {name: select_data(selection, val) for name, val in data.items()}
    elif isinstance(data, (list, tuple)):
        return [select_data(selection, val) for val in data]
    elif isinstance(data, torch.Tensor):
        return data[selection]
    return data


def download_from_github(to_path, organisation, repository, file_path, branch='main', username=None, access_token=None):
    """ download files (including LFS files) from github.

    For example, in order to downlod https://github.com/FacePerceiver/facer/blob/main/README.md, call with
    ```
    download_from_github(
        to_path='README.md', organisation='FacePerceiver', 
        repository='facer', file_path='README.md', branch='main')
    ```
    """
    if username is not None:
        assert access_token is not None
        auth = (username, access_token)
    else:
        auth = None
    r = requests.get(f'https://api.github.com/repos/{organisation}/{repository}/contents/{file_path}?ref={branch}',
                     auth=auth)
    data = json.loads(r.content)
    torch.hub.download_url_to_file(data['download_url'], to_path)


def is_github_url(url: str):
    """
    A typical github url should be like 
        https://github.com/FacePerceiver/facer/blob/main/facer/util.py or 
        https://github.com/FacePerceiver/facer/raw/main/facer/util.py.
    """
    return ('blob' in url or 'raw' in url) and url.startswith('https://github.com/')


def get_github_components(url: str):
    assert is_github_url(url)
    organisation, repository, blob_or_raw, branch, * \
        path = url[len('https://github.com/'):].split('/')
    assert blob_or_raw in {'blob', 'raw'}
    return organisation, repository, branch, '/'.join(path)


def download_url_to_file(url, dst, **kwargs):
    if is_github_url(url):
        org, rep, branch, path = get_github_components(url)
        download_from_github(dst, org, rep, path, branch, kwargs.get(
            'username', None), kwargs.get('access_token', None))
    else:
        torch.hub.download_url_to_file(url, dst)


def select_data(selection, data):
    if isinstance(data, dict):
        return {name: select_data(selection, val) for name, val in data.items()}
    elif isinstance(data, (list, tuple)):
        return [select_data(selection, val) for val in data]
    elif isinstance(data, torch.Tensor):
        return data[selection]
    return data


def download_jit(url_or_paths: Union[str, List[str]], model_dir=None, map_location=None, jit=True, **kwargs):
    if isinstance(url_or_paths, str):
        url_or_paths = [url_or_paths]

    for url_or_path in url_or_paths:
        try:
            if validators.url(url_or_path):
                url = url_or_path
                if model_dir is None:
                    if hasattr(torch.hub, 'get_dir'):
                        hub_dir = torch.hub.get_dir()
                    else:
                        hub_dir = os.path.join(os.path.expanduser(
                            '~'), '.cache', 'torch', 'hub')
                    model_dir = os.path.join(hub_dir, 'checkpoints')

                try:
                    os.makedirs(model_dir)
                except OSError as e:
                    if e.errno == errno.EEXIST:
                        # Directory already exists, ignore.
                        pass
                    else:
                        # Unexpected OSError, re-raise.
                        raise

                parts = urlparse(url)
                filename = os.path.basename(parts.path)
                cached_file = os.path.join(model_dir, filename)
                if not os.path.exists(cached_file):
                    sys.stderr.write(
                        'Downloading: "{}" to {}\n'.format(url, cached_file))
                    download_url_to_file(url, cached_file)
            else:
                cached_file = url_or_path
            if jit:
                return torch.jit.load(cached_file, map_location=map_location, **kwargs)
            else:
                return torch.load(cached_file, map_location=map_location, **kwargs)
        except:
            sys.stderr.write(f'failed downloading from {url_or_path}\n')
            raise

    raise RuntimeError('failed to download jit models from all given urls')