Spaces:
Runtime error
Runtime error
File size: 5,915 Bytes
19c4ddf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
"""
Adapted from: https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/download.py
"""
import hashlib
import os
from functools import lru_cache
from typing import Dict, Optional
import requests
import torch
import yaml
from filelock import FileLock
from tqdm.auto import tqdm
MODEL_PATHS = {
"transmitter": "https://openaipublic.azureedge.net/main/shap-e/transmitter.pt",
"decoder": "https://openaipublic.azureedge.net/main/shap-e/vector_decoder.pt",
"text300M": "https://openaipublic.azureedge.net/main/shap-e/text_cond.pt",
"image300M": "https://openaipublic.azureedge.net/main/shap-e/image_cond.pt",
}
CONFIG_PATHS = {
"transmitter": "https://openaipublic.azureedge.net/main/shap-e/transmitter_config.yaml",
"decoder": "https://openaipublic.azureedge.net/main/shap-e/vector_decoder_config.yaml",
"text300M": "https://openaipublic.azureedge.net/main/shap-e/text_cond_config.yaml",
"image300M": "https://openaipublic.azureedge.net/main/shap-e/image_cond_config.yaml",
"diffusion": "https://openaipublic.azureedge.net/main/shap-e/diffusion_config.yaml",
}
URL_HASHES = {
"https://openaipublic.azureedge.net/main/shap-e/transmitter.pt": "af02a0b85a8abdfb3919584b63c540ba175f6ad4790f574a7fef4617e5acdc3b",
"https://openaipublic.azureedge.net/main/shap-e/vector_decoder.pt": "d7e7ebbfe3780499ae89b2da5e7c1354012dba5a6abfe295bed42f25c3be1b98",
"https://openaipublic.azureedge.net/main/shap-e/text_cond.pt": "e6b4fa599a7b3c3b16c222d5f5fe56f9db9289ff0b6575fbe5c11bc97106aad4",
"https://openaipublic.azureedge.net/main/shap-e/image_cond.pt": "cb8072c64bbbcf6910488814d212227de5db291780d4ea99c6152f9346cf12aa",
"https://openaipublic.azureedge.net/main/shap-e/transmitter_config.yaml": "ffe1bcb405104a37d9408391182ab118a4ef313c391e07689684f1f62071605e",
"https://openaipublic.azureedge.net/main/shap-e/vector_decoder_config.yaml": "e6d373649f8e24d85925f4674b9ac41c57aba5f60e42cde6d10f87381326365c",
"https://openaipublic.azureedge.net/main/shap-e/text_cond_config.yaml": "f290beeea3d3e9ff15db01bde5382b6e549e463060c0744f89c049505be246c1",
"https://openaipublic.azureedge.net/main/shap-e/image_cond_config.yaml": "4e0745605a533c543c72add803a78d233e2a6401e0abfa0cad58afb4d74ad0b0",
"https://openaipublic.azureedge.net/main/shap-e/diffusion_config.yaml": "efcb2cd7ee545b2d27223979d41857802448143990572a42645cd09c2942ed57",
}
@lru_cache()
def default_cache_dir() -> str:
return os.path.join(os.path.abspath(os.getcwd()), "shap_e_model_cache")
def fetch_file_cached(
url: str, progress: bool = True, cache_dir: Optional[str] = None, chunk_size: int = 4096
) -> str:
"""
Download the file at the given URL into a local file and return the path.
If cache_dir is specified, it will be used to download the files.
Otherwise, default_cache_dir() is used.
"""
expected_hash = URL_HASHES[url]
if cache_dir is None:
cache_dir = default_cache_dir()
os.makedirs(cache_dir, exist_ok=True)
local_path = os.path.join(cache_dir, url.split("/")[-1])
if os.path.exists(local_path):
check_hash(local_path, expected_hash)
return local_path
response = requests.get(url, stream=True)
size = int(response.headers.get("content-length", "0"))
with FileLock(local_path + ".lock"):
if progress:
pbar = tqdm(total=size, unit="iB", unit_scale=True)
tmp_path = local_path + ".tmp"
with open(tmp_path, "wb") as f:
for chunk in response.iter_content(chunk_size):
if progress:
pbar.update(len(chunk))
f.write(chunk)
os.rename(tmp_path, local_path)
if progress:
pbar.close()
check_hash(local_path, expected_hash)
return local_path
def check_hash(path: str, expected_hash: str):
actual_hash = hash_file(path)
if actual_hash != expected_hash:
raise RuntimeError(
f"The file {path} should have hash {expected_hash} but has {actual_hash}. "
"Try deleting it and running this call again."
)
def hash_file(path: str) -> str:
sha256_hash = hashlib.sha256()
with open(path, "rb") as file:
while True:
data = file.read(4096)
if not len(data):
break
sha256_hash.update(data)
return sha256_hash.hexdigest()
def load_config(
config_name: str,
progress: bool = False,
cache_dir: Optional[str] = None,
chunk_size: int = 4096,
):
if config_name not in CONFIG_PATHS:
raise ValueError(
f"Unknown config name {config_name}. Known names are: {CONFIG_PATHS.keys()}."
)
path = fetch_file_cached(
CONFIG_PATHS[config_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size
)
with open(path, "r") as f:
return yaml.safe_load(f)
def load_checkpoint(
checkpoint_name: str,
device: torch.device,
progress: bool = True,
cache_dir: Optional[str] = None,
chunk_size: int = 4096,
) -> Dict[str, torch.Tensor]:
if checkpoint_name not in MODEL_PATHS:
raise ValueError(
f"Unknown checkpoint name {checkpoint_name}. Known names are: {MODEL_PATHS.keys()}."
)
print(checkpoint_name)
path = fetch_file_cached(
MODEL_PATHS[checkpoint_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size
)
return torch.load(path, map_location=device)
def load_model(
model_name: str,
device: torch.device,
**kwargs,
) -> Dict[str, torch.Tensor]:
from .configs import model_from_config
model = model_from_config(load_config(model_name, **kwargs), device=device)
# print(model_name, kwargs)
# print(model)
model.load_state_dict(load_checkpoint(model_name, device=device, **kwargs))
model.eval()
return model
|