|
import os, sys, json |
|
from safetensors_file import SafeTensorsFile |
|
|
|
def _need_force_overwrite(output_file:str,cmdLine:dict) -> bool: |
|
if cmdLine["force_overwrite"]==False: |
|
if os.path.exists(output_file): |
|
print(f'output file "{output_file}" already exists, use -f flag to force overwrite',file=sys.stderr) |
|
return True |
|
return False |
|
|
|
def WriteMetadataToHeader(cmdLine:dict,in_st_file:str,in_json_file:str,output_file:str) -> int: |
|
if _need_force_overwrite(output_file,cmdLine): return -1 |
|
|
|
with open(in_json_file,"rt") as f: |
|
inmeta=json.load(f) |
|
if not "__metadata__" in inmeta: |
|
print(f"file {in_json_file} does not contain a top-level __metadata__ item",file=sys.stderr) |
|
|
|
return -2 |
|
inmeta=inmeta["__metadata__"] |
|
|
|
|
|
s=SafeTensorsFile.open_file(in_st_file) |
|
js=s.get_header() |
|
|
|
if inmeta==[]: |
|
js.pop("__metadata__",0) |
|
print("loaded __metadata__ is an empty list, output file will not contain __metadata__ in header") |
|
else: |
|
print("adding __metadata__ to header:") |
|
json.dump(inmeta,fp=sys.stdout,indent=2) |
|
if isinstance(inmeta,dict): |
|
for k in inmeta: |
|
inmeta[k]=str(inmeta[k]) |
|
else: |
|
inmeta=str(inmeta) |
|
|
|
js["__metadata__"]=inmeta |
|
print() |
|
|
|
newhdrbuf=json.dumps(js,separators=(',',':'),ensure_ascii=False).encode('utf-8') |
|
newhdrlen:int=int(len(newhdrbuf)) |
|
pad:int=((newhdrlen+7)&(~7))-newhdrlen |
|
|
|
with open(output_file,"wb") as f: |
|
f.write(int(newhdrlen+pad).to_bytes(8,'little')) |
|
f.write(newhdrbuf) |
|
if pad>0: f.write(bytearray([32]*pad)) |
|
i:int=s.copy_data_to_file(f) |
|
if i==0: |
|
print(f"file {output_file} saved successfully") |
|
else: |
|
print(f"error {i} occurred when writing to file {output_file}") |
|
return i |
|
|
|
def PrintHeader(cmdLine:dict,input_file:str) -> int: |
|
s=SafeTensorsFile.open_file(input_file,cmdLine['quiet']) |
|
js=s.get_header() |
|
|
|
|
|
|
|
|
|
|
|
firstKey=True |
|
print("{") |
|
for key in js: |
|
if firstKey: |
|
firstKey=False |
|
else: |
|
print(",") |
|
json.dump(key,fp=sys.stdout,ensure_ascii=False,separators=(',',':')) |
|
print(": ",end='') |
|
json.dump(js[key],fp=sys.stdout,ensure_ascii=False,separators=(',',':')) |
|
print("\n}") |
|
return 0 |
|
|
|
def _ParseMore(d:dict): |
|
'''Basically try to turn this: |
|
|
|
"ss_dataset_dirs":"{\"abc\": {\"n_repeats\": 2, \"img_count\": 60}}", |
|
|
|
into this: |
|
|
|
"ss_dataset_dirs":{ |
|
"abc":{ |
|
"n_repeats":2, |
|
"img_count":60 |
|
} |
|
}, |
|
|
|
''' |
|
for key in d: |
|
value=d[key] |
|
|
|
if isinstance(value,str): |
|
try: |
|
v2=json.loads(value) |
|
d[key]=v2 |
|
value=v2 |
|
except json.JSONDecodeError as e: |
|
pass |
|
if isinstance(value,dict): |
|
_ParseMore(value) |
|
|
|
def PrintMetadata(cmdLine:dict,input_file:str) -> int: |
|
with SafeTensorsFile.open_file(input_file,cmdLine['quiet']) as s: |
|
js=s.get_header() |
|
|
|
if not "__metadata__" in js: |
|
print("file header does not contain a __metadata__ item",file=sys.stderr) |
|
return -2 |
|
|
|
md=js["__metadata__"] |
|
if cmdLine['parse_more']: |
|
_ParseMore(md) |
|
json.dump({"__metadata__":md},fp=sys.stdout,ensure_ascii=False,separators=(',',':'),indent=1) |
|
return 0 |
|
|
|
def HeaderKeysToLists(cmdLine:dict,input_file:str) -> int: |
|
s=SafeTensorsFile.open_file(input_file,cmdLine['quiet']) |
|
js=s.get_header() |
|
|
|
_lora_keys:list[tuple(str,bool)]=[] |
|
for key in js: |
|
if key=='__metadata__': continue |
|
v=js[key] |
|
isScalar=False |
|
if isinstance(v,dict): |
|
if 'shape' in v: |
|
if 0==len(v['shape']): |
|
isScalar=True |
|
_lora_keys.append((key,isScalar)) |
|
_lora_keys.sort(key=lambda x:x[0]) |
|
|
|
def printkeylist(kl): |
|
firstKey=True |
|
for key in kl: |
|
if firstKey: firstKey=False |
|
else: print(",") |
|
print(key,end='') |
|
print() |
|
|
|
print("# use list to keep insertion order") |
|
print("_lora_keys:list[tuple[str,bool]]=[") |
|
printkeylist(_lora_keys) |
|
print("]") |
|
|
|
return 0 |
|
|
|
|
|
def ExtractHeader(cmdLine:dict,input_file:str,output_file:str)->int: |
|
if _need_force_overwrite(output_file,cmdLine): return -1 |
|
|
|
s=SafeTensorsFile.open_file(input_file,parseHeader=False) |
|
if s.error!=0: return s.error |
|
|
|
hdrbuf=s.hdrbuf |
|
s.close_file() |
|
with open(output_file,"wb") as fo: |
|
wn=fo.write(hdrbuf) |
|
if wn!=len(hdrbuf): |
|
print(f"write output file failed, tried to write {len(hdrbuf)} bytes, only wrote {wn} bytes",file=sys.stderr) |
|
return -1 |
|
print(f"raw header saved to file {output_file}") |
|
return 0 |
|
|
|
|
|
def _CheckLoRA_internal(s:SafeTensorsFile)->int: |
|
import lora_keys_sd15 as lora_keys |
|
js=s.get_header() |
|
set_scalar=set() |
|
set_nonscalar=set() |
|
for x in lora_keys._lora_keys: |
|
if x[1]==True: set_scalar.add(x[0]) |
|
else: set_nonscalar.add(x[0]) |
|
|
|
bad_unknowns:list[str]=[] |
|
bad_scalars:list[str]=[] |
|
bad_nonscalars:list[str]=[] |
|
for key in js: |
|
if key in set_nonscalar: |
|
if js[key]['shape']==[]: bad_nonscalars.append(key) |
|
set_nonscalar.remove(key) |
|
elif key in set_scalar: |
|
if js[key]['shape']!=[]: bad_scalars.append(key) |
|
set_scalar.remove(key) |
|
else: |
|
if "__metadata__"!=key: |
|
bad_unknowns.append(key) |
|
|
|
hasError=False |
|
|
|
if len(bad_unknowns)!=0: |
|
print("INFO: unrecognized items:") |
|
for x in bad_unknowns: print(" ",x) |
|
|
|
|
|
if len(set_scalar)>0: |
|
print("missing scalar keys:") |
|
for x in set_scalar: print(" ",x) |
|
hasError=True |
|
if len(set_nonscalar)>0: |
|
print("missing nonscalar keys:") |
|
for x in set_nonscalar: print(" ",x) |
|
hasError=True |
|
|
|
if len(bad_scalars)!=0: |
|
print("keys expected to be scalar but are nonscalar:") |
|
for x in bad_scalars: print(" ",x) |
|
hasError=True |
|
|
|
if len(bad_nonscalars)!=0: |
|
print("keys expected to be nonscalar but are scalar:") |
|
for x in bad_nonscalars: print(" ",x) |
|
hasError=True |
|
|
|
return (1 if hasError else 0) |
|
|
|
def CheckLoRA(cmdLine:dict,input_file:str)->int: |
|
s=SafeTensorsFile.open_file(input_file) |
|
i:int=_CheckLoRA_internal(s) |
|
if i==0: print("looks like an OK SD 1.x LoRA file") |
|
return 0 |
|
|
|
def ExtractData(cmdLine:dict,input_file:str,key_name:str,output_file:str)->int: |
|
if _need_force_overwrite(output_file,cmdLine): return -1 |
|
|
|
s=SafeTensorsFile.open_file(input_file,cmdLine['quiet']) |
|
if s.error!=0: return s.error |
|
|
|
bindata=s.load_one_tensor(key_name) |
|
s.close_file() |
|
if bindata is None: |
|
print(f'key "{key_name}" not found in header (key names are case-sensitive)',file=sys.stderr) |
|
return -1 |
|
|
|
with open(output_file,"wb") as fo: |
|
wn=fo.write(bindata) |
|
if wn!=len(bindata): |
|
print(f"write output file failed, tried to write {len(bindata)} bytes, only wrote {wn} bytes",file=sys.stderr) |
|
return -1 |
|
if cmdLine['quiet']==False: print(f"{key_name} saved to {output_file}, len={wn}") |
|
return 0 |
|
|