|
import torch |
|
from typing import Any, Optional, Union, List, Dict |
|
import math |
|
import os |
|
from urllib.parse import urlparse |
|
import errno |
|
import sys |
|
import validators |
|
import requests |
|
import json |
|
|
|
|
|
def hwc2bchw(images: torch.Tensor) -> torch.Tensor: |
|
return images.unsqueeze(0).permute(0, 3, 1, 2) |
|
|
|
|
|
def bchw2hwc(images: torch.Tensor, nrows: Optional[int] = None, border: int = 2, |
|
background_value: float = 0) -> torch.Tensor: |
|
""" make a grid image from an image batch. |
|
|
|
Args: |
|
images (torch.Tensor): input image batch. |
|
nrows: rows of grid. |
|
border: border size in pixel. |
|
background_value: color value of background. |
|
""" |
|
assert images.ndim == 4 |
|
images = images.permute(0, 2, 3, 1) |
|
n, h, w, c = images.shape |
|
if nrows is None: |
|
nrows = max(int(math.sqrt(n)), 1) |
|
ncols = (n + nrows - 1) // nrows |
|
result = torch.full([(h + border) * nrows - border, |
|
(w + border) * ncols - border, c], background_value, |
|
device=images.device, |
|
dtype=images.dtype) |
|
|
|
for i, single_image in enumerate(images): |
|
row = i // ncols |
|
col = i % ncols |
|
yy = (h + border) * row |
|
xx = (w + border) * col |
|
result[yy:(yy + h), xx:(xx + w), :] = single_image |
|
return result |
|
|
|
|
|
def bchw2bhwc(images: torch.Tensor) -> torch.Tensor: |
|
return images.permute(0, 2, 3, 1) |
|
|
|
|
|
def bhwc2bchw(images: torch.Tensor) -> torch.Tensor: |
|
return images.permute(0, 3, 1, 2) |
|
|
|
|
|
def bhwc2hwc(images: torch.Tensor, *kargs, **kwargs) -> torch.Tensor: |
|
return bchw2hwc(bhwc2bchw(images), *kargs, **kwargs) |
|
|
|
|
|
def select_data(selection, data): |
|
if isinstance(data, dict): |
|
return {name: select_data(selection, val) for name, val in data.items()} |
|
elif isinstance(data, (list, tuple)): |
|
return [select_data(selection, val) for val in data] |
|
elif isinstance(data, torch.Tensor): |
|
return data[selection] |
|
return data |
|
|
|
|
|
def download_from_github(to_path, organisation, repository, file_path, branch='main', username=None, access_token=None): |
|
""" download files (including LFS files) from github. |
|
|
|
For example, in order to downlod https://github.com/FacePerceiver/facer/blob/main/README.md, call with |
|
``` |
|
download_from_github( |
|
to_path='README.md', organisation='FacePerceiver', |
|
repository='facer', file_path='README.md', branch='main') |
|
``` |
|
""" |
|
if username is not None: |
|
assert access_token is not None |
|
auth = (username, access_token) |
|
else: |
|
auth = None |
|
r = requests.get(f'https://api.github.com/repos/{organisation}/{repository}/contents/{file_path}?ref={branch}', |
|
auth=auth) |
|
data = json.loads(r.content) |
|
torch.hub.download_url_to_file(data['download_url'], to_path) |
|
|
|
|
|
def is_github_url(url: str): |
|
""" |
|
A typical github url should be like |
|
https://github.com/FacePerceiver/facer/blob/main/facer/util.py or |
|
https://github.com/FacePerceiver/facer/raw/main/facer/util.py. |
|
""" |
|
return ('blob' in url or 'raw' in url) and url.startswith('https://github.com/') |
|
|
|
|
|
def get_github_components(url: str): |
|
assert is_github_url(url) |
|
organisation, repository, blob_or_raw, branch, * \ |
|
path = url[len('https://github.com/'):].split('/') |
|
assert blob_or_raw in {'blob', 'raw'} |
|
return organisation, repository, branch, '/'.join(path) |
|
|
|
|
|
def download_url_to_file(url, dst, **kwargs): |
|
if is_github_url(url): |
|
org, rep, branch, path = get_github_components(url) |
|
download_from_github(dst, org, rep, path, branch, kwargs.get( |
|
'username', None), kwargs.get('access_token', None)) |
|
else: |
|
torch.hub.download_url_to_file(url, dst) |
|
|
|
|
|
def select_data(selection, data): |
|
if isinstance(data, dict): |
|
return {name: select_data(selection, val) for name, val in data.items()} |
|
elif isinstance(data, (list, tuple)): |
|
return [select_data(selection, val) for val in data] |
|
elif isinstance(data, torch.Tensor): |
|
return data[selection] |
|
return data |
|
|
|
|
|
def download_jit(url_or_paths: Union[str, List[str]], model_dir=None, map_location=None, jit=True, **kwargs): |
|
if isinstance(url_or_paths, str): |
|
url_or_paths = [url_or_paths] |
|
|
|
for url_or_path in url_or_paths: |
|
try: |
|
if validators.url(url_or_path): |
|
url = url_or_path |
|
if model_dir is None: |
|
if hasattr(torch.hub, 'get_dir'): |
|
hub_dir = torch.hub.get_dir() |
|
else: |
|
hub_dir = os.path.join(os.path.expanduser( |
|
'~'), '.cache', 'torch', 'hub') |
|
model_dir = os.path.join(hub_dir, 'checkpoints') |
|
|
|
try: |
|
os.makedirs(model_dir) |
|
except OSError as e: |
|
if e.errno == errno.EEXIST: |
|
|
|
pass |
|
else: |
|
|
|
raise |
|
|
|
parts = urlparse(url) |
|
filename = os.path.basename(parts.path) |
|
cached_file = os.path.join(model_dir, filename) |
|
if not os.path.exists(cached_file): |
|
sys.stderr.write( |
|
'Downloading: "{}" to {}\n'.format(url, cached_file)) |
|
download_url_to_file(url, cached_file) |
|
else: |
|
cached_file = url_or_path |
|
if jit: |
|
return torch.jit.load(cached_file, map_location=map_location, **kwargs) |
|
else: |
|
return torch.load(cached_file, map_location=map_location, **kwargs) |
|
except: |
|
sys.stderr.write(f'failed downloading from {url_or_path}\n') |
|
raise |
|
|
|
raise RuntimeError('failed to download jit models from all given urls') |
|
|