#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse from pathlib import Path from huggingface_hub import snapshot_download from project_settings import environment, project_path def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--trained_model_dir", default=(project_path / "trained_models").as_posix(), type=str, ) parser.add_argument( "--models_repo_id", default="qgyd2021/vm_sound_classification", type=str, ) parser.add_argument( "--model_pattern", default="sound-*-ch32.zip", type=str, ) parser.add_argument( "--hf_token", default=environment.get("hf_token"), type=str, ) args = parser.parse_args() return args def main(): args = get_args() trained_model_dir = Path(args.trained_model_dir) trained_model_dir.mkdir(parents=True, exist_ok=True) _ = snapshot_download( repo_id=args.models_repo_id, allow_patterns=[args.model_pattern], local_dir=trained_model_dir.as_posix(), token=args.hf_token, ) return if __name__ == '__main__': main()