Spaces:
Sleeping
Sleeping
File size: 7,588 Bytes
8915205 659e2ba 8915205 659e2ba 8915205 659e2ba 8915205 659e2ba 8915205 659e2ba 8915205 659e2ba 8915205 659e2ba 8915205 659e2ba 8915205 659e2ba 8915205 659e2ba 8915205 659e2ba 8915205 659e2ba 8915205 659e2ba 8915205 659e2ba 8915205 659e2ba 8915205 659e2ba 8915205 659e2ba 8915205 659e2ba 8915205 659e2ba 8915205 659e2ba 8915205 659e2ba 8915205 659e2ba 8915205 659e2ba 8915205 659e2ba 8915205 659e2ba 8915205 659e2ba 8915205 659e2ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import gradio as gr
import torch
import os
import shutil
from tempfile import TemporaryDirectory
from huggingface_hub import hf_hub_download, HfApi
from safetensors.torch import save_file, load_file
from collections import defaultdict
from typing import Dict, List
# --- Logic copied from the original `convert.py` script ---
# These internal functions are necessary for correctly handling shared tensors.
# We copy them here to make the application self-contained.
# Source: https://github.com/huggingface/safetensors/blob/main/safetensors/torch.py
def _is_complete(storage):
# The UserWarning from this line can be ignored; it's expected.
return storage.size() * storage.element_size() == storage.nbytes()
def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[List[str]]:
tensors = list(state_dict.values())
storages = {tensor.storage().data_ptr(): [] for tensor in tensors}
for name, tensor in state_dict.items():
storages[tensor.storage().data_ptr()].append(name)
return [names for names in storages.values() if len(names) > 1]
def _remove_duplicate_names(
state_dict: Dict[str, torch.Tensor]
) -> Dict[str, List[str]]:
shareds = _find_shared_tensors(state_dict)
to_remove = defaultdict(list)
for shared in shareds:
complete_names = set([name for name in shared if _is_complete(state_dict[name])])
if not complete_names:
name = list(shared)[0]
state_dict[name] = state_dict[name].clone()
complete_names = {name}
keep_name = sorted(list(complete_names))[0]
for name in sorted(shared):
if name != keep_name:
to_remove[keep_name].append(name)
return to_remove
def check_file_size(sf_filename: str, pt_filename: str):
sf_size = os.stat(sf_filename).st_size
pt_size = os.stat(pt_filename).st_size
if (sf_size - pt_size) / pt_size > 0.01:
return (
f"WARNING: The converted file size ({sf_size} bytes) "
f"differs from the original ({pt_size} bytes) by more than 1%."
)
return None
def convert_file(pt_filename: str, sf_filename: str, device: str):
"""Main function to convert a single file."""
loaded = torch.load(pt_filename, map_location=device, weights_only=True)
if "state_dict" in loaded:
loaded = loaded["state_dict"]
to_removes = _remove_duplicate_names(loaded)
metadata = {"format": "pt"}
for kept_name, to_remove_group in to_removes.items():
for to_remove in to_remove_group:
if to_remove not in metadata:
metadata[to_remove] = kept_name
del loaded[to_remove]
loaded = {k: v.contiguous() for k, v in loaded.items()}
os.makedirs(os.path.dirname(sf_filename), exist_ok=True)
save_file(loaded, sf_filename, metadata=metadata)
size_warning = check_file_size(sf_filename, pt_filename)
reloaded = load_file(sf_filename)
for k in loaded:
pt_tensor = loaded[k].to("cpu")
sf_tensor = reloaded[k].to("cpu")
if not torch.equal(pt_tensor, sf_tensor):
raise RuntimeError(f"Tensors do not match for key {k}!")
return size_warning
# --- Main Gradio App Logic ---
def process_model(model_id: str, revision: str, progress=gr.Progress(track_tqdm=True)):
if not model_id:
return None, "Error: Model ID cannot be empty."
device = "cuda" if torch.cuda.is_available() else "cpu"
log_messages = [f"β
Detected device: {device.upper()}"]
try:
api = HfApi()
info = api.model_info(repo_id=model_id, revision=revision)
filenames = [s.rfilename for s in info.siblings]
except Exception as e:
return None, f"β Error: Failed to get model info for `{model_id}`.\n{e}"
files_to_convert = [f for f in filenames if f.endswith(".bin") or f.endswith(".ckpt")]
if not files_to_convert:
return None, f"βΉοΈ No .bin or .ckpt files found in model `{model_id}` for conversion."
log_messages.append(f"π Found {len(files_to_convert)} file(s) to convert: {', '.join(files_to_convert)}")
with TemporaryDirectory() as temp_dir:
temp_converted_files = []
for filename in progress.tqdm(files_to_convert, desc="Converting files"):
try:
log_messages.append(f"\nπ Downloading `{filename}`...")
pt_path = hf_hub_download(
repo_id=model_id, filename=filename, revision=revision,
cache_dir=os.path.join(temp_dir, "downloads"),
)
log_messages.append(f"π οΈ Converting `{filename}`...")
sf_filename = os.path.splitext(os.path.basename(filename))[0] + ".safetensors"
sf_path = os.path.join(temp_dir, "converted", sf_filename)
size_warning = convert_file(pt_path, sf_path, device)
if size_warning:
log_messages.append(f"β οΈ {size_warning}")
temp_converted_files.append(sf_path)
log_messages.append(f"β
Successfully converted to `{sf_filename}`")
except Exception as e:
log_messages.append(f"β Error processing file `{filename}`: {e}")
continue
if not temp_converted_files:
return None, "\n".join(log_messages) + "\n\nFailed to convert any files."
# --- KEY CHANGE ---
# Copy files from the temporary directory to a persistent (for Gradio) location
# before the directory is deleted.
persistent_files = []
for temp_path in temp_converted_files:
# shutil.copy() creates a new file that won't be deleted
persistent_path = shutil.copy(temp_path, ".")
persistent_files.append(persistent_path)
# --------------------
final_message = "\n".join(log_messages) + "\n\n" + "π All files processed successfully! Ready for download."
# Return the paths to the persistent files
return persistent_files, final_message
# --- Create Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Model Converter to `.safetensors`
This utility converts PyTorch model weights (`.bin`, `.ckpt`) from Hugging Face repositories
to the safe and fast `.safetensors` format.
**How to use:**
1. Enter the Model ID from Hugging Face (e.g., `stabilityai/stable-diffusion-2-1-base`).
2. Click the "Convert" button.
3. Wait for the process to complete and download the resulting files.
"""
)
with gr.Row():
model_id = gr.Textbox(label="Hugging Face Model ID", placeholder="e.g., runwayml/stable-diffusion-v1-5")
revision = gr.Textbox(label="Revision (branch)", value="main")
convert_button = gr.Button("Convert", variant="primary")
gr.Markdown("### Result")
log_output = gr.Markdown(value="Waiting for input...")
file_output = gr.File(label="Download Converted Files")
gr.Markdown(
"<p style='color:grey;font-size:0.8em;'>"
"<b>Note:</b> A `UserWarning: TypedStorage is deprecated` message may appear in the logs. "
"This is normal and does not affect the result."
"</p>"
)
convert_button.click(
fn=process_model,
inputs=[model_id, revision],
outputs=[file_output, log_output],
)
if __name__ == "__main__":
demo.launch() |