Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,27 +1,27 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import os
|
|
|
4 |
from tempfile import TemporaryDirectory
|
5 |
from huggingface_hub import hf_hub_download, HfApi
|
6 |
from safetensors.torch import save_file, load_file
|
7 |
from collections import defaultdict
|
8 |
from typing import Dict, List
|
9 |
|
10 |
-
# ---
|
11 |
-
#
|
12 |
-
#
|
13 |
-
#
|
14 |
|
15 |
def _is_complete(storage):
|
|
|
16 |
return storage.size() * storage.element_size() == storage.nbytes()
|
17 |
|
18 |
def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[List[str]]:
|
19 |
tensors = list(state_dict.values())
|
20 |
-
# Can't handle unpickled storages
|
21 |
storages = {tensor.storage().data_ptr(): [] for tensor in tensors}
|
22 |
for name, tensor in state_dict.items():
|
23 |
storages[tensor.storage().data_ptr()].append(name)
|
24 |
-
# Return only tensors that share storage
|
25 |
return [names for names in storages.values() if len(names) > 1]
|
26 |
|
27 |
def _remove_duplicate_names(
|
@@ -32,9 +32,6 @@ def _remove_duplicate_names(
|
|
32 |
for shared in shareds:
|
33 |
complete_names = set([name for name in shared if _is_complete(state_dict[name])])
|
34 |
if not complete_names:
|
35 |
-
# Fallback for very weird cases.
|
36 |
-
# The model is likely to be incorrect after this
|
37 |
-
# but it will be loadable.
|
38 |
name = list(shared)[0]
|
39 |
state_dict[name] = state_dict[name].clone()
|
40 |
complete_names = {name}
|
@@ -50,15 +47,14 @@ def check_file_size(sf_filename: str, pt_filename: str):
|
|
50 |
sf_size = os.stat(sf_filename).st_size
|
51 |
pt_size = os.stat(pt_filename).st_size
|
52 |
if (sf_size - pt_size) / pt_size > 0.01:
|
53 |
-
# Не бросаем ошибку, а возвращаем предупреждение
|
54 |
return (
|
55 |
-
f"
|
56 |
-
f"
|
57 |
)
|
58 |
return None
|
59 |
|
60 |
def convert_file(pt_filename: str, sf_filename: str, device: str):
|
61 |
-
"""
|
62 |
loaded = torch.load(pt_filename, map_location=device, weights_only=True)
|
63 |
if "state_dict" in loaded:
|
64 |
loaded = loaded["state_dict"]
|
@@ -72,110 +68,116 @@ def convert_file(pt_filename: str, sf_filename: str, device: str):
|
|
72 |
del loaded[to_remove]
|
73 |
|
74 |
loaded = {k: v.contiguous() for k, v in loaded.items()}
|
75 |
-
|
76 |
os.makedirs(os.path.dirname(sf_filename), exist_ok=True)
|
77 |
save_file(loaded, sf_filename, metadata=metadata)
|
78 |
|
79 |
size_warning = check_file_size(sf_filename, pt_filename)
|
80 |
|
81 |
-
# Проверка на корректность
|
82 |
reloaded = load_file(sf_filename)
|
83 |
for k in loaded:
|
84 |
pt_tensor = loaded[k].to("cpu")
|
85 |
sf_tensor = reloaded[k].to("cpu")
|
86 |
if not torch.equal(pt_tensor, sf_tensor):
|
87 |
-
raise RuntimeError(f"
|
88 |
|
89 |
return size_warning
|
90 |
|
91 |
|
92 |
-
# ---
|
93 |
|
94 |
def process_model(model_id: str, revision: str, progress=gr.Progress(track_tqdm=True)):
|
95 |
-
"""
|
96 |
-
Скачивает, конвертирует и возвращает пути к файлам `.safetensors`.
|
97 |
-
"""
|
98 |
if not model_id:
|
99 |
-
return None, "
|
100 |
|
101 |
-
# 1. Определяем устройство (GPU или CPU)
|
102 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
103 |
-
log_messages = [f"✅
|
104 |
|
105 |
try:
|
106 |
api = HfApi()
|
107 |
info = api.model_info(repo_id=model_id, revision=revision)
|
108 |
filenames = [s.rfilename for s in info.siblings]
|
109 |
except Exception as e:
|
110 |
-
return None, f"❌
|
111 |
|
112 |
-
# Ищем файлы для конвертации
|
113 |
files_to_convert = [f for f in filenames if f.endswith(".bin") or f.endswith(".ckpt")]
|
114 |
if not files_to_convert:
|
115 |
-
return None, f"ℹ️
|
116 |
|
117 |
-
log_messages.append(f"🔍
|
118 |
|
119 |
-
# Используем временную директорию для чистоты
|
120 |
with TemporaryDirectory() as temp_dir:
|
121 |
-
|
122 |
-
for filename in progress.tqdm(files_to_convert, desc="
|
123 |
try:
|
124 |
-
|
125 |
-
log_messages.append(f"\n🚀 Скачивание `{filename}`...")
|
126 |
pt_path = hf_hub_download(
|
127 |
-
repo_id=model_id,
|
128 |
-
filename=filename,
|
129 |
-
revision=revision,
|
130 |
cache_dir=os.path.join(temp_dir, "downloads"),
|
131 |
)
|
132 |
|
133 |
-
|
134 |
-
|
135 |
-
sf_filename = os.path.splitext(filename)[0] + ".safetensors"
|
136 |
sf_path = os.path.join(temp_dir, "converted", sf_filename)
|
137 |
|
138 |
size_warning = convert_file(pt_path, sf_path, device)
|
139 |
if size_warning:
|
140 |
log_messages.append(f"⚠️ {size_warning}")
|
141 |
|
142 |
-
|
143 |
-
log_messages.append(f"✅
|
144 |
except Exception as e:
|
145 |
-
log_messages.append(f"❌
|
146 |
continue
|
147 |
|
148 |
-
if not
|
149 |
-
return None, "\n".join(log_messages) + "\n\
|
150 |
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
|
155 |
-
# ---
|
156 |
|
157 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
158 |
gr.Markdown(
|
159 |
"""
|
160 |
-
#
|
161 |
-
|
162 |
-
|
163 |
|
164 |
-
|
165 |
-
1.
|
166 |
-
2.
|
167 |
-
3.
|
168 |
"""
|
169 |
)
|
170 |
with gr.Row():
|
171 |
-
model_id = gr.Textbox(label="
|
172 |
-
revision = gr.Textbox(label="
|
173 |
|
174 |
-
convert_button = gr.Button("
|
175 |
|
176 |
-
gr.Markdown("###
|
177 |
-
log_output = gr.Markdown(value="
|
178 |
-
file_output = gr.File(label="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
convert_button.click(
|
181 |
fn=process_model,
|
@@ -184,4 +186,4 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
184 |
)
|
185 |
|
186 |
if __name__ == "__main__":
|
187 |
-
demo.launch(
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import os
|
4 |
+
import shutil
|
5 |
from tempfile import TemporaryDirectory
|
6 |
from huggingface_hub import hf_hub_download, HfApi
|
7 |
from safetensors.torch import save_file, load_file
|
8 |
from collections import defaultdict
|
9 |
from typing import Dict, List
|
10 |
|
11 |
+
# --- Logic copied from the original `convert.py` script ---
|
12 |
+
# These internal functions are necessary for correctly handling shared tensors.
|
13 |
+
# We copy them here to make the application self-contained.
|
14 |
+
# Source: https://github.com/huggingface/safetensors/blob/main/safetensors/torch.py
|
15 |
|
16 |
def _is_complete(storage):
|
17 |
+
# The UserWarning from this line can be ignored; it's expected.
|
18 |
return storage.size() * storage.element_size() == storage.nbytes()
|
19 |
|
20 |
def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[List[str]]:
|
21 |
tensors = list(state_dict.values())
|
|
|
22 |
storages = {tensor.storage().data_ptr(): [] for tensor in tensors}
|
23 |
for name, tensor in state_dict.items():
|
24 |
storages[tensor.storage().data_ptr()].append(name)
|
|
|
25 |
return [names for names in storages.values() if len(names) > 1]
|
26 |
|
27 |
def _remove_duplicate_names(
|
|
|
32 |
for shared in shareds:
|
33 |
complete_names = set([name for name in shared if _is_complete(state_dict[name])])
|
34 |
if not complete_names:
|
|
|
|
|
|
|
35 |
name = list(shared)[0]
|
36 |
state_dict[name] = state_dict[name].clone()
|
37 |
complete_names = {name}
|
|
|
47 |
sf_size = os.stat(sf_filename).st_size
|
48 |
pt_size = os.stat(pt_filename).st_size
|
49 |
if (sf_size - pt_size) / pt_size > 0.01:
|
|
|
50 |
return (
|
51 |
+
f"WARNING: The converted file size ({sf_size} bytes) "
|
52 |
+
f"differs from the original ({pt_size} bytes) by more than 1%."
|
53 |
)
|
54 |
return None
|
55 |
|
56 |
def convert_file(pt_filename: str, sf_filename: str, device: str):
|
57 |
+
"""Main function to convert a single file."""
|
58 |
loaded = torch.load(pt_filename, map_location=device, weights_only=True)
|
59 |
if "state_dict" in loaded:
|
60 |
loaded = loaded["state_dict"]
|
|
|
68 |
del loaded[to_remove]
|
69 |
|
70 |
loaded = {k: v.contiguous() for k, v in loaded.items()}
|
|
|
71 |
os.makedirs(os.path.dirname(sf_filename), exist_ok=True)
|
72 |
save_file(loaded, sf_filename, metadata=metadata)
|
73 |
|
74 |
size_warning = check_file_size(sf_filename, pt_filename)
|
75 |
|
|
|
76 |
reloaded = load_file(sf_filename)
|
77 |
for k in loaded:
|
78 |
pt_tensor = loaded[k].to("cpu")
|
79 |
sf_tensor = reloaded[k].to("cpu")
|
80 |
if not torch.equal(pt_tensor, sf_tensor):
|
81 |
+
raise RuntimeError(f"Tensors do not match for key {k}!")
|
82 |
|
83 |
return size_warning
|
84 |
|
85 |
|
86 |
+
# --- Main Gradio App Logic ---
|
87 |
|
88 |
def process_model(model_id: str, revision: str, progress=gr.Progress(track_tqdm=True)):
|
|
|
|
|
|
|
89 |
if not model_id:
|
90 |
+
return None, "Error: Model ID cannot be empty."
|
91 |
|
|
|
92 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
93 |
+
log_messages = [f"✅ Detected device: {device.upper()}"]
|
94 |
|
95 |
try:
|
96 |
api = HfApi()
|
97 |
info = api.model_info(repo_id=model_id, revision=revision)
|
98 |
filenames = [s.rfilename for s in info.siblings]
|
99 |
except Exception as e:
|
100 |
+
return None, f"❌ Error: Failed to get model info for `{model_id}`.\n{e}"
|
101 |
|
|
|
102 |
files_to_convert = [f for f in filenames if f.endswith(".bin") or f.endswith(".ckpt")]
|
103 |
if not files_to_convert:
|
104 |
+
return None, f"ℹ️ No .bin or .ckpt files found in model `{model_id}` for conversion."
|
105 |
|
106 |
+
log_messages.append(f"🔍 Found {len(files_to_convert)} file(s) to convert: {', '.join(files_to_convert)}")
|
107 |
|
|
|
108 |
with TemporaryDirectory() as temp_dir:
|
109 |
+
temp_converted_files = []
|
110 |
+
for filename in progress.tqdm(files_to_convert, desc="Converting files"):
|
111 |
try:
|
112 |
+
log_messages.append(f"\n🚀 Downloading `{filename}`...")
|
|
|
113 |
pt_path = hf_hub_download(
|
114 |
+
repo_id=model_id, filename=filename, revision=revision,
|
|
|
|
|
115 |
cache_dir=os.path.join(temp_dir, "downloads"),
|
116 |
)
|
117 |
|
118 |
+
log_messages.append(f"🛠️ Converting `{filename}`...")
|
119 |
+
sf_filename = os.path.splitext(os.path.basename(filename))[0] + ".safetensors"
|
|
|
120 |
sf_path = os.path.join(temp_dir, "converted", sf_filename)
|
121 |
|
122 |
size_warning = convert_file(pt_path, sf_path, device)
|
123 |
if size_warning:
|
124 |
log_messages.append(f"⚠️ {size_warning}")
|
125 |
|
126 |
+
temp_converted_files.append(sf_path)
|
127 |
+
log_messages.append(f"✅ Successfully converted to `{sf_filename}`")
|
128 |
except Exception as e:
|
129 |
+
log_messages.append(f"❌ Error processing file `{filename}`: {e}")
|
130 |
continue
|
131 |
|
132 |
+
if not temp_converted_files:
|
133 |
+
return None, "\n".join(log_messages) + "\n\nFailed to convert any files."
|
134 |
|
135 |
+
# --- KEY CHANGE ---
|
136 |
+
# Copy files from the temporary directory to a persistent (for Gradio) location
|
137 |
+
# before the directory is deleted.
|
138 |
+
persistent_files = []
|
139 |
+
for temp_path in temp_converted_files:
|
140 |
+
# shutil.copy() creates a new file that won't be deleted
|
141 |
+
persistent_path = shutil.copy(temp_path, ".")
|
142 |
+
persistent_files.append(persistent_path)
|
143 |
+
# --------------------
|
144 |
+
|
145 |
+
final_message = "\n".join(log_messages) + "\n\n" + "🎉 All files processed successfully! Ready for download."
|
146 |
+
# Return the paths to the persistent files
|
147 |
+
return persistent_files, final_message
|
148 |
|
149 |
|
150 |
+
# --- Create Gradio Interface ---
|
151 |
|
152 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
153 |
gr.Markdown(
|
154 |
"""
|
155 |
+
# Model Converter to `.safetensors`
|
156 |
+
This utility converts PyTorch model weights (`.bin`, `.ckpt`) from Hugging Face repositories
|
157 |
+
to the safe and fast `.safetensors` format.
|
158 |
|
159 |
+
**How to use:**
|
160 |
+
1. Enter the Model ID from Hugging Face (e.g., `stabilityai/stable-diffusion-2-1-base`).
|
161 |
+
2. Click the "Convert" button.
|
162 |
+
3. Wait for the process to complete and download the resulting files.
|
163 |
"""
|
164 |
)
|
165 |
with gr.Row():
|
166 |
+
model_id = gr.Textbox(label="Hugging Face Model ID", placeholder="e.g., runwayml/stable-diffusion-v1-5")
|
167 |
+
revision = gr.Textbox(label="Revision (branch)", value="main")
|
168 |
|
169 |
+
convert_button = gr.Button("Convert", variant="primary")
|
170 |
|
171 |
+
gr.Markdown("### Result")
|
172 |
+
log_output = gr.Markdown(value="Waiting for input...")
|
173 |
+
file_output = gr.File(label="Download Converted Files")
|
174 |
+
|
175 |
+
gr.Markdown(
|
176 |
+
"<p style='color:grey;font-size:0.8em;'>"
|
177 |
+
"<b>Note:</b> A `UserWarning: TypedStorage is deprecated` message may appear in the logs. "
|
178 |
+
"This is normal and does not affect the result."
|
179 |
+
"</p>"
|
180 |
+
)
|
181 |
|
182 |
convert_button.click(
|
183 |
fn=process_model,
|
|
|
186 |
)
|
187 |
|
188 |
if __name__ == "__main__":
|
189 |
+
demo.launch()
|