Spaces:
Running
Running
import gradio as gr | |
import torch | |
import os | |
import shutil | |
from tempfile import TemporaryDirectory | |
from huggingface_hub import hf_hub_download, HfApi | |
from safetensors.torch import save_file, load_file | |
from collections import defaultdict | |
from typing import Dict, List | |
# --- Logic copied from the original `convert.py` script --- | |
# These internal functions are necessary for correctly handling shared tensors. | |
# We copy them here to make the application self-contained. | |
# Source: https://github.com/huggingface/safetensors/blob/main/safetensors/torch.py | |
def _is_complete(storage): | |
# The UserWarning from this line can be ignored; it's expected. | |
return storage.size() * storage.element_size() == storage.nbytes() | |
def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[List[str]]: | |
tensors = list(state_dict.values()) | |
storages = {tensor.storage().data_ptr(): [] for tensor in tensors} | |
for name, tensor in state_dict.items(): | |
storages[tensor.storage().data_ptr()].append(name) | |
return [names for names in storages.values() if len(names) > 1] | |
def _remove_duplicate_names( | |
state_dict: Dict[str, torch.Tensor] | |
) -> Dict[str, List[str]]: | |
shareds = _find_shared_tensors(state_dict) | |
to_remove = defaultdict(list) | |
for shared in shareds: | |
complete_names = set([name for name in shared if _is_complete(state_dict[name])]) | |
if not complete_names: | |
name = list(shared)[0] | |
state_dict[name] = state_dict[name].clone() | |
complete_names = {name} | |
keep_name = sorted(list(complete_names))[0] | |
for name in sorted(shared): | |
if name != keep_name: | |
to_remove[keep_name].append(name) | |
return to_remove | |
def check_file_size(sf_filename: str, pt_filename: str): | |
sf_size = os.stat(sf_filename).st_size | |
pt_size = os.stat(pt_filename).st_size | |
if (sf_size - pt_size) / pt_size > 0.01: | |
return ( | |
f"WARNING: The converted file size ({sf_size} bytes) " | |
f"differs from the original ({pt_size} bytes) by more than 1%." | |
) | |
return None | |
def convert_file(pt_filename: str, sf_filename: str, device: str): | |
"""Main function to convert a single file.""" | |
loaded = torch.load(pt_filename, map_location=device, weights_only=True) | |
if "state_dict" in loaded: | |
loaded = loaded["state_dict"] | |
to_removes = _remove_duplicate_names(loaded) | |
metadata = {"format": "pt"} | |
for kept_name, to_remove_group in to_removes.items(): | |
for to_remove in to_remove_group: | |
if to_remove not in metadata: | |
metadata[to_remove] = kept_name | |
del loaded[to_remove] | |
loaded = {k: v.contiguous() for k, v in loaded.items()} | |
os.makedirs(os.path.dirname(sf_filename), exist_ok=True) | |
save_file(loaded, sf_filename, metadata=metadata) | |
size_warning = check_file_size(sf_filename, pt_filename) | |
reloaded = load_file(sf_filename) | |
for k in loaded: | |
pt_tensor = loaded[k].to("cpu") | |
sf_tensor = reloaded[k].to("cpu") | |
if not torch.equal(pt_tensor, sf_tensor): | |
raise RuntimeError(f"Tensors do not match for key {k}!") | |
return size_warning | |
# --- Main Gradio App Logic --- | |
def process_model(model_id: str, revision: str, progress=gr.Progress(track_tqdm=True)): | |
if not model_id: | |
return None, "Error: Model ID cannot be empty." | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
log_messages = [f"β Detected device: {device.upper()}"] | |
try: | |
api = HfApi() | |
info = api.model_info(repo_id=model_id, revision=revision) | |
filenames = [s.rfilename for s in info.siblings] | |
except Exception as e: | |
return None, f"β Error: Failed to get model info for `{model_id}`.\n{e}" | |
files_to_convert = [f for f in filenames if f.endswith(".bin") or f.endswith(".ckpt")] | |
if not files_to_convert: | |
return None, f"βΉοΈ No .bin or .ckpt files found in model `{model_id}` for conversion." | |
log_messages.append(f"π Found {len(files_to_convert)} file(s) to convert: {', '.join(files_to_convert)}") | |
with TemporaryDirectory() as temp_dir: | |
temp_converted_files = [] | |
for filename in progress.tqdm(files_to_convert, desc="Converting files"): | |
try: | |
log_messages.append(f"\nπ Downloading `{filename}`...") | |
pt_path = hf_hub_download( | |
repo_id=model_id, filename=filename, revision=revision, | |
cache_dir=os.path.join(temp_dir, "downloads"), | |
) | |
log_messages.append(f"π οΈ Converting `{filename}`...") | |
sf_filename = os.path.splitext(os.path.basename(filename))[0] + ".safetensors" | |
sf_path = os.path.join(temp_dir, "converted", sf_filename) | |
size_warning = convert_file(pt_path, sf_path, device) | |
if size_warning: | |
log_messages.append(f"β οΈ {size_warning}") | |
temp_converted_files.append(sf_path) | |
log_messages.append(f"β Successfully converted to `{sf_filename}`") | |
except Exception as e: | |
log_messages.append(f"β Error processing file `{filename}`: {e}") | |
continue | |
if not temp_converted_files: | |
return None, "\n".join(log_messages) + "\n\nFailed to convert any files." | |
# --- KEY CHANGE --- | |
# Copy files from the temporary directory to a persistent (for Gradio) location | |
# before the directory is deleted. | |
persistent_files = [] | |
for temp_path in temp_converted_files: | |
# shutil.copy() creates a new file that won't be deleted | |
persistent_path = shutil.copy(temp_path, ".") | |
persistent_files.append(persistent_path) | |
# -------------------- | |
final_message = "\n".join(log_messages) + "\n\n" + "π All files processed successfully! Ready for download." | |
# Return the paths to the persistent files | |
return persistent_files, final_message | |
# --- Create Gradio Interface --- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# Model Converter to `.safetensors` | |
This utility converts PyTorch model weights (`.bin`, `.ckpt`) from Hugging Face repositories | |
to the safe and fast `.safetensors` format. | |
**How to use:** | |
1. Enter the Model ID from Hugging Face (e.g., `stabilityai/stable-diffusion-2-1-base`). | |
2. Click the "Convert" button. | |
3. Wait for the process to complete and download the resulting files. | |
""" | |
) | |
with gr.Row(): | |
model_id = gr.Textbox(label="Hugging Face Model ID", placeholder="e.g., runwayml/stable-diffusion-v1-5") | |
revision = gr.Textbox(label="Revision (branch)", value="main") | |
convert_button = gr.Button("Convert", variant="primary") | |
gr.Markdown("### Result") | |
log_output = gr.Markdown(value="Waiting for input...") | |
file_output = gr.File(label="Download Converted Files") | |
gr.Markdown( | |
"<p style='color:grey;font-size:0.8em;'>" | |
"<b>Note:</b> A `UserWarning: TypedStorage is deprecated` message may appear in the logs. " | |
"This is normal and does not affect the result." | |
"</p>" | |
) | |
convert_button.click( | |
fn=process_model, | |
inputs=[model_id, revision], | |
outputs=[file_output, log_output], | |
) | |
if __name__ == "__main__": | |
demo.launch() |