import gradio as gr import torch import os import shutil from tempfile import TemporaryDirectory from huggingface_hub import hf_hub_download, HfApi from safetensors.torch import save_file, load_file from collections import defaultdict from typing import Dict, List # --- Logic copied from the original `convert.py` script --- # These internal functions are necessary for correctly handling shared tensors. # We copy them here to make the application self-contained. # Source: https://github.com/huggingface/safetensors/blob/main/safetensors/torch.py def _is_complete(storage): # The UserWarning from this line can be ignored; it's expected. return storage.size() * storage.element_size() == storage.nbytes() def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[List[str]]: tensors = list(state_dict.values()) storages = {tensor.storage().data_ptr(): [] for tensor in tensors} for name, tensor in state_dict.items(): storages[tensor.storage().data_ptr()].append(name) return [names for names in storages.values() if len(names) > 1] def _remove_duplicate_names( state_dict: Dict[str, torch.Tensor] ) -> Dict[str, List[str]]: shareds = _find_shared_tensors(state_dict) to_remove = defaultdict(list) for shared in shareds: complete_names = set([name for name in shared if _is_complete(state_dict[name])]) if not complete_names: name = list(shared)[0] state_dict[name] = state_dict[name].clone() complete_names = {name} keep_name = sorted(list(complete_names))[0] for name in sorted(shared): if name != keep_name: to_remove[keep_name].append(name) return to_remove def check_file_size(sf_filename: str, pt_filename: str): sf_size = os.stat(sf_filename).st_size pt_size = os.stat(pt_filename).st_size if (sf_size - pt_size) / pt_size > 0.01: return ( f"WARNING: The converted file size ({sf_size} bytes) " f"differs from the original ({pt_size} bytes) by more than 1%." ) return None def convert_file(pt_filename: str, sf_filename: str, device: str): """Main function to convert a single file.""" loaded = torch.load(pt_filename, map_location=device, weights_only=True) if "state_dict" in loaded: loaded = loaded["state_dict"] to_removes = _remove_duplicate_names(loaded) metadata = {"format": "pt"} for kept_name, to_remove_group in to_removes.items(): for to_remove in to_remove_group: if to_remove not in metadata: metadata[to_remove] = kept_name del loaded[to_remove] loaded = {k: v.contiguous() for k, v in loaded.items()} os.makedirs(os.path.dirname(sf_filename), exist_ok=True) save_file(loaded, sf_filename, metadata=metadata) size_warning = check_file_size(sf_filename, pt_filename) reloaded = load_file(sf_filename) for k in loaded: pt_tensor = loaded[k].to("cpu") sf_tensor = reloaded[k].to("cpu") if not torch.equal(pt_tensor, sf_tensor): raise RuntimeError(f"Tensors do not match for key {k}!") return size_warning # --- Main Gradio App Logic --- def process_model(model_id: str, revision: str, progress=gr.Progress(track_tqdm=True)): if not model_id: return None, "Error: Model ID cannot be empty." device = "cuda" if torch.cuda.is_available() else "cpu" log_messages = [f"ā Detected device: {device.upper()}"] try: api = HfApi() info = api.model_info(repo_id=model_id, revision=revision) filenames = [s.rfilename for s in info.siblings] except Exception as e: return None, f"ā Error: Failed to get model info for `{model_id}`.\n{e}" files_to_convert = [f for f in filenames if f.endswith(".bin") or f.endswith(".ckpt")] if not files_to_convert: return None, f"ā¹ļø No .bin or .ckpt files found in model `{model_id}` for conversion." log_messages.append(f"š Found {len(files_to_convert)} file(s) to convert: {', '.join(files_to_convert)}") with TemporaryDirectory() as temp_dir: temp_converted_files = [] for filename in progress.tqdm(files_to_convert, desc="Converting files"): try: log_messages.append(f"\nš Downloading `{filename}`...") pt_path = hf_hub_download( repo_id=model_id, filename=filename, revision=revision, cache_dir=os.path.join(temp_dir, "downloads"), ) log_messages.append(f"š ļø Converting `{filename}`...") sf_filename = os.path.splitext(os.path.basename(filename))[0] + ".safetensors" sf_path = os.path.join(temp_dir, "converted", sf_filename) size_warning = convert_file(pt_path, sf_path, device) if size_warning: log_messages.append(f"ā ļø {size_warning}") temp_converted_files.append(sf_path) log_messages.append(f"ā Successfully converted to `{sf_filename}`") except Exception as e: log_messages.append(f"ā Error processing file `{filename}`: {e}") continue if not temp_converted_files: return None, "\n".join(log_messages) + "\n\nFailed to convert any files." # --- KEY CHANGE --- # Copy files from the temporary directory to a persistent (for Gradio) location # before the directory is deleted. persistent_files = [] for temp_path in temp_converted_files: # shutil.copy() creates a new file that won't be deleted persistent_path = shutil.copy(temp_path, ".") persistent_files.append(persistent_path) # -------------------- final_message = "\n".join(log_messages) + "\n\n" + "š All files processed successfully! Ready for download." # Return the paths to the persistent files return persistent_files, final_message # --- Create Gradio Interface --- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # Model Converter to `.safetensors` This utility converts PyTorch model weights (`.bin`, `.ckpt`) from Hugging Face repositories to the safe and fast `.safetensors` format. **How to use:** 1. Enter the Model ID from Hugging Face (e.g., `stabilityai/stable-diffusion-2-1-base`). 2. Click the "Convert" button. 3. Wait for the process to complete and download the resulting files. """ ) with gr.Row(): model_id = gr.Textbox(label="Hugging Face Model ID", placeholder="e.g., runwayml/stable-diffusion-v1-5") revision = gr.Textbox(label="Revision (branch)", value="main") convert_button = gr.Button("Convert", variant="primary") gr.Markdown("### Result") log_output = gr.Markdown(value="Waiting for input...") file_output = gr.File(label="Download Converted Files") gr.Markdown( "
" "Note: A `UserWarning: TypedStorage is deprecated` message may appear in the logs. " "This is normal and does not affect the result." "
" ) convert_button.click( fn=process_model, inputs=[model_id, revision], outputs=[file_output, log_output], ) if __name__ == "__main__": demo.launch()