File size: 4,446 Bytes
dfd2e65 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import hashlib
from io import BytesIO
from typing import Optional
import safetensors.torch
import torch
def model_hash(filename):
"""Old model hash used by stable-diffusion-webui"""
try:
with open(filename, "rb") as file:
m = hashlib.sha256()
file.seek(0x100000)
m.update(file.read(0x10000))
return m.hexdigest()[0:8]
except FileNotFoundError:
return "NOFILE"
except IsADirectoryError: # Linux?
return "IsADirectory"
except PermissionError: # Windows
return "IsADirectory"
def calculate_sha256(filename):
"""New model hash used by stable-diffusion-webui"""
try:
hash_sha256 = hashlib.sha256()
blksize = 1024 * 1024
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
except FileNotFoundError:
return "NOFILE"
except IsADirectoryError: # Linux?
return "IsADirectory"
except PermissionError: # Windows
return "IsADirectory"
def addnet_hash_legacy(b):
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
m = hashlib.sha256()
b.seek(0x100000)
m.update(b.read(0x10000))
return m.hexdigest()[0:8]
def addnet_hash_safetensors(b):
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
hash_sha256 = hashlib.sha256()
blksize = 1024 * 1024
b.seek(0)
header = b.read(8)
n = int.from_bytes(header, "little")
offset = n + 8
b.seek(offset)
for chunk in iter(lambda: b.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
def precalculate_safetensors_hashes(tensors, metadata):
"""Precalculate the model hashes needed by sd-webui-additional-networks to
save time on indexing the model later."""
# Because writing user metadata to the file can change the result of
# sd_models.model_hash(), only retain the training metadata for purposes of
# calculating the hash, as they are meant to be immutable
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
bytes = safetensors.torch.save(tensors, metadata)
b = BytesIO(bytes)
model_hash = addnet_hash_safetensors(b)
legacy_hash = addnet_hash_legacy(b)
return model_hash, legacy_hash
def dtype_to_str(dtype: torch.dtype) -> str:
# get name of the dtype
dtype_name = str(dtype).split(".")[-1]
return dtype_name
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
"""
Convert a string to a torch.dtype
Args:
s: string representation of the dtype
default_dtype: default dtype to return if s is None
Returns:
torch.dtype: the corresponding torch.dtype
Raises:
ValueError: if the dtype is not supported
Examples:
>>> str_to_dtype("float32")
torch.float32
>>> str_to_dtype("fp32")
torch.float32
>>> str_to_dtype("float16")
torch.float16
>>> str_to_dtype("fp16")
torch.float16
>>> str_to_dtype("bfloat16")
torch.bfloat16
>>> str_to_dtype("bf16")
torch.bfloat16
>>> str_to_dtype("fp8")
torch.float8_e4m3fn
>>> str_to_dtype("fp8_e4m3fn")
torch.float8_e4m3fn
>>> str_to_dtype("fp8_e4m3fnuz")
torch.float8_e4m3fnuz
>>> str_to_dtype("fp8_e5m2")
torch.float8_e5m2
>>> str_to_dtype("fp8_e5m2fnuz")
torch.float8_e5m2fnuz
"""
if s is None:
return default_dtype
if s in ["bf16", "bfloat16"]:
return torch.bfloat16
elif s in ["fp16", "float16"]:
return torch.float16
elif s in ["fp32", "float32", "float"]:
return torch.float32
elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]:
return torch.float8_e4m3fn
elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]:
return torch.float8_e4m3fnuz
elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]:
return torch.float8_e5m2
elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]:
return torch.float8_e5m2fnuz
elif s in ["fp8", "float8"]:
return torch.float8_e4m3fn # default fp8
else:
raise ValueError(f"Unsupported dtype: {s}")
|