VOIDER commited on
Commit
659e2ba
·
verified ·
1 Parent(s): 8915205

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -60
app.py CHANGED
@@ -1,27 +1,27 @@
1
  import gradio as gr
2
  import torch
3
  import os
 
4
  from tempfile import TemporaryDirectory
5
  from huggingface_hub import hf_hub_download, HfApi
6
  from safetensors.torch import save_file, load_file
7
  from collections import defaultdict
8
  from typing import Dict, List
9
 
10
- # --- Логика, скопированная из оригинального скрипта `convert.py` ---
11
- # Эти внутренние функции нужны для корректной обработки общих (shared) тензоров.
12
- # Копируем их, чтобы сделать приложение самодостаточным.
13
- # Источник: https://github.com/huggingface/safetensors/blob/main/safetensors/torch.py
14
 
15
  def _is_complete(storage):
 
16
  return storage.size() * storage.element_size() == storage.nbytes()
17
 
18
  def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[List[str]]:
19
  tensors = list(state_dict.values())
20
- # Can't handle unpickled storages
21
  storages = {tensor.storage().data_ptr(): [] for tensor in tensors}
22
  for name, tensor in state_dict.items():
23
  storages[tensor.storage().data_ptr()].append(name)
24
- # Return only tensors that share storage
25
  return [names for names in storages.values() if len(names) > 1]
26
 
