import os import torch from huggingface_hub import HfApi, upload_file from pathlib import Path import shutil import json def prepare_model_for_upload( checkpoint_path: str, output_dir: str, model_name: str = "voice-cloning-model", organization: str = None ): """准备模型文件用于上传到Hugging Face Hub""" # 创建临时目录 output_dir = Path(output_dir) os.makedirs(output_dir, exist_ok=True) # 加载检查点 checkpoint = torch.load(checkpoint_path, map_location='cpu') # 保存模型状态 model_path = output_dir / "pytorch_model.bin" torch.save(checkpoint['model_state_dict'], model_path) # 创建配置文件 config = { "model_type": "speaker_encoder", "hidden_dim": 256, "embedding_dim": 512, "num_layers": 3, "dropout": 0.1, "version": "1.0.0" } with open(output_dir / "config.json", "w") as f: json.dump(config, f, indent=2) # 复制模型卡片 shutil.copy( Path(__file__).parent / "model_card.md", output_dir / "README.md" ) return output_dir def upload_to_hub( model_dir: str, model_name: str, organization: str = None, token: str = None ): """上传模型到Hugging Face Hub""" # 初始化API api = HfApi() # 创建仓库 repo_id = f"{organization}/{model_name}" if organization else model_name api.create_repo( repo_id=repo_id, exist_ok=True, token=token ) # 上传文件 model_dir = Path(model_dir) for file_path in model_dir.glob("*"): upload_file( path_or_fileobj=str(file_path), path_in_repo=file_path.name, repo_id=repo_id, token=token ) print(f"Uploaded {file_path.name}") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Upload model to Hugging Face Hub") parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint") parser.add_argument("--model_name", type=str, required=True, help="Name for the model on HuggingFace Hub") parser.add_argument("--organization", type=str, help="Optional organization name") parser.add_argument("--token", type=str, help="HuggingFace token (or set via HUGGING_FACE_TOKEN env var)") args = parser.parse_args() # 准备模型文件 output_dir = "tmp_model" model_dir = prepare_model_for_upload( args.checkpoint, output_dir, args.model_name, args.organization ) # 上传到Hub token = args.token or os.environ.get("HUGGING_FACE_TOKEN") if not token: raise ValueError("Please provide a HuggingFace token") upload_to_hub( model_dir, args.model_name, args.organization, token ) # 清理临时文件 shutil.rmtree(output_dir)