Spaces:
Sleeping
Sleeping
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| import copy | |
| import os | |
| from typing import Any, Dict, Union | |
| import yaml | |
| CONFIG_FILE = "config.yaml" | |
| class PretrainedConfig(object): | |
| def __init__(self, **kwargs): | |
| pass | |
| def _dict_from_yaml_file(cls, yaml_file: Union[str, os.PathLike]): | |
| with open(yaml_file, encoding="utf-8") as f: | |
| config_dict = yaml.safe_load(f) | |
| return config_dict | |
| def get_config_dict( | |
| cls, pretrained_model_name_or_path: Union[str, os.PathLike] | |
| ) -> Dict[str, Any]: | |
| if os.path.isdir(pretrained_model_name_or_path): | |
| config_file = os.path.join(pretrained_model_name_or_path, CONFIG_FILE) | |
| else: | |
| config_file = pretrained_model_name_or_path | |
| config_dict = cls._dict_from_yaml_file(config_file) | |
| return config_dict | |
| def from_dict(cls, config_dict: Dict[str, Any], **kwargs): | |
| for k, v in kwargs.items(): | |
| if k in config_dict.keys(): | |
| config_dict[k] = v | |
| config = cls(**config_dict) | |
| return config | |
| def from_pretrained( | |
| cls, | |
| pretrained_model_name_or_path: Union[str, os.PathLike], | |
| **kwargs, | |
| ): | |
| config_dict = cls.get_config_dict(pretrained_model_name_or_path) | |
| return cls.from_dict(config_dict, **kwargs) | |
| def to_dict(self): | |
| output = copy.deepcopy(self.__dict__) | |
| return output | |
| def to_yaml_file(self, yaml_file_path: Union[str, os.PathLike]): | |
| config_dict = self.to_dict() | |
| with open(yaml_file_path, "w", encoding="utf-8") as writer: | |
| yaml.safe_dump(config_dict, writer) | |
| if __name__ == '__main__': | |
| pass | |