|
|
|
import datetime
|
|
import hashlib
|
|
from io import BytesIO
|
|
import os
|
|
from typing import List, Optional, Tuple, Union
|
|
import safetensors
|
|
from library.utils import setup_logging
|
|
setup_logging()
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
r"""
|
|
# Metadata Example
|
|
metadata = {
|
|
# === Must ===
|
|
"modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
|
|
"modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID
|
|
"modelspec.implementation": "sgm",
|
|
"modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc
|
|
# === Should ===
|
|
"modelspec.author": "Example Corp", # Your name or company name
|
|
"modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know
|
|
"modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created
|
|
# === Can ===
|
|
"modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc.
|
|
"modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model
|
|
}
|
|
"""
|
|
|
|
BASE_METADATA = {
|
|
|
|
"modelspec.sai_model_spec": "1.0.0",
|
|
"modelspec.architecture": None,
|
|
"modelspec.implementation": None,
|
|
"modelspec.title": None,
|
|
"modelspec.resolution": None,
|
|
|
|
"modelspec.description": None,
|
|
"modelspec.author": None,
|
|
"modelspec.date": None,
|
|
|
|
"modelspec.license": None,
|
|
"modelspec.tags": None,
|
|
"modelspec.merged_from": None,
|
|
"modelspec.prediction_type": None,
|
|
"modelspec.timestep_range": None,
|
|
"modelspec.encoder_layer": None,
|
|
}
|
|
|
|
|
|
MODELSPEC_TITLE = "modelspec.title"
|
|
|
|
ARCH_SD_V1 = "stable-diffusion-v1"
|
|
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
|
|
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
|
|
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
|
|
|
|
ADAPTER_LORA = "lora"
|
|
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
|
|
|
|
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
|
|
IMPL_DIFFUSERS = "diffusers"
|
|
|
|
PRED_TYPE_EPSILON = "epsilon"
|
|
PRED_TYPE_V = "v"
|
|
|
|
|
|
def load_bytes_in_safetensors(tensors):
|
|
bytes = safetensors.torch.save(tensors)
|
|
b = BytesIO(bytes)
|
|
|
|
b.seek(0)
|
|
header = b.read(8)
|
|
n = int.from_bytes(header, "little")
|
|
|
|
offset = n + 8
|
|
b.seek(offset)
|
|
|
|
return b.read()
|
|
|
|
|
|
def precalculate_safetensors_hashes(state_dict):
|
|
|
|
hash_sha256 = hashlib.sha256()
|
|
for tensor in state_dict.values():
|
|
single_tensor_sd = {"tensor": tensor}
|
|
bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd)
|
|
hash_sha256.update(bytes_for_tensor)
|
|
|
|
return f"0x{hash_sha256.hexdigest()}"
|
|
|
|
|
|
def update_hash_sha256(metadata: dict, state_dict: dict):
|
|
raise NotImplementedError
|
|
|
|
|
|
def build_metadata(
|
|
state_dict: Optional[dict],
|
|
v2: bool,
|
|
v_parameterization: bool,
|
|
sdxl: bool,
|
|
lora: bool,
|
|
textual_inversion: bool,
|
|
timestamp: float,
|
|
title: Optional[str] = None,
|
|
reso: Optional[Union[int, Tuple[int, int]]] = None,
|
|
is_stable_diffusion_ckpt: Optional[bool] = None,
|
|
author: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
license: Optional[str] = None,
|
|
tags: Optional[str] = None,
|
|
merged_from: Optional[str] = None,
|
|
timesteps: Optional[Tuple[int, int]] = None,
|
|
clip_skip: Optional[int] = None,
|
|
):
|
|
|
|
|
|
metadata = {}
|
|
metadata.update(BASE_METADATA)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if sdxl:
|
|
arch = ARCH_SD_XL_V1_BASE
|
|
elif v2:
|
|
if v_parameterization:
|
|
arch = ARCH_SD_V2_768_V
|
|
else:
|
|
arch = ARCH_SD_V2_512
|
|
else:
|
|
arch = ARCH_SD_V1
|
|
|
|
if lora:
|
|
arch += f"/{ADAPTER_LORA}"
|
|
elif textual_inversion:
|
|
arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
|
|
|
|
metadata["modelspec.architecture"] = arch
|
|
|
|
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
|
|
is_stable_diffusion_ckpt = True
|
|
|
|
if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
|
|
|
impl = IMPL_STABILITY_AI
|
|
else:
|
|
|
|
impl = IMPL_DIFFUSERS
|
|
metadata["modelspec.implementation"] = impl
|
|
|
|
if title is None:
|
|
if lora:
|
|
title = "LoRA"
|
|
elif textual_inversion:
|
|
title = "TextualInversion"
|
|
else:
|
|
title = "Checkpoint"
|
|
title += f"@{timestamp}"
|
|
metadata[MODELSPEC_TITLE] = title
|
|
|
|
if author is not None:
|
|
metadata["modelspec.author"] = author
|
|
else:
|
|
del metadata["modelspec.author"]
|
|
|
|
if description is not None:
|
|
metadata["modelspec.description"] = description
|
|
else:
|
|
del metadata["modelspec.description"]
|
|
|
|
if merged_from is not None:
|
|
metadata["modelspec.merged_from"] = merged_from
|
|
else:
|
|
del metadata["modelspec.merged_from"]
|
|
|
|
if license is not None:
|
|
metadata["modelspec.license"] = license
|
|
else:
|
|
del metadata["modelspec.license"]
|
|
|
|
if tags is not None:
|
|
metadata["modelspec.tags"] = tags
|
|
else:
|
|
del metadata["modelspec.tags"]
|
|
|
|
|
|
int_ts = int(timestamp)
|
|
|
|
|
|
date = datetime.datetime.fromtimestamp(int_ts).isoformat()
|
|
metadata["modelspec.date"] = date
|
|
|
|
if reso is not None:
|
|
|
|
if isinstance(reso, str):
|
|
reso = tuple(map(int, reso.split(",")))
|
|
if len(reso) == 1:
|
|
reso = (reso[0], reso[0])
|
|
else:
|
|
|
|
if sdxl:
|
|
reso = 1024
|
|
elif v2 and v_parameterization:
|
|
reso = 768
|
|
else:
|
|
reso = 512
|
|
if isinstance(reso, int):
|
|
reso = (reso, reso)
|
|
|
|
metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
|
|
|
|
if v_parameterization:
|
|
metadata["modelspec.prediction_type"] = PRED_TYPE_V
|
|
else:
|
|
metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
|
|
|
|
if timesteps is not None:
|
|
if isinstance(timesteps, str) or isinstance(timesteps, int):
|
|
timesteps = (timesteps, timesteps)
|
|
if len(timesteps) == 1:
|
|
timesteps = (timesteps[0], timesteps[0])
|
|
metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}"
|
|
else:
|
|
del metadata["modelspec.timestep_range"]
|
|
|
|
if clip_skip is not None:
|
|
metadata["modelspec.encoder_layer"] = f"{clip_skip}"
|
|
else:
|
|
del metadata["modelspec.encoder_layer"]
|
|
|
|
|
|
|
|
if not all([v is not None for v in metadata.values()]):
|
|
logger.error(f"Internal error: some metadata values are None: {metadata}")
|
|
|
|
return metadata
|
|
|
|
|
|
|
|
|
|
|
|
def get_title(metadata: dict) -> Optional[str]:
|
|
return metadata.get(MODELSPEC_TITLE, None)
|
|
|
|
|
|
def load_metadata_from_safetensors(model: str) -> dict:
|
|
if not model.endswith(".safetensors"):
|
|
return {}
|
|
|
|
with safetensors.safe_open(model, framework="pt") as f:
|
|
metadata = f.metadata()
|
|
if metadata is None:
|
|
metadata = {}
|
|
return metadata
|
|
|
|
|
|
def build_merged_from(models: List[str]) -> str:
|
|
def get_title(model: str):
|
|
metadata = load_metadata_from_safetensors(model)
|
|
title = metadata.get(MODELSPEC_TITLE, None)
|
|
if title is None:
|
|
title = os.path.splitext(os.path.basename(model))[0]
|
|
return title
|
|
|
|
titles = [get_title(model) for model in models]
|
|
return ", ".join(titles)
|
|
|
|
|
|
|
|
|
|
|
|
r"""
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
import torch
|
|
from safetensors.torch import load_file
|
|
from library import train_util
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--ckpt", type=str, required=True)
|
|
args = parser.parse_args()
|
|
|
|
print(f"Loading {args.ckpt}")
|
|
state_dict = load_file(args.ckpt)
|
|
|
|
print(f"Calculating metadata")
|
|
metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0)
|
|
print(metadata)
|
|
del state_dict
|
|
|
|
# by reference implementation
|
|
with open(args.ckpt, mode="rb") as file_data:
|
|
file_hash = hashlib.sha256()
|
|
head_len = struct.unpack("Q", file_data.read(8)) # int64 header length prefix
|
|
header = json.loads(file_data.read(head_len[0])) # header itself, json string
|
|
content = (
|
|
file_data.read()
|
|
) # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl.
|
|
file_hash.update(content)
|
|
# ===== Update the hash for modelspec =====
|
|
by_ref = f"0x{file_hash.hexdigest()}"
|
|
print(by_ref)
|
|
print("is same?", by_ref == metadata["modelspec.hash_sha256"])
|
|
|
|
"""
|
|
|