File size: 6,021 Bytes
d4e7f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import torch
from typing import Any, Optional, Union, List, Dict
import math
import os
from urllib.parse import urlparse
import errno
import sys
import validators
import requests
import json
def hwc2bchw(images: torch.Tensor) -> torch.Tensor:
return images.unsqueeze(0).permute(0, 3, 1, 2)
def bchw2hwc(images: torch.Tensor, nrows: Optional[int] = None, border: int = 2,
background_value: float = 0) -> torch.Tensor:
""" make a grid image from an image batch.
images (torch.Tensor): input image batch.
nrows: rows of grid.
border: border size in pixel.
background_value: color value of background.
assert images.ndim == 4 # n x c x h x w
images = images.permute(0, 2, 3, 1) # n x h x w x c
n, h, w, c = images.shape
if nrows is None:
nrows = max(int(math.sqrt(n)), 1)
ncols = (n + nrows - 1) // nrows
result = torch.full([(h + border) * nrows - border,
(w + border) * ncols - border, c], background_value,
for i, single_image in enumerate(images):
row = i // ncols
col = i % ncols
yy = (h + border) * row
xx = (w + border) * col
result[yy:(yy + h), xx:(xx + w), :] = single_image
return result
def bchw2bhwc(images: torch.Tensor) -> torch.Tensor:
return images.permute(0, 2, 3, 1)
def bhwc2bchw(images: torch.Tensor) -> torch.Tensor:
return images.permute(0, 3, 1, 2)
def bhwc2hwc(images: torch.Tensor, *kargs, **kwargs) -> torch.Tensor:
return bchw2hwc(bhwc2bchw(images), *kargs, **kwargs)
def select_data(selection, data):
if isinstance(data, dict):
return {name: select_data(selection, val) for name, val in data.items()}
elif isinstance(data, (list, tuple)):
return [select_data(selection, val) for val in data]
elif isinstance(data, torch.Tensor):
return data[selection]
return data
def download_from_github(to_path, organisation, repository, file_path, branch='main', username=None, access_token=None):
""" download files (including LFS files) from github.
For example, in order to downlod, call with
to_path='', organisation='FacePerceiver',
repository='facer', file_path='', branch='main')
if username is not None:
assert access_token is not None
auth = (username, access_token)
auth = None
r = requests.get(f'{organisation}/{repository}/contents/{file_path}?ref={branch}',
data = json.loads(r.content)
torch.hub.download_url_to_file(data['download_url'], to_path)
def is_github_url(url: str):
A typical github url should be like or
return ('blob' in url or 'raw' in url) and url.startswith('')
def get_github_components(url: str):
assert is_github_url(url)
organisation, repository, blob_or_raw, branch, * \
path = url[len(''):].split('/')
assert blob_or_raw in {'blob', 'raw'}
return organisation, repository, branch, '/'.join(path)
def download_url_to_file(url, dst, **kwargs):
if is_github_url(url):
org, rep, branch, path = get_github_components(url)
download_from_github(dst, org, rep, path, branch, kwargs.get(
'username', None), kwargs.get('access_token', None))
torch.hub.download_url_to_file(url, dst)
def select_data(selection, data):
if isinstance(data, dict):
return {name: select_data(selection, val) for name, val in data.items()}
elif isinstance(data, (list, tuple)):
return [select_data(selection, val) for val in data]
elif isinstance(data, torch.Tensor):
return data[selection]
return data
def download_jit(url_or_paths: Union[str, List[str]], model_dir=None, map_location=None, jit=True, **kwargs):
if isinstance(url_or_paths, str):
url_or_paths = [url_or_paths]
for url_or_path in url_or_paths:
if validators.url(url_or_path):
url = url_or_path
if model_dir is None:
if hasattr(torch.hub, 'get_dir'):
hub_dir = torch.hub.get_dir()
hub_dir = os.path.join(os.path.expanduser(
'~'), '.cache', 'torch', 'hub')
model_dir = os.path.join(hub_dir, 'checkpoints')
except OSError as e:
if e.errno == errno.EEXIST:
# Directory already exists, ignore.
# Unexpected OSError, re-raise.
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
'Downloading: "{}" to {}\n'.format(url, cached_file))
download_url_to_file(url, cached_file)
cached_file = url_or_path
if jit:
return torch.jit.load(cached_file, map_location=map_location, **kwargs)
return torch.load(cached_file, map_location=map_location, **kwargs)
sys.stderr.write(f'failed downloading from {url_or_path}\n')
raise RuntimeError('failed to download jit models from all given urls')