voice-clone-app / src /deploy /upload_to_hub.py
hengjie yang
Initial commit: Voice Clone App with Gradio interface
9580089
raw
history blame
3.07 kB
import os
import torch
from huggingface_hub import HfApi, upload_file
from pathlib import Path
import shutil
import json
def prepare_model_for_upload(
checkpoint_path: str,
output_dir: str,
model_name: str = "voice-cloning-model",
organization: str = None
):
"""准备模型文件用于上传到Hugging Face Hub"""
# 创建临时目录
output_dir = Path(output_dir)
os.makedirs(output_dir, exist_ok=True)
# 加载检查点
checkpoint = torch.load(checkpoint_path, map_location='cpu')
# 保存模型状态
model_path = output_dir / "pytorch_model.bin"
torch.save(checkpoint['model_state_dict'], model_path)
# 创建配置文件
config = {
"model_type": "speaker_encoder",
"hidden_dim": 256,
"embedding_dim": 512,
"num_layers": 3,
"dropout": 0.1,
"version": "1.0.0"
}
with open(output_dir / "config.json", "w") as f:
json.dump(config, f, indent=2)
# 复制模型卡片
shutil.copy(
Path(__file__).parent / "model_card.md",
output_dir / "README.md"
)
return output_dir
def upload_to_hub(
model_dir: str,
model_name: str,
organization: str = None,
token: str = None
):
"""上传模型到Hugging Face Hub"""
# 初始化API
api = HfApi()
# 创建仓库
repo_id = f"{organization}/{model_name}" if organization else model_name
api.create_repo(
repo_id=repo_id,
exist_ok=True,
token=token
)
# 上传文件
model_dir = Path(model_dir)
for file_path in model_dir.glob("*"):
upload_file(
path_or_fileobj=str(file_path),
path_in_repo=file_path.name,
repo_id=repo_id,
token=token
)
print(f"Uploaded {file_path.name}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Upload model to Hugging Face Hub")
parser.add_argument("--checkpoint", type=str, required=True,
help="Path to model checkpoint")
parser.add_argument("--model_name", type=str, required=True,
help="Name for the model on HuggingFace Hub")
parser.add_argument("--organization", type=str,
help="Optional organization name")
parser.add_argument("--token", type=str,
help="HuggingFace token (or set via HUGGING_FACE_TOKEN env var)")
args = parser.parse_args()
# 准备模型文件
output_dir = "tmp_model"
model_dir = prepare_model_for_upload(
args.checkpoint,
output_dir,
args.model_name,
args.organization
)
# 上传到Hub
token = args.token or os.environ.get("HUGGING_FACE_TOKEN")
if not token:
raise ValueError("Please provide a HuggingFace token")
upload_to_hub(
model_dir,
args.model_name,
args.organization,
token
)
# 清理临时文件
shutil.rmtree(output_dir)