|
import json |
|
import os |
|
import pickle |
|
import signal |
|
import threading |
|
import time |
|
import zipfile |
|
|
|
import gdown |
|
import numpy as np |
|
import requests |
|
import torch |
|
import tqdm |
|
from autocuda import auto_cuda, auto_cuda_name |
|
from findfile import find_files, find_cwd_file, find_file |
|
from termcolor import colored |
|
from functools import wraps |
|
|
|
from update_checker import parse_version |
|
|
|
from anonymous_demo import __version__ |
|
|
|
|
|
def save_args(config, save_path): |
|
f = open(os.path.join(save_path), mode='w', encoding='utf8') |
|
for arg in config.args: |
|
if config.args_call_count[arg]: |
|
f.write('{}: {}\n'.format(arg, config.args[arg])) |
|
f.close() |
|
|
|
|
|
def print_args(config, logger=None, mode=0): |
|
args = [key for key in sorted(config.args.keys())] |
|
for arg in args: |
|
if logger: |
|
logger.info('{0}:{1}\t-->\tCalling Count:{2}'.format(arg, config.args[arg], config.args_call_count[arg])) |
|
else: |
|
print('{0}:{1}\t-->\tCalling Count:{2}'.format(arg, config.args[arg], config.args_call_count[arg])) |
|
|
|
|
|
def check_and_fix_labels(label_set: set, label_name, all_data, opt): |
|
if '-100' in label_set: |
|
|
|
label_to_index = {origin_label: int(idx) - 1 if origin_label != '-100' else -100 for origin_label, idx in zip(sorted(label_set), range(len(label_set)))} |
|
index_to_label = {int(idx) - 1 if origin_label != '-100' else -100: origin_label for origin_label, idx in zip(sorted(label_set), range(len(label_set)))} |
|
else: |
|
label_to_index = {origin_label: int(idx) for origin_label, idx in zip(sorted(label_set), range(len(label_set)))} |
|
index_to_label = {int(idx): origin_label for origin_label, idx in zip(sorted(label_set), range(len(label_set)))} |
|
if 'index_to_label' not in opt.args: |
|
opt.index_to_label = index_to_label |
|
opt.label_to_index = label_to_index |
|
|
|
if opt.index_to_label != index_to_label: |
|
opt.index_to_label.update(index_to_label) |
|
opt.label_to_index.update(label_to_index) |
|
num_label = {l: 0 for l in label_set} |
|
num_label['Sum'] = len(all_data) |
|
for item in all_data: |
|
try: |
|
num_label[item[label_name]] += 1 |
|
item[label_name] = label_to_index[item[label_name]] |
|
except Exception as e: |
|
|
|
num_label[item.polarity] += 1 |
|
item.polarity = label_to_index[item.polarity] |
|
print('Dataset Label Details: {}'.format(num_label)) |
|
|
|
|
|
def check_and_fix_IOB_labels(label_map, opt): |
|
index_to_IOB_label = {int(label_map[origin_label]): origin_label for origin_label in label_map} |
|
opt.index_to_IOB_label = index_to_IOB_label |
|
|
|
|
|
def get_device(auto_device): |
|
if isinstance(auto_device, str) and auto_device == 'allcuda': |
|
device = 'cuda' |
|
elif isinstance(auto_device, str): |
|
device = auto_device |
|
elif isinstance(auto_device, bool): |
|
device = auto_cuda() if auto_device else 'cpu' |
|
else: |
|
device = auto_cuda() |
|
try: |
|
torch.device(device) |
|
except RuntimeError as e: |
|
print(colored('Device assignment error: {}, redirect to CPU'.format(e), 'red')) |
|
device = 'cpu' |
|
device_name = auto_cuda_name() |
|
return device, device_name |
|
|
|
|
|
def _load_word_vec(path, word2idx=None, embed_dim=300): |
|
fin = open(path, 'r', encoding='utf-8', newline='\n', errors='ignore') |
|
word_vec = {} |
|
for line in tqdm.tqdm(fin.readlines(), postfix='Loading embedding file...'): |
|
tokens = line.rstrip().split() |
|
word, vec = ' '.join(tokens[:-embed_dim]), tokens[-embed_dim:] |
|
if word in word2idx.keys(): |
|
word_vec[word] = np.asarray(vec, dtype='float32') |
|
return word_vec |
|
|
|
|
|
def build_embedding_matrix(word2idx, embed_dim, dat_fname, opt): |
|
if not os.path.exists('run'): |
|
os.makedirs('run') |
|
embed_matrix_path = 'run/{}'.format(os.path.join(opt.dataset_name, dat_fname)) |
|
if os.path.exists(embed_matrix_path): |
|
print(colored('Loading cached embedding_matrix from {} (Please remove all cached files if there is any problem!)'.format(embed_matrix_path), 'green')) |
|
embedding_matrix = pickle.load(open(embed_matrix_path, 'rb')) |
|
else: |
|
glove_path = prepare_glove840_embedding(embed_matrix_path) |
|
embedding_matrix = np.zeros((len(word2idx) + 2, embed_dim)) |
|
|
|
word_vec = _load_word_vec(glove_path, word2idx=word2idx, embed_dim=embed_dim) |
|
|
|
for word, i in tqdm.tqdm(word2idx.items(), postfix=colored('Building embedding_matrix {}'.format(dat_fname), 'yellow')): |
|
vec = word_vec.get(word) |
|
if vec is not None: |
|
embedding_matrix[i] = vec |
|
pickle.dump(embedding_matrix, open(embed_matrix_path, 'wb')) |
|
return embedding_matrix |
|
|
|
|
|
def pad_and_truncate(sequence, maxlen, dtype='int64', padding='post', truncating='post', value=0): |
|
x = (np.ones(maxlen) * value).astype(dtype) |
|
if truncating == 'pre': |
|
trunc = sequence[-maxlen:] |
|
else: |
|
trunc = sequence[:maxlen] |
|
trunc = np.asarray(trunc, dtype=dtype) |
|
if padding == 'post': |
|
x[:len(trunc)] = trunc |
|
else: |
|
x[-len(trunc):] = trunc |
|
return x |
|
|
|
|
|
class TransformerConnectionError(ValueError): |
|
def __init__(self): |
|
pass |
|
|
|
|
|
def retry(f): |
|
@wraps(f) |
|
def decorated(*args, **kwargs): |
|
count = 5 |
|
while count: |
|
|
|
try: |
|
return f(*args, **kwargs) |
|
except ( |
|
TransformerConnectionError, |
|
requests.exceptions.RequestException, |
|
requests.exceptions.ConnectionError, |
|
requests.exceptions.HTTPError, |
|
requests.exceptions.ConnectTimeout, |
|
requests.exceptions.ProxyError, |
|
requests.exceptions.SSLError, |
|
requests.exceptions.BaseHTTPError, |
|
) as e: |
|
print(colored('Training Exception: {}, will retry later'.format(e))) |
|
time.sleep(60) |
|
count -= 1 |
|
|
|
return decorated |
|
|
|
|
|
def save_json(dic, save_path): |
|
if isinstance(dic, str): |
|
dic = eval(dic) |
|
with open(save_path, 'w', encoding='utf-8') as f: |
|
|
|
str_ = json.dumps(dic, ensure_ascii=False) |
|
f.write(str_) |
|
|
|
|
|
def load_json(save_path): |
|
with open(save_path, 'r', encoding='utf-8') as f: |
|
data = f.readline().strip() |
|
print(type(data), data) |
|
dic = json.loads(data) |
|
return dic |
|
|
|
|
|
def init_optimizer(optimizer): |
|
optimizers = { |
|
'adadelta': torch.optim.Adadelta, |
|
'adagrad': torch.optim.Adagrad, |
|
'adam': torch.optim.Adam, |
|
'adamax': torch.optim.Adamax, |
|
'asgd': torch.optim.ASGD, |
|
'rmsprop': torch.optim.RMSprop, |
|
'sgd': torch.optim.SGD, |
|
'adamw': torch.optim.AdamW, |
|
torch.optim.Adadelta: torch.optim.Adadelta, |
|
torch.optim.Adagrad: torch.optim.Adagrad, |
|
torch.optim.Adam: torch.optim.Adam, |
|
torch.optim.Adamax: torch.optim.Adamax, |
|
torch.optim.ASGD: torch.optim.ASGD, |
|
torch.optim.RMSprop: torch.optim.RMSprop, |
|
torch.optim.SGD: torch.optim.SGD, |
|
torch.optim.AdamW: torch.optim.AdamW, |
|
} |
|
if optimizer in optimizers: |
|
return optimizers[optimizer] |
|
elif hasattr(torch.optim, optimizer.__name__): |
|
return optimizer |
|
else: |
|
raise KeyError('Unsupported optimizer: {}. Please use string or the optimizer objects in torch.optim as your optimizer'.format(optimizer)) |
|
|