27
  def _remove_duplicate_names(
@@ -32,9 +32,6 @@ def _remove_duplicate_names(
32
  for shared in shareds:
33
  complete_names = set([name for name in shared if _is_complete(state_dict[name])])
34
  if not complete_names:
35
- # Fallback for very weird cases.
36
- # The model is likely to be incorrect after this
37
- # but it will be loadable.
38
  name = list(shared)[0]
39
  state_dict[name] = state_dict[name].clone()
40
  complete_names = {name}
@@ -50,15 +47,14 @@ def check_file_size(sf_filename: str, pt_filename: str):
50
  sf_size = os.stat(sf_filename).st_size
51
  pt_size = os.stat(pt_filename).st_size
52
  if (sf_size - pt_size) / pt_size > 0.01:
53
- # Не бросаем ошибку, а возвращаем предупреждение
54
  return (
55
- f"ВНИМАНИЕ: Размер сконвертированного файла ({sf_size} байт) "
56
- f"более чем на 1% отличается от оригинала ({pt_size} байт)."
57
  )
58
  return None
59
 
60
  def convert_file(pt_filename: str, sf_filename: str, device: str):
61
- """Основная функция конвертации одного файла."""
62
  loaded = torch.load(pt_filename, map_location=device, weights_only=True)
63
  if "state_dict" in loaded:
64
  loaded = loaded["state_dict"]
@@ -72,110 +68,116 @@ def convert_file(pt_filename: str, sf_filename: str, device: str):
72
  del loaded[to_remove]
73
 
74
  loaded = {k: v.contiguous() for k, v in loaded.items()}
75
-
76
  os.makedirs(os.path.dirname(sf_filename), exist_ok=True)
77
  save_file(loaded, sf_filename, metadata=metadata)
78
 
79
  size_warning = check_file_size(sf_filename, pt_filename)
80
 
81
- # Проверка на корректность
82
  reloaded = load_file(sf_filename)
83
  for k in loaded:
84
  pt_tensor = loaded[k].to("cpu")
85
  sf_tensor = reloaded[k].to("cpu")
86
  if not torch.equal(pt_tensor, sf_tensor):
87
- raise RuntimeError(f"Тензоры не совпадают для ключа {k}!")
88
 
89
  return size_warning
90
 
91
 
92
- # --- Основная логика Gradio-приложения ---
93
 
94
  def process_model(model_id: str, revision: str, progress=gr.Progress(track_tqdm=True)):
95
- """
96
- Скачивает, конвертирует и возвращает пути к файлам `.safetensors`.
97
- """
98
  if not model_id:
99
- return None, "Ошибка: ID модели не может быть пустым."
100
 
101
- # 1. Определяем устройство (GPU или CPU)
102
  device = "cuda" if torch.cuda.is_available() else "cpu"
103
- log_messages = [f"✅ Обнаружено устройство: {device.upper()}"]
104
 
105
  try:
106
  api = HfApi()
107
  info = api.model_info(repo_id=model_id, revision=revision)
108
  filenames = [s.rfilename for s in info.siblings]
109
  except Exception as e:
110
- return None, f"❌ Ошибка: Не ��далось получить информацию о модели `{model_id}`.\n{e}"
111
 
112
- # Ищем файлы для конвертации
113
  files_to_convert = [f for f in filenames if f.endswith(".bin") or f.endswith(".ckpt")]
114
  if not files_to_convert:
115
- return None, f"ℹ️ В модели `{model_id}` не найдено файлов `.bin` или `.ckpt` для конвертации."
116
 
117
- log_messages.append(f"🔍 Найдено {len(files_to_convert)} файлов для конвертации: {', '.join(files_to_convert)}")
118
 
119
- # Используем временную директорию для чистоты
120
  with TemporaryDirectory() as temp_dir:
121
- converted_files = []
122
- for filename in progress.tqdm(files_to_convert, desc="Конвертация файлов"):
123
  try:
124
- # Скачиваем файл
125
- log_messages.append(f"\n🚀 Скачивание `{filename}`...")
126
  pt_path = hf_hub_download(
127
- repo_id=model_id,
128
- filename=filename,
129
- revision=revision,
130
  cache_dir=os.path.join(temp_dir, "downloads"),
131
  )
132
 
133
- # Конвертируем
134
- log_messages.append(f"🛠️ Конвертация `{filename}`...")
135
- sf_filename = os.path.splitext(filename)[0] + ".safetensors"
136
  sf_path = os.path.join(temp_dir, "converted", sf_filename)
137
 
138
  size_warning = convert_file(pt_path, sf_path, device)
139
  if size_warning:
140
  log_messages.append(f"⚠️ {size_warning}")
141
 
142
- converted_files.append(sf_path)
143
- log_messages.append(f"✅ Успешно сконвертировано в `{sf_filename}`")
144
  except Exception as e:
145
- log_messages.append(f"❌ Ошибка при обработке файла `{filename}`: {e}")
146
  continue
147
 
148
- if not converted_files:
149
- return None, "\n".join(log_messages) + "\n\nНе удалось сконвертировать ни один файл."
150
 
151
- final_message = "\n".join(log_messages) + "\n\n" + "🎉 Все файлы успешно обработаны! Готово к скачиванию."
152
- return converted_files, final_message
 
 
 
 
 
 
 
 
 
 
 
153
 
154
 
155
- # --- Создание интерфейса Gradio ---
156
 
157
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
158
  gr.Markdown(
159
  """
160
- # Конвертер моделей в `.safetensors`
161
- Эта утилита преобразует веса моделей PyTorch (`.bin`, `.ckpt`) из репозиториев Hugging Face
162
- в безопасный и быстрый формат `.safetensors`.
163
 
164
- **Как использовать:**
165
- 1. Введите ID модели с Hugging Face (например, `stabilityai/stable-diffusion-2-1-base`).
166
- 2. Нажмите кноп��у "Конвертировать".
167
- 3. Дождитесь завершения процесса и скачайте полученные файлы.
168
  """
169
  )
170
  with gr.Row():
171
- model_id = gr.Textbox(label="ID модели на Hugging Face", placeholder="например, runwayml/stable-diffusion-v1-5")
172
- revision = gr.Textbox(label="Ревизия (ветка)", value="main")
173
 
174
- convert_button = gr.Button("Конвертировать", variant="primary")
175
 
176
- gr.Markdown("### Результат")
177
- log_output = gr.Markdown(value="Ожидание запуска...")
178
- file_output = gr.File(label="Скачать сконвертированные файлы")
 
 
 
 
 
 
 
179
 
180
  convert_button.click(
181
  fn=process_model,
@@ -184,4 +186,4 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
184
  )
185
 
186
  if __name__ == "__main__":
187
- demo.launch(debug=True)
 
1
  import gradio as gr
2
  import torch
3
  import os
4
+ import shutil
5
  from tempfile import TemporaryDirectory
6
  from huggingface_hub import hf_hub_download, HfApi
7
  from safetensors.torch import save_file, load_file
8
  from collections import defaultdict
9
  from typing import Dict, List
10
 
11
+ # --- Logic copied from the original `convert.py` script ---
12
+ # These internal functions are necessary for correctly handling shared tensors.
13
+ # We copy them here to make the application self-contained.
14
+ # Source: https://github.com/huggingface/safetensors/blob/main/safetensors/torch.py
15
 
16
  def _is_complete(storage):
17
+ # The UserWarning from this line can be ignored; it's expected.
18
  return storage.size() * storage.element_size() == storage.nbytes()
19
 
20
  def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[List[str]]:
21
  tensors = list(state_dict.values())
 
22
  storages = {tensor.storage().data_ptr(): [] for tensor in tensors}
23
  for name, tensor in state_dict.items():
24
  storages[tensor.storage().data_ptr()].append(name)
 
25
  return [names for names in storages.values() if len(names) > 1]
26
 
27
  def _remove_duplicate_names(
 
32
  for shared in shareds:
33
  complete_names = set([name for name in shared if _is_complete(state_dict[name])])
34
  if not complete_names:
 
 
 
35
  name = list(shared)[0]
36
  state_dict[name] = state_dict[name].clone()
37
  complete_names = {name}
 
47
  sf_size = os.stat(sf_filename).st_size
48
  pt_size = os.stat(pt_filename).st_size
49
  if (sf_size - pt_size) / pt_size > 0.01:
 
50
  return (
51
+ f"WARNING: The converted file size ({sf_size} bytes) "
52
+ f"differs from the original ({pt_size} bytes) by more than 1%."
53
  )
54
  return None
55
 
56
  def convert_file(pt_filename: str, sf_filename: str, device: str):
57
+ """Main function to convert a single file."""
58
  loaded = torch.load(pt_filename, map_location=device, weights_only=True)
59
  if "state_dict" in loaded:
60
  loaded = loaded["state_dict"]
 
68
  del loaded[to_remove]
69
 
70
  loaded = {k: v.contiguous() for k, v in loaded.items()}
 
71
  os.makedirs(os.path.dirname(sf_filename), exist_ok=True)
72
  save_file(loaded, sf_filename, metadata=metadata)
73
 
74
  size_warning = check_file_size(sf_filename, pt_filename)
75
 
 
76
  reloaded = load_file(sf_filename)
77
  for k in loaded:
78
  pt_tensor = loaded[k].to("cpu")
79
  sf_tensor = reloaded[k].to("cpu")
80
  if not torch.equal(pt_tensor, sf_tensor):
81
+ raise RuntimeError(f"Tensors do not match for key {k}!")
82
 
83
  return size_warning
84
 
85
 
86
+ # --- Main Gradio App Logic ---
87
 
88
  def process_model(model_id: str, revision: str, progress=gr.Progress(track_tqdm=True)):
 
 
 
89
  if not model_id:
90
+ return None, "Error: Model ID cannot be empty."
91
 
 
92
  device = "cuda" if torch.cuda.is_available() else "cpu"
93
+ log_messages = [f"✅ Detected device: {device.upper()}"]
94
 
95
  try:
96
  api = HfApi()
97
  info = api.model_info(repo_id=model_id, revision=revision)
98
  filenames = [s.rfilename for s in info.siblings]
99
  except Exception as e:
100
+ return None, f"❌ Error: Failed to get model info for `{model_id}`.\n{e}"
101
 
 
102
  files_to_convert = [f for f in filenames if f.endswith(".bin") or f.endswith(".ckpt")]
103
  if not files_to_convert:
104
+ return None, f"ℹ️ No .bin or .ckpt files found in model `{model_id}` for conversion."
105
 
106
+ log_messages.append(f"🔍 Found {len(files_to_convert)} file(s) to convert: {', '.join(files_to_convert)}")
107
 
 
108
  with TemporaryDirectory() as temp_dir:
109
+ temp_converted_files = []
110
+ for filename in progress.tqdm(files_to_convert, desc="Converting files"):
111
  try:
112
+ log_messages.append(f"\n🚀 Downloading `{filename}`...")
 
113
  pt_path = hf_hub_download(
114
+ repo_id=model_id, filename=filename, revision=revision,
 
 
115
  cache_dir=os.path.join(temp_dir, "downloads"),
116
  )
117
 
118
+ log_messages.append(f"🛠️ Converting `{filename}`...")
119
+ sf_filename = os.path.splitext(os.path.basename(filename))[0] + ".safetensors"
 
120
  sf_path = os.path.join(temp_dir, "converted", sf_filename)
121
 
122
  size_warning = convert_file(pt_path, sf_path, device)
123
  if size_warning:
124
  log_messages.append(f"⚠️ {size_warning}")
125
 
126
+ temp_converted_files.append(sf_path)
127
+ log_messages.append(f"✅ Successfully converted to `{sf_filename}`")
128
  except Exception as e:
129
+ log_messages.append(f"❌ Error processing file `{filename}`: {e}")
130
  continue
131
 
132
+ if not temp_converted_files:
133
+ return None, "\n".join(log_messages) + "\n\nFailed to convert any files."
134
 
135
+ # --- KEY CHANGE ---
136
+ # Copy files from the temporary directory to a persistent (for Gradio) location
137
+ # before the directory is deleted.
138
+ persistent_files = []
139
+ for temp_path in temp_converted_files:
140
+ # shutil.copy() creates a new file that won't be deleted
141
+ persistent_path = shutil.copy(temp_path, ".")
142
+ persistent_files.append(persistent_path)
143
+ # --------------------
144
+
145
+ final_message = "\n".join(log_messages) + "\n\n" + "🎉 All files processed successfully! Ready for download."
146
+ # Return the paths to the persistent files
147
+ return persistent_files, final_message
148
 
149
 
150
+ # --- Create Gradio Interface ---
151
 
152
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
153
  gr.Markdown(
154
  """
155
+ # Model Converter to `.safetensors`
156
+ This utility converts PyTorch model weights (`.bin`, `.ckpt`) from Hugging Face repositories
157
+ to the safe and fast `.safetensors` format.
158
 
159
+ **How to use:**
160
+ 1. Enter the Model ID from Hugging Face (e.g., `stabilityai/stable-diffusion-2-1-base`).
161
+ 2. Click the "Convert" button.
162
+ 3. Wait for the process to complete and download the resulting files.
163
  """
164
  )
165
  with gr.Row():
166
+ model_id = gr.Textbox(label="Hugging Face Model ID", placeholder="e.g., runwayml/stable-diffusion-v1-5")
167
+ revision = gr.Textbox(label="Revision (branch)", value="main")
168
 
169
+ convert_button = gr.Button("Convert", variant="primary")
170
 
171
+ gr.Markdown("### Result")
172
+ log_output = gr.Markdown(value="Waiting for input...")
173
+ file_output = gr.File(label="Download Converted Files")
174
+
175
+ gr.Markdown(
176
+ "<p style='color:grey;font-size:0.8em;'>"
177
+ "<b>Note:</b> A `UserWarning: TypedStorage is deprecated` message may appear in the logs. "
178
+ "This is normal and does not affect the result."
179
+ "</p>"
180
+ )
181
 
182
  convert_button.click(
183
  fn=process_model,
 
186
  )
187
 
188
  if __name__ == "__main__":
189
+ demo.launch()