File size: 7,588 Bytes
8915205
 
 
659e2ba
8915205
 
 
 
 
 
659e2ba
 
 
 
8915205
 
659e2ba
8915205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659e2ba
 
8915205
 
 
 
659e2ba
8915205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659e2ba
8915205
 
 
 
659e2ba
8915205
 
 
659e2ba
8915205
 
659e2ba
8915205
 
 
 
 
 
659e2ba
8915205
 
 
659e2ba
8915205
659e2ba
8915205
 
659e2ba
 
8915205
659e2ba
8915205
659e2ba
8915205
 
 
659e2ba
 
8915205
 
 
 
 
 
659e2ba
 
8915205
659e2ba
8915205
 
659e2ba
 
8915205
659e2ba
 
 
 
 
 
 
 
 
 
 
 
 
8915205
 
659e2ba
8915205
 
 
 
659e2ba
 
 
8915205
659e2ba
 
 
 
8915205
 
 
659e2ba
 
8915205
659e2ba
8915205
659e2ba
 
 
 
 
 
 
 
 
 
8915205
 
 
 
 
 
 
 
659e2ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import gradio as gr
import torch
import os
import shutil
from tempfile import TemporaryDirectory
from huggingface_hub import hf_hub_download, HfApi
from safetensors.torch import save_file, load_file
from collections import defaultdict
from typing import Dict, List

# --- Logic copied from the original `convert.py` script ---
# These internal functions are necessary for correctly handling shared tensors.
# We copy them here to make the application self-contained.
# Source: https://github.com/huggingface/safetensors/blob/main/safetensors/torch.py

def _is_complete(storage):
    # The UserWarning from this line can be ignored; it's expected.
    return storage.size() * storage.element_size() == storage.nbytes()

def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[List[str]]:
    tensors = list(state_dict.values())
    storages = {tensor.storage().data_ptr(): [] for tensor in tensors}
    for name, tensor in state_dict.items():
        storages[tensor.storage().data_ptr()].append(name)
    return [names for names in storages.values() if len(names) > 1]

def _remove_duplicate_names(
    state_dict: Dict[str, torch.Tensor]
) -> Dict[str, List[str]]:
    shareds = _find_shared_tensors(state_dict)
    to_remove = defaultdict(list)
    for shared in shareds:
        complete_names = set([name for name in shared if _is_complete(state_dict[name])])
        if not complete_names:
            name = list(shared)[0]
            state_dict[name] = state_dict[name].clone()
            complete_names = {name}

        keep_name = sorted(list(complete_names))[0]
        
        for name in sorted(shared):
            if name != keep_name:
                to_remove[keep_name].append(name)
    return to_remove

def check_file_size(sf_filename: str, pt_filename: str):
    sf_size = os.stat(sf_filename).st_size
    pt_size = os.stat(pt_filename).st_size
    if (sf_size - pt_size) / pt_size > 0.01:
        return (
            f"WARNING: The converted file size ({sf_size} bytes) "
            f"differs from the original ({pt_size} bytes) by more than 1%."
        )
    return None

def convert_file(pt_filename: str, sf_filename: str, device: str):
    """Main function to convert a single file."""
    loaded = torch.load(pt_filename, map_location=device, weights_only=True)
    if "state_dict" in loaded:
        loaded = loaded["state_dict"]

    to_removes = _remove_duplicate_names(loaded)
    metadata = {"format": "pt"}
    for kept_name, to_remove_group in to_removes.items():
        for to_remove in to_remove_group:
            if to_remove not in metadata:
                metadata[to_remove] = kept_name
            del loaded[to_remove]

    loaded = {k: v.contiguous() for k, v in loaded.items()}
    os.makedirs(os.path.dirname(sf_filename), exist_ok=True)
    save_file(loaded, sf_filename, metadata=metadata)
    
    size_warning = check_file_size(sf_filename, pt_filename)
    
    reloaded = load_file(sf_filename)
    for k in loaded:
        pt_tensor = loaded[k].to("cpu")
        sf_tensor = reloaded[k].to("cpu")
        if not torch.equal(pt_tensor, sf_tensor):
            raise RuntimeError(f"Tensors do not match for key {k}!")
            
    return size_warning


