VOIDER's picture
Update app.py
659e2ba verified
import gradio as gr
import torch
import os
import shutil
from tempfile import TemporaryDirectory
from huggingface_hub import hf_hub_download, HfApi
from safetensors.torch import save_file, load_file
from collections import defaultdict
from typing import Dict, List
# --- Logic copied from the original `convert.py` script ---
# These internal functions are necessary for correctly handling shared tensors.
# We copy them here to make the application self-contained.
# Source: https://github.com/huggingface/safetensors/blob/main/safetensors/torch.py
def _is_complete(storage):
# The UserWarning from this line can be ignored; it's expected.
return storage.size() * storage.element_size() == storage.nbytes()
def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[List[str]]:
tensors = list(state_dict.values())
storages = {tensor.storage().data_ptr(): [] for tensor in tensors}
for name, tensor in state_dict.items():
storages[tensor.storage().data_ptr()].append(name)
return [names for names in storages.values() if len(names) > 1]
def _remove_duplicate_names(
state_dict: Dict[str, torch.Tensor]
) -> Dict[str, List[str]]:
shareds = _find_shared_tensors(state_dict)
to_remove = defaultdict(list)
for shared in shareds:
complete_names = set([name for name in shared if _is_complete(state_dict[name])])
if not complete_names:
name = list(shared)[0]
state_dict[name] = state_dict[name].clone()
complete_names = {name}
keep_name = sorted(list(complete_names))[0]
for name in sorted(shared):
if name != keep_name:
to_remove[keep_name].append(name)
return to_remove
def check_file_size(sf_filename: str, pt_filename: str):
sf_size = os.stat(sf_filename).st_size
pt_size = os.stat(pt_filename).st_size
if (sf_size - pt_size) / pt_size > 0.01:
return (
f"WARNING: The converted file size ({sf_size} bytes) "
f"differs from the original ({pt_size} bytes) by more than 1%."
)
return None
def convert_file(pt_filename: str, sf_filename: str, device: str):
"""Main function to convert a single file."""
loaded = torch.load(pt_filename, map_location=device, weights_only=True)
if "state_dict" in loaded:
loaded = loaded["state_dict"]
to_removes = _remove_duplicate_names(loaded)
metadata = {"format": "pt"}
for kept_name, to_remove_group in to_removes.items():
for to_remove in to_remove_group:
if to_remove not in metadata:
metadata[to_remove] = kept_name
del loaded[to_remove]
loaded = {k: v.contiguous() for k, v in loaded.items()}
os.makedirs(os.path.dirname(sf_filename), exist_ok=True)
save_file(loaded, sf_filename, metadata=metadata)
size_warning = check_file_size(sf_filename, pt_filename)
reloaded = load_file(sf_filename)
for k in loaded:
pt_tensor = loaded[k].to("cpu")
sf_tensor = reloaded[k].to("cpu")
if not torch.equal(pt_tensor, sf_tensor):
raise RuntimeError(f"Tensors do not match for key {k}!")
return size_warning
# --- Main Gradio App Logic ---
def process_model(model_id: str, revision: str, progress=gr.Progress(track_tqdm=True)):
if not model_id:
return None, "Error: Model ID cannot be empty."
device = "cuda" if torch.cuda.is_available() else "cpu"
log_messages = [f"βœ… Detected device: {device.upper()}"]
try:
api = HfApi()
info = api.model_info(repo_id=model_id, revision=revision)
filenames = [s.rfilename for s in info.siblings]
except Exception as e:
return None, f"❌ Error: Failed to get model info for `{model_id}`.\n{e}"
files_to_convert = [f for f in filenames if f.endswith(".bin") or f.endswith(".ckpt")]
if not files_to_convert:
return None, f"ℹ️ No .bin or .ckpt files found in model `{model_id}` for conversion."
log_messages.append(f"πŸ” Found {len(files_to_convert)} file(s) to convert: {', '.join(files_to_convert)}")
with TemporaryDirectory() as temp_dir:
temp_converted_files = []
for filename in progress.tqdm(files_to_convert, desc="Converting files"):
try:
log_messages.append(f"\nπŸš€ Downloading `{filename}`...")
pt_path = hf_hub_download(
repo_id=model_id, filename=filename, revision=revision,
cache_dir=os.path.join(temp_dir, "downloads"),
)
log_messages.append(f"πŸ› οΈ Converting `{filename}`...")
sf_filename = os.path.splitext(os.path.basename(filename))[0] + ".safetensors"
sf_path = os.path.join(temp_dir, "converted", sf_filename)
size_warning = convert_file(pt_path, sf_path, device)
if size_warning:
log_messages.append(f"⚠️ {size_warning}")
temp_converted_files.append(sf_path)
log_messages.append(f"βœ… Successfully converted to `{sf_filename}`")
except Exception as e:
log_messages.append(f"❌ Error processing file `{filename}`: {e}")
continue
if not temp_converted_files:
return None, "\n".join(log_messages) + "\n\nFailed to convert any files."
# --- KEY CHANGE ---
# Copy files from the temporary directory to a persistent (for Gradio) location
# before the directory is deleted.
persistent_files = []
for temp_path in temp_converted_files:
# shutil.copy() creates a new file that won't be deleted
persistent_path = shutil.copy(temp_path, ".")
persistent_files.append(persistent_path)
# --------------------
final_message = "\n".join(log_messages) + "\n\n" + "πŸŽ‰ All files processed successfully! Ready for download."
# Return the paths to the persistent files
return persistent_files, final_message
# --- Create Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Model Converter to `.safetensors`
This utility converts PyTorch model weights (`.bin`, `.ckpt`) from Hugging Face repositories
to the safe and fast `.safetensors` format.
**How to use:**
1. Enter the Model ID from Hugging Face (e.g., `stabilityai/stable-diffusion-2-1-base`).
2. Click the "Convert" button.
3. Wait for the process to complete and download the resulting files.
"""
)
with gr.Row():
model_id = gr.Textbox(label="Hugging Face Model ID", placeholder="e.g., runwayml/stable-diffusion-v1-5")
revision = gr.Textbox(label="Revision (branch)", value="main")
convert_button = gr.Button("Convert", variant="primary")
gr.Markdown("### Result")
log_output = gr.Markdown(value="Waiting for input...")
file_output = gr.File(label="Download Converted Files")
gr.Markdown(
"<p style='color:grey;font-size:0.8em;'>"
"<b>Note:</b> A `UserWarning: TypedStorage is deprecated` message may appear in the logs. "
"This is normal and does not affect the result."
"</p>"
)
convert_button.click(
fn=process_model,
inputs=[model_id, revision],
outputs=[file_output, log_output],
)
if __name__ == "__main__":
demo.launch()