import copy from anonymous_demo.functional.config.config_manager import ConfigManager from anonymous_demo.core.tad.classic.__bert__.models import TADBERT _tad_config_template = {'model': TADBERT, 'optimizer': "adamw", 'learning_rate': 0.00002, 'patience': 99999, 'pretrained_bert': "microsoft/mdeberta-v3-base", 'cache_dataset': True, 'warmup_step': -1, 'show_metric': False, 'max_seq_len': 80, 'dropout': 0, 'l2reg': 0.000001, 'num_epoch': 10, 'batch_size': 16, 'initializer': 'xavier_uniform_', 'seed': 52, 'polarities_dim': 3, 'log_step': 10, 'evaluate_begin': 0, 'cross_validate_fold': -1, 'use_amp': False, # split train and test datasets into 5 folds and repeat 3 training } _tad_config_base = {'model': TADBERT, 'optimizer': "adamw", 'learning_rate': 0.00002, 'pretrained_bert': "microsoft/deberta-v3-base", 'cache_dataset': True, 'warmup_step': -1, 'show_metric': False, 'max_seq_len': 80, 'patience': 99999, 'dropout': 0, 'l2reg': 0.000001, 'num_epoch': 10, 'batch_size': 16, 'initializer': 'xavier_uniform_', 'seed': 52, 'polarities_dim': 3, 'log_step': 10, 'evaluate_begin': 0, 'cross_validate_fold': -1 # split train and test datasets into 5 folds and repeat 3 training } _tad_config_english = {'model': TADBERT, 'optimizer': "adamw", 'learning_rate': 0.00002, 'patience': 99999, 'pretrained_bert': "microsoft/deberta-v3-base", 'cache_dataset': True, 'warmup_step': -1, 'show_metric': False, 'max_seq_len': 80, 'dropout': 0, 'l2reg': 0.000001, 'num_epoch': 10, 'batch_size': 16, 'initializer': 'xavier_uniform_', 'seed': 52, 'polarities_dim': 3, 'log_step': 10, 'evaluate_begin': 0, 'cross_validate_fold': -1 # split train and test datasets into 5 folds and repeat 3 training } _tad_config_multilingual = {'model': TADBERT, 'optimizer': "adamw", 'learning_rate': 0.00002, 'patience': 99999, 'pretrained_bert': "microsoft/mdeberta-v3-base", 'cache_dataset': True, 'warmup_step': -1, 'show_metric': False, 'max_seq_len': 80, 'dropout': 0, 'l2reg': 0.000001, 'num_epoch': 10, 'batch_size': 16, 'initializer': 'xavier_uniform_', 'seed': 52, 'polarities_dim': 3, 'log_step': 10, 'evaluate_begin': 0, 'cross_validate_fold': -1 # split train and test datasets into 5 folds and repeat 3 training } _tad_config_chinese = {'model': TADBERT, 'optimizer': "adamw", 'learning_rate': 0.00002, 'patience': 99999, 'cache_dataset': True, 'warmup_step': -1, 'show_metric': False, 'pretrained_bert': "bert-base-chinese", 'max_seq_len': 80, 'dropout': 0, 'l2reg': 0.000001, 'num_epoch': 10, 'batch_size': 16, 'initializer': 'xavier_uniform_', 'seed': 52, 'polarities_dim': 3, 'log_step': 10, 'evaluate_begin': 0, 'cross_validate_fold': -1 # split train and test datasets into 5 folds and repeat 3 training } class TADConfigManager(ConfigManager): def __init__(self, args, **kwargs): """ Available Params: {'model': BERT, 'optimizer': "adamw", 'learning_rate': 0.00002, 'pretrained_bert': "roberta-base", 'cache_dataset': True, 'warmup_step': -1, 'show_metric': False, 'max_seq_len': 80, 'patience': 99999, 'dropout': 0, 'l2reg': 0.000001, 'num_epoch': 10, 'batch_size': 16, 'initializer': 'xavier_uniform_', 'seed': {52, 25} 'embed_dim': 768, 'hidden_dim': 768, 'polarities_dim': 3, 'log_step': 10, 'evaluate_begin': 0, 'cross_validate_fold': -1 # split train and test datasets into 5 folds and repeat 3 training } :param args: :param kwargs: """ super().__init__(args, **kwargs) @staticmethod def set_tad_config(configType: str, newitem: dict): if isinstance(newitem, dict): if configType == 'template': _tad_config_template.update(newitem) elif configType == 'base': _tad_config_base.update(newitem) elif configType == 'english': _tad_config_english.update(newitem) elif configType == 'chinese': _tad_config_chinese.update(newitem) elif configType == 'multilingual': _tad_config_multilingual.update(newitem) elif configType == 'glove': _tad_config_glove.update(newitem) else: raise ValueError( "Wrong value of config type supplied, please use one from following type: template, base, english, chinese, multilingual, glove") else: raise TypeError("Wrong type of new config item supplied, please use dict e.g.{'NewConfig': NewValue}") @staticmethod def set_tad_config_template(newitem): TADConfigManager.set_tad_config('template', newitem) @staticmethod def set_tad_config_base(newitem): TADConfigManager.set_tad_config('base', newitem) @staticmethod def set_tad_config_english(newitem): TADConfigManager.set_tad_config('english', newitem) @staticmethod def set_tad_config_chinese(newitem): TADConfigManager.set_tad_config('chinese', newitem) @staticmethod def set_tad_config_multilingual(newitem): TADConfigManager.set_tad_config('multilingual', newitem) @staticmethod def set_tad_config_glove(newitem): TADConfigManager.set_tad_config('glove', newitem) @staticmethod def get_tad_config_template() -> ConfigManager: _tad_config_template.update(_tad_config_template) return TADConfigManager(copy.deepcopy(_tad_config_template)) @staticmethod def get_tad_config_base() -> ConfigManager: _tad_config_template.update(_tad_config_base) return TADConfigManager(copy.deepcopy(_tad_config_template)) @staticmethod def get_tad_config_english() -> ConfigManager: _tad_config_template.update(_tad_config_english) return TADConfigManager(copy.deepcopy(_tad_config_template)) @staticmethod def get_tad_config_chinese() -> ConfigManager: _tad_config_template.update(_tad_config_chinese) return TADConfigManager(copy.deepcopy(_tad_config_template)) @staticmethod def get_tad_config_multilingual() -> ConfigManager: _tad_config_template.update(_tad_config_multilingual) return TADConfigManager(copy.deepcopy(_tad_config_template)) @staticmethod def get_tad_config_glove() -> ConfigManager: _tad_config_template.update(_tad_config_glove) return TADConfigManager(copy.deepcopy(_tad_config_template))