Spaces:
Runtime error
Runtime error
| from typing import Callable | |
| import importlib | |
| import yaml | |
| from argparse import ArgumentParser | |
| import os | |
| ROOT_DIR = os.path.basename(os.path.dirname(__file__)) | |
| def get_training_fn(id: str) -> Callable: | |
| module_name, fn_name = id.rsplit(".", 1) | |
| module = importlib.import_module("models." + module_name, ROOT_DIR) | |
| return getattr(module, fn_name) | |
| def get_config(filepath: str) -> dict: | |
| with open(filepath, "r") as f: | |
| config = yaml.safe_load(f) | |
| return config | |
| if __name__ == "__main__": | |
| parser = ArgumentParser( | |
| description="Trains models on the dance dataset and saves weights." | |
| ) | |
| parser.add_argument( | |
| "--config", | |
| help="Path to the yaml file that defines the training configuration.", | |
| default="models/config/train_local.yaml", | |
| ) | |
| args = parser.parse_args() | |
| config = get_config(args.config) | |
| training_fn_path = config["training_fn"] | |
| print(f"Config: {args.config}\nTrainer Id: {training_fn_path}") | |
| train = get_training_fn(training_fn_path) | |
| train(config) | |