import torch from typing import Any, Optional, Union, List, Dict import math import os from urllib.parse import urlparse import errno import sys import validators import requests import json def hwc2bchw(images: torch.Tensor) -> torch.Tensor: return images.unsqueeze(0).permute(0, 3, 1, 2) def bchw2hwc(images: torch.Tensor, nrows: Optional[int] = None, border: int = 2, background_value: float = 0) -> torch.Tensor: """ make a grid image from an image batch. Args: images (torch.Tensor): input image batch. nrows: rows of grid. border: border size in pixel. background_value: color value of background. """ assert images.ndim == 4 # n x c x h x w images = images.permute(0, 2, 3, 1) # n x h x w x c n, h, w, c = images.shape if nrows is None: nrows = max(int(math.sqrt(n)), 1) ncols = (n + nrows - 1) // nrows result = torch.full([(h + border) * nrows - border, (w + border) * ncols - border, c], background_value, device=images.device, dtype=images.dtype) for i, single_image in enumerate(images): row = i // ncols col = i % ncols yy = (h + border) * row xx = (w + border) * col result[yy:(yy + h), xx:(xx + w), :] = single_image return result def bchw2bhwc(images: torch.Tensor) -> torch.Tensor: return images.permute(0, 2, 3, 1) def bhwc2bchw(images: torch.Tensor) -> torch.Tensor: return images.permute(0, 3, 1, 2) def bhwc2hwc(images: torch.Tensor, *kargs, **kwargs) -> torch.Tensor: return bchw2hwc(bhwc2bchw(images), *kargs, **kwargs) def select_data(selection, data): if isinstance(data, dict): return {name: select_data(selection, val) for name, val in data.items()} elif isinstance(data, (list, tuple)): return [select_data(selection, val) for val in data] elif isinstance(data, torch.Tensor): return data[selection] return data def download_from_github(to_path, organisation, repository, file_path, branch='main', username=None, access_token=None): """ download files (including LFS files) from github. For example, in order to downlod https://github.com/FacePerceiver/facer/blob/main/README.md, call with ``` download_from_github( to_path='README.md', organisation='FacePerceiver', repository='facer', file_path='README.md', branch='main') ``` """ if username is not None: assert access_token is not None auth = (username, access_token) else: auth = None r = requests.get(f'https://api.github.com/repos/{organisation}/{repository}/contents/{file_path}?ref={branch}', auth=auth) data = json.loads(r.content) torch.hub.download_url_to_file(data['download_url'], to_path) def is_github_url(url: str): """ A typical github url should be like https://github.com/FacePerceiver/facer/blob/main/facer/util.py or https://github.com/FacePerceiver/facer/raw/main/facer/util.py. """ return ('blob' in url or 'raw' in url) and url.startswith('https://github.com/') def get_github_components(url: str): assert is_github_url(url) organisation, repository, blob_or_raw, branch, * \ path = url[len('https://github.com/'):].split('/') assert blob_or_raw in {'blob', 'raw'} return organisation, repository, branch, '/'.join(path) def download_url_to_file(url, dst, **kwargs): if is_github_url(url): org, rep, branch, path = get_github_components(url) download_from_github(dst, org, rep, path, branch, kwargs.get( 'username', None), kwargs.get('access_token', None)) else: torch.hub.download_url_to_file(url, dst) def select_data(selection, data): if isinstance(data, dict): return {name: select_data(selection, val) for name, val in data.items()} elif isinstance(data, (list, tuple)): return [select_data(selection, val) for val in data] elif isinstance(data, torch.Tensor): return data[selection] return data def download_jit(url_or_paths: Union[str, List[str]], model_dir=None, map_location=None, jit=True, **kwargs): if isinstance(url_or_paths, str): url_or_paths = [url_or_paths] for url_or_path in url_or_paths: try: if validators.url(url_or_path): url = url_or_path if model_dir is None: if hasattr(torch.hub, 'get_dir'): hub_dir = torch.hub.get_dir() else: hub_dir = os.path.join(os.path.expanduser( '~'), '.cache', 'torch', 'hub') model_dir = os.path.join(hub_dir, 'checkpoints') try: os.makedirs(model_dir) except OSError as e: if e.errno == errno.EEXIST: # Directory already exists, ignore. pass else: # Unexpected OSError, re-raise. raise parts = urlparse(url) filename = os.path.basename(parts.path) cached_file = os.path.join(model_dir, filename) if not os.path.exists(cached_file): sys.stderr.write( 'Downloading: "{}" to {}\n'.format(url, cached_file)) download_url_to_file(url, cached_file) else: cached_file = url_or_path if jit: return torch.jit.load(cached_file, map_location=map_location, **kwargs) else: return torch.load(cached_file, map_location=map_location, **kwargs) except: sys.stderr.write(f'failed downloading from {url_or_path}\n') raise raise RuntimeError('failed to download jit models from all given urls')