File size: 5,915 Bytes
19c4ddf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
Adapted from: https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/download.py
"""

import hashlib
import os
from functools import lru_cache
from typing import Dict, Optional

import requests
import torch
import yaml
from filelock import FileLock
from tqdm.auto import tqdm

MODEL_PATHS = {
    "transmitter": "https://openaipublic.azureedge.net/main/shap-e/transmitter.pt",
    "decoder": "https://openaipublic.azureedge.net/main/shap-e/vector_decoder.pt",
    "text300M": "https://openaipublic.azureedge.net/main/shap-e/text_cond.pt",
    "image300M": "https://openaipublic.azureedge.net/main/shap-e/image_cond.pt",
}

CONFIG_PATHS = {
    "transmitter": "https://openaipublic.azureedge.net/main/shap-e/transmitter_config.yaml",
    "decoder": "https://openaipublic.azureedge.net/main/shap-e/vector_decoder_config.yaml",
    "text300M": "https://openaipublic.azureedge.net/main/shap-e/text_cond_config.yaml",
    "image300M": "https://openaipublic.azureedge.net/main/shap-e/image_cond_config.yaml",
    "diffusion": "https://openaipublic.azureedge.net/main/shap-e/diffusion_config.yaml",
}

URL_HASHES = {
    "https://openaipublic.azureedge.net/main/shap-e/transmitter.pt": "af02a0b85a8abdfb3919584b63c540ba175f6ad4790f574a7fef4617e5acdc3b",
    "https://openaipublic.azureedge.net/main/shap-e/vector_decoder.pt": "d7e7ebbfe3780499ae89b2da5e7c1354012dba5a6abfe295bed42f25c3be1b98",
    "https://openaipublic.azureedge.net/main/shap-e/text_cond.pt": "e6b4fa599a7b3c3b16c222d5f5fe56f9db9289ff0b6575fbe5c11bc97106aad4",
    "https://openaipublic.azureedge.net/main/shap-e/image_cond.pt": "cb8072c64bbbcf6910488814d212227de5db291780d4ea99c6152f9346cf12aa",
    "https://openaipublic.azureedge.net/main/shap-e/transmitter_config.yaml": "ffe1bcb405104a37d9408391182ab118a4ef313c391e07689684f1f62071605e",
    "https://openaipublic.azureedge.net/main/shap-e/vector_decoder_config.yaml": "e6d373649f8e24d85925f4674b9ac41c57aba5f60e42cde6d10f87381326365c",
    "https://openaipublic.azureedge.net/main/shap-e/text_cond_config.yaml": "f290beeea3d3e9ff15db01bde5382b6e549e463060c0744f89c049505be246c1",
    "https://openaipublic.azureedge.net/main/shap-e/image_cond_config.yaml": "4e0745605a533c543c72add803a78d233e2a6401e0abfa0cad58afb4d74ad0b0",
    "https://openaipublic.azureedge.net/main/shap-e/diffusion_config.yaml": "efcb2cd7ee545b2d27223979d41857802448143990572a42645cd09c2942ed57",
}


@lru_cache()
def default_cache_dir() -> str:
    return os.path.join(os.path.abspath(os.getcwd()), "shap_e_model_cache")


def fetch_file_cached(
    url: str, progress: bool = True, cache_dir: Optional[str] = None, chunk_size: int = 4096
) -> str:
    """
    Download the file at the given URL into a local file and return the path.
    If cache_dir is specified, it will be used to download the files.
    Otherwise, default_cache_dir() is used.
    """
    expected_hash = URL_HASHES[url]

    if cache_dir is None:
        cache_dir = default_cache_dir()
    os.makedirs(cache_dir, exist_ok=True)
    local_path = os.path.join(cache_dir, url.split("/")[-1])
    if os.path.exists(local_path):
        check_hash(local_path, expected_hash)
        return local_path

    response = requests.get(url, stream=True)
    size = int(response.headers.get("content-length", "0"))
    with FileLock(local_path + ".lock"):
        if progress:
            pbar = tqdm(total=size, unit="iB", unit_scale=True)
        tmp_path = local_path + ".tmp"
        with open(tmp_path, "wb") as f:
            for chunk in response.iter_content(chunk_size):
                if progress:
                    pbar.update(len(chunk))
                f.write(chunk)
        os.rename(tmp_path, local_path)
        if progress:
            pbar.close()
        check_hash(local_path, expected_hash)
        return local_path


def check_hash(path: str, expected_hash: str):
    actual_hash = hash_file(path)
    if actual_hash != expected_hash:
        raise RuntimeError(
            f"The file {path} should have hash {expected_hash} but has {actual_hash}. "
            "Try deleting it and running this call again."
        )


def hash_file(path: str) -> str:
    sha256_hash = hashlib.sha256()
    with open(path, "rb") as file:
        while True:
            data = file.read(4096)
            if not len(data):
                break
            sha256_hash.update(data)
    return sha256_hash.hexdigest()


def load_config(
    config_name: str,
    progress: bool = False,
    cache_dir: Optional[str] = None,
    chunk_size: int = 4096,
):
    if config_name not in CONFIG_PATHS:
        raise ValueError(
            f"Unknown config name {config_name}. Known names are: {CONFIG_PATHS.keys()}."
        )
    path = fetch_file_cached(
        CONFIG_PATHS[config_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size
    )
    with open(path, "r") as f:
        return yaml.safe_load(f)


def load_checkpoint(
    checkpoint_name: str,
    device: torch.device,
    progress: bool = True,
    cache_dir: Optional[str] = None,
    chunk_size: int = 4096,
) -> Dict[str, torch.Tensor]:
    if checkpoint_name not in MODEL_PATHS:
        raise ValueError(
            f"Unknown checkpoint name {checkpoint_name}. Known names are: {MODEL_PATHS.keys()}."
        )
    print(checkpoint_name)
    path = fetch_file_cached(
        MODEL_PATHS[checkpoint_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size
    )
    return torch.load(path, map_location=device)


def load_model(
    model_name: str,
    device: torch.device,
    **kwargs,
) -> Dict[str, torch.Tensor]:
    from .configs import model_from_config

    model = model_from_config(load_config(model_name, **kwargs), device=device)
    # print(model_name, kwargs)
    # print(model)
    model.load_state_dict(load_checkpoint(model_name, device=device, **kwargs))
    model.eval()
    return model