Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
from tempfile import TemporaryDirectory
|
5 |
+
from huggingface_hub import hf_hub_download, HfApi
|
6 |
+
from safetensors.torch import save_file, load_file
|
7 |
+
from collections import defaultdict
|
8 |
+
from typing import Dict, List
|
9 |
+
|
10 |
+
# --- Логика, скопированная из оригинального скрипта `convert.py` ---
|
11 |
+
# Эти внутренние функции нужны для корректной обработки общих (shared) тензоров.
|
12 |
+
# Копируем их, чтобы сделать приложение самодостаточным.
|
13 |
+
# Источник: https://github.com/huggingface/safetensors/blob/main/safetensors/torch.py
|
14 |
+
|
15 |
+
def _is_complete(storage):
|
16 |
+
return storage.size() * storage.element_size() == storage.nbytes()
|
17 |
+
|
18 |
+
def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[List[str]]:
|
19 |
+
tensors = list(state_dict.values())
|
20 |
+
# Can't handle unpickled storages
|
21 |
+
storages = {tensor.storage().data_ptr(): [] for tensor in tensors}
|
22 |
+
for name, tensor in state_dict.items():
|
23 |
+
storages[tensor.storage().data_ptr()].append(name)
|
24 |
+
# Return only tensors that share storage
|
25 |
+
return [names for names in storages.values() if len(names) > 1]
|
26 |
+
|
27 |
+
def _remove_duplicate_names(
|
28 |
+
state_dict: Dict[str, torch.Tensor]
|
29 |
+
) -> Dict[str, List[str]]:
|
30 |
+
shareds = _find_shared_tensors(state_dict)
|
31 |
+
to_remove = defaultdict(list)
|
32 |
+
for shared in shareds:
|
33 |
+
complete_names = set([name for name in shared if _is_complete(state_dict[name])])
|
34 |
+
if not complete_names:
|
35 |
+
# Fallback for very weird cases.
|
36 |
+
# The model is likely to be incorrect after this
|
37 |
+
# but it will be loadable.
|
38 |
+
name = list(shared)[0]
|
39 |
+
state_dict[name] = state_dict[name].clone()
|
40 |
+
complete_names = {name}
|
41 |
+
|
42 |
+
keep_name = sorted(list(complete_names))[0]
|
43 |
+
|
44 |
+
for name in sorted(shared):
|
45 |
+
if name != keep_name:
|
46 |
+
to_remove[keep_name].append(name)
|
47 |
+
return to_remove
|
48 |
+
|
49 |
+
def check_file_size(sf_filename: str, pt_filename: str):
|
50 |
+
sf_size = os.stat(sf_filename).st_size
|
51 |
+
pt_size = os.stat(pt_filename).st_size
|
52 |
+
if (sf_size - pt_size) / pt_size > 0.01:
|
53 |
+
# Не бросаем ошибку, а возвращаем предупреждение
|
54 |
+
return (
|
55 |
+
f"ВНИМАНИЕ: Размер сконвертированного файла ({sf_size} байт) "
|
56 |
+
f"более чем на 1% отличается от оригинала ({pt_size} байт)."
|
57 |
+
)
|
58 |
+
return None
|
59 |
+
|
60 |
+
def convert_file(pt_filename: str, sf_filename: str, device: str):
|
61 |
+
"""Основная функция конвертации одного файла."""
|
62 |
+
loaded = torch.load(pt_filename, map_location=device, weights_only=True)
|
63 |
+
if "state_dict" in loaded:
|
64 |
+
loaded = loaded["state_dict"]
|
65 |
+
|
66 |
+
to_removes = _remove_duplicate_names(loaded)
|
67 |
+
metadata = {"format": "pt"}
|
68 |
+
for kept_name, to_remove_group in to_removes.items():
|
69 |
+
for to_remove in to_remove_group:
|
70 |
+
if to_remove not in metadata:
|
71 |
+
metadata[to_remove] = kept_name
|
72 |
+
del loaded[to_remove]
|
73 |
+
|
74 |
+
loaded = {k: v.contiguous() for k, v in loaded.items()}
|
75 |
+
|
76 |
+
os.makedirs(os.path.dirname(sf_filename), exist_ok=True)
|
77 |
+
save_file(loaded, sf_filename, metadata=metadata)
|
78 |
+
|
79 |
+
size_warning = check_file_size(sf_filename, pt_filename)
|
80 |
+
|
81 |
+
# Проверка на корректность
|
82 |
+
reloaded = load_file(sf_filename)
|
83 |
+
for k in loaded:
|
84 |
+
pt_tensor = loaded[k].to("cpu")
|
85 |
+
sf_tensor = reloaded[k].to("cpu")
|
86 |
+
if not torch.equal(pt_tensor, sf_tensor):
|
87 |
+
raise RuntimeError(f"Тензоры не совпадают для ключа {k}!")
|
88 |
+
|
89 |
+
return size_warning
|
90 |
+
|
91 |
+
|
92 |
+
# --- Основная логика Gradio-приложения ---
|
93 |
+
|
94 |
+
def process_model(model_id: str, revision: str, progress=gr.Progress(track_tqdm=True)):
|
95 |
+
"""
|
96 |
+
Скачивает, конвертирует и возвращает пути к файлам `.safetensors`.
|
97 |
+
"""
|
98 |
+
if not model_id:
|
99 |
+
return None, "Ошибка: ID модели не может быть пустым."
|
100 |
+
|
101 |
+
# 1. Определяем устройство (GPU или CPU)
|
102 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
103 |
+
log_messages = [f"✅ Обнаружено устройство: {device.upper()}"]
|
104 |
+
|
105 |
+
try:
|
106 |
+
api = HfApi()
|
107 |
+
info = api.model_info(repo_id=model_id, revision=revision)
|
108 |
+
filenames = [s.rfilename for s in info.siblings]
|
109 |
+
except Exception as e:
|
110 |
+
return None, f"❌ Ошибка: Не удалось получить информацию о модели `{model_id}`.\n{e}"
|
111 |
+
|
112 |
+
# Ищем файлы для конвертации
|
113 |
+
files_to_convert = [f for f in filenames if f.endswith(".bin") or f.endswith(".ckpt")]
|
114 |
+
if not files_to_convert:
|
115 |
+
return None, f"ℹ️ В ��одели `{model_id}` не найдено файлов `.bin` или `.ckpt` для конвертации."
|
116 |
+
|
117 |
+
log_messages.append(f"🔍 Найдено {len(files_to_convert)} файлов для конвертации: {', '.join(files_to_convert)}")
|
118 |
+
|
119 |
+
# Используем временную директорию для чистоты
|
120 |
+
with TemporaryDirectory() as temp_dir:
|
121 |
+
converted_files = []
|
122 |
+
for filename in progress.tqdm(files_to_convert, desc="Конвертация файлов"):
|
123 |
+
try:
|
124 |
+
# Скачиваем файл
|
125 |
+
log_messages.append(f"\n🚀 Скачивание `{filename}`...")
|
126 |
+
pt_path = hf_hub_download(
|
127 |
+
repo_id=model_id,
|
128 |
+
filename=filename,
|
129 |
+
revision=revision,
|
130 |
+
cache_dir=os.path.join(temp_dir, "downloads"),
|
131 |
+
)
|
132 |
+
|
133 |
+
# Конвертируем
|
134 |
+
log_messages.append(f"🛠️ Конвертация `{filename}`...")
|
135 |
+
sf_filename = os.path.splitext(filename)[0] + ".safetensors"
|
136 |
+
sf_path = os.path.join(temp_dir, "converted", sf_filename)
|
137 |
+
|
138 |
+
size_warning = convert_file(pt_path, sf_path, device)
|
139 |
+
if size_warning:
|
140 |
+
log_messages.append(f"⚠️ {size_warning}")
|
141 |
+
|
142 |
+
converted_files.append(sf_path)
|
143 |
+
log_messages.append(f"✅ Успешно сконвертировано в `{sf_filename}`")
|
144 |
+
except Exception as e:
|
145 |
+
log_messages.append(f"❌ Ошибка при обработке файла `{filename}`: {e}")
|
146 |
+
continue
|
147 |
+
|
148 |
+
if not converted_files:
|
149 |
+
return None, "\n".join(log_messages) + "\n\nНе удалось сконвертировать ни один файл."
|
150 |
+
|
151 |
+
final_message = "\n".join(log_messages) + "\n\n" + "🎉 Все файлы успешно обработаны! Готово к скачиванию."
|
152 |
+
return converted_files, final_message
|
153 |
+
|
154 |
+
|
155 |
+
# --- Создание интерфейса Gradio ---
|
156 |
+
|
157 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
158 |
+
gr.Markdown(
|
159 |
+
"""
|
160 |
+
# Конвертер моделей в `.safetensors`
|
161 |
+
Эта утилита преобразует веса моделей PyTorch (`.bin`, `.ckpt`) из репозиториев Hugging Face
|
162 |
+
в безопасный и быстрый формат `.safetensors`.
|
163 |
+
|
164 |
+
**Как использовать:**
|
165 |
+
1. Введите ID модели с Hugging Face (например, `stabilityai/stable-diffusion-2-1-base`).
|
166 |
+
2. Нажмите кнопку "Конвертировать".
|
167 |
+
3. Дождитесь завершения процесса и скачайте полученные файлы.
|
168 |
+
"""
|
169 |
+
)
|
170 |
+
with gr.Row():
|
171 |
+
model_id = gr.Textbox(label="ID модели на Hugging Face", placeholder="например, runwayml/stable-diffusion-v1-5")
|
172 |
+
revision = gr.Textbox(label="Ревизия (ветка)", value="main")
|
173 |
+
|
174 |
+
convert_button = gr.Button("Конвертировать", variant="primary")
|
175 |
+
|
176 |
+
gr.Markdown("### Результат")
|
177 |
+
log_output = gr.Markdown(value="Ожидание запуска...")
|
178 |
+
file_output = gr.File(label="Скачать сконвертированные файлы")
|
179 |
+
|
180 |
+
convert_button.click(
|
181 |
+
fn=process_model,
|
182 |
+
inputs=[model_id, revision],
|
183 |
+
outputs=[file_output, log_output],
|
184 |
+
)
|
185 |
+
|
186 |
+
if __name__ == "__main__":
|
187 |
+
demo.launch(debug=True)
|