import gradio as gr import torch import os from tempfile import TemporaryDirectory from huggingface_hub import hf_hub_download, HfApi from safetensors.torch import save_file, load_file from collections import defaultdict from typing import Dict, List # --- Логика, скопированная из оригинального скрипта `convert.py` --- # Эти внутренние функции нужны для корректной обработки общих (shared) тензоров. # Копируем их, чтобы сделать приложение самодостаточным. # Источник: https://github.com/huggingface/safetensors/blob/main/safetensors/torch.py def _is_complete(storage): return storage.size() * storage.element_size() == storage.nbytes() def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[List[str]]: tensors = list(state_dict.values()) # Can't handle unpickled storages storages = {tensor.storage().data_ptr(): [] for tensor in tensors} for name, tensor in state_dict.items(): storages[tensor.storage().data_ptr()].append(name) # Return only tensors that share storage return [names for names in storages.values() if len(names) > 1] def _remove_duplicate_names( state_dict: Dict[str, torch.Tensor] ) -> Dict[str, List[str]]: shareds = _find_shared_tensors(state_dict) to_remove = defaultdict(list) for shared in shareds: complete_names = set([name for name in shared if _is_complete(state_dict[name])]) if not complete_names: # Fallback for very weird cases. # The model is likely to be incorrect after this # but it will be loadable. name = list(shared)[0] state_dict[name] = state_dict[name].clone() complete_names = {name} keep_name = sorted(list(complete_names))[0] for name in sorted(shared): if name != keep_name: to_remove[keep_name].append(name) return to_remove def check_file_size(sf_filename: str, pt_filename: str): sf_size = os.stat(sf_filename).st_size pt_size = os.stat(pt_filename).st_size if (sf_size - pt_size) / pt_size > 0.01: # Не бросаем ошибку, а возвращаем предупреждение return ( f"ВНИМАНИЕ: Размер сконвертированного файла ({sf_size} байт) " f"более чем на 1% отличается от оригинала ({pt_size} байт)." ) return None def convert_file(pt_filename: str, sf_filename: str, device: str): """Основная функция конвертации одного файла.""" loaded = torch.load(pt_filename, map_location=device, weights_only=True) if "state_dict" in loaded: loaded = loaded["state_dict"] to_removes = _remove_duplicate_names(loaded) metadata = {"format": "pt"} for kept_name, to_remove_group in to_removes.items(): for to_remove in to_remove_group: if to_remove not in metadata: metadata[to_remove] = kept_name del loaded[to_remove] loaded = {k: v.contiguous() for k, v in loaded.items()} os.makedirs(os.path.dirname(sf_filename), exist_ok=True) save_file(loaded, sf_filename, metadata=metadata) size_warning = check_file_size(sf_filename, pt_filename) # Проверка на корректность reloaded = load_file(sf_filename) for k in loaded: pt_tensor = loaded[k].to("cpu") sf_tensor = reloaded[k].to("cpu") if not torch.equal(pt_tensor, sf_tensor): raise RuntimeError(f"Тензоры не совпадают для ключа {k}!") return size_warning # --- Основная логика Gradio-приложения --- def process_model(model_id: str, revision: str, progress=gr.Progress(track_tqdm=True)): """ Скачивает, конвертирует и возвращает пути к файлам `.safetensors`. """ if not model_id: return None, "Ошибка: ID модели не может быть пустым." # 1. Определяем устройство (GPU или CPU) device = "cuda" if torch.cuda.is_available() else "cpu" log_messages = [f"✅ Обнаружено устройство: {device.upper()}"] try: api = HfApi() info = api.model_info(repo_id=model_id, revision=revision) filenames = [s.rfilename for s in info.siblings] except Exception as e: return None, f"❌ Ошибка: Не удалось получить информацию о модели `{model_id}`.\n{e}" # Ищем файлы для конвертации files_to_convert = [f for f in filenames if f.endswith(".bin") or f.endswith(".ckpt")] if not files_to_convert: return None, f"ℹ️ В модели `{model_id}` не найдено файлов `.bin` или `.ckpt` для конвертации." log_messages.append(f"🔍 Найдено {len(files_to_convert)} файлов для конвертации: {', '.join(files_to_convert)}") # Используем временную директорию для чистоты with TemporaryDirectory() as temp_dir: converted_files = [] for filename in progress.tqdm(files_to_convert, desc="Конвертация файлов"): try: # Скачиваем файл log_messages.append(f"\n🚀 Скачивание `{filename}`...") pt_path = hf_hub_download( repo_id=model_id, filename=filename, revision=revision, cache_dir=os.path.join(temp_dir, "downloads"), ) # Конвертируем log_messages.append(f"🛠️ Конвертация `{filename}`...") sf_filename = os.path.splitext(filename)[0] + ".safetensors" sf_path = os.path.join(temp_dir, "converted", sf_filename) size_warning = convert_file(pt_path, sf_path, device) if size_warning: log_messages.append(f"⚠️ {size_warning}") converted_files.append(sf_path) log_messages.append(f"✅ Успешно сконвертировано в `{sf_filename}`") except Exception as e: log_messages.append(f"❌ Ошибка при обработке файла `{filename}`: {e}") continue if not converted_files: return None, "\n".join(log_messages) + "\n\nНе удалось сконвертировать ни один файл." final_message = "\n".join(log_messages) + "\n\n" + "🎉 Все файлы успешно обработаны! Готово к скачиванию." return converted_files, final_message # --- Создание интерфейса Gradio --- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # Конвертер моделей в `.safetensors` Эта утилита преобразует веса моделей PyTorch (`.bin`, `.ckpt`) из репозиториев Hugging Face в безопасный и быстрый формат `.safetensors`. **Как использовать:** 1. Введите ID модели с Hugging Face (например, `stabilityai/stable-diffusion-2-1-base`). 2. Нажмите кнопку "Конвертировать". 3. Дождитесь завершения процесса и скачайте полученные файлы. """ ) with gr.Row(): model_id = gr.Textbox(label="ID модели на Hugging Face", placeholder="например, runwayml/stable-diffusion-v1-5") revision = gr.Textbox(label="Ревизия (ветка)", value="main") convert_button = gr.Button("Конвертировать", variant="primary") gr.Markdown("### Результат") log_output = gr.Markdown(value="Ожидание запуска...") file_output = gr.File(label="Скачать сконвертированные файлы") convert_button.click( fn=process_model, inputs=[model_id, revision], outputs=[file_output, log_output], ) if __name__ == "__main__": demo.launch(debug=True)