Spaces:
Sleeping
Sleeping
import gradio as gr | |
import json | |
import sys | |
import io | |
import subprocess | |
import tempfile | |
from pathlib import Path | |
from safetensors_worker import PrintMetadata | |
class Context: | |
def __init__(self): | |
self.obj = {'quiet': True, 'parse_more': True} | |
ctx = Context() | |
def debug_log(message: str): | |
print(f"[DEBUG] {message}") | |
def load_metadata(file_path: str) -> tuple: | |
try: | |
debug_log(f"Loading file: {file_path}") | |
if not file_path: | |
return {"status": "Awaiting input"}, {}, "", "", "" | |
old_stdout = sys.stdout | |
sys.stdout = buffer = io.StringIO() | |
exit_code = PrintMetadata(ctx.obj, file_path.name) | |
sys.stdout = old_stdout | |
metadata_str = buffer.getvalue().strip() | |
if exit_code != 0: | |
error_msg = f"Error code {exit_code}" | |
return {"error": error_msg}, {}, "", error_msg, "" | |
try: | |
full_metadata = json.loads(metadata_str) | |
except json.JSONDecodeError: | |
error_msg = "Invalid metadata structure" | |
return {"error": error_msg}, {}, "", error_msg, "" | |
training_params = full_metadata.get("__metadata__", {}) | |
key_metrics = { | |
key: training_params.get(key, "N/A") | |
for key in [ | |
"ss_optimizer", "ss_num_epochs", "ss_unet_lr", | |
"ss_text_encoder_lr", "ss_steps" | |
] | |
} | |
return full_metadata, key_metrics, json.dumps(full_metadata, indent=2), "", file_path.name | |
except Exception as e: | |
return {"error": str(e)}, {}, "", str(e), "" | |
def validate_json(edited_json: str) -> tuple: | |
try: | |
return True, json.loads(edited_json), "" | |
except Exception as e: | |
return False, None, str(e) | |
def update_metadata(edited_json: str) -> tuple: | |
try: | |
modified_data = json.loads(edited_json) | |
metadata = modified_data.get("__metadata__", {}) | |
key_fields = { | |
param: metadata.get(param, "N/A") | |
for param in [ | |
"ss_optimizer", "ss_num_epochs", "ss_unet_lr", | |
"ss_text_encoder_lr", "ss_steps" | |
] | |
} | |
return key_fields, modified_data, "" | |
except: | |
return gr.update(), gr.update(), "" | |
def save_metadata(edited_json: str, source_file: str, output_name: str) -> tuple: | |
debug_log("Initiating save process") | |
try: | |
if not source_file: | |
return None, "No source file provided" | |
is_valid, parsed_data, error = validate_json(edited_json) | |
if not is_valid: | |
return None, f"Validation error: {error}" | |
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp: | |
json.dump(parsed_data, tmp, indent=2) | |
temp_path = tmp.name | |
source_path = Path(source_file) | |
if output_name.strip(): | |
base_name = output_name.strip() | |
if not base_name.endswith(".safetensors"): | |
base_name += ".safetensors" | |
else: | |
base_name = f"{source_path.stem}_modified.safetensors" | |
output_path = Path(base_name) | |
version = 1 | |
while output_path.exists(): | |
output_path = Path(f"{source_path.stem}_modified_{version}.safetensors") | |
version += 1 | |
cmd = [ | |
sys.executable, | |
"safetensors_util.py", | |
"writemd", | |
source_file, | |
temp_path, | |
str(output_path), | |
"-f" | |
] | |
result = subprocess.run( | |
cmd, | |
capture_output=True, | |
text=True, | |
check=False | |
) | |
Path(temp_path).unlink(missing_ok=True) | |
if result.returncode != 0: | |
error_msg = f"Save failure: {result.stderr}" | |
return None, error_msg | |
return str(output_path), "" | |
except Exception as e: | |
return None, f"Critical error: {str(e)}" | |
def create_interface(): | |
with gr.Blocks(title="LoRA Metadata Editor") as app: | |
gr.Markdown("# LoRA Metadata Editor") | |
with gr.Tabs(): | |
with gr.Tab("Metdata Viewer"): | |
gr.Markdown("### LoRa Upload") | |
file_input = gr.File( | |
file_types=[".safetensors"], | |
show_label=False | |
) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Full Metadata") | |
full_viewer = gr.JSON(show_label=False) | |
with gr.Column(): | |
gr.Markdown("### Key Metrics") | |
key_viewer = gr.JSON(show_label=False) | |
with gr.Tab("Edit Metadata"): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### JSON Workspace") | |
metadata_editor = gr.Textbox( | |
lines=25, | |
show_label=False, | |
placeholder="Edit metadata JSON here" | |
) | |
gr.Markdown("### Output Name") | |
filename_input = gr.Textbox( | |
placeholder="Leave empty for auto-naming", | |
show_label=False | |
) | |
with gr.Column(): | |
gr.Markdown("### Live Preview") | |
modified_viewer = gr.JSON(show_label=False) | |
save_btn = gr.Button("💾 Save Metadata", variant="primary") | |
gr.Markdown("### Download Modified LoRa") | |
output_file = gr.File( | |
visible=False, | |
show_label=False | |
) | |
status_display = gr.HTML(visible=False) | |
source_tracker = gr.State() | |
file_input.upload( | |
load_metadata, | |
inputs=file_input, | |
outputs=[full_viewer, key_viewer, metadata_editor, status_display, source_tracker] | |
) | |
metadata_editor.change( | |
update_metadata, | |
inputs=metadata_editor, | |
outputs=[key_viewer, modified_viewer, status_display] | |
) | |
save_btn.click( | |
save_metadata, | |
inputs=[metadata_editor, source_tracker, filename_input], | |
outputs=[output_file, status_display], | |
).then( | |
lambda x: gr.File(value=x, visible=True), | |
inputs=output_file, | |
outputs=output_file | |
) | |
return app | |
if __name__ == "__main__": | |
interface = create_interface() | |
interface.launch() |