Spaces:
Runtime error
Runtime error
| import hashlib | |
| import os | |
| import requests | |
| from tqdm import tqdm | |
| URL_MAP = { | |
| 'vggishish_lpaps': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/vggishish16.pt', | |
| 'vggishish_mean_std_melspec_10s_22050hz': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/train_means_stds_melspec_10s_22050hz.txt', | |
| 'melception': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/melception-21-05-10T09-28-40.pt', | |
| } | |
| CKPT_MAP = { | |
| 'vggishish_lpaps': 'vggishish16.pt', | |
| 'vggishish_mean_std_melspec_10s_22050hz': 'train_means_stds_melspec_10s_22050hz.txt', | |
| 'melception': 'melception-21-05-10T09-28-40.pt', | |
| } | |
| MD5_MAP = { | |
| 'vggishish_lpaps': '197040c524a07ccacf7715d7080a80bd', | |
| 'vggishish_mean_std_melspec_10s_22050hz': 'f449c6fd0e248936c16f6d22492bb625', | |
| 'melception': 'a71a41041e945b457c7d3d814bbcf72d', | |
| } | |
| def download(url, local_path, chunk_size=1024): | |
| os.makedirs(os.path.split(local_path)[0], exist_ok=True) | |
| with requests.get(url, stream=True) as r: | |
| total_size = int(r.headers.get("content-length", 0)) | |
| with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: | |
| with open(local_path, "wb") as f: | |
| for data in r.iter_content(chunk_size=chunk_size): | |
| if data: | |
| f.write(data) | |
| pbar.update(chunk_size) | |
| def md5_hash(path): | |
| with open(path, "rb") as f: | |
| content = f.read() | |
| return hashlib.md5(content).hexdigest() | |
| def get_ckpt_path(name, root, check=False): | |
| assert name in URL_MAP | |
| path = os.path.join(root, CKPT_MAP[name]) | |
| if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): | |
| print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) | |
| download(URL_MAP[name], path) | |
| md5 = md5_hash(path) | |
| assert md5 == MD5_MAP[name], md5 | |
| return path | |
| class KeyNotFoundError(Exception): | |
| def __init__(self, cause, keys=None, visited=None): | |
| self.cause = cause | |
| self.keys = keys | |
| self.visited = visited | |
| messages = list() | |
| if keys is not None: | |
| messages.append("Key not found: {}".format(keys)) | |
| if visited is not None: | |
| messages.append("Visited: {}".format(visited)) | |
| messages.append("Cause:\n{}".format(cause)) | |
| message = "\n".join(messages) | |
| super().__init__(message) | |
| def retrieve( | |
| list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False | |
| ): | |
| """Given a nested list or dict return the desired value at key expanding | |
| callable nodes if necessary and :attr:`expand` is ``True``. The expansion | |
| is done in-place. | |
| Parameters | |
| ---------- | |
| list_or_dict : list or dict | |
| Possibly nested list or dictionary. | |
| key : str | |
| key/to/value, path like string describing all keys necessary to | |
| consider to get to the desired value. List indices can also be | |
| passed here. | |
| splitval : str | |
| String that defines the delimiter between keys of the | |
| different depth levels in `key`. | |
| default : obj | |
| Value returned if :attr:`key` is not found. | |
| expand : bool | |
| Whether to expand callable nodes on the path or not. | |
| Returns | |
| ------- | |
| The desired value or if :attr:`default` is not ``None`` and the | |
| :attr:`key` is not found returns ``default``. | |
| Raises | |
| ------ | |
| Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is | |
| ``None``. | |
| """ | |
| keys = key.split(splitval) | |
| success = True | |
| try: | |
| visited = [] | |
| parent = None | |
| last_key = None | |
| for key in keys: | |
| if callable(list_or_dict): | |
| if not expand: | |
| raise KeyNotFoundError( | |
| ValueError( | |
| "Trying to get past callable node with expand=False." | |
| ), | |
| keys=keys, | |
| visited=visited, | |
| ) | |
| list_or_dict = list_or_dict() | |
| parent[last_key] = list_or_dict | |
| last_key = key | |
| parent = list_or_dict | |
| try: | |
| if isinstance(list_or_dict, dict): | |
| list_or_dict = list_or_dict[key] | |
| else: | |
| list_or_dict = list_or_dict[int(key)] | |
| except (KeyError, IndexError, ValueError) as e: | |
| raise KeyNotFoundError(e, keys=keys, visited=visited) | |
| visited += [key] | |
| # final expansion of retrieved value | |
| if expand and callable(list_or_dict): | |
| list_or_dict = list_or_dict() | |
| parent[last_key] = list_or_dict | |
| except KeyNotFoundError as e: | |
| if default is None: | |
| raise e | |
| else: | |
| list_or_dict = default | |
| success = False | |
| if not pass_success: | |
| return list_or_dict | |
| else: | |
| return list_or_dict, success | |