VOIDER commited on
Commit
8915205
·
verified ·
1 Parent(s): df14df3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -0
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ from tempfile import TemporaryDirectory
5
+ from huggingface_hub import hf_hub_download, HfApi
6
+ from safetensors.torch import save_file, load_file
7
+ from collections import defaultdict
8
+ from typing import Dict, List
9
+
10
+ # --- Логика, скопированная из оригинального скрипта `convert.py` ---
11
+ # Эти внутренние функции нужны для корректной обработки общих (shared) тензоров.
12
+ # Копируем их, чтобы сделать приложение самодостаточным.
13
+ # Источник: https://github.com/huggingface/safetensors/blob/main/safetensors/torch.py
14
+
15
+ def _is_complete(storage):
16
+ return storage.size() * storage.element_size() == storage.nbytes()
17
+
18
+ def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[List[str]]:
19
+ tensors = list(state_dict.values())
20
+ # Can't handle unpickled storages
21
+ storages = {tensor.storage().data_ptr(): [] for tensor in tensors}
22
+ for name, tensor in state_dict.items():
23
+ storages[tensor.storage().data_ptr()].append(name)
24
+ # Return only tensors that share storage
25
+ return [names for names in storages.values() if len(names) > 1]
26
+
27
+ def _remove_duplicate_names(
28
+ state_dict: Dict[str, torch.Tensor]
29
+ ) -> Dict[str, List[str]]:
30
+ shareds = _find_shared_tensors(state_dict)
31
+ to_remove = defaultdict(list)
32
+ for shared in shareds:
33
+ complete_names = set([name for name in shared if _is_complete(state_dict[name])])
34
+ if not complete_names:
35
+ # Fallback for very weird cases.
36
+ # The model is likely to be incorrect after this
37
+ # but it will be loadable.
38
+ name = list(shared)[0]
39
+ state_dict[name] = state_dict[name].clone()
40
+ complete_names = {name}
41
+
42
+ keep_name = sorted(list(complete_names))[0]
43
+
44
+ for name in sorted(shared):
45
+ if name != keep_name:
46
+ to_remove[keep_name].append(name)
47
+ return to_remove
48
+
49
+ def check_file_size(sf_filename: str, pt_filename: str):
50
+ sf_size = os.stat(sf_filename).st_size
51
+ pt_size = os.stat(pt_filename).st_size
52
+ if (sf_size - pt_size) / pt_size > 0.01:
53
+ # Не бросаем ошибку, а возвращаем предупреждение
54
+ return (
55
+ f"ВНИМАНИЕ: Размер сконвертированного файла ({sf_size} байт) "
56
+ f"более чем на 1% отличается от оригинала ({pt_size} байт)."
57
+ )
58
+ return None
59
+
60
+ def convert_file(pt_filename: str, sf_filename: str, device: str):
61
+ """Основная функция конвертации одного файла."""
62
+ loaded = torch.load(pt_filename, map_location=device, weights_only=True)
63
+ if "state_dict" in loaded:
64
+ loaded = loaded["state_dict"]
65
+
66
+ to_removes = _remove_duplicate_names(loaded)
67
+ metadata = {"format": "pt"}
68
+ for kept_name, to_remove_group in to_removes.items():
69
+ for to_remove in to_remove_group:
70
+ if to_remove not in metadata:
71
+ metadata[to_remove] = kept_name
72
+ del loaded[to_remove]
73
+
74
+ loaded = {k: v.contiguous() for k, v in loaded.items()}
75
+
76
+ os.makedirs(os.path.dirname(sf_filename), exist_ok=True)
77
+ save_file(loaded, sf_filename, metadata=metadata)
78
+
79
+ size_warning = check_file_size(sf_filename, pt_filename)
80
+
81
+ # Проверка на корректность
82
+ reloaded = load_file(sf_filename)
83
+ for k in loaded:
84
+ pt_tensor = loaded[k].to("cpu")
85
+ sf_tensor = reloaded[k].to("cpu")
86
+ if not torch.equal(pt_tensor, sf_tensor):
87
+ raise RuntimeError(f"Тензоры не совпадают для ключа {k}!")
88
+
89
+ return size_warning
90
+
91
+
92
+ # --- Основная логика Gradio-приложения ---
93
+
94
+ def process_model(model_id: str, revision: str, progress=gr.Progress(track_tqdm=True)):
95
+ """
96
+ Скачивает, конвертирует и возвращает пути к файлам `.safetensors`.
97
+ """
98
+ if not model_id:
99
+ return None, "Ошибка: ID модели не может быть пустым."
100
+
101
+ # 1. Определяем устройство (GPU или CPU)
102
+ device = "cuda" if torch.cuda.is_available() else "cpu"
103
+ log_messages = [f"✅ Обнаружено устройство: {device.upper()}"]
104
+
105
+ try:
106
+ api = HfApi()
107
+ info = api.model_info(repo_id=model_id, revision=revision)
108
+ filenames = [s.rfilename for s in info.siblings]
109
+ except Exception as e:
110
+ return None, f"❌ Ошибка: Не удалось получить информацию о модели `{model_id}`.\n{e}"
111
+
112
+ # Ищем файлы для конвертации
113
+ files_to_convert = [f for f in filenames if f.endswith(".bin") or f.endswith(".ckpt")]
114
+ if not files_to_convert:
115
+ return None, f"ℹ️ В ��одели `{model_id}` не найдено файлов `.bin` или `.ckpt` для конвертации."
116
+
117
+ log_messages.append(f"🔍 Найдено {len(files_to_convert)} файлов для конвертации: {', '.join(files_to_convert)}")
118
+
119
+ # Используем временную директорию для чистоты
120
+ with TemporaryDirectory() as temp_dir:
121
+ converted_files = []
122
+ for filename in progress.tqdm(files_to_convert, desc="Конвертация файлов"):
123
+ try:
124
+ # Скачиваем файл
125
+ log_messages.append(f"\n🚀 Скачивание `{filename}`...")
126
+ pt_path = hf_hub_download(
127
+ repo_id=model_id,
128
+ filename=filename,
129
+ revision=revision,
130
+ cache_dir=os.path.join(temp_dir, "downloads"),
131
+ )
132
+
133
+ # Конвертируем
134
+ log_messages.append(f"🛠️ Конвертация `{filename}`...")
135
+ sf_filename = os.path.splitext(filename)[0] + ".safetensors"
136
+ sf_path = os.path.join(temp_dir, "converted", sf_filename)
137
+
138
+ size_warning = convert_file(pt_path, sf_path, device)
139
+ if size_warning:
140
+ log_messages.append(f"⚠️ {size_warning}")
141
+
142
+ converted_files.append(sf_path)
143
+ log_messages.append(f"✅ Успешно сконвертировано в `{sf_filename}`")
144
+ except Exception as e:
145
+ log_messages.append(f"❌ Ошибка при обработке файла `{filename}`: {e}")
146
+ continue
147
+
148
+ if not converted_files:
149
+ return None, "\n".join(log_messages) + "\n\nНе удалось сконвертировать ни один файл."
150
+
151
+ final_message = "\n".join(log_messages) + "\n\n" + "🎉 Все файлы успешно обработаны! Готово к скачиванию."
152
+ return converted_files, final_message
153
+
154
+
155
+ # --- Создание интерфейса Gradio ---
156
+
157
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
158
+ gr.Markdown(
159
+ """
160
+ # Конвертер моделей в `.safetensors`
161
+ Эта утилита преобразует веса моделей PyTorch (`.bin`, `.ckpt`) из репозиториев Hugging Face
162
+ в безопасный и быстрый формат `.safetensors`.
163
+
164
+ **Как использовать:**
165
+ 1. Введите ID модели с Hugging Face (например, `stabilityai/stable-diffusion-2-1-base`).
166
+ 2. Нажмите кнопку "Конвертировать".
167
+ 3. Дождитесь завершения процесса и скачайте полученные файлы.
168
+ """
169
+ )
170
+ with gr.Row():
171
+ model_id = gr.Textbox(label="ID модели на Hugging Face", placeholder="например, runwayml/stable-diffusion-v1-5")
172
+ revision = gr.Textbox(label="Ревизия (ветка)", value="main")
173
+
174
+ convert_button = gr.Button("Конвертировать", variant="primary")
175
+
176
+ gr.Markdown("### Результат")
177
+ log_output = gr.Markdown(value="Ожидание запуска...")
178
+ file_output = gr.File(label="Скачать сконвертированные файлы")
179
+
180
+ convert_button.click(
181
+ fn=process_model,
182
+ inputs=[model_id, revision],
183
+ outputs=[file_output, log_output],
184
+ )
185
+
186
+ if __name__ == "__main__":
187
+ demo.launch(debug=True)