""" Finetuning functions for doing transfer learning to new datasets. |
""" |
from __future__ import print_function |
import uuid |
from time import sleep |
from io import open |
import math |
import pickle |
import numpy as np |
import torch |
import torch.nn as nn |
import torch.optim as optim |
from torch.autograd import Variable |
from torch.utils.data import Dataset, DataLoader |
from torch.utils.data.sampler import BatchSampler, SequentialSampler |
from torch.nn.utils import clip_grad_norm |
from sklearn.metrics import f1_score |
from torchmoji.global_variables import (FINETUNING_METHODS, |
from torchmoji.tokenizer import tokenize |
from torchmoji.sentence_tokenizer import SentenceTokenizer |
try: |
unicode |
IS_PYTHON2 = True |
except NameError: |
unicode = str |
IS_PYTHON2 = False |
def load_benchmark(path, vocab, extend_with=0): |
""" Loads the given benchmark dataset. |
Tokenizes the texts using the provided vocabulary, extending it with |
words from the training dataset if extend_with > 0. Splits them into |
three lists: training, validation and testing (in that order). |
Also calculates the maximum length of the texts and the |
suggested batch_size. |
# Arguments: |
path: Path to the dataset to be loaded. |
vocab: Vocabulary to be used for tokenizing texts. |
extend_with: If > 0, the vocabulary will be extended with up to |
extend_with tokens from the training set before tokenizing. |
# Returns: |
A dictionary with the following fields: |
texts: List of three lists, containing tokenized inputs for |
training, validation and testing (in that order). |
labels: List of three lists, containing labels for training, |
validation and testing (in that order). |
added: Number of tokens added to the vocabulary. |
batch_size: Batch size. |
maxlen: Maximum length of an input. |
""" |
with open(path, 'rb') as dataset: |
if IS_PYTHON2: |
data = pickle.load(dataset) |
else: |
data = pickle.load(dataset, fix_imports=True) |
try: |
texts = [unicode(x) for x in data['texts']] |
except UnicodeDecodeError: |
texts = [x.decode('utf-8') for x in data['texts']] |
labels = [x['label'] for x in data['info']] |
batch_size, maxlen = calculate_batchsize_maxlen(texts) |
st = SentenceTokenizer(vocab, maxlen) |
texts, labels, added = st.split_train_val_test(texts, |
labels, |
[data['train_ind'], |
data['val_ind'], |
data['test_ind']], |
extend_with=extend_with) |
return {'texts': texts, |
'labels': labels, |
'added': added, |
'batch_size': batch_size, |
'maxlen': maxlen} |
def calculate_batchsize_maxlen(texts): |
""" Calculates the maximum length in the provided texts and a suitable |
batch size. Rounds up maxlen to the nearest multiple of ten. |
# Arguments: |
texts: List of inputs. |
# Returns: |
Batch size, |
max length |
""" |
def roundup(x): |
return int(math.ceil(x / 10.0)) * 10 |
lengths = [len(tokenize(t)) for t in texts] |
maxlen = roundup(np.percentile(lengths, 80.0)) |
batch_size = 250 if maxlen <= 100 else 50 |
return batch_size, maxlen |
def freeze_layers(model, unfrozen_types=[], unfrozen_keyword=None): |
""" Freezes all layers in the given model, except for ones that are |
explicitly specified to not be frozen. |
# Arguments: |
model: Model whose layers should be modified. |
unfrozen_types: List of layer types which shouldn't be frozen. |
unfrozen_keyword: Name keywords of layers that shouldn't be frozen. |
# Returns: |
Model with the selected layers frozen. |
""" |
trainable_modules = [(n, m) for n, m in model.named_children() if len([id(p) for p in m.parameters()]) != 0] |
for name, module in trainable_modules: |
trainable = (any(typ in str(module) for typ in unfrozen_types) or |
(unfrozen_keyword is not None and unfrozen_keyword.lower() in name.lower())) |
change_trainable(module, trainable, verbose=False) |
return model |
def change_trainable(module, trainable, verbose=False): |
""" Helper method that freezes or unfreezes a given layer. |
# Arguments: |
module: Module to be modified. |
trainable: Whether the layer should be frozen or unfrozen. |
verbose: Verbosity flag. |
""" |
if verbose: print('Changing MODULE', module, 'to trainable =', trainable) |
for name, param in module.named_parameters(): |
if verbose: print('Setting weight', name, 'to trainable =', trainable) |
param.requires_grad = trainable |
if verbose: |
action = 'Unfroze' if trainable else 'Froze' |
if verbose: print("{} {}".format(action, module)) |
def find_f1_threshold(model, val_gen, test_gen, average='binary'): |
""" Choose a threshold for F1 based on the validation dataset |
(see https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4442797/ |
for details on why to find another threshold than simply 0.5) |
# Arguments: |
model: pyTorch model |
val_gen: Validation set dataloader. |
test_gen: Testing set dataloader. |
# Returns: |
F1 score for the given data and |
the corresponding F1 threshold |
""" |
thresholds = np.arange(0.01, 0.5, step=0.01) |
f1_scores = [] |
model.eval() |
val_out = [(y, model(X)) for X, y in val_gen] |
y_val, y_pred_val = (list(t) for t in zip(*val_out)) |
test_out = [(y, model(X)) for X, y in test_gen] |
y_test, y_pred_test = (list(t) for t in zip(*val_out)) |
for t in thresholds: |
y_pred_val_ind = (y_pred_val > t) |
f1_val = f1_score(y_val, y_pred_val_ind, average=average) |
f1_scores.append(f1_val) |
best_t = thresholds[np.argmax(f1_scores)] |
y_pred_ind = (y_pred_test > best_t) |
f1_test = f1_score(y_test, y_pred_ind, average=average) |
return f1_test, best_t |
def finetune(model, texts, labels, nb_classes, batch_size, method, |
metric='acc', epoch_size=5000, nb_epochs=1000, embed_l2=1E-6, |
verbose=1): |
""" Compiles and finetunes the given pytorch model. |
# Arguments: |
model: Model to be finetuned |
texts: List of three lists, containing tokenized inputs for training, |
validation and testing (in that order). |
labels: List of three lists, containing labels for training, |
validation and testing (in that order). |
nb_classes: Number of classes in the dataset. |
batch_size: Batch size. |
method: Finetuning method to be used. For available methods, see |
FINETUNING_METHODS in global_variables.py. |
metric: Evaluation metric to be used. For available metrics, see |
FINETUNING_METRICS in global_variables.py. |
epoch_size: Number of samples in an epoch. |
nb_epochs: Number of epochs. Doesn't matter much as early stopping is used. |
embed_l2: L2 regularization for the embedding layer. |
verbose: Verbosity flag. |
# Returns: |
Model after finetuning, |
score after finetuning using the provided metric. |
""" |
if method not in FINETUNING_METHODS: |
raise ValueError('ERROR (finetune): Invalid method parameter. ' |
'Available options: {}'.format(FINETUNING_METHODS)) |
if metric not in FINETUNING_METRICS: |
raise ValueError('ERROR (finetune): Invalid metric parameter. ' |
'Available options: {}'.format(FINETUNING_METRICS)) |
train_gen = get_data_loader(texts[0], labels[0], batch_size, |
extended_batch_sampler=True, epoch_size=epoch_size) |
val_gen = get_data_loader(texts[1], labels[1], batch_size, |
extended_batch_sampler=False) |
test_gen = get_data_loader(texts[2], labels[2], batch_size, |
extended_batch_sampler=False) |
checkpoint_path = '{}/torchmoji-checkpoint-{}.bin' \ |
.format(WEIGHTS_DIR, str(uuid.uuid4())) |
if method in ['last', 'new']: |
lr = 0.001 |
elif method in ['full', 'chain-thaw']: |
lr = 0.0001 |
loss_op = nn.BCEWithLogitsLoss() if nb_classes <= 2 \ |
else nn.CrossEntropyLoss() |
if method == 'last': |
model = freeze_layers(model, unfrozen_keyword='output_layer') |
if method == 'last': |
adam = optim.Adam((p for p in model.parameters() if p.requires_grad), lr=lr) |
elif method in ['full', 'new']: |
embed_params_id = [id(p) for p in model.embed.parameters()] |
output_layer_params_id = [id(p) for p in model.output_layer.parameters()] |
base_params = [p for p in model.parameters() |
if id(p) not in embed_params_id and id(p) not in output_layer_params_id and p.requires_grad] |
embed_params = [p for p in model.parameters() if id(p) in embed_params_id and p.requires_grad] |
output_layer_params = [p for p in model.parameters() if id(p) in output_layer_params_id and p.requires_grad] |
adam = optim.Adam([ |
{'params': base_params}, |
{'params': embed_params, 'weight_decay': embed_l2}, |
{'params': output_layer_params, 'lr': 0.001}, |
], lr=lr) |
if verbose: |
print('Method: {}'.format(method)) |
print('Metric: {}'.format(metric)) |
print('Classes: {}'.format(nb_classes)) |
if method == 'chain-thaw': |
result = chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op, embed_l2=embed_l2, |
evaluate=metric, verbose=verbose) |
else: |
result = tune_trainable(model, loss_op, adam, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, |
evaluate=metric, verbose=verbose) |
return model, result |
def tune_trainable(model, loss_op, optim_op, train_gen, val_gen, test_gen, |
nb_epochs, checkpoint_path, patience=5, evaluate='acc', |
verbose=2): |
""" Finetunes the given model using the accuracy measure. |
# Arguments: |
model: Model to be finetuned. |
nb_classes: Number of classes in the given dataset. |
train: Training data, given as a tuple of (inputs, outputs) |
val: Validation data, given as a tuple of (inputs, outputs) |
test: Testing data, given as a tuple of (inputs, outputs) |
epoch_size: Number of samples in an epoch. |
nb_epochs: Number of epochs. |
batch_size: Batch size. |
checkpoint_weight_path: Filepath where weights will be checkpointed to |
during training. This file will be rewritten by the function. |
patience: Patience for callback methods. |
evaluate: Evaluation method to use. Can be 'acc' or 'weighted_f1'. |
verbose: Verbosity flag. |
# Returns: |
Accuracy of the trained model, ONLY if 'evaluate' is set. |
""" |
if verbose: |
print("Trainable weights: {}".format([n for n, p in model.named_parameters() if p.requires_grad])) |
print("Training...") |
if evaluate == 'acc': |
print("Evaluation on test set prior training:", evaluate_using_acc(model, test_gen)) |
elif evaluate == 'weighted_f1': |
print("Evaluation on test set prior training:", evaluate_using_weighted_f1(model, test_gen, val_gen)) |
fit_model(model, loss_op, optim_op, train_gen, val_gen, nb_epochs, checkpoint_path, patience) |
sleep(1) |
model.load_state_dict(torch.load(checkpoint_path)) |
if verbose >= 2: |
print("Loaded weights from {}".format(checkpoint_path)) |
if evaluate == 'acc': |
return evaluate_using_acc(model, test_gen) |
elif evaluate == 'weighted_f1': |
return evaluate_using_weighted_f1(model, test_gen, val_gen) |
def evaluate_using_weighted_f1(model, test_gen, val_gen): |
""" Evaluation function using macro weighted F1 score. |
# Arguments: |
model: Model to be evaluated. |
X_test: Inputs of the testing set. |
y_test: Outputs of the testing set. |
X_val: Inputs of the validation set. |
y_val: Outputs of the validation set. |
batch_size: Batch size. |
# Returns: |
Weighted F1 score of the given model. |
""" |
f1_test, _ = find_f1_threshold(model, test_gen, val_gen, average='weighted_f1') |
return f1_test |
def evaluate_using_acc(model, test_gen): |
""" Evaluation function using accuracy. |
# Arguments: |
model: Model to be evaluated. |
test_gen: Testing data iterator (DataLoader) |
# Returns: |
Accuracy of the given model. |
""" |
model.eval() |
correct_count = 0.0 |
total_y = sum(len(y) for _, y in test_gen) |
for i, data in enumerate(test_gen): |
x, y = data |
outs = model(x) |
pred = (outs >= 0).long() |
added_counts = (pred == y).double().sum() |
correct_count += added_counts |
return correct_count/total_y |
def chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op, |
patience=5, initial_lr=0.001, next_lr=0.0001, embed_l2=1E-6, evaluate='acc', verbose=1): |
""" Finetunes given model using chain-thaw and evaluates using accuracy. |
# Arguments: |
model: Model to be finetuned. |
train: Training data, given as a tuple of (inputs, outputs) |
val: Validation data, given as a tuple of (inputs, outputs) |
test: Testing data, given as a tuple of (inputs, outputs) |
batch_size: Batch size. |
loss: Loss function to be used during training. |
epoch_size: Number of samples in an epoch. |
nb_epochs: Number of epochs. |
checkpoint_weight_path: Filepath where weights will be checkpointed to |
during training. This file will be rewritten by the function. |
initial_lr: Initial learning rate. Will only be used for the first |
training step (i.e. the output_layer layer) |
next_lr: Learning rate for every subsequent step. |
seed: Random number generator seed. |
verbose: Verbosity flag. |
evaluate: Evaluation method to use. Can be 'acc' or 'weighted_f1'. |
# Returns: |
Accuracy of the finetuned model. |
""" |
if verbose: |
print('Training..') |
train_by_chain_thaw(model, train_gen, val_gen, loss_op, patience, nb_epochs, checkpoint_path, |
initial_lr, next_lr, embed_l2, verbose) |
if evaluate == 'acc': |
return evaluate_using_acc(model, test_gen) |
elif evaluate == 'weighted_f1': |
return evaluate_using_weighted_f1(model, test_gen, val_gen) |
def train_by_chain_thaw(model, train_gen, val_gen, loss_op, patience, nb_epochs, checkpoint_path, |
initial_lr=0.001, next_lr=0.0001, embed_l2=1E-6, verbose=1): |
""" Finetunes model using the chain-thaw method. |
This is done as follows: |
1) Freeze every layer except the last (output_layer) layer and train it. |
2) Freeze every layer except the first layer and train it. |
3) Freeze every layer except the second etc., until the second last layer. |
4) Unfreeze all layers and train entire model. |
# Arguments: |
model: Model to be trained. |
train_gen: Training sample generator. |
val_data: Validation data. |
loss: Loss function to be used. |
finetuning_args: Training early stopping and checkpoint saving parameters |
epoch_size: Number of samples in an epoch. |
nb_epochs: Number of epochs. |
checkpoint_weight_path: Where weight checkpoints should be saved. |
batch_size: Batch size. |
initial_lr: Initial learning rate. Will only be used for the first |
training step (i.e. the output_layer layer) |
next_lr: Learning rate for every subsequent step. |
verbose: Verbosity flag. |
""" |
layers = [m for m in model.children() if len([id(p) for p in m.parameters()]) != 0] |
layers.insert(0, layers.pop(len(layers) - 1)) |
layers.append(None) |
lr = None |
for layer in layers: |
if lr is None: |
lr = initial_lr |
elif lr == initial_lr: |
lr = next_lr |
for _layer in layers: |
if _layer is not None: |
trainable = _layer == layer or layer is None |
change_trainable(_layer, trainable=trainable, verbose=False) |
for _layer in model.children(): |
assert all(p.requires_grad == (_layer == layer) for p in _layer.parameters()) or layer is None |
if verbose: |
if layer is None: |
print('Finetuning all layers') |
else: |
print('Finetuning {}'.format(layer)) |
special_params = [id(p) for p in model.embed.parameters()] |
base_params = [p for p in model.parameters() if id(p) not in special_params and p.requires_grad] |
embed_parameters = [p for p in model.parameters() if id(p) in special_params and p.requires_grad] |
adam = optim.Adam([ |
{'params': base_params}, |
{'params': embed_parameters, 'weight_decay': embed_l2}, |
], lr=lr) |
fit_model(model, loss_op, adam, train_gen, val_gen, nb_epochs, |
checkpoint_path, patience) |
sleep(1) |
model.load_state_dict(torch.load(checkpoint_path)) |
if verbose >= 2: |
print("Loaded weights from {}".format(checkpoint_path)) |
def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs, |
checkpoint_path, patience): |
""" Analog to Keras fit_generator function. |
# Arguments: |
model: Model to be finetuned. |
loss_op: loss operation (BCEWithLogitsLoss or CrossEntropy for e.g.) |
optim_op: optimization operation (Adam e.g.) |
train_gen: Training data iterator (DataLoader) |
val_gen: Validation data iterator (DataLoader) |
epochs: Number of epochs. |
checkpoint_path: Filepath where weights will be checkpointed to |
during training. This file will be rewritten by the function. |
patience: Patience for callback methods. |
verbose: Verbosity flag. |
# Returns: |
Accuracy of the trained model, ONLY if 'evaluate' is set. |
""" |
torch.save(model.state_dict(), checkpoint_path) |
model.eval() |
best_loss = np.mean([loss_op(model(Variable(xv)).squeeze(), Variable(yv.float()).squeeze()).data.cpu().numpy()[0] for xv, yv in val_gen]) |
print("original val loss", best_loss) |
epoch_without_impr = 0 |
for epoch in range(epochs): |
for i, data in enumerate(train_gen): |
X_train, y_train = data |
X_train = Variable(X_train, requires_grad=False) |
y_train = Variable(y_train, requires_grad=False) |
model.train() |
optim_op.zero_grad() |
output = model(X_train) |
loss = loss_op(output, y_train.float()) |
loss.backward() |
clip_grad_norm(model.parameters(), 1) |
optim_op.step() |
acc = evaluate_using_acc(model, [(X_train.data, y_train.data)]) |
print("== Epoch", epoch, "step", i, "train loss", loss.data.cpu().numpy()[0], "train acc", acc) |
model.eval() |
acc = evaluate_using_acc(model, val_gen) |
print("val acc", acc) |
val_loss = np.mean([loss_op(model(Variable(xv)).squeeze(), Variable(yv.float()).squeeze()).data.cpu().numpy()[0] for xv, yv in val_gen]) |
print("val loss", val_loss) |
if best_loss is not None and val_loss >= best_loss: |
epoch_without_impr += 1 |
print('No improvement over previous best loss: ', best_loss) |
if best_loss is None or val_loss < best_loss: |
best_loss = val_loss |
torch.save(model.state_dict(), checkpoint_path) |
print('Saving model at', checkpoint_path) |
if epoch_without_impr >= patience: |
break |
def get_data_loader(X_in, y_in, batch_size, extended_batch_sampler=True, epoch_size=25000, upsample=False, seed=42): |
""" Returns a dataloader that enables larger epochs on small datasets and |
has upsampling functionality. |
# Arguments: |
X_in: Inputs of the given dataset. |
y_in: Outputs of the given dataset. |
batch_size: Batch size. |
epoch_size: Number of samples in an epoch. |
upsample: Whether upsampling should be done. This flag should only be |
set on binary class problems. |
# Returns: |
DataLoader. |
""" |
dataset = DeepMojiDataset(X_in, y_in) |
if extended_batch_sampler: |
batch_sampler = DeepMojiBatchSampler(y_in, batch_size, epoch_size=epoch_size, upsample=upsample, seed=seed) |
else: |
batch_sampler = BatchSampler(SequentialSampler(y_in), batch_size, drop_last=False) |
return DataLoader(dataset, batch_sampler=batch_sampler, num_workers=0) |
class DeepMojiDataset(Dataset): |
""" A simple Dataset class. |
# Arguments: |
X_in: Inputs of the given dataset. |
y_in: Outputs of the given dataset. |
# __getitem__ output: |
(torch.LongTensor, torch.LongTensor) |
""" |
def __init__(self, X_in, y_in): |
if not isinstance(X_in, torch.LongTensor): |
X_in = torch.from_numpy(X_in.astype('int64')).long() |
if not isinstance(y_in, torch.LongTensor): |
y_in = torch.from_numpy(y_in.astype('int64')).long() |
self.X_in = torch.split(X_in, 1, dim=0) |
self.y_in = torch.split(y_in, 1, dim=0) |
def __len__(self): |
return len(self.X_in) |
def __getitem__(self, idx): |
return self.X_in[idx].squeeze(), self.y_in[idx].squeeze() |
class DeepMojiBatchSampler(object): |
"""A Batch sampler that enables larger epochs on small datasets and |
has upsampling functionality. |
# Arguments: |
y_in: Labels of the dataset. |
batch_size: Batch size. |
epoch_size: Number of samples in an epoch. |
upsample: Whether upsampling should be done. This flag should only be |
set on binary class problems. |
seed: Random number generator seed. |
# __iter__ output: |
iterator of lists (batches) of indices in the dataset |
""" |
def __init__(self, y_in, batch_size, epoch_size, upsample, seed): |
self.batch_size = batch_size |
self.epoch_size = epoch_size |
self.upsample = upsample |
np.random.seed(seed) |
if upsample: |
assert len(y_in.shape) == 1 |
neg = np.where(y_in.numpy() == 0)[0] |
pos = np.where(y_in.numpy() == 1)[0] |
assert epoch_size % 2 == 0 |
samples_pr_class = int(epoch_size / 2) |
else: |
ind = range(len(y_in)) |
if not upsample: |
self.sample_ind = np.random.choice(ind, epoch_size, replace=True) |
else: |
sample_neg = np.random.choice(neg, samples_pr_class, replace=True) |
sample_pos = np.random.choice(pos, samples_pr_class, replace=True) |
concat_ind = np.concatenate((sample_neg, sample_pos), axis=0) |
p = np.random.permutation(len(concat_ind)) |
self.sample_ind = concat_ind[p] |
label_dist = np.mean(y_in.numpy()[self.sample_ind]) |
assert(label_dist > 0.45) |
assert(label_dist < 0.55) |
def __iter__(self): |
for i in range(int(self.epoch_size/self.batch_size)): |
start = i * self.batch_size |
end = min(start + self.batch_size, self.epoch_size) |
yield self.sample_ind[start:end] |
def __len__(self): |
return (self.epoch_size + self.batch_size - 1) // self.batch_size |