Juan Sebastian Giraldo
commited on
Commit
·
75cf81d
1
Parent(s):
57ff33f
Upload Lora app
Browse files- .gitignore +3 -0
- app.py +211 -0
- requirements.txt +0 -0
- safetensors_file.py +125 -0
- safetensors_util.py +98 -0
- safetensors_worker.py +243 -0
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**/__pycache__/
|
| 2 |
+
/.venv/
|
| 3 |
+
/scripts/
|
app.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import json
|
| 3 |
+
import sys
|
| 4 |
+
import io
|
| 5 |
+
import subprocess
|
| 6 |
+
import tempfile
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from safetensors_worker import PrintMetadata
|
| 9 |
+
|
| 10 |
+
class Context:
|
| 11 |
+
def __init__(self):
|
| 12 |
+
self.obj = {'quiet': True, 'parse_more': True}
|
| 13 |
+
|
| 14 |
+
ctx = Context()
|
| 15 |
+
|
| 16 |
+
def debug_log(message: str):
|
| 17 |
+
print(f"[DEBUG] {message}")
|
| 18 |
+
|
| 19 |
+
def load_metadata(file_path: str) -> tuple:
|
| 20 |
+
try:
|
| 21 |
+
debug_log(f"Loading file: {file_path}")
|
| 22 |
+
|
| 23 |
+
if not file_path:
|
| 24 |
+
return {"status": "Awaiting input"}, {}, "", "", ""
|
| 25 |
+
|
| 26 |
+
old_stdout = sys.stdout
|
| 27 |
+
sys.stdout = buffer = io.StringIO()
|
| 28 |
+
exit_code = PrintMetadata(ctx.obj, file_path.name)
|
| 29 |
+
sys.stdout = old_stdout
|
| 30 |
+
|
| 31 |
+
metadata_str = buffer.getvalue().strip()
|
| 32 |
+
|
| 33 |
+
if exit_code != 0:
|
| 34 |
+
error_msg = f"Error code {exit_code}"
|
| 35 |
+
return {"error": error_msg}, {}, "", error_msg, ""
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
full_metadata = json.loads(metadata_str)
|
| 39 |
+
except json.JSONDecodeError:
|
| 40 |
+
error_msg = "Invalid metadata structure"
|
| 41 |
+
return {"error": error_msg}, {}, "", error_msg, ""
|
| 42 |
+
|
| 43 |
+
training_params = full_metadata.get("__metadata__", {})
|
| 44 |
+
key_metrics = {
|
| 45 |
+
key: training_params.get(key, "N/A")
|
| 46 |
+
for key in [
|
| 47 |
+
"ss_optimizer", "ss_num_epochs", "ss_unet_lr",
|
| 48 |
+
"ss_text_encoder_lr", "ss_steps"
|
| 49 |
+
]
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
return full_metadata, key_metrics, json.dumps(full_metadata, indent=2), "", file_path.name
|
| 53 |
+
|
| 54 |
+
except Exception as e:
|
| 55 |
+
return {"error": str(e)}, {}, "", str(e), ""
|
| 56 |
+
|
| 57 |
+
def validate_json(edited_json: str) -> tuple:
|
| 58 |
+
try:
|
| 59 |
+
return True, json.loads(edited_json), ""
|
| 60 |
+
except Exception as e:
|
| 61 |
+
return False, None, str(e)
|
| 62 |
+
|
| 63 |
+
def update_metadata(edited_json: str) -> tuple:
|
| 64 |
+
try:
|
| 65 |
+
modified_data = json.loads(edited_json)
|
| 66 |
+
metadata = modified_data.get("__metadata__", {})
|
| 67 |
+
|
| 68 |
+
key_fields = {
|
| 69 |
+
param: metadata.get(param, "N/A")
|
| 70 |
+
for param in [
|
| 71 |
+
"ss_optimizer", "ss_num_epochs", "ss_unet_lr",
|
| 72 |
+
"ss_text_encoder_lr", "ss_steps"
|
| 73 |
+
]
|
| 74 |
+
}
|
| 75 |
+
return key_fields, modified_data, ""
|
| 76 |
+
except:
|
| 77 |
+
return gr.update(), gr.update(), ""
|
| 78 |
+
|
| 79 |
+
def save_metadata(edited_json: str, source_file: str, output_name: str) -> tuple:
|
| 80 |
+
debug_log("Initiating save process")
|
| 81 |
+
try:
|
| 82 |
+
if not source_file:
|
| 83 |
+
return None, "No source file provided"
|
| 84 |
+
|
| 85 |
+
is_valid, parsed_data, error = validate_json(edited_json)
|
| 86 |
+
if not is_valid:
|
| 87 |
+
return None, f"Validation error: {error}"
|
| 88 |
+
|
| 89 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp:
|
| 90 |
+
json.dump(parsed_data, tmp, indent=2)
|
| 91 |
+
temp_path = tmp.name
|
| 92 |
+
|
| 93 |
+
source_path = Path(source_file)
|
| 94 |
+
|
| 95 |
+
if output_name.strip():
|
| 96 |
+
base_name = output_name.strip()
|
| 97 |
+
if not base_name.endswith(".safetensors"):
|
| 98 |
+
base_name += ".safetensors"
|
| 99 |
+
else:
|
| 100 |
+
base_name = f"{source_path.stem}_modified.safetensors"
|
| 101 |
+
|
| 102 |
+
output_path = Path(base_name)
|
| 103 |
+
version = 1
|
| 104 |
+
while output_path.exists():
|
| 105 |
+
output_path = Path(f"{source_path.stem}_modified_{version}.safetensors")
|
| 106 |
+
version += 1
|
| 107 |
+
|
| 108 |
+
cmd = [
|
| 109 |
+
sys.executable,
|
| 110 |
+
"safetensors_util.py",
|
| 111 |
+
"writemd",
|
| 112 |
+
source_file,
|
| 113 |
+
temp_path,
|
| 114 |
+
str(output_path),
|
| 115 |
+
"-f"
|
| 116 |
+
]
|
| 117 |
+
|
| 118 |
+
result = subprocess.run(
|
| 119 |
+
cmd,
|
| 120 |
+
capture_output=True,
|
| 121 |
+
text=True,
|
| 122 |
+
check=False
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
Path(temp_path).unlink(missing_ok=True)
|
| 126 |
+
|
| 127 |
+
if result.returncode != 0:
|
| 128 |
+
error_msg = f"Save failure: {result.stderr}"
|
| 129 |
+
return None, error_msg
|
| 130 |
+
|
| 131 |
+
return str(output_path), ""
|
| 132 |
+
|
| 133 |
+
except Exception as e:
|
| 134 |
+
return None, f"Critical error: {str(e)}"
|
| 135 |
+
|
| 136 |
+
def create_interface():
|
| 137 |
+
with gr.Blocks(title="LoRA Metadata Editor") as app:
|
| 138 |
+
gr.Markdown("# LoRA Metadata Editor")
|
| 139 |
+
|
| 140 |
+
with gr.Tabs():
|
| 141 |
+
with gr.Tab("Metdata Viewer"):
|
| 142 |
+
gr.Markdown("### LoRa Upload")
|
| 143 |
+
file_input = gr.File(
|
| 144 |
+
file_types=[".safetensors"],
|
| 145 |
+
show_label=False
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
with gr.Row():
|
| 149 |
+
with gr.Column():
|
| 150 |
+
gr.Markdown("### Full Metadata")
|
| 151 |
+
full_viewer = gr.JSON(show_label=False)
|
| 152 |
+
|
| 153 |
+
with gr.Column():
|
| 154 |
+
gr.Markdown("### Key Metrics")
|
| 155 |
+
key_viewer = gr.JSON(show_label=False)
|
| 156 |
+
|
| 157 |
+
with gr.Tab("Edit Metadata"):
|
| 158 |
+
with gr.Row():
|
| 159 |
+
with gr.Column():
|
| 160 |
+
gr.Markdown("### JSON Workspace")
|
| 161 |
+
metadata_editor = gr.Textbox(
|
| 162 |
+
lines=25,
|
| 163 |
+
show_label=False,
|
| 164 |
+
placeholder="Edit metadata JSON here"
|
| 165 |
+
)
|
| 166 |
+
gr.Markdown("### Output Name")
|
| 167 |
+
filename_input = gr.Textbox(
|
| 168 |
+
placeholder="Leave empty for auto-naming",
|
| 169 |
+
show_label=False
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
with gr.Column():
|
| 173 |
+
gr.Markdown("### Live Preview")
|
| 174 |
+
modified_viewer = gr.JSON(show_label=False)
|
| 175 |
+
save_btn = gr.Button("💾 Save Metadata", variant="primary")
|
| 176 |
+
gr.Markdown("### Download Modified LoRa")
|
| 177 |
+
output_file = gr.File(
|
| 178 |
+
visible=False,
|
| 179 |
+
show_label=False
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
status_display = gr.HTML(visible=False)
|
| 183 |
+
source_tracker = gr.State()
|
| 184 |
+
|
| 185 |
+
file_input.upload(
|
| 186 |
+
load_metadata,
|
| 187 |
+
inputs=file_input,
|
| 188 |
+
outputs=[full_viewer, key_viewer, metadata_editor, status_display, source_tracker]
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
metadata_editor.change(
|
| 192 |
+
update_metadata,
|
| 193 |
+
inputs=metadata_editor,
|
| 194 |
+
outputs=[key_viewer, modified_viewer, status_display]
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
save_btn.click(
|
| 198 |
+
save_metadata,
|
| 199 |
+
inputs=[metadata_editor, source_tracker, filename_input],
|
| 200 |
+
outputs=[output_file, status_display],
|
| 201 |
+
).then(
|
| 202 |
+
lambda x: gr.File(value=x, visible=True),
|
| 203 |
+
inputs=output_file,
|
| 204 |
+
outputs=output_file
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
return app
|
| 208 |
+
|
| 209 |
+
if __name__ == "__main__":
|
| 210 |
+
interface = create_interface()
|
| 211 |
+
interface.launch()
|
requirements.txt
ADDED
|
Binary file (4.84 kB). View file
|
|
|
safetensors_file.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys, json
|
| 2 |
+
|
| 3 |
+
class SafeTensorsException(Exception):
|
| 4 |
+
def __init__(self, msg:str):
|
| 5 |
+
self.msg=msg
|
| 6 |
+
super().__init__(msg)
|
| 7 |
+
|
| 8 |
+
@staticmethod
|
| 9 |
+
def invalid_file(filename:str,whatiswrong:str):
|
| 10 |
+
s=f"{filename} is not a valid .safetensors file: {whatiswrong}"
|
| 11 |
+
return SafeTensorsException(msg=s)
|
| 12 |
+
|
| 13 |
+
def __str__(self):
|
| 14 |
+
return self.msg
|
| 15 |
+
|
| 16 |
+
class SafeTensorsChunk:
|
| 17 |
+
def __init__(self,name:str,dtype:str,shape:list[int],offset0:int,offset1:int):
|
| 18 |
+
self.name=name
|
| 19 |
+
self.dtype=dtype
|
| 20 |
+
self.shape=shape
|
| 21 |
+
self.offset0=offset0
|
| 22 |
+
self.offset1=offset1
|
| 23 |
+
|
| 24 |
+
class SafeTensorsFile:
|
| 25 |
+
def __init__(self):
|
| 26 |
+
self.f=None #file handle
|
| 27 |
+
self.hdrbuf=None #header byte buffer
|
| 28 |
+
self.header=None #parsed header as a dict
|
| 29 |
+
self.error=0
|
| 30 |
+
|
| 31 |
+
def __del__(self):
|
| 32 |
+
self.close_file()
|
| 33 |
+
|
| 34 |
+
def __enter__(self):
|
| 35 |
+
return self
|
| 36 |
+
|
| 37 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
| 38 |
+
self.close_file()
|
| 39 |
+
|
| 40 |
+
def close_file(self):
|
| 41 |
+
if self.f is not None:
|
| 42 |
+
self.f.close()
|
| 43 |
+
self.f=None
|
| 44 |
+
self.filename=""
|
| 45 |
+
|
| 46 |
+
#test file: duplicate_keys_in_header.safetensors
|
| 47 |
+
def _CheckDuplicateHeaderKeys(self):
|
| 48 |
+
def parse_object_pairs(pairs):
|
| 49 |
+
return [k for k,_ in pairs]
|
| 50 |
+
|
| 51 |
+
keys=json.loads(self.hdrbuf,object_pairs_hook=parse_object_pairs)
|
| 52 |
+
#print(keys)
|
| 53 |
+
d={}
|
| 54 |
+
for k in keys:
|
| 55 |
+
if k in d: d[k]=d[k]+1
|
| 56 |
+
else: d[k]=1
|
| 57 |
+
hasError=False
|
| 58 |
+
for k,v in d.items():
|
| 59 |
+
if v>1:
|
| 60 |
+
print(f"key {k} used {v} times in header",file=sys.stderr)
|
| 61 |
+
hasError=True
|
| 62 |
+
if hasError:
|
| 63 |
+
raise SafeTensorsException.invalid_file(self.filename,"duplicate keys in header")
|
| 64 |
+
|
| 65 |
+
@staticmethod
|
| 66 |
+
def open_file(filename:str,quiet=False,parseHeader=True):
|
| 67 |
+
s=SafeTensorsFile()
|
| 68 |
+
s.open(filename,quiet,parseHeader)
|
| 69 |
+
return s
|
| 70 |
+
|
| 71 |
+
def open(self,fn:str,quiet=False,parseHeader=True)->int:
|
| 72 |
+
st=os.stat(fn)
|
| 73 |
+
if st.st_size<8: #test file: zero_len_file.safetensors
|
| 74 |
+
raise SafeTensorsException.invalid_file(fn,"length less than 8 bytes")
|
| 75 |
+
|
| 76 |
+
f=open(fn,"rb")
|
| 77 |
+
b8=f.read(8) #read header size
|
| 78 |
+
if len(b8)!=8:
|
| 79 |
+
raise SafeTensorsException.invalid_file(fn,f"read only {len(b8)} bytes at start of file")
|
| 80 |
+
headerlen=int.from_bytes(b8,'little',signed=False)
|
| 81 |
+
|
| 82 |
+
if (8+headerlen>st.st_size): #test file: header_size_too_big.safetensors
|
| 83 |
+
raise SafeTensorsException.invalid_file(fn,"header extends past end of file")
|
| 84 |
+
|
| 85 |
+
if quiet==False:
|
| 86 |
+
print(f"{fn}: length={st.st_size}, header length={headerlen}")
|
| 87 |
+
hdrbuf=f.read(headerlen)
|
| 88 |
+
if len(hdrbuf)!=headerlen:
|
| 89 |
+
raise SafeTensorsException.invalid_file(fn,f"header size is {headerlen}, but read {len(hdrbuf)} bytes")
|
| 90 |
+
self.filename=fn
|
| 91 |
+
self.f=f
|
| 92 |
+
self.st=st
|
| 93 |
+
self.hdrbuf=hdrbuf
|
| 94 |
+
self.error=0
|
| 95 |
+
self.headerlen=headerlen
|
| 96 |
+
if parseHeader==True:
|
| 97 |
+
self._CheckDuplicateHeaderKeys()
|
| 98 |
+
self.header=json.loads(self.hdrbuf)
|
| 99 |
+
return 0
|
| 100 |
+
|
| 101 |
+
def get_header(self):
|
| 102 |
+
return self.header
|
| 103 |
+
|
| 104 |
+
def load_one_tensor(self,tensor_name:str):
|
| 105 |
+
self.get_header()
|
| 106 |
+
if tensor_name not in self.header: return None
|
| 107 |
+
|
| 108 |
+
t=self.header[tensor_name]
|
| 109 |
+
self.f.seek(8+self.headerlen+t['data_offsets'][0])
|
| 110 |
+
bytesToRead=t['data_offsets'][1]-t['data_offsets'][0]
|
| 111 |
+
bytes=self.f.read(bytesToRead)
|
| 112 |
+
if len(bytes)!=bytesToRead:
|
| 113 |
+
print(f"{tensor_name}: length={bytesToRead}, only read {len(bytes)} bytes",file=sys.stderr)
|
| 114 |
+
return bytes
|
| 115 |
+
|
| 116 |
+
def copy_data_to_file(self,file_handle) -> int:
|
| 117 |
+
|
| 118 |
+
self.f.seek(8+self.headerlen)
|
| 119 |
+
bytesLeft:int=self.st.st_size - 8 - self.headerlen
|
| 120 |
+
while bytesLeft>0:
|
| 121 |
+
chunklen:int=min(bytesLeft,int(16*1024*1024)) #copy in blocks of 16 MB
|
| 122 |
+
file_handle.write(self.f.read(chunklen))
|
| 123 |
+
bytesLeft-=chunklen
|
| 124 |
+
|
| 125 |
+
return 0
|
safetensors_util.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys, click
|
| 2 |
+
|
| 3 |
+
import safetensors_worker
|
| 4 |
+
# This file deals with command line only. If the command line is parsed successfully,
|
| 5 |
+
# we will call one of the functions in safetensors_worker.py.
|
| 6 |
+
|
| 7 |
+
readonly_input_file=click.argument("input_file", metavar='input_file',
|
| 8 |
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True))
|
| 9 |
+
output_file=click.argument("output_file", metavar='output_file',
|
| 10 |
+
type=click.Path(file_okay=True, dir_okay=False, writable=True))
|
| 11 |
+
|
| 12 |
+
force_overwrite_flag=click.option("-f","--force-overwrite",default=False,is_flag=True, show_default=True,
|
| 13 |
+
help="overwrite existing files")
|
| 14 |
+
fix_ued_flag=click.option("-pm","--parse-more",default=False,is_flag=True, show_default=True,
|
| 15 |
+
help="when printing metadata, unescaped doublequotes to make text more readable" )
|
| 16 |
+
quiet_flag=click.option("-q","--quiet",default=False,is_flag=True, show_default=True,
|
| 17 |
+
help="Quiet mode, don't print informational stuff" )
|
| 18 |
+
|
| 19 |
+
@click.group()
|
| 20 |
+
@click.version_option(version=7)
|
| 21 |
+
@quiet_flag
|
| 22 |
+
|
| 23 |
+
@click.pass_context
|
| 24 |
+
def cli(ctx,quiet:bool):
|
| 25 |
+
# ensure that ctx.obj exists and is a dict (in case `cli()` is called
|
| 26 |
+
# by means other than the `if` block below)
|
| 27 |
+
ctx.ensure_object(dict)
|
| 28 |
+
ctx.obj['quiet'] = quiet
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@cli.command(name="header",short_help="print file header")
|
| 32 |
+
@readonly_input_file
|
| 33 |
+
@click.pass_context
|
| 34 |
+
def cmd_header(ctx,input_file:str) -> int:
|
| 35 |
+
sys.exit( safetensors_worker.PrintHeader(ctx.obj,input_file) )
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@cli.command(name="metadata",short_help="print only __metadata__ in file header")
|
| 39 |
+
@readonly_input_file
|
| 40 |
+
@fix_ued_flag
|
| 41 |
+
@click.pass_context
|
| 42 |
+
def cmd_meta(ctx,input_file:str,parse_more:bool)->int:
|
| 43 |
+
ctx.obj['parse_more'] = parse_more
|
| 44 |
+
sys.exit( safetensors_worker.PrintMetadata(ctx.obj,input_file) )
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@cli.command(name="listkeys",short_help="print header key names (except __metadata__) as a Python list")
|
| 48 |
+
@readonly_input_file
|
| 49 |
+
@click.pass_context
|
| 50 |
+
def cmd_keyspy(ctx,input_file:str) -> int:
|
| 51 |
+
sys.exit( safetensors_worker.HeaderKeysToLists(ctx.obj,input_file) )
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@cli.command(name="writemd",short_help="read __metadata__ from json and write to safetensors file")
|
| 55 |
+
@click.argument("in_st_file", metavar='input_st_file',
|
| 56 |
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True))
|
| 57 |
+
@click.argument("in_json_file", metavar='input_json_file',
|
| 58 |
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True))
|
| 59 |
+
@output_file
|
| 60 |
+
@force_overwrite_flag
|
| 61 |
+
@click.pass_context
|
| 62 |
+
def cmd_writemd(ctx,in_st_file:str,in_json_file:str,output_file:str,force_overwrite:bool) -> int:
|
| 63 |
+
"""Read "__metadata__" from json file and write to safetensors header"""
|
| 64 |
+
ctx.obj['force_overwrite'] = force_overwrite
|
| 65 |
+
sys.exit( safetensors_worker.WriteMetadataToHeader(ctx.obj,in_st_file,in_json_file,output_file) )
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@cli.command(name="extracthdr",short_help="extract file header and save to output file")
|
| 69 |
+
@readonly_input_file
|
| 70 |
+
@output_file
|
| 71 |
+
@force_overwrite_flag
|
| 72 |
+
@click.pass_context
|
| 73 |
+
def cmd_extractheader(ctx,input_file:str,output_file:str,force_overwrite:bool) -> int:
|
| 74 |
+
ctx.obj['force_overwrite'] = force_overwrite
|
| 75 |
+
sys.exit( safetensors_worker.ExtractHeader(ctx.obj,input_file,output_file) )
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@cli.command(name="extractdata",short_help="extract one tensor and save to file")
|
| 79 |
+
@readonly_input_file
|
| 80 |
+
@click.argument("key_name", metavar='key_name',type=click.STRING)
|
| 81 |
+
@output_file
|
| 82 |
+
@force_overwrite_flag
|
| 83 |
+
@click.pass_context
|
| 84 |
+
def cmd_extractheader(ctx,input_file:str,key_name:str,output_file:str,force_overwrite:bool) -> int:
|
| 85 |
+
ctx.obj['force_overwrite'] = force_overwrite
|
| 86 |
+
sys.exit( safetensors_worker.ExtractData(ctx.obj,input_file,key_name,output_file) )
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@cli.command(name="checklora",short_help="see if input file is a SD 1.x LoRA file")
|
| 90 |
+
@readonly_input_file
|
| 91 |
+
@click.pass_context
|
| 92 |
+
def cmd_checklora(ctx,input_file:str)->int:
|
| 93 |
+
sys.exit( safetensors_worker.CheckLoRA(ctx.obj,input_file) )
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
if __name__ == '__main__':
|
| 97 |
+
sys.stdout.reconfigure(encoding='utf-8')
|
| 98 |
+
cli(obj={},max_content_width=96)
|
safetensors_worker.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys, json
|
| 2 |
+
from safetensors_file import SafeTensorsFile
|
| 3 |
+
|
| 4 |
+
def _need_force_overwrite(output_file:str,cmdLine:dict) -> bool:
|
| 5 |
+
if cmdLine["force_overwrite"]==False:
|
| 6 |
+
if os.path.exists(output_file):
|
| 7 |
+
print(f'output file "{output_file}" already exists, use -f flag to force overwrite',file=sys.stderr)
|
| 8 |
+
return True
|
| 9 |
+
return False
|
| 10 |
+
|
| 11 |
+
def WriteMetadataToHeader(cmdLine:dict,in_st_file:str,in_json_file:str,output_file:str) -> int:
|
| 12 |
+
if _need_force_overwrite(output_file,cmdLine): return -1
|
| 13 |
+
|
| 14 |
+
with open(in_json_file,"rt") as f:
|
| 15 |
+
inmeta=json.load(f)
|
| 16 |
+
if not "__metadata__" in inmeta:
|
| 17 |
+
print(f"file {in_json_file} does not contain a top-level __metadata__ item",file=sys.stderr)
|
| 18 |
+
#json.dump(inmeta,fp=sys.stdout,indent=2)
|
| 19 |
+
return -2
|
| 20 |
+
inmeta=inmeta["__metadata__"] #keep only metadata
|
| 21 |
+
#json.dump(inmeta,fp=sys.stdout,indent=2)
|
| 22 |
+
|
| 23 |
+
s=SafeTensorsFile.open_file(in_st_file)
|
| 24 |
+
js=s.get_header()
|
| 25 |
+
|
| 26 |
+
if inmeta==[]:
|
| 27 |
+
js.pop("__metadata__",0)
|
| 28 |
+
print("loaded __metadata__ is an empty list, output file will not contain __metadata__ in header")
|
| 29 |
+
else:
|
| 30 |
+
print("adding __metadata__ to header:")
|
| 31 |
+
json.dump(inmeta,fp=sys.stdout,indent=2)
|
| 32 |
+
if isinstance(inmeta,dict):
|
| 33 |
+
for k in inmeta:
|
| 34 |
+
inmeta[k]=str(inmeta[k])
|
| 35 |
+
else:
|
| 36 |
+
inmeta=str(inmeta)
|
| 37 |
+
#js["__metadata__"]=json.dumps(inmeta,ensure_ascii=False)
|
| 38 |
+
js["__metadata__"]=inmeta
|
| 39 |
+
print()
|
| 40 |
+
|
| 41 |
+
newhdrbuf=json.dumps(js,separators=(',',':'),ensure_ascii=False).encode('utf-8')
|
| 42 |
+
newhdrlen:int=int(len(newhdrbuf))
|
| 43 |
+
pad:int=((newhdrlen+7)&(~7))-newhdrlen #pad to multiple of 8
|
| 44 |
+
|
| 45 |
+
with open(output_file,"wb") as f:
|
| 46 |
+
f.write(int(newhdrlen+pad).to_bytes(8,'little'))
|
| 47 |
+
f.write(newhdrbuf)
|
| 48 |
+
if pad>0: f.write(bytearray([32]*pad))
|
| 49 |
+
i:int=s.copy_data_to_file(f)
|
| 50 |
+
if i==0:
|
| 51 |
+
print(f"file {output_file} saved successfully")
|
| 52 |
+
else:
|
| 53 |
+
print(f"error {i} occurred when writing to file {output_file}")
|
| 54 |
+
return i
|
| 55 |
+
|
| 56 |
+
def PrintHeader(cmdLine:dict,input_file:str) -> int:
|
| 57 |
+
s=SafeTensorsFile.open_file(input_file,cmdLine['quiet'])
|
| 58 |
+
js=s.get_header()
|
| 59 |
+
|
| 60 |
+
# All the .safetensors files I've seen have long key names, and as a result,
|
| 61 |
+
# neither json nor pprint package prints text in very readable format,
|
| 62 |
+
# so we print it ourselves, putting key name & value on one long line.
|
| 63 |
+
# Note the print out is in Python format, not valid JSON format.
|
| 64 |
+
firstKey=True
|
| 65 |
+
print("{")
|
| 66 |
+
for key in js:
|
| 67 |
+
if firstKey:
|
| 68 |
+
firstKey=False
|
| 69 |
+
else:
|
| 70 |
+
print(",")
|
| 71 |
+
json.dump(key,fp=sys.stdout,ensure_ascii=False,separators=(',',':'))
|
| 72 |
+
print(": ",end='')
|
| 73 |
+
json.dump(js[key],fp=sys.stdout,ensure_ascii=False,separators=(',',':'))
|
| 74 |
+
print("\n}")
|
| 75 |
+
return 0
|
| 76 |
+
|
| 77 |
+
def _ParseMore(d:dict):
|
| 78 |
+
'''Basically try to turn this:
|
| 79 |
+
|
| 80 |
+
"ss_dataset_dirs":"{\"abc\": {\"n_repeats\": 2, \"img_count\": 60}}",
|
| 81 |
+
|
| 82 |
+
into this:
|
| 83 |
+
|
| 84 |
+
"ss_dataset_dirs":{
|
| 85 |
+
"abc":{
|
| 86 |
+
"n_repeats":2,
|
| 87 |
+
"img_count":60
|
| 88 |
+
}
|
| 89 |
+
},
|
| 90 |
+
|
| 91 |
+
'''
|
| 92 |
+
for key in d:
|
| 93 |
+
value=d[key]
|
| 94 |
+
#print("+++",key,value,type(value),"+++",sep='|')
|
| 95 |
+
if isinstance(value,str):
|
| 96 |
+
try:
|
| 97 |
+
v2=json.loads(value)
|
| 98 |
+
d[key]=v2
|
| 99 |
+
value=v2
|
| 100 |
+
except json.JSONDecodeError as e:
|
| 101 |
+
pass
|
| 102 |
+
if isinstance(value,dict):
|
| 103 |
+
_ParseMore(value)
|
| 104 |
+
|
| 105 |
+
def PrintMetadata(cmdLine:dict,input_file:str) -> int:
|
| 106 |
+
with SafeTensorsFile.open_file(input_file,cmdLine['quiet']) as s:
|
| 107 |
+
js=s.get_header()
|
| 108 |
+
|
| 109 |
+
if not "__metadata__" in js:
|
| 110 |
+
print("file header does not contain a __metadata__ item",file=sys.stderr)
|
| 111 |
+
return -2
|
| 112 |
+
|
| 113 |
+
md=js["__metadata__"]
|
| 114 |
+
if cmdLine['parse_more']:
|
| 115 |
+
_ParseMore(md)
|
| 116 |
+
json.dump({"__metadata__":md},fp=sys.stdout,ensure_ascii=False,separators=(',',':'),indent=1)
|
| 117 |
+
return 0
|
| 118 |
+
|
| 119 |
+
def HeaderKeysToLists(cmdLine:dict,input_file:str) -> int:
|
| 120 |
+
s=SafeTensorsFile.open_file(input_file,cmdLine['quiet'])
|
| 121 |
+
js=s.get_header()
|
| 122 |
+
|
| 123 |
+
_lora_keys:list[tuple(str,bool)]=[] # use list to sort by name
|
| 124 |
+
for key in js:
|
| 125 |
+
if key=='__metadata__': continue
|
| 126 |
+
v=js[key]
|
| 127 |
+
isScalar=False
|
| 128 |
+
if isinstance(v,dict):
|
| 129 |
+
if 'shape' in v:
|
| 130 |
+
if 0==len(v['shape']):
|
| 131 |
+
isScalar=True
|
| 132 |
+
_lora_keys.append((key,isScalar))
|
| 133 |
+
_lora_keys.sort(key=lambda x:x[0])
|
| 134 |
+
|
| 135 |
+
def printkeylist(kl):
|
| 136 |
+
firstKey=True
|
| 137 |
+
for key in kl:
|
| 138 |
+
if firstKey: firstKey=False
|
| 139 |
+
else: print(",")
|
| 140 |
+
print(key,end='')
|
| 141 |
+
print()
|
| 142 |
+
|
| 143 |
+
print("# use list to keep insertion order")
|
| 144 |
+
print("_lora_keys:list[tuple[str,bool]]=[")
|
| 145 |
+
printkeylist(_lora_keys)
|
| 146 |
+
print("]")
|
| 147 |
+
|
| 148 |
+
return 0
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def ExtractHeader(cmdLine:dict,input_file:str,output_file:str)->int:
|
| 152 |
+
if _need_force_overwrite(output_file,cmdLine): return -1
|
| 153 |
+
|
| 154 |
+
s=SafeTensorsFile.open_file(input_file,parseHeader=False)
|
| 155 |
+
if s.error!=0: return s.error
|
| 156 |
+
|
| 157 |
+
hdrbuf=s.hdrbuf
|
| 158 |
+
s.close_file() #close it in case user wants to write back to input_file itself
|
| 159 |
+
with open(output_file,"wb") as fo:
|
| 160 |
+
wn=fo.write(hdrbuf)
|
| 161 |
+
if wn!=len(hdrbuf):
|
| 162 |
+
print(f"write output file failed, tried to write {len(hdrbuf)} bytes, only wrote {wn} bytes",file=sys.stderr)
|
| 163 |
+
return -1
|
| 164 |
+
print(f"raw header saved to file {output_file}")
|
| 165 |
+
return 0
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _CheckLoRA_internal(s:SafeTensorsFile)->int:
|
| 169 |
+
import lora_keys_sd15 as lora_keys
|
| 170 |
+
js=s.get_header()
|
| 171 |
+
set_scalar=set()
|
| 172 |
+
set_nonscalar=set()
|
| 173 |
+
for x in lora_keys._lora_keys:
|
| 174 |
+
if x[1]==True: set_scalar.add(x[0])
|
| 175 |
+
else: set_nonscalar.add(x[0])
|
| 176 |
+
|
| 177 |
+
bad_unknowns:list[str]=[] # unrecognized keys
|
| 178 |
+
bad_scalars:list[str]=[] #bad scalar
|
| 179 |
+
bad_nonscalars:list[str]=[] #bad nonscalar
|
| 180 |
+
for key in js:
|
| 181 |
+
if key in set_nonscalar:
|
| 182 |
+
if js[key]['shape']==[]: bad_nonscalars.append(key)
|
| 183 |
+
set_nonscalar.remove(key)
|
| 184 |
+
elif key in set_scalar:
|
| 185 |
+
if js[key]['shape']!=[]: bad_scalars.append(key)
|
| 186 |
+
set_scalar.remove(key)
|
| 187 |
+
else:
|
| 188 |
+
if "__metadata__"!=key:
|
| 189 |
+
bad_unknowns.append(key)
|
| 190 |
+
|
| 191 |
+
hasError=False
|
| 192 |
+
|
| 193 |
+
if len(bad_unknowns)!=0:
|
| 194 |
+
print("INFO: unrecognized items:")
|
| 195 |
+
for x in bad_unknowns: print(" ",x)
|
| 196 |
+
#hasError=True
|
| 197 |
+
|
| 198 |
+
if len(set_scalar)>0:
|
| 199 |
+
print("missing scalar keys:")
|
| 200 |
+
for x in set_scalar: print(" ",x)
|
| 201 |
+
hasError=True
|
| 202 |
+
if len(set_nonscalar)>0:
|
| 203 |
+
print("missing nonscalar keys:")
|
| 204 |
+
for x in set_nonscalar: print(" ",x)
|
| 205 |
+
hasError=True
|
| 206 |
+
|
| 207 |
+
if len(bad_scalars)!=0:
|
| 208 |
+
print("keys expected to be scalar but are nonscalar:")
|
| 209 |
+
for x in bad_scalars: print(" ",x)
|
| 210 |
+
hasError=True
|
| 211 |
+
|
| 212 |
+
if len(bad_nonscalars)!=0:
|
| 213 |
+
print("keys expected to be nonscalar but are scalar:")
|
| 214 |
+
for x in bad_nonscalars: print(" ",x)
|
| 215 |
+
hasError=True
|
| 216 |
+
|
| 217 |
+
return (1 if hasError else 0)
|
| 218 |
+
|
| 219 |
+
def CheckLoRA(cmdLine:dict,input_file:str)->int:
|
| 220 |
+
s=SafeTensorsFile.open_file(input_file)
|
| 221 |
+
i:int=_CheckLoRA_internal(s)
|
| 222 |
+
if i==0: print("looks like an OK SD 1.x LoRA file")
|
| 223 |
+
return 0
|
| 224 |
+
|
| 225 |
+
def ExtractData(cmdLine:dict,input_file:str,key_name:str,output_file:str)->int:
|
| 226 |
+
if _need_force_overwrite(output_file,cmdLine): return -1
|
| 227 |
+
|
| 228 |
+
s=SafeTensorsFile.open_file(input_file,cmdLine['quiet'])
|
| 229 |
+
if s.error!=0: return s.error
|
| 230 |
+
|
| 231 |
+
bindata=s.load_one_tensor(key_name)
|
| 232 |
+
s.close_file() #close it just in case user wants to write back to input_file itself
|
| 233 |
+
if bindata is None:
|
| 234 |
+
print(f'key "{key_name}" not found in header (key names are case-sensitive)',file=sys.stderr)
|
| 235 |
+
return -1
|
| 236 |
+
|
| 237 |
+
with open(output_file,"wb") as fo:
|
| 238 |
+
wn=fo.write(bindata)
|
| 239 |
+
if wn!=len(bindata):
|
| 240 |
+
print(f"write output file failed, tried to write {len(bindata)} bytes, only wrote {wn} bytes",file=sys.stderr)
|
| 241 |
+
return -1
|
| 242 |
+
if cmdLine['quiet']==False: print(f"{key_name} saved to {output_file}, len={wn}")
|
| 243 |
+
return 0
|