Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from huggingface_hub import HfApi, upload_file | |
from pathlib import Path | |
import shutil | |
import json | |
def prepare_model_for_upload( | |
checkpoint_path: str, | |
output_dir: str, | |
model_name: str = "voice-cloning-model", | |
organization: str = None | |
): | |
"""准备模型文件用于上传到Hugging Face Hub""" | |
# 创建临时目录 | |
output_dir = Path(output_dir) | |
os.makedirs(output_dir, exist_ok=True) | |
# 加载检查点 | |
checkpoint = torch.load(checkpoint_path, map_location='cpu') | |
# 保存模型状态 | |
model_path = output_dir / "pytorch_model.bin" | |
torch.save(checkpoint['model_state_dict'], model_path) | |
# 创建配置文件 | |
config = { | |
"model_type": "speaker_encoder", | |
"hidden_dim": 256, | |
"embedding_dim": 512, | |
"num_layers": 3, | |
"dropout": 0.1, | |
"version": "1.0.0" | |
} | |
with open(output_dir / "config.json", "w") as f: | |
json.dump(config, f, indent=2) | |
# 复制模型卡片 | |
shutil.copy( | |
Path(__file__).parent / "model_card.md", | |
output_dir / "README.md" | |
) | |
return output_dir | |
def upload_to_hub( | |
model_dir: str, | |
model_name: str, | |
organization: str = None, | |
token: str = None | |
): | |
"""上传模型到Hugging Face Hub""" | |
# 初始化API | |
api = HfApi() | |
# 创建仓库 | |
repo_id = f"{organization}/{model_name}" if organization else model_name | |
api.create_repo( | |
repo_id=repo_id, | |
exist_ok=True, | |
token=token | |
) | |
# 上传文件 | |
model_dir = Path(model_dir) | |
for file_path in model_dir.glob("*"): | |
upload_file( | |
path_or_fileobj=str(file_path), | |
path_in_repo=file_path.name, | |
repo_id=repo_id, | |
token=token | |
) | |
print(f"Uploaded {file_path.name}") | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser(description="Upload model to Hugging Face Hub") | |
parser.add_argument("--checkpoint", type=str, required=True, | |
help="Path to model checkpoint") | |
parser.add_argument("--model_name", type=str, required=True, | |
help="Name for the model on HuggingFace Hub") | |
parser.add_argument("--organization", type=str, | |
help="Optional organization name") | |
parser.add_argument("--token", type=str, | |
help="HuggingFace token (or set via HUGGING_FACE_TOKEN env var)") | |
args = parser.parse_args() | |
# 准备模型文件 | |
output_dir = "tmp_model" | |
model_dir = prepare_model_for_upload( | |
args.checkpoint, | |
output_dir, | |
args.model_name, | |
args.organization | |
) | |
# 上传到Hub | |
token = args.token or os.environ.get("HUGGING_FACE_TOKEN") | |
if not token: | |
raise ValueError("Please provide a HuggingFace token") | |
upload_to_hub( | |
model_dir, | |
args.model_name, | |
args.organization, | |
token | |
) | |
# 清理临时文件 | |
shutil.rmtree(output_dir) |