import json import os import pickle import signal import threading import time import zipfile import gdown import numpy as np import requests import torch import tqdm from autocuda import auto_cuda, auto_cuda_name from findfile import find_files, find_cwd_file, find_file from termcolor import colored from functools import wraps from update_checker import parse_version from anonymous_demo import __version__ def save_args(config, save_path): f = open(os.path.join(save_path), mode='w', encoding='utf8') for arg in config.args: if config.args_call_count[arg]: f.write('{}: {}\n'.format(arg, config.args[arg])) f.close() def print_args(config, logger=None, mode=0): args = [key for key in sorted(config.args.keys())] for arg in args: if logger: logger.info('{0}:{1}\t-->\tCalling Count:{2}'.format(arg, config.args[arg], config.args_call_count[arg])) else: print('{0}:{1}\t-->\tCalling Count:{2}'.format(arg, config.args[arg], config.args_call_count[arg])) def check_and_fix_labels(label_set: set, label_name, all_data, opt): if '-100' in label_set: label_to_index = {origin_label: int(idx) - 1 if origin_label != '-100' else -100 for origin_label, idx in zip(sorted(label_set), range(len(label_set)))} index_to_label = {int(idx) - 1 if origin_label != '-100' else -100: origin_label for origin_label, idx in zip(sorted(label_set), range(len(label_set)))} else: label_to_index = {origin_label: int(idx) for origin_label, idx in zip(sorted(label_set), range(len(label_set)))} index_to_label = {int(idx): origin_label for origin_label, idx in zip(sorted(label_set), range(len(label_set)))} if 'index_to_label' not in opt.args: opt.index_to_label = index_to_label opt.label_to_index = label_to_index if opt.index_to_label != index_to_label: opt.index_to_label.update(index_to_label) opt.label_to_index.update(label_to_index) num_label = {l: 0 for l in label_set} num_label['Sum'] = len(all_data) for item in all_data: try: num_label[item[label_name]] += 1 item[label_name] = label_to_index[item[label_name]] except Exception as e: # print(e) num_label[item.polarity] += 1 item.polarity = label_to_index[item.polarity] print('Dataset Label Details: {}'.format(num_label)) def check_and_fix_IOB_labels(label_map, opt): index_to_IOB_label = {int(label_map[origin_label]): origin_label for origin_label in label_map} opt.index_to_IOB_label = index_to_IOB_label def get_device(auto_device): if isinstance(auto_device, str) and auto_device == 'allcuda': device = 'cuda' elif isinstance(auto_device, str): device = auto_device elif isinstance(auto_device, bool): device = auto_cuda() if auto_device else 'cpu' else: device = auto_cuda() try: torch.device(device) except RuntimeError as e: print(colored('Device assignment error: {}, redirect to CPU'.format(e), 'red')) device = 'cpu' device_name = auto_cuda_name() return device, device_name def _load_word_vec(path, word2idx=None, embed_dim=300): fin = open(path, 'r', encoding='utf-8', newline='\n', errors='ignore') word_vec = {} for line in tqdm.tqdm(fin.readlines(), postfix='Loading embedding file...'): tokens = line.rstrip().split() word, vec = ' '.join(tokens[:-embed_dim]), tokens[-embed_dim:] if word in word2idx.keys(): word_vec[word] = np.asarray(vec, dtype='float32') return word_vec def build_embedding_matrix(word2idx, embed_dim, dat_fname, opt): if not os.path.exists('run'): os.makedirs('run') embed_matrix_path = 'run/{}'.format(os.path.join(opt.dataset_name, dat_fname)) if os.path.exists(embed_matrix_path): print(colored('Loading cached embedding_matrix from {} (Please remove all cached files if there is any problem!)'.format(embed_matrix_path), 'green')) embedding_matrix = pickle.load(open(embed_matrix_path, 'rb')) else: glove_path = prepare_glove840_embedding(embed_matrix_path) embedding_matrix = np.zeros((len(word2idx) + 2, embed_dim)) word_vec = _load_word_vec(glove_path, word2idx=word2idx, embed_dim=embed_dim) for word, i in tqdm.tqdm(word2idx.items(), postfix=colored('Building embedding_matrix {}'.format(dat_fname), 'yellow')): vec = word_vec.get(word) if vec is not None: embedding_matrix[i] = vec pickle.dump(embedding_matrix, open(embed_matrix_path, 'wb')) return embedding_matrix def pad_and_truncate(sequence, maxlen, dtype='int64', padding='post', truncating='post', value=0): x = (np.ones(maxlen) * value).astype(dtype) if truncating == 'pre': trunc = sequence[-maxlen:] else: trunc = sequence[:maxlen] trunc = np.asarray(trunc, dtype=dtype) if padding == 'post': x[:len(trunc)] = trunc else: x[-len(trunc):] = trunc return x class TransformerConnectionError(ValueError): def __init__(self): pass def retry(f): @wraps(f) def decorated(*args, **kwargs): count = 5 while count: try: return f(*args, **kwargs) except ( TransformerConnectionError, requests.exceptions.RequestException, requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.ConnectTimeout, requests.exceptions.ProxyError, requests.exceptions.SSLError, requests.exceptions.BaseHTTPError, ) as e: print(colored('Training Exception: {}, will retry later'.format(e))) time.sleep(60) count -= 1 return decorated def save_json(dic, save_path): if isinstance(dic, str): dic = eval(dic) with open(save_path, 'w', encoding='utf-8') as f: # f.write(str(dict)) str_ = json.dumps(dic, ensure_ascii=False) f.write(str_) def load_json(save_path): with open(save_path, 'r', encoding='utf-8') as f: data = f.readline().strip() print(type(data), data) dic = json.loads(data) return dic def init_optimizer(optimizer): optimizers = { 'adadelta': torch.optim.Adadelta, # default lr=1.0 'adagrad': torch.optim.Adagrad, # default lr=0.01 'adam': torch.optim.Adam, # default lr=0.001 'adamax': torch.optim.Adamax, # default lr=0.002 'asgd': torch.optim.ASGD, # default lr=0.01 'rmsprop': torch.optim.RMSprop, # default lr=0.01 'sgd': torch.optim.SGD, 'adamw': torch.optim.AdamW, torch.optim.Adadelta: torch.optim.Adadelta, # default lr=1.0 torch.optim.Adagrad: torch.optim.Adagrad, # default lr=0.01 torch.optim.Adam: torch.optim.Adam, # default lr=0.001 torch.optim.Adamax: torch.optim.Adamax, # default lr=0.002 torch.optim.ASGD: torch.optim.ASGD, # default lr=0.01 torch.optim.RMSprop: torch.optim.RMSprop, # default lr=0.01 torch.optim.SGD: torch.optim.SGD, torch.optim.AdamW: torch.optim.AdamW, } if optimizer in optimizers: return optimizers[optimizer] elif hasattr(torch.optim, optimizer.__name__): return optimizer else: raise KeyError('Unsupported optimizer: {}. Please use string or the optimizer objects in torch.optim as your optimizer'.format(optimizer))