Lora-Metadata_Editor / safetensors_worker.py
Juan Sebastian Giraldo
Upload Lora app
75cf81d
import os, sys, json
from safetensors_file import SafeTensorsFile
def _need_force_overwrite(output_file:str,cmdLine:dict) -> bool:
if cmdLine["force_overwrite"]==False:
if os.path.exists(output_file):
print(f'output file "{output_file}" already exists, use -f flag to force overwrite',file=sys.stderr)
return True
return False
def WriteMetadataToHeader(cmdLine:dict,in_st_file:str,in_json_file:str,output_file:str) -> int:
if _need_force_overwrite(output_file,cmdLine): return -1
with open(in_json_file,"rt") as f:
inmeta=json.load(f)
if not "__metadata__" in inmeta:
print(f"file {in_json_file} does not contain a top-level __metadata__ item",file=sys.stderr)
#json.dump(inmeta,fp=sys.stdout,indent=2)
return -2
inmeta=inmeta["__metadata__"] #keep only metadata
#json.dump(inmeta,fp=sys.stdout,indent=2)
s=SafeTensorsFile.open_file(in_st_file)
js=s.get_header()
if inmeta==[]:
js.pop("__metadata__",0)
print("loaded __metadata__ is an empty list, output file will not contain __metadata__ in header")
else:
print("adding __metadata__ to header:")
json.dump(inmeta,fp=sys.stdout,indent=2)
if isinstance(inmeta,dict):
for k in inmeta:
inmeta[k]=str(inmeta[k])
else:
inmeta=str(inmeta)
#js["__metadata__"]=json.dumps(inmeta,ensure_ascii=False)
js["__metadata__"]=inmeta
print()
newhdrbuf=json.dumps(js,separators=(',',':'),ensure_ascii=False).encode('utf-8')
newhdrlen:int=int(len(newhdrbuf))
pad:int=((newhdrlen+7)&(~7))-newhdrlen #pad to multiple of 8
with open(output_file,"wb") as f:
f.write(int(newhdrlen+pad).to_bytes(8,'little'))
f.write(newhdrbuf)
if pad>0: f.write(bytearray([32]*pad))
i:int=s.copy_data_to_file(f)
if i==0:
print(f"file {output_file} saved successfully")
else:
print(f"error {i} occurred when writing to file {output_file}")
return i
def PrintHeader(cmdLine:dict,input_file:str) -> int:
s=SafeTensorsFile.open_file(input_file,cmdLine['quiet'])
js=s.get_header()
# All the .safetensors files I've seen have long key names, and as a result,
# neither json nor pprint package prints text in very readable format,
# so we print it ourselves, putting key name & value on one long line.
# Note the print out is in Python format, not valid JSON format.
firstKey=True
print("{")
for key in js:
if firstKey:
firstKey=False
else:
print(",")
json.dump(key,fp=sys.stdout,ensure_ascii=False,separators=(',',':'))
print(": ",end='')
json.dump(js[key],fp=sys.stdout,ensure_ascii=False,separators=(',',':'))
print("\n}")
return 0
def _ParseMore(d:dict):
'''Basically try to turn this:
"ss_dataset_dirs":"{\"abc\": {\"n_repeats\": 2, \"img_count\": 60}}",
into this:
"ss_dataset_dirs":{
"abc":{
"n_repeats":2,
"img_count":60
}
},
'''
for key in d:
value=d[key]
#print("+++",key,value,type(value),"+++",sep='|')
if isinstance(value,str):
try:
v2=json.loads(value)
d[key]=v2
value=v2
except json.JSONDecodeError as e:
pass
if isinstance(value,dict):
_ParseMore(value)
def PrintMetadata(cmdLine:dict,input_file:str) -> int:
with SafeTensorsFile.open_file(input_file,cmdLine['quiet']) as s:
js=s.get_header()
if not "__metadata__" in js:
print("file header does not contain a __metadata__ item",file=sys.stderr)
return -2
md=js["__metadata__"]
if cmdLine['parse_more']:
_ParseMore(md)
json.dump({"__metadata__":md},fp=sys.stdout,ensure_ascii=False,separators=(',',':'),indent=1)
return 0
def HeaderKeysToLists(cmdLine:dict,input_file:str) -> int:
s=SafeTensorsFile.open_file(input_file,cmdLine['quiet'])
js=s.get_header()
_lora_keys:list[tuple(str,bool)]=[] # use list to sort by name
for key in js:
if key=='__metadata__': continue
v=js[key]
isScalar=False
if isinstance(v,dict):
if 'shape' in v:
if 0==len(v['shape']):
isScalar=True
_lora_keys.append((key,isScalar))
_lora_keys.sort(key=lambda x:x[0])
def printkeylist(kl):
firstKey=True
for key in kl:
if firstKey: firstKey=False
else: print(",")
print(key,end='')
print()
print("# use list to keep insertion order")
print("_lora_keys:list[tuple[str,bool]]=[")
printkeylist(_lora_keys)
print("]")
return 0
def ExtractHeader(cmdLine:dict,input_file:str,output_file:str)->int:
if _need_force_overwrite(output_file,cmdLine): return -1
s=SafeTensorsFile.open_file(input_file,parseHeader=False)
if s.error!=0: return s.error
hdrbuf=s.hdrbuf
s.close_file() #close it in case user wants to write back to input_file itself
with open(output_file,"wb") as fo:
wn=fo.write(hdrbuf)
if wn!=len(hdrbuf):
print(f"write output file failed, tried to write {len(hdrbuf)} bytes, only wrote {wn} bytes",file=sys.stderr)
return -1
print(f"raw header saved to file {output_file}")
return 0
def _CheckLoRA_internal(s:SafeTensorsFile)->int:
import lora_keys_sd15 as lora_keys
js=s.get_header()
set_scalar=set()
set_nonscalar=set()
for x in lora_keys._lora_keys:
if x[1]==True: set_scalar.add(x[0])
else: set_nonscalar.add(x[0])
bad_unknowns:list[str]=[] # unrecognized keys
bad_scalars:list[str]=[] #bad scalar
bad_nonscalars:list[str]=[] #bad nonscalar
for key in js:
if key in set_nonscalar:
if js[key]['shape']==[]: bad_nonscalars.append(key)
set_nonscalar.remove(key)
elif key in set_scalar:
if js[key]['shape']!=[]: bad_scalars.append(key)
set_scalar.remove(key)
else:
if "__metadata__"!=key:
bad_unknowns.append(key)
hasError=False
if len(bad_unknowns)!=0:
print("INFO: unrecognized items:")
for x in bad_unknowns: print(" ",x)
#hasError=True
if len(set_scalar)>0:
print("missing scalar keys:")
for x in set_scalar: print(" ",x)
hasError=True
if len(set_nonscalar)>0:
print("missing nonscalar keys:")
for x in set_nonscalar: print(" ",x)
hasError=True
if len(bad_scalars)!=0:
print("keys expected to be scalar but are nonscalar:")
for x in bad_scalars: print(" ",x)
hasError=True
if len(bad_nonscalars)!=0:
print("keys expected to be nonscalar but are scalar:")
for x in bad_nonscalars: print(" ",x)
hasError=True
return (1 if hasError else 0)
def CheckLoRA(cmdLine:dict,input_file:str)->int:
s=SafeTensorsFile.open_file(input_file)
i:int=_CheckLoRA_internal(s)
if i==0: print("looks like an OK SD 1.x LoRA file")
return 0
def ExtractData(cmdLine:dict,input_file:str,key_name:str,output_file:str)->int:
if _need_force_overwrite(output_file,cmdLine): return -1
s=SafeTensorsFile.open_file(input_file,cmdLine['quiet'])
if s.error!=0: return s.error
bindata=s.load_one_tensor(key_name)
s.close_file() #close it just in case user wants to write back to input_file itself
if bindata is None:
print(f'key "{key_name}" not found in header (key names are case-sensitive)',file=sys.stderr)
return -1
with open(output_file,"wb") as fo:
wn=fo.write(bindata)
if wn!=len(bindata):
print(f"write output file failed, tried to write {len(bindata)} bytes, only wrote {wn} bytes",file=sys.stderr)
return -1
if cmdLine['quiet']==False: print(f"{key_name} saved to {output_file}, len={wn}")
return 0