# --- Main Gradio App Logic ---

def process_model(model_id: str, revision: str, progress=gr.Progress(track_tqdm=True)):
    if not model_id:
        return None, "Error: Model ID cannot be empty."

    device = "cuda" if torch.cuda.is_available() else "cpu"
    log_messages = [f"βœ… Detected device: {device.upper()}"]

    try:
        api = HfApi()
        info = api.model_info(repo_id=model_id, revision=revision)
        filenames = [s.rfilename for s in info.siblings]
    except Exception as e:
        return None, f"❌ Error: Failed to get model info for `{model_id}`.\n{e}"

    files_to_convert = [f for f in filenames if f.endswith(".bin") or f.endswith(".ckpt")]
    if not files_to_convert:
        return None, f"ℹ️ No .bin or .ckpt files found in model `{model_id}` for conversion."

    log_messages.append(f"πŸ” Found {len(files_to_convert)} file(s) to convert: {', '.join(files_to_convert)}")

    with TemporaryDirectory() as temp_dir:
        temp_converted_files = []
        for filename in progress.tqdm(files_to_convert, desc="Converting files"):
            try:
                log_messages.append(f"\nπŸš€ Downloading `{filename}`...")
                pt_path = hf_hub_download(
                    repo_id=model_id, filename=filename, revision=revision,
                    cache_dir=os.path.join(temp_dir, "downloads"),
                )
                
                log_messages.append(f"πŸ› οΈ Converting `{filename}`...")
                sf_filename = os.path.splitext(os.path.basename(filename))[0] + ".safetensors"
                sf_path = os.path.join(temp_dir, "converted", sf_filename)
                
                size_warning = convert_file(pt_path, sf_path, device)
                if size_warning:
                    log_messages.append(f"⚠️ {size_warning}")

                temp_converted_files.append(sf_path)
                log_messages.append(f"βœ… Successfully converted to `{sf_filename}`")
            except Exception as e:
                log_messages.append(f"❌ Error processing file `{filename}`: {e}")
                continue
        
        if not temp_converted_files:
            return None, "\n".join(log_messages) + "\n\nFailed to convert any files."
            
        # --- KEY CHANGE ---
        # Copy files from the temporary directory to a persistent (for Gradio) location
        # before the directory is deleted.
        persistent_files = []
        for temp_path in temp_converted_files:
            # shutil.copy() creates a new file that won't be deleted
            persistent_path = shutil.copy(temp_path, ".")
            persistent_files.append(persistent_path)
        # --------------------
        
        final_message = "\n".join(log_messages) + "\n\n" + "πŸŽ‰ All files processed successfully! Ready for download."
        # Return the paths to the persistent files
        return persistent_files, final_message


# --- Create Gradio Interface ---

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # Model Converter to `.safetensors`
        This utility converts PyTorch model weights (`.bin`, `.ckpt`) from Hugging Face repositories 
        to the safe and fast `.safetensors` format.
        
        **How to use:**
        1. Enter the Model ID from Hugging Face (e.g., `stabilityai/stable-diffusion-2-1-base`).
        2. Click the "Convert" button.
        3. Wait for the process to complete and download the resulting files.
        """
    )
    with gr.Row():
        model_id = gr.Textbox(label="Hugging Face Model ID", placeholder="e.g., runwayml/stable-diffusion-v1-5")
        revision = gr.Textbox(label="Revision (branch)", value="main")
    
    convert_button = gr.Button("Convert", variant="primary")
    
    gr.Markdown("### Result")
    log_output = gr.Markdown(value="Waiting for input...")
    file_output = gr.File(label="Download Converted Files")

    gr.Markdown(
        "<p style='color:grey;font-size:0.8em;'>"
        "<b>Note:</b> A `UserWarning: TypedStorage is deprecated` message may appear in the logs. "
        "This is normal and does not affect the result."
        "</p>"
    )

    convert_button.click(
        fn=process_model,
        inputs=[model_id, revision],
        outputs=[file_output, log_output],
    )

if __name__ == "__main__":
    demo.launch()