Spaces:
Running
Running
torchmoji itself
Browse files- torchmoji/.gitkeep +1 -0
- torchmoji/__init__.py +0 -0
- torchmoji/attlayer.py +68 -0
- torchmoji/class_avg_finetuning.py +315 -0
- torchmoji/create_vocab.py +271 -0
- torchmoji/filter_input.py +36 -0
- torchmoji/filter_utils.py +194 -0
- torchmoji/finetuning.py +674 -0
- torchmoji/global_variables.py +28 -0
- torchmoji/lstm.py +357 -0
- torchmoji/model_def.py +311 -0
- torchmoji/sentence_tokenizer.py +245 -0
- torchmoji/tokenizer.py +162 -0
- torchmoji/word_generator.py +312 -0
torchmoji/.gitkeep
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
torchmoji/__init__.py
ADDED
|
File without changes
|
torchmoji/attlayer.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
""" Define the Attention Layer of the model.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import print_function, division
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from torch.autograd import Variable
|
| 10 |
+
from torch.nn import Module
|
| 11 |
+
from torch.nn.parameter import Parameter
|
| 12 |
+
|
| 13 |
+
class Attention(Module):
|
| 14 |
+
"""
|
| 15 |
+
Computes a weighted average of the different channels across timesteps.
|
| 16 |
+
Uses 1 parameter pr. channel to compute the attention value for a single timestep.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, attention_size, return_attention=False):
|
| 20 |
+
""" Initialize the attention layer
|
| 21 |
+
|
| 22 |
+
# Arguments:
|
| 23 |
+
attention_size: Size of the attention vector.
|
| 24 |
+
return_attention: If true, output will include the weight for each input token
|
| 25 |
+
used for the prediction
|
| 26 |
+
|
| 27 |
+
"""
|
| 28 |
+
super(Attention, self).__init__()
|
| 29 |
+
self.return_attention = return_attention
|
| 30 |
+
self.attention_size = attention_size
|
| 31 |
+
self.attention_vector = Parameter(torch.FloatTensor(attention_size))
|
| 32 |
+
self.attention_vector.data.normal_(std=0.05) # Initialize attention vector
|
| 33 |
+
|
| 34 |
+
def __repr__(self):
|
| 35 |
+
s = '{name}({attention_size}, return attention={return_attention})'
|
| 36 |
+
return s.format(name=self.__class__.__name__, **self.__dict__)
|
| 37 |
+
|
| 38 |
+
def forward(self, inputs, input_lengths):
|
| 39 |
+
""" Forward pass.
|
| 40 |
+
|
| 41 |
+
# Arguments:
|
| 42 |
+
inputs (Torch.Variable): Tensor of input sequences
|
| 43 |
+
input_lengths (torch.LongTensor): Lengths of the sequences
|
| 44 |
+
|
| 45 |
+
# Return:
|
| 46 |
+
Tuple with (representations and attentions if self.return_attention else None).
|
| 47 |
+
"""
|
| 48 |
+
logits = inputs.matmul(self.attention_vector)
|
| 49 |
+
unnorm_ai = (logits - logits.max()).exp()
|
| 50 |
+
|
| 51 |
+
# Compute a mask for the attention on the padded sequences
|
| 52 |
+
# See e.g. https://discuss.pytorch.org/t/self-attention-on-words-and-masking/5671/5
|
| 53 |
+
max_len = unnorm_ai.size(1)
|
| 54 |
+
idxes = torch.arange(0, max_len, out=torch.LongTensor(max_len)).unsqueeze(0)
|
| 55 |
+
mask = Variable((idxes < input_lengths.unsqueeze(1)).float())
|
| 56 |
+
|
| 57 |
+
# apply mask and renormalize attention scores (weights)
|
| 58 |
+
masked_weights = unnorm_ai * mask
|
| 59 |
+
att_sums = masked_weights.sum(dim=1, keepdim=True) # sums per sequence
|
| 60 |
+
attentions = masked_weights.div(att_sums)
|
| 61 |
+
|
| 62 |
+
# apply attention weights
|
| 63 |
+
weighted = torch.mul(inputs, attentions.unsqueeze(-1).expand_as(inputs))
|
| 64 |
+
|
| 65 |
+
# get the final fixed vector representations of the sentences
|
| 66 |
+
representations = weighted.sum(dim=1)
|
| 67 |
+
|
| 68 |
+
return (representations, attentions if self.return_attention else None)
|
torchmoji/class_avg_finetuning.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
""" Class average finetuning functions. Before using any of these finetuning
|
| 3 |
+
functions, ensure that the model is set up with nb_classes=2.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import print_function
|
| 6 |
+
|
| 7 |
+
import uuid
|
| 8 |
+
from time import sleep
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.optim as optim
|
| 14 |
+
|
| 15 |
+
from torchmoji.global_variables import (
|
| 16 |
+
FINETUNING_METHODS,
|
| 17 |
+
WEIGHTS_DIR)
|
| 18 |
+
from torchmoji.finetuning import (
|
| 19 |
+
freeze_layers,
|
| 20 |
+
get_data_loader,
|
| 21 |
+
fit_model,
|
| 22 |
+
train_by_chain_thaw,
|
| 23 |
+
find_f1_threshold)
|
| 24 |
+
|
| 25 |
+
def relabel(y, current_label_nr, nb_classes):
|
| 26 |
+
""" Makes a binary classification for a specific class in a
|
| 27 |
+
multi-class dataset.
|
| 28 |
+
|
| 29 |
+
# Arguments:
|
| 30 |
+
y: Outputs to be relabelled.
|
| 31 |
+
current_label_nr: Current label number.
|
| 32 |
+
nb_classes: Total number of classes.
|
| 33 |
+
|
| 34 |
+
# Returns:
|
| 35 |
+
Relabelled outputs of a given multi-class dataset into a binary
|
| 36 |
+
classification dataset.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
# Handling binary classification
|
| 40 |
+
if nb_classes == 2 and len(y.shape) == 1:
|
| 41 |
+
return y
|
| 42 |
+
|
| 43 |
+
y_new = np.zeros(len(y))
|
| 44 |
+
y_cut = y[:, current_label_nr]
|
| 45 |
+
label_pos = np.where(y_cut == 1)[0]
|
| 46 |
+
y_new[label_pos] = 1
|
| 47 |
+
return y_new
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def class_avg_finetune(model, texts, labels, nb_classes, batch_size,
|
| 51 |
+
method, epoch_size=5000, nb_epochs=1000, embed_l2=1E-6,
|
| 52 |
+
verbose=True):
|
| 53 |
+
""" Compiles and finetunes the given model.
|
| 54 |
+
|
| 55 |
+
# Arguments:
|
| 56 |
+
model: Model to be finetuned
|
| 57 |
+
texts: List of three lists, containing tokenized inputs for training,
|
| 58 |
+
validation and testing (in that order).
|
| 59 |
+
labels: List of three lists, containing labels for training,
|
| 60 |
+
validation and testing (in that order).
|
| 61 |
+
nb_classes: Number of classes in the dataset.
|
| 62 |
+
batch_size: Batch size.
|
| 63 |
+
method: Finetuning method to be used. For available methods, see
|
| 64 |
+
FINETUNING_METHODS in global_variables.py. Note that the model
|
| 65 |
+
should be defined accordingly (see docstring for torchmoji_transfer())
|
| 66 |
+
epoch_size: Number of samples in an epoch.
|
| 67 |
+
nb_epochs: Number of epochs. Doesn't matter much as early stopping is used.
|
| 68 |
+
embed_l2: L2 regularization for the embedding layer.
|
| 69 |
+
verbose: Verbosity flag.
|
| 70 |
+
|
| 71 |
+
# Returns:
|
| 72 |
+
Model after finetuning,
|
| 73 |
+
score after finetuning using the class average F1 metric.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
if method not in FINETUNING_METHODS:
|
| 77 |
+
raise ValueError('ERROR (class_avg_tune_trainable): '
|
| 78 |
+
'Invalid method parameter. '
|
| 79 |
+
'Available options: {}'.format(FINETUNING_METHODS))
|
| 80 |
+
|
| 81 |
+
(X_train, y_train) = (texts[0], labels[0])
|
| 82 |
+
(X_val, y_val) = (texts[1], labels[1])
|
| 83 |
+
(X_test, y_test) = (texts[2], labels[2])
|
| 84 |
+
|
| 85 |
+
checkpoint_path = '{}/torchmoji-checkpoint-{}.bin' \
|
| 86 |
+
.format(WEIGHTS_DIR, str(uuid.uuid4()))
|
| 87 |
+
|
| 88 |
+
f1_init_path = '{}/torchmoji-f1-init-{}.bin' \
|
| 89 |
+
.format(WEIGHTS_DIR, str(uuid.uuid4()))
|
| 90 |
+
|
| 91 |
+
if method in ['last', 'new']:
|
| 92 |
+
lr = 0.001
|
| 93 |
+
elif method in ['full', 'chain-thaw']:
|
| 94 |
+
lr = 0.0001
|
| 95 |
+
|
| 96 |
+
loss_op = nn.BCEWithLogitsLoss()
|
| 97 |
+
|
| 98 |
+
# Freeze layers if using last
|
| 99 |
+
if method == 'last':
|
| 100 |
+
model = freeze_layers(model, unfrozen_keyword='output_layer')
|
| 101 |
+
|
| 102 |
+
# Define optimizer, for chain-thaw we define it later (after freezing)
|
| 103 |
+
if method == 'last':
|
| 104 |
+
adam = optim.Adam((p for p in model.parameters() if p.requires_grad), lr=lr)
|
| 105 |
+
elif method in ['full', 'new']:
|
| 106 |
+
# Add L2 regulation on embeddings only
|
| 107 |
+
special_params = [id(p) for p in model.embed.parameters()]
|
| 108 |
+
base_params = [p for p in model.parameters() if id(p) not in special_params and p.requires_grad]
|
| 109 |
+
embed_parameters = [p for p in model.parameters() if id(p) in special_params and p.requires_grad]
|
| 110 |
+
adam = optim.Adam([
|
| 111 |
+
{'params': base_params},
|
| 112 |
+
{'params': embed_parameters, 'weight_decay': embed_l2},
|
| 113 |
+
], lr=lr)
|
| 114 |
+
|
| 115 |
+
# Training
|
| 116 |
+
if verbose:
|
| 117 |
+
print('Method: {}'.format(method))
|
| 118 |
+
print('Classes: {}'.format(nb_classes))
|
| 119 |
+
|
| 120 |
+
if method == 'chain-thaw':
|
| 121 |
+
result = class_avg_chainthaw(model, nb_classes=nb_classes,
|
| 122 |
+
loss_op=loss_op,
|
| 123 |
+
train=(X_train, y_train),
|
| 124 |
+
val=(X_val, y_val),
|
| 125 |
+
test=(X_test, y_test),
|
| 126 |
+
batch_size=batch_size,
|
| 127 |
+
epoch_size=epoch_size,
|
| 128 |
+
nb_epochs=nb_epochs,
|
| 129 |
+
checkpoint_weight_path=checkpoint_path,
|
| 130 |
+
f1_init_weight_path=f1_init_path,
|
| 131 |
+
verbose=verbose)
|
| 132 |
+
else:
|
| 133 |
+
result = class_avg_tune_trainable(model, nb_classes=nb_classes,
|
| 134 |
+
loss_op=loss_op,
|
| 135 |
+
optim_op=adam,
|
| 136 |
+
train=(X_train, y_train),
|
| 137 |
+
val=(X_val, y_val),
|
| 138 |
+
test=(X_test, y_test),
|
| 139 |
+
epoch_size=epoch_size,
|
| 140 |
+
nb_epochs=nb_epochs,
|
| 141 |
+
batch_size=batch_size,
|
| 142 |
+
init_weight_path=f1_init_path,
|
| 143 |
+
checkpoint_weight_path=checkpoint_path,
|
| 144 |
+
verbose=verbose)
|
| 145 |
+
return model, result
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def prepare_labels(y_train, y_val, y_test, iter_i, nb_classes):
|
| 149 |
+
# Relabel into binary classification
|
| 150 |
+
y_train_new = relabel(y_train, iter_i, nb_classes)
|
| 151 |
+
y_val_new = relabel(y_val, iter_i, nb_classes)
|
| 152 |
+
y_test_new = relabel(y_test, iter_i, nb_classes)
|
| 153 |
+
return y_train_new, y_val_new, y_test_new
|
| 154 |
+
|
| 155 |
+
def prepare_generators(X_train, y_train_new, X_val, y_val_new, batch_size, epoch_size):
|
| 156 |
+
# Create sample generators
|
| 157 |
+
# Make a fixed validation set to avoid fluctuations in validation
|
| 158 |
+
train_gen = get_data_loader(X_train, y_train_new, batch_size,
|
| 159 |
+
extended_batch_sampler=True)
|
| 160 |
+
val_gen = get_data_loader(X_val, y_val_new, epoch_size,
|
| 161 |
+
extended_batch_sampler=True)
|
| 162 |
+
X_val_resamp, y_val_resamp = next(iter(val_gen))
|
| 163 |
+
return train_gen, X_val_resamp, y_val_resamp
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def class_avg_tune_trainable(model, nb_classes, loss_op, optim_op, train, val, test,
|
| 167 |
+
epoch_size, nb_epochs, batch_size,
|
| 168 |
+
init_weight_path, checkpoint_weight_path, patience=5,
|
| 169 |
+
verbose=True):
|
| 170 |
+
""" Finetunes the given model using the F1 measure.
|
| 171 |
+
|
| 172 |
+
# Arguments:
|
| 173 |
+
model: Model to be finetuned.
|
| 174 |
+
nb_classes: Number of classes in the given dataset.
|
| 175 |
+
train: Training data, given as a tuple of (inputs, outputs)
|
| 176 |
+
val: Validation data, given as a tuple of (inputs, outputs)
|
| 177 |
+
test: Testing data, given as a tuple of (inputs, outputs)
|
| 178 |
+
epoch_size: Number of samples in an epoch.
|
| 179 |
+
nb_epochs: Number of epochs.
|
| 180 |
+
batch_size: Batch size.
|
| 181 |
+
init_weight_path: Filepath where weights will be initially saved before
|
| 182 |
+
training each class. This file will be rewritten by the function.
|
| 183 |
+
checkpoint_weight_path: Filepath where weights will be checkpointed to
|
| 184 |
+
during training. This file will be rewritten by the function.
|
| 185 |
+
verbose: Verbosity flag.
|
| 186 |
+
|
| 187 |
+
# Returns:
|
| 188 |
+
F1 score of the trained model
|
| 189 |
+
"""
|
| 190 |
+
total_f1 = 0
|
| 191 |
+
nb_iter = nb_classes if nb_classes > 2 else 1
|
| 192 |
+
|
| 193 |
+
# Unpack args
|
| 194 |
+
X_train, y_train = train
|
| 195 |
+
X_val, y_val = val
|
| 196 |
+
X_test, y_test = test
|
| 197 |
+
|
| 198 |
+
# Save and reload initial weights after running for
|
| 199 |
+
# each class to avoid learning across classes
|
| 200 |
+
torch.save(model.state_dict(), init_weight_path)
|
| 201 |
+
for i in range(nb_iter):
|
| 202 |
+
if verbose:
|
| 203 |
+
print('Iteration number {}/{}'.format(i+1, nb_iter))
|
| 204 |
+
|
| 205 |
+
model.load_state_dict(torch.load(init_weight_path))
|
| 206 |
+
y_train_new, y_val_new, y_test_new = prepare_labels(y_train, y_val,
|
| 207 |
+
y_test, i, nb_classes)
|
| 208 |
+
train_gen, X_val_resamp, y_val_resamp = \
|
| 209 |
+
prepare_generators(X_train, y_train_new, X_val, y_val_new,
|
| 210 |
+
batch_size, epoch_size)
|
| 211 |
+
|
| 212 |
+
if verbose:
|
| 213 |
+
print("Training..")
|
| 214 |
+
fit_model(model, loss_op, optim_op, train_gen, [(X_val_resamp, y_val_resamp)],
|
| 215 |
+
nb_epochs, checkpoint_weight_path, patience, verbose=0)
|
| 216 |
+
|
| 217 |
+
# Reload the best weights found to avoid overfitting
|
| 218 |
+
# Wait a bit to allow proper closing of weights file
|
| 219 |
+
sleep(1)
|
| 220 |
+
model.load_state_dict(torch.load(checkpoint_weight_path))
|
| 221 |
+
|
| 222 |
+
# Evaluate
|
| 223 |
+
y_pred_val = model(X_val).cpu().numpy()
|
| 224 |
+
y_pred_test = model(X_test).cpu().numpy()
|
| 225 |
+
|
| 226 |
+
f1_test, best_t = find_f1_threshold(y_val_new, y_pred_val,
|
| 227 |
+
y_test_new, y_pred_test)
|
| 228 |
+
if verbose:
|
| 229 |
+
print('f1_test: {}'.format(f1_test))
|
| 230 |
+
print('best_t: {}'.format(best_t))
|
| 231 |
+
total_f1 += f1_test
|
| 232 |
+
|
| 233 |
+
return total_f1 / nb_iter
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def class_avg_chainthaw(model, nb_classes, loss_op, train, val, test, batch_size,
|
| 237 |
+
epoch_size, nb_epochs, checkpoint_weight_path,
|
| 238 |
+
f1_init_weight_path, patience=5,
|
| 239 |
+
initial_lr=0.001, next_lr=0.0001, verbose=True):
|
| 240 |
+
""" Finetunes given model using chain-thaw and evaluates using F1.
|
| 241 |
+
For a dataset with multiple classes, the model is trained once for
|
| 242 |
+
each class, relabeling those classes into a binary classification task.
|
| 243 |
+
The result is an average of all F1 scores for each class.
|
| 244 |
+
|
| 245 |
+
# Arguments:
|
| 246 |
+
model: Model to be finetuned.
|
| 247 |
+
nb_classes: Number of classes in the given dataset.
|
| 248 |
+
train: Training data, given as a tuple of (inputs, outputs)
|
| 249 |
+
val: Validation data, given as a tuple of (inputs, outputs)
|
| 250 |
+
test: Testing data, given as a tuple of (inputs, outputs)
|
| 251 |
+
batch_size: Batch size.
|
| 252 |
+
loss: Loss function to be used during training.
|
| 253 |
+
epoch_size: Number of samples in an epoch.
|
| 254 |
+
nb_epochs: Number of epochs.
|
| 255 |
+
checkpoint_weight_path: Filepath where weights will be checkpointed to
|
| 256 |
+
during training. This file will be rewritten by the function.
|
| 257 |
+
f1_init_weight_path: Filepath where weights will be saved to and
|
| 258 |
+
reloaded from before training each class. This ensures that
|
| 259 |
+
each class is trained independently. This file will be rewritten.
|
| 260 |
+
initial_lr: Initial learning rate. Will only be used for the first
|
| 261 |
+
training step (i.e. the softmax layer)
|
| 262 |
+
next_lr: Learning rate for every subsequent step.
|
| 263 |
+
seed: Random number generator seed.
|
| 264 |
+
verbose: Verbosity flag.
|
| 265 |
+
|
| 266 |
+
# Returns:
|
| 267 |
+
Averaged F1 score.
|
| 268 |
+
"""
|
| 269 |
+
|
| 270 |
+
# Unpack args
|
| 271 |
+
X_train, y_train = train
|
| 272 |
+
X_val, y_val = val
|
| 273 |
+
X_test, y_test = test
|
| 274 |
+
|
| 275 |
+
total_f1 = 0
|
| 276 |
+
nb_iter = nb_classes if nb_classes > 2 else 1
|
| 277 |
+
|
| 278 |
+
torch.save(model.state_dict(), f1_init_weight_path)
|
| 279 |
+
|
| 280 |
+
for i in range(nb_iter):
|
| 281 |
+
if verbose:
|
| 282 |
+
print('Iteration number {}/{}'.format(i+1, nb_iter))
|
| 283 |
+
|
| 284 |
+
model.load_state_dict(torch.load(f1_init_weight_path))
|
| 285 |
+
y_train_new, y_val_new, y_test_new = prepare_labels(y_train, y_val,
|
| 286 |
+
y_test, i, nb_classes)
|
| 287 |
+
train_gen, X_val_resamp, y_val_resamp = \
|
| 288 |
+
prepare_generators(X_train, y_train_new, X_val, y_val_new,
|
| 289 |
+
batch_size, epoch_size)
|
| 290 |
+
|
| 291 |
+
if verbose:
|
| 292 |
+
print("Training..")
|
| 293 |
+
|
| 294 |
+
# Train using chain-thaw
|
| 295 |
+
train_by_chain_thaw(model=model, train_gen=train_gen,
|
| 296 |
+
val_gen=[(X_val_resamp, y_val_resamp)],
|
| 297 |
+
loss_op=loss_op, patience=patience,
|
| 298 |
+
nb_epochs=nb_epochs,
|
| 299 |
+
checkpoint_path=checkpoint_weight_path,
|
| 300 |
+
initial_lr=initial_lr, next_lr=next_lr,
|
| 301 |
+
verbose=verbose)
|
| 302 |
+
|
| 303 |
+
# Evaluate
|
| 304 |
+
y_pred_val = model(X_val).cpu().numpy()
|
| 305 |
+
y_pred_test = model(X_test).cpu().numpy()
|
| 306 |
+
|
| 307 |
+
f1_test, best_t = find_f1_threshold(y_val_new, y_pred_val,
|
| 308 |
+
y_test_new, y_pred_test)
|
| 309 |
+
|
| 310 |
+
if verbose:
|
| 311 |
+
print('f1_test: {}'.format(f1_test))
|
| 312 |
+
print('best_t: {}'.format(best_t))
|
| 313 |
+
total_f1 += f1_test
|
| 314 |
+
|
| 315 |
+
return total_f1 / nb_iter
|
torchmoji/create_vocab.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
from __future__ import print_function, division
|
| 3 |
+
|
| 4 |
+
import glob
|
| 5 |
+
import json
|
| 6 |
+
import uuid
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
from collections import defaultdict, OrderedDict
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from torchmoji.filter_utils import is_special_token
|
| 12 |
+
from torchmoji.word_generator import WordGenerator
|
| 13 |
+
from torchmoji.global_variables import SPECIAL_TOKENS, VOCAB_PATH
|
| 14 |
+
|
| 15 |
+
class VocabBuilder():
|
| 16 |
+
""" Create vocabulary with words extracted from sentences as fed from a
|
| 17 |
+
word generator.
|
| 18 |
+
"""
|
| 19 |
+
def __init__(self, word_gen):
|
| 20 |
+
# initialize any new key with value of 0
|
| 21 |
+
self.word_counts = defaultdict(lambda: 0, {})
|
| 22 |
+
self.word_length_limit=30
|
| 23 |
+
|
| 24 |
+
for token in SPECIAL_TOKENS:
|
| 25 |
+
assert len(token) < self.word_length_limit
|
| 26 |
+
self.word_counts[token] = 0
|
| 27 |
+
self.word_gen = word_gen
|
| 28 |
+
|
| 29 |
+
def count_words_in_sentence(self, words):
|
| 30 |
+
""" Generates word counts for all tokens in the given sentence.
|
| 31 |
+
|
| 32 |
+
# Arguments:
|
| 33 |
+
words: Tokenized sentence whose words should be counted.
|
| 34 |
+
"""
|
| 35 |
+
for word in words:
|
| 36 |
+
if 0 < len(word) and len(word) <= self.word_length_limit:
|
| 37 |
+
try:
|
| 38 |
+
self.word_counts[word] += 1
|
| 39 |
+
except KeyError:
|
| 40 |
+
self.word_counts[word] = 1
|
| 41 |
+
|
| 42 |
+
def save_vocab(self, path=None):
|
| 43 |
+
""" Saves the vocabulary into a file.
|
| 44 |
+
|
| 45 |
+
# Arguments:
|
| 46 |
+
path: Where the vocabulary should be saved. If not specified, a
|
| 47 |
+
randomly generated filename is used instead.
|
| 48 |
+
"""
|
| 49 |
+
dtype = ([('word','|S{}'.format(self.word_length_limit)),('count','int')])
|
| 50 |
+
np_dict = np.array(self.word_counts.items(), dtype=dtype)
|
| 51 |
+
|
| 52 |
+
# sort from highest to lowest frequency
|
| 53 |
+
np_dict[::-1].sort(order='count')
|
| 54 |
+
data = np_dict
|
| 55 |
+
|
| 56 |
+
if path is None:
|
| 57 |
+
path = str(uuid.uuid4())
|
| 58 |
+
|
| 59 |
+
np.savez_compressed(path, data=data)
|
| 60 |
+
print("Saved dict to {}".format(path))
|
| 61 |
+
|
| 62 |
+
def get_next_word(self):
|
| 63 |
+
""" Returns next tokenized sentence from the word geneerator.
|
| 64 |
+
|
| 65 |
+
# Returns:
|
| 66 |
+
List of strings, representing the next tokenized sentence.
|
| 67 |
+
"""
|
| 68 |
+
return self.word_gen.__iter__().next()
|
| 69 |
+
|
| 70 |
+
def count_all_words(self):
|
| 71 |
+
""" Generates word counts for all words in all sentences of the word
|
| 72 |
+
generator.
|
| 73 |
+
"""
|
| 74 |
+
for words, _ in self.word_gen:
|
| 75 |
+
self.count_words_in_sentence(words)
|
| 76 |
+
|
| 77 |
+
class MasterVocab():
|
| 78 |
+
""" Combines vocabularies.
|
| 79 |
+
"""
|
| 80 |
+
def __init__(self):
|
| 81 |
+
|
| 82 |
+
# initialize custom tokens
|
| 83 |
+
self.master_vocab = {}
|
| 84 |
+
|
| 85 |
+
def populate_master_vocab(self, vocab_path, min_words=1, force_appearance=None):
|
| 86 |
+
""" Populates the master vocabulary using all vocabularies found in the
|
| 87 |
+
given path. Vocabularies should be named *.npz. Expects the
|
| 88 |
+
vocabularies to be numpy arrays with counts. Normalizes the counts
|
| 89 |
+
and combines them.
|
| 90 |
+
|
| 91 |
+
# Arguments:
|
| 92 |
+
vocab_path: Path containing vocabularies to be combined.
|
| 93 |
+
min_words: Minimum amount of occurences a word must have in order
|
| 94 |
+
to be included in the master vocabulary.
|
| 95 |
+
force_appearance: Optional vocabulary filename that will be added
|
| 96 |
+
to the master vocabulary no matter what. This vocabulary must
|
| 97 |
+
be present in vocab_path.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
paths = glob.glob(vocab_path + '*.npz')
|
| 101 |
+
sizes = {path: 0 for path in paths}
|
| 102 |
+
dicts = {path: {} for path in paths}
|
| 103 |
+
|
| 104 |
+
# set up and get sizes of individual dictionaries
|
| 105 |
+
for path in paths:
|
| 106 |
+
np_data = np.load(path)['data']
|
| 107 |
+
|
| 108 |
+
for entry in np_data:
|
| 109 |
+
word, count = entry
|
| 110 |
+
if count < min_words:
|
| 111 |
+
continue
|
| 112 |
+
if is_special_token(word):
|
| 113 |
+
continue
|
| 114 |
+
dicts[path][word] = count
|
| 115 |
+
|
| 116 |
+
sizes[path] = sum(dicts[path].values())
|
| 117 |
+
print('Overall word count for {} -> {}'.format(path, sizes[path]))
|
| 118 |
+
print('Overall word number for {} -> {}'.format(path, len(dicts[path])))
|
| 119 |
+
|
| 120 |
+
vocab_of_max_size = max(sizes, key=sizes.get)
|
| 121 |
+
max_size = sizes[vocab_of_max_size]
|
| 122 |
+
print('Min: {}, {}, {}'.format(sizes, vocab_of_max_size, max_size))
|
| 123 |
+
|
| 124 |
+
# can force one vocabulary to always be present
|
| 125 |
+
if force_appearance is not None:
|
| 126 |
+
force_appearance_path = [p for p in paths if force_appearance in p][0]
|
| 127 |
+
force_appearance_vocab = deepcopy(dicts[force_appearance_path])
|
| 128 |
+
print(force_appearance_path)
|
| 129 |
+
else:
|
| 130 |
+
force_appearance_path, force_appearance_vocab = None, None
|
| 131 |
+
|
| 132 |
+
# normalize word counts before inserting into master dict
|
| 133 |
+
for path in paths:
|
| 134 |
+
normalization_factor = max_size / sizes[path]
|
| 135 |
+
print('Norm factor for path {} -> {}'.format(path, normalization_factor))
|
| 136 |
+
|
| 137 |
+
for word in dicts[path]:
|
| 138 |
+
if is_special_token(word):
|
| 139 |
+
print("SPECIAL - ", word)
|
| 140 |
+
continue
|
| 141 |
+
normalized_count = dicts[path][word] * normalization_factor
|
| 142 |
+
|
| 143 |
+
# can force one vocabulary to always be present
|
| 144 |
+
if force_appearance_vocab is not None:
|
| 145 |
+
try:
|
| 146 |
+
force_word_count = force_appearance_vocab[word]
|
| 147 |
+
except KeyError:
|
| 148 |
+
continue
|
| 149 |
+
#if force_word_count < 5:
|
| 150 |
+
#continue
|
| 151 |
+
|
| 152 |
+
if word in self.master_vocab:
|
| 153 |
+
self.master_vocab[word] += normalized_count
|
| 154 |
+
else:
|
| 155 |
+
self.master_vocab[word] = normalized_count
|
| 156 |
+
|
| 157 |
+
print('Size of master_dict {}'.format(len(self.master_vocab)))
|
| 158 |
+
print("Hashes for master dict: {}".format(
|
| 159 |
+
len([w for w in self.master_vocab if '#' in w[0]])))
|
| 160 |
+
|
| 161 |
+
def save_vocab(self, path_count, path_vocab, word_limit=100000):
|
| 162 |
+
""" Saves the master vocabulary into a file.
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
# reserve space for 10 special tokens
|
| 166 |
+
words = OrderedDict()
|
| 167 |
+
for token in SPECIAL_TOKENS:
|
| 168 |
+
# store -1 instead of np.inf, which can overflow
|
| 169 |
+
words[token] = -1
|
| 170 |
+
|
| 171 |
+
# sort words by frequency
|
| 172 |
+
desc_order = OrderedDict(sorted(self.master_vocab.items(),
|
| 173 |
+
key=lambda kv: kv[1], reverse=True))
|
| 174 |
+
words.update(desc_order)
|
| 175 |
+
|
| 176 |
+
# use encoding of up to 30 characters (no token conversions)
|
| 177 |
+
# use float to store large numbers (we don't care about precision loss)
|
| 178 |
+
np_vocab = np.array(words.items(),
|
| 179 |
+
dtype=([('word','|S30'),('count','float')]))
|
| 180 |
+
|
| 181 |
+
# output count for debugging
|
| 182 |
+
counts = np_vocab[:word_limit]
|
| 183 |
+
np.savez_compressed(path_count, counts=counts)
|
| 184 |
+
|
| 185 |
+
# output the index of each word for easy lookup
|
| 186 |
+
final_words = OrderedDict()
|
| 187 |
+
for i, w in enumerate(words.keys()[:word_limit]):
|
| 188 |
+
final_words.update({w:i})
|
| 189 |
+
with open(path_vocab, 'w') as f:
|
| 190 |
+
f.write(json.dumps(final_words, indent=4, separators=(',', ': ')))
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def all_words_in_sentences(sentences):
|
| 194 |
+
""" Extracts all unique words from a given list of sentences.
|
| 195 |
+
|
| 196 |
+
# Arguments:
|
| 197 |
+
sentences: List or word generator of sentences to be processed.
|
| 198 |
+
|
| 199 |
+
# Returns:
|
| 200 |
+
List of all unique words contained in the given sentences.
|
| 201 |
+
"""
|
| 202 |
+
vocab = []
|
| 203 |
+
if isinstance(sentences, WordGenerator):
|
| 204 |
+
sentences = [s for s, _ in sentences]
|
| 205 |
+
|
| 206 |
+
for sentence in sentences:
|
| 207 |
+
for word in sentence:
|
| 208 |
+
if word not in vocab:
|
| 209 |
+
vocab.append(word)
|
| 210 |
+
|
| 211 |
+
return vocab
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def extend_vocab_in_file(vocab, max_tokens=10000, vocab_path=VOCAB_PATH):
|
| 215 |
+
""" Extends JSON-formatted vocabulary with words from vocab that are not
|
| 216 |
+
present in the current vocabulary. Adds up to max_tokens words.
|
| 217 |
+
Overwrites file in vocab_path.
|
| 218 |
+
|
| 219 |
+
# Arguments:
|
| 220 |
+
new_vocab: Vocabulary to be added. MUST have word_counts populated, i.e.
|
| 221 |
+
must have run count_all_words() previously.
|
| 222 |
+
max_tokens: Maximum number of words to be added.
|
| 223 |
+
vocab_path: Path to the vocabulary json which is to be extended.
|
| 224 |
+
"""
|
| 225 |
+
try:
|
| 226 |
+
with open(vocab_path, 'r') as f:
|
| 227 |
+
current_vocab = json.load(f)
|
| 228 |
+
except IOError:
|
| 229 |
+
print('Vocabulary file not found, expected at ' + vocab_path)
|
| 230 |
+
return
|
| 231 |
+
|
| 232 |
+
extend_vocab(current_vocab, vocab, max_tokens)
|
| 233 |
+
|
| 234 |
+
# Save back to file
|
| 235 |
+
with open(vocab_path, 'w') as f:
|
| 236 |
+
json.dump(current_vocab, f, sort_keys=True, indent=4, separators=(',',': '))
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def extend_vocab(current_vocab, new_vocab, max_tokens=10000):
|
| 240 |
+
""" Extends current vocabulary with words from vocab that are not
|
| 241 |
+
present in the current vocabulary. Adds up to max_tokens words.
|
| 242 |
+
|
| 243 |
+
# Arguments:
|
| 244 |
+
current_vocab: Current dictionary of tokens.
|
| 245 |
+
new_vocab: Vocabulary to be added. MUST have word_counts populated, i.e.
|
| 246 |
+
must have run count_all_words() previously.
|
| 247 |
+
max_tokens: Maximum number of words to be added.
|
| 248 |
+
|
| 249 |
+
# Returns:
|
| 250 |
+
How many new tokens have been added.
|
| 251 |
+
"""
|
| 252 |
+
if max_tokens < 0:
|
| 253 |
+
max_tokens = 10000
|
| 254 |
+
|
| 255 |
+
words = OrderedDict()
|
| 256 |
+
|
| 257 |
+
# sort words by frequency
|
| 258 |
+
desc_order = OrderedDict(sorted(new_vocab.word_counts.items(),
|
| 259 |
+
key=lambda kv: kv[1], reverse=True))
|
| 260 |
+
words.update(desc_order)
|
| 261 |
+
|
| 262 |
+
base_index = len(current_vocab.keys())
|
| 263 |
+
added = 0
|
| 264 |
+
for word in words:
|
| 265 |
+
if added >= max_tokens:
|
| 266 |
+
break
|
| 267 |
+
if word not in current_vocab.keys():
|
| 268 |
+
current_vocab[word] = base_index + added
|
| 269 |
+
added += 1
|
| 270 |
+
|
| 271 |
+
return added
|
torchmoji/filter_input.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
from __future__ import print_function, division
|
| 3 |
+
import codecs
|
| 4 |
+
import csv
|
| 5 |
+
import numpy as np
|
| 6 |
+
from emoji import UNICODE_EMOJI
|
| 7 |
+
|
| 8 |
+
def read_english(path="english_words.txt", add_emojis=True):
|
| 9 |
+
# read english words for filtering (includes emojis as part of set)
|
| 10 |
+
english = set()
|
| 11 |
+
with codecs.open(path, "r", "utf-8") as f:
|
| 12 |
+
for line in f:
|
| 13 |
+
line = line.strip().lower().replace('\n', '')
|
| 14 |
+
if len(line):
|
| 15 |
+
english.add(line)
|
| 16 |
+
if add_emojis:
|
| 17 |
+
for e in UNICODE_EMOJI:
|
| 18 |
+
english.add(e)
|
| 19 |
+
return english
|
| 20 |
+
|
| 21 |
+
def read_wanted_emojis(path="wanted_emojis.csv"):
|
| 22 |
+
emojis = []
|
| 23 |
+
with open(path, 'rb') as f:
|
| 24 |
+
reader = csv.reader(f)
|
| 25 |
+
for line in reader:
|
| 26 |
+
line = line[0].strip().replace('\n', '')
|
| 27 |
+
line = line.decode('unicode-escape')
|
| 28 |
+
emojis.append(line)
|
| 29 |
+
return emojis
|
| 30 |
+
|
| 31 |
+
def read_non_english_users(path="unwanted_users.npz"):
|
| 32 |
+
try:
|
| 33 |
+
neu_set = set(np.load(path)['userids'])
|
| 34 |
+
except IOError:
|
| 35 |
+
neu_set = set()
|
| 36 |
+
return neu_set
|
torchmoji/filter_utils.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
from __future__ import print_function, division, unicode_literals
|
| 4 |
+
import sys
|
| 5 |
+
import re
|
| 6 |
+
import string
|
| 7 |
+
import emoji
|
| 8 |
+
from itertools import groupby
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
from torchmoji.tokenizer import RE_MENTION, RE_URL
|
| 12 |
+
from torchmoji.global_variables import SPECIAL_TOKENS
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
unichr # Python 2
|
| 16 |
+
except NameError:
|
| 17 |
+
unichr = chr # Python 3
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
AtMentionRegex = re.compile(RE_MENTION)
|
| 21 |
+
urlRegex = re.compile(RE_URL)
|
| 22 |
+
|
| 23 |
+
# from http://bit.ly/2rdjgjE (UTF-8 encodings and Unicode chars)
|
| 24 |
+
VARIATION_SELECTORS = [ '\ufe00',
|
| 25 |
+
'\ufe01',
|
| 26 |
+
'\ufe02',
|
| 27 |
+
'\ufe03',
|
| 28 |
+
'\ufe04',
|
| 29 |
+
'\ufe05',
|
| 30 |
+
'\ufe06',
|
| 31 |
+
'\ufe07',
|
| 32 |
+
'\ufe08',
|
| 33 |
+
'\ufe09',
|
| 34 |
+
'\ufe0a',
|
| 35 |
+
'\ufe0b',
|
| 36 |
+
'\ufe0c',
|
| 37 |
+
'\ufe0d',
|
| 38 |
+
'\ufe0e',
|
| 39 |
+
'\ufe0f']
|
| 40 |
+
|
| 41 |
+
# from https://stackoverflow.com/questions/92438/stripping-non-printable-characters-from-a-string-in-python
|
| 42 |
+
ALL_CHARS = (unichr(i) for i in range(sys.maxunicode))
|
| 43 |
+
CONTROL_CHARS = ''.join(map(unichr, list(range(0,32)) + list(range(127,160))))
|
| 44 |
+
CONTROL_CHAR_REGEX = re.compile('[%s]' % re.escape(CONTROL_CHARS))
|
| 45 |
+
|
| 46 |
+
def is_special_token(word):
|
| 47 |
+
equal = False
|
| 48 |
+
for spec in SPECIAL_TOKENS:
|
| 49 |
+
if word == spec:
|
| 50 |
+
equal = True
|
| 51 |
+
break
|
| 52 |
+
return equal
|
| 53 |
+
|
| 54 |
+
def mostly_english(words, english, pct_eng_short=0.5, pct_eng_long=0.6, ignore_special_tokens=True, min_length=2):
|
| 55 |
+
""" Ensure text meets threshold for containing English words """
|
| 56 |
+
|
| 57 |
+
n_words = 0
|
| 58 |
+
n_english = 0
|
| 59 |
+
|
| 60 |
+
if english is None:
|
| 61 |
+
return True, 0, 0
|
| 62 |
+
|
| 63 |
+
for w in words:
|
| 64 |
+
if len(w) < min_length:
|
| 65 |
+
continue
|
| 66 |
+
if punct_word(w):
|
| 67 |
+
continue
|
| 68 |
+
if ignore_special_tokens and is_special_token(w):
|
| 69 |
+
continue
|
| 70 |
+
n_words += 1
|
| 71 |
+
if w in english:
|
| 72 |
+
n_english += 1
|
| 73 |
+
|
| 74 |
+
if n_words < 2:
|
| 75 |
+
return True, n_words, n_english
|
| 76 |
+
if n_words < 5:
|
| 77 |
+
valid_english = n_english >= n_words * pct_eng_short
|
| 78 |
+
else:
|
| 79 |
+
valid_english = n_english >= n_words * pct_eng_long
|
| 80 |
+
return valid_english, n_words, n_english
|
| 81 |
+
|
| 82 |
+
def correct_length(words, min_words, max_words, ignore_special_tokens=True):
|
| 83 |
+
""" Ensure text meets threshold for containing English words
|
| 84 |
+
and that it's within the min and max words limits. """
|
| 85 |
+
|
| 86 |
+
if min_words is None:
|
| 87 |
+
min_words = 0
|
| 88 |
+
|
| 89 |
+
if max_words is None:
|
| 90 |
+
max_words = 99999
|
| 91 |
+
|
| 92 |
+
n_words = 0
|
| 93 |
+
for w in words:
|
| 94 |
+
if punct_word(w):
|
| 95 |
+
continue
|
| 96 |
+
if ignore_special_tokens and is_special_token(w):
|
| 97 |
+
continue
|
| 98 |
+
n_words += 1
|
| 99 |
+
valid = min_words <= n_words and n_words <= max_words
|
| 100 |
+
return valid
|
| 101 |
+
|
| 102 |
+
def punct_word(word, punctuation=string.punctuation):
|
| 103 |
+
return all([True if c in punctuation else False for c in word])
|
| 104 |
+
|
| 105 |
+
def load_non_english_user_set():
|
| 106 |
+
non_english_user_set = set(np.load('uids.npz')['data'])
|
| 107 |
+
return non_english_user_set
|
| 108 |
+
|
| 109 |
+
def non_english_user(userid, non_english_user_set):
|
| 110 |
+
neu_found = int(userid) in non_english_user_set
|
| 111 |
+
return neu_found
|
| 112 |
+
|
| 113 |
+
def separate_emojis_and_text(text):
|
| 114 |
+
emoji_chars = []
|
| 115 |
+
non_emoji_chars = []
|
| 116 |
+
for c in text:
|
| 117 |
+
if c in emoji.UNICODE_EMOJI:
|
| 118 |
+
emoji_chars.append(c)
|
| 119 |
+
else:
|
| 120 |
+
non_emoji_chars.append(c)
|
| 121 |
+
return ''.join(emoji_chars), ''.join(non_emoji_chars)
|
| 122 |
+
|
| 123 |
+
def extract_emojis(text, wanted_emojis):
|
| 124 |
+
text = remove_variation_selectors(text)
|
| 125 |
+
return [c for c in text if c in wanted_emojis]
|
| 126 |
+
|
| 127 |
+
def remove_variation_selectors(text):
|
| 128 |
+
""" Remove styling glyph variants for Unicode characters.
|
| 129 |
+
For instance, remove skin color from emojis.
|
| 130 |
+
"""
|
| 131 |
+
for var in VARIATION_SELECTORS:
|
| 132 |
+
text = text.replace(var, '')
|
| 133 |
+
return text
|
| 134 |
+
|
| 135 |
+
def shorten_word(word):
|
| 136 |
+
""" Shorten groupings of 3+ identical consecutive chars to 2, e.g. '!!!!' --> '!!'
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
# only shorten ASCII words
|
| 140 |
+
try:
|
| 141 |
+
word.decode('ascii')
|
| 142 |
+
except (UnicodeDecodeError, UnicodeEncodeError, AttributeError) as e:
|
| 143 |
+
return word
|
| 144 |
+
|
| 145 |
+
# must have at least 3 char to be shortened
|
| 146 |
+
if len(word) < 3:
|
| 147 |
+
return word
|
| 148 |
+
|
| 149 |
+
# find groups of 3+ consecutive letters
|
| 150 |
+
letter_groups = [list(g) for k, g in groupby(word)]
|
| 151 |
+
triple_or_more = [''.join(g) for g in letter_groups if len(g) >= 3]
|
| 152 |
+
if len(triple_or_more) == 0:
|
| 153 |
+
return word
|
| 154 |
+
|
| 155 |
+
# replace letters to find the short word
|
| 156 |
+
short_word = word
|
| 157 |
+
for trip in triple_or_more:
|
| 158 |
+
short_word = short_word.replace(trip, trip[0]*2)
|
| 159 |
+
|
| 160 |
+
return short_word
|
| 161 |
+
|
| 162 |
+
def detect_special_tokens(word):
|
| 163 |
+
try:
|
| 164 |
+
int(word)
|
| 165 |
+
word = SPECIAL_TOKENS[4]
|
| 166 |
+
except ValueError:
|
| 167 |
+
if AtMentionRegex.findall(word):
|
| 168 |
+
word = SPECIAL_TOKENS[2]
|
| 169 |
+
elif urlRegex.findall(word):
|
| 170 |
+
word = SPECIAL_TOKENS[3]
|
| 171 |
+
return word
|
| 172 |
+
|
| 173 |
+
def process_word(word):
|
| 174 |
+
""" Shortening and converting the word to a special token if relevant.
|
| 175 |
+
"""
|
| 176 |
+
word = shorten_word(word)
|
| 177 |
+
word = detect_special_tokens(word)
|
| 178 |
+
return word
|
| 179 |
+
|
| 180 |
+
def remove_control_chars(text):
|
| 181 |
+
return CONTROL_CHAR_REGEX.sub('', text)
|
| 182 |
+
|
| 183 |
+
def convert_nonbreaking_space(text):
|
| 184 |
+
# ugly hack handling non-breaking space no matter how badly it's been encoded in the input
|
| 185 |
+
for r in ['\\\\xc2', '\\xc2', '\xc2', '\\\\xa0', '\\xa0', '\xa0']:
|
| 186 |
+
text = text.replace(r, ' ')
|
| 187 |
+
return text
|
| 188 |
+
|
| 189 |
+
def convert_linebreaks(text):
|
| 190 |
+
# ugly hack handling non-breaking space no matter how badly it's been encoded in the input
|
| 191 |
+
# space around to ensure proper tokenization
|
| 192 |
+
for r in ['\\\\n', '\\n', '\n', '\\\\r', '\\r', '\r', '<br>']:
|
| 193 |
+
text = text.replace(r, ' ' + SPECIAL_TOKENS[5] + ' ')
|
| 194 |
+
return text
|
torchmoji/finetuning.py
ADDED
|
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
""" Finetuning functions for doing transfer learning to new datasets.
|
| 3 |
+
"""
|
| 4 |
+
from __future__ import print_function
|
| 5 |
+
|
| 6 |
+
import uuid
|
| 7 |
+
from time import sleep
|
| 8 |
+
from io import open
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
import pickle
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.optim as optim
|
| 17 |
+
from sklearn.metrics import accuracy_score
|
| 18 |
+
from torch.autograd import Variable
|
| 19 |
+
from torch.utils.data import Dataset, DataLoader
|
| 20 |
+
from torch.utils.data.sampler import BatchSampler, SequentialSampler
|
| 21 |
+
from torch.nn.utils import clip_grad_norm_
|
| 22 |
+
|
| 23 |
+
from sklearn.metrics import f1_score
|
| 24 |
+
|
| 25 |
+
from torchmoji.global_variables import (FINETUNING_METHODS,
|
| 26 |
+
FINETUNING_METRICS,
|
| 27 |
+
WEIGHTS_DIR)
|
| 28 |
+
from torchmoji.tokenizer import tokenize
|
| 29 |
+
from torchmoji.sentence_tokenizer import SentenceTokenizer
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
unicode
|
| 33 |
+
IS_PYTHON2 = True
|
| 34 |
+
except NameError:
|
| 35 |
+
unicode = str
|
| 36 |
+
IS_PYTHON2 = False
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def load_benchmark(path, vocab, extend_with=0):
|
| 40 |
+
""" Loads the given benchmark dataset.
|
| 41 |
+
|
| 42 |
+
Tokenizes the texts using the provided vocabulary, extending it with
|
| 43 |
+
words from the training dataset if extend_with > 0. Splits them into
|
| 44 |
+
three lists: training, validation and testing (in that order).
|
| 45 |
+
|
| 46 |
+
Also calculates the maximum length of the texts and the
|
| 47 |
+
suggested batch_size.
|
| 48 |
+
|
| 49 |
+
# Arguments:
|
| 50 |
+
path: Path to the dataset to be loaded.
|
| 51 |
+
vocab: Vocabulary to be used for tokenizing texts.
|
| 52 |
+
extend_with: If > 0, the vocabulary will be extended with up to
|
| 53 |
+
extend_with tokens from the training set before tokenizing.
|
| 54 |
+
|
| 55 |
+
# Returns:
|
| 56 |
+
A dictionary with the following fields:
|
| 57 |
+
texts: List of three lists, containing tokenized inputs for
|
| 58 |
+
training, validation and testing (in that order).
|
| 59 |
+
labels: List of three lists, containing labels for training,
|
| 60 |
+
validation and testing (in that order).
|
| 61 |
+
added: Number of tokens added to the vocabulary.
|
| 62 |
+
batch_size: Batch size.
|
| 63 |
+
maxlen: Maximum length of an input.
|
| 64 |
+
"""
|
| 65 |
+
# Pre-processing dataset
|
| 66 |
+
with open(path, 'rb') as dataset:
|
| 67 |
+
if IS_PYTHON2:
|
| 68 |
+
data = pickle.load(dataset)
|
| 69 |
+
else:
|
| 70 |
+
data = pickle.load(dataset, fix_imports=True)
|
| 71 |
+
|
| 72 |
+
# Decode data
|
| 73 |
+
try:
|
| 74 |
+
texts = [unicode(x) for x in data['texts']]
|
| 75 |
+
except UnicodeDecodeError:
|
| 76 |
+
texts = [x.decode('utf-8') for x in data['texts']]
|
| 77 |
+
|
| 78 |
+
# Extract labels
|
| 79 |
+
labels = [x['label'] for x in data['info']]
|
| 80 |
+
|
| 81 |
+
batch_size, maxlen = calculate_batchsize_maxlen(texts)
|
| 82 |
+
|
| 83 |
+
st = SentenceTokenizer(vocab, maxlen)
|
| 84 |
+
|
| 85 |
+
# Split up dataset. Extend the existing vocabulary with up to extend_with
|
| 86 |
+
# tokens from the training dataset.
|
| 87 |
+
texts, labels, added = st.split_train_val_test(texts,
|
| 88 |
+
labels,
|
| 89 |
+
[data['train_ind'],
|
| 90 |
+
data['val_ind'],
|
| 91 |
+
data['test_ind']],
|
| 92 |
+
extend_with=extend_with)
|
| 93 |
+
return {'texts': texts,
|
| 94 |
+
'labels': labels,
|
| 95 |
+
'added': added,
|
| 96 |
+
'batch_size': batch_size,
|
| 97 |
+
'maxlen': maxlen}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def calculate_batchsize_maxlen(texts):
|
| 101 |
+
""" Calculates the maximum length in the provided texts and a suitable
|
| 102 |
+
batch size. Rounds up maxlen to the nearest multiple of ten.
|
| 103 |
+
|
| 104 |
+
# Arguments:
|
| 105 |
+
texts: List of inputs.
|
| 106 |
+
|
| 107 |
+
# Returns:
|
| 108 |
+
Batch size,
|
| 109 |
+
max length
|
| 110 |
+
"""
|
| 111 |
+
def roundup(x):
|
| 112 |
+
return int(math.ceil(x / 10.0)) * 10
|
| 113 |
+
|
| 114 |
+
# Calculate max length of sequences considered
|
| 115 |
+
# Adjust batch_size accordingly to prevent GPU overflow
|
| 116 |
+
lengths = [len(tokenize(t)) for t in texts]
|
| 117 |
+
maxlen = roundup(np.percentile(lengths, 80.0))
|
| 118 |
+
batch_size = 250 if maxlen <= 100 else 50
|
| 119 |
+
return batch_size, maxlen
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def freeze_layers(model, unfrozen_types=[], unfrozen_keyword=None):
|
| 124 |
+
""" Freezes all layers in the given model, except for ones that are
|
| 125 |
+
explicitly specified to not be frozen.
|
| 126 |
+
|
| 127 |
+
# Arguments:
|
| 128 |
+
model: Model whose layers should be modified.
|
| 129 |
+
unfrozen_types: List of layer types which shouldn't be frozen.
|
| 130 |
+
unfrozen_keyword: Name keywords of layers that shouldn't be frozen.
|
| 131 |
+
|
| 132 |
+
# Returns:
|
| 133 |
+
Model with the selected layers frozen.
|
| 134 |
+
"""
|
| 135 |
+
# Get trainable modules
|
| 136 |
+
trainable_modules = [(n, m) for n, m in model.named_children() if len([id(p) for p in m.parameters()]) != 0]
|
| 137 |
+
for name, module in trainable_modules:
|
| 138 |
+
trainable = (any(typ in str(module) for typ in unfrozen_types) or
|
| 139 |
+
(unfrozen_keyword is not None and unfrozen_keyword.lower() in name.lower()))
|
| 140 |
+
change_trainable(module, trainable, verbose=False)
|
| 141 |
+
return model
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def change_trainable(module, trainable, verbose=False):
|
| 145 |
+
""" Helper method that freezes or unfreezes a given layer.
|
| 146 |
+
|
| 147 |
+
# Arguments:
|
| 148 |
+
module: Module to be modified.
|
| 149 |
+
trainable: Whether the layer should be frozen or unfrozen.
|
| 150 |
+
verbose: Verbosity flag.
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
if verbose: print('Changing MODULE', module, 'to trainable =', trainable)
|
| 154 |
+
for name, param in module.named_parameters():
|
| 155 |
+
if verbose: print('Setting weight', name, 'to trainable =', trainable)
|
| 156 |
+
param.requires_grad = trainable
|
| 157 |
+
|
| 158 |
+
if verbose:
|
| 159 |
+
action = 'Unfroze' if trainable else 'Froze'
|
| 160 |
+
if verbose: print("{} {}".format(action, module))
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def find_f1_threshold(model, val_gen, test_gen, average='binary'):
|
| 164 |
+
""" Choose a threshold for F1 based on the validation dataset
|
| 165 |
+
(see https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4442797/
|
| 166 |
+
for details on why to find another threshold than simply 0.5)
|
| 167 |
+
|
| 168 |
+
# Arguments:
|
| 169 |
+
model: pyTorch model
|
| 170 |
+
val_gen: Validation set dataloader.
|
| 171 |
+
test_gen: Testing set dataloader.
|
| 172 |
+
|
| 173 |
+
# Returns:
|
| 174 |
+
F1 score for the given data and
|
| 175 |
+
the corresponding F1 threshold
|
| 176 |
+
"""
|
| 177 |
+
thresholds = np.arange(0.01, 0.5, step=0.01)
|
| 178 |
+
f1_scores = []
|
| 179 |
+
|
| 180 |
+
model.eval()
|
| 181 |
+
val_out = [(y, model(X)) for X, y in val_gen]
|
| 182 |
+
y_val, y_pred_val = (list(t) for t in zip(*val_out))
|
| 183 |
+
|
| 184 |
+
test_out = [(y, model(X)) for X, y in test_gen]
|
| 185 |
+
y_test, y_pred_test = (list(t) for t in zip(*val_out))
|
| 186 |
+
|
| 187 |
+
for t in thresholds:
|
| 188 |
+
y_pred_val_ind = (y_pred_val > t)
|
| 189 |
+
f1_val = f1_score(y_val, y_pred_val_ind, average=average)
|
| 190 |
+
f1_scores.append(f1_val)
|
| 191 |
+
|
| 192 |
+
best_t = thresholds[np.argmax(f1_scores)]
|
| 193 |
+
y_pred_ind = (y_pred_test > best_t)
|
| 194 |
+
f1_test = f1_score(y_test, y_pred_ind, average=average)
|
| 195 |
+
return f1_test, best_t
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def finetune(model, texts, labels, nb_classes, batch_size, method,
|
| 199 |
+
metric='acc', epoch_size=5000, nb_epochs=1000, embed_l2=1E-6,
|
| 200 |
+
verbose=1):
|
| 201 |
+
""" Compiles and finetunes the given pytorch model.
|
| 202 |
+
|
| 203 |
+
# Arguments:
|
| 204 |
+
model: Model to be finetuned
|
| 205 |
+
texts: List of three lists, containing tokenized inputs for training,
|
| 206 |
+
validation and testing (in that order).
|
| 207 |
+
labels: List of three lists, containing labels for training,
|
| 208 |
+
validation and testing (in that order).
|
| 209 |
+
nb_classes: Number of classes in the dataset.
|
| 210 |
+
batch_size: Batch size.
|
| 211 |
+
method: Finetuning method to be used. For available methods, see
|
| 212 |
+
FINETUNING_METHODS in global_variables.py.
|
| 213 |
+
metric: Evaluation metric to be used. For available metrics, see
|
| 214 |
+
FINETUNING_METRICS in global_variables.py.
|
| 215 |
+
epoch_size: Number of samples in an epoch.
|
| 216 |
+
nb_epochs: Number of epochs. Doesn't matter much as early stopping is used.
|
| 217 |
+
embed_l2: L2 regularization for the embedding layer.
|
| 218 |
+
verbose: Verbosity flag.
|
| 219 |
+
|
| 220 |
+
# Returns:
|
| 221 |
+
Model after finetuning,
|
| 222 |
+
score after finetuning using the provided metric.
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
if method not in FINETUNING_METHODS:
|
| 226 |
+
raise ValueError('ERROR (finetune): Invalid method parameter. '
|
| 227 |
+
'Available options: {}'.format(FINETUNING_METHODS))
|
| 228 |
+
if metric not in FINETUNING_METRICS:
|
| 229 |
+
raise ValueError('ERROR (finetune): Invalid metric parameter. '
|
| 230 |
+
'Available options: {}'.format(FINETUNING_METRICS))
|
| 231 |
+
|
| 232 |
+
train_gen = get_data_loader(texts[0], labels[0], batch_size,
|
| 233 |
+
extended_batch_sampler=True, epoch_size=epoch_size)
|
| 234 |
+
val_gen = get_data_loader(texts[1], labels[1], batch_size,
|
| 235 |
+
extended_batch_sampler=False)
|
| 236 |
+
test_gen = get_data_loader(texts[2], labels[2], batch_size,
|
| 237 |
+
extended_batch_sampler=False)
|
| 238 |
+
|
| 239 |
+
checkpoint_path = '{}/torchmoji-checkpoint-{}.bin' \
|
| 240 |
+
.format(WEIGHTS_DIR, str(uuid.uuid4()))
|
| 241 |
+
|
| 242 |
+
if method in ['last', 'new']:
|
| 243 |
+
lr = 0.001
|
| 244 |
+
elif method in ['full', 'chain-thaw']:
|
| 245 |
+
lr = 0.0001
|
| 246 |
+
|
| 247 |
+
loss_op = nn.BCEWithLogitsLoss() if nb_classes <= 2 \
|
| 248 |
+
else nn.CrossEntropyLoss()
|
| 249 |
+
|
| 250 |
+
# Freeze layers if using last
|
| 251 |
+
if method == 'last':
|
| 252 |
+
model = freeze_layers(model, unfrozen_keyword='output_layer')
|
| 253 |
+
|
| 254 |
+
# Define optimizer, for chain-thaw we define it later (after freezing)
|
| 255 |
+
if method == 'last':
|
| 256 |
+
adam = optim.Adam((p for p in model.parameters() if p.requires_grad), lr=lr)
|
| 257 |
+
elif method in ['full', 'new']:
|
| 258 |
+
# Add L2 regulation on embeddings only
|
| 259 |
+
embed_params_id = [id(p) for p in model.embed.parameters()]
|
| 260 |
+
output_layer_params_id = [id(p) for p in model.output_layer.parameters()]
|
| 261 |
+
base_params = [p for p in model.parameters()
|
| 262 |
+
if id(p) not in embed_params_id and id(p) not in output_layer_params_id and p.requires_grad]
|
| 263 |
+
embed_params = [p for p in model.parameters() if id(p) in embed_params_id and p.requires_grad]
|
| 264 |
+
output_layer_params = [p for p in model.parameters() if id(p) in output_layer_params_id and p.requires_grad]
|
| 265 |
+
adam = optim.Adam([
|
| 266 |
+
{'params': base_params},
|
| 267 |
+
{'params': embed_params, 'weight_decay': embed_l2},
|
| 268 |
+
{'params': output_layer_params, 'lr': 0.001},
|
| 269 |
+
], lr=lr)
|
| 270 |
+
|
| 271 |
+
# Training
|
| 272 |
+
if verbose:
|
| 273 |
+
print('Method: {}'.format(method))
|
| 274 |
+
print('Metric: {}'.format(metric))
|
| 275 |
+
print('Classes: {}'.format(nb_classes))
|
| 276 |
+
|
| 277 |
+
if method == 'chain-thaw':
|
| 278 |
+
result = chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op, embed_l2=embed_l2,
|
| 279 |
+
evaluate=metric, verbose=verbose)
|
| 280 |
+
else:
|
| 281 |
+
result = tune_trainable(model, loss_op, adam, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path,
|
| 282 |
+
evaluate=metric, verbose=verbose)
|
| 283 |
+
return model, result
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def tune_trainable(model, loss_op, optim_op, train_gen, val_gen, test_gen,
|
| 287 |
+
nb_epochs, checkpoint_path, patience=5, evaluate='acc',
|
| 288 |
+
verbose=2):
|
| 289 |
+
""" Finetunes the given model using the accuracy measure.
|
| 290 |
+
|
| 291 |
+
# Arguments:
|
| 292 |
+
model: Model to be finetuned.
|
| 293 |
+
nb_classes: Number of classes in the given dataset.
|
| 294 |
+
train: Training data, given as a tuple of (inputs, outputs)
|
| 295 |
+
val: Validation data, given as a tuple of (inputs, outputs)
|
| 296 |
+
test: Testing data, given as a tuple of (inputs, outputs)
|
| 297 |
+
epoch_size: Number of samples in an epoch.
|
| 298 |
+
nb_epochs: Number of epochs.
|
| 299 |
+
batch_size: Batch size.
|
| 300 |
+
checkpoint_weight_path: Filepath where weights will be checkpointed to
|
| 301 |
+
during training. This file will be rewritten by the function.
|
| 302 |
+
patience: Patience for callback methods.
|
| 303 |
+
evaluate: Evaluation method to use. Can be 'acc' or 'weighted_f1'.
|
| 304 |
+
verbose: Verbosity flag.
|
| 305 |
+
|
| 306 |
+
# Returns:
|
| 307 |
+
Accuracy of the trained model, ONLY if 'evaluate' is set.
|
| 308 |
+
"""
|
| 309 |
+
if verbose:
|
| 310 |
+
print("Trainable weights: {}".format([n for n, p in model.named_parameters() if p.requires_grad]))
|
| 311 |
+
print("Training...")
|
| 312 |
+
if evaluate == 'acc':
|
| 313 |
+
print("Evaluation on test set prior training:", evaluate_using_acc(model, test_gen))
|
| 314 |
+
elif evaluate == 'weighted_f1':
|
| 315 |
+
print("Evaluation on test set prior training:", evaluate_using_weighted_f1(model, test_gen, val_gen))
|
| 316 |
+
|
| 317 |
+
fit_model(model, loss_op, optim_op, train_gen, val_gen, nb_epochs, checkpoint_path, patience)
|
| 318 |
+
|
| 319 |
+
# Reload the best weights found to avoid overfitting
|
| 320 |
+
# Wait a bit to allow proper closing of weights file
|
| 321 |
+
sleep(1)
|
| 322 |
+
model.load_state_dict(torch.load(checkpoint_path))
|
| 323 |
+
if verbose >= 2:
|
| 324 |
+
print("Loaded weights from {}".format(checkpoint_path))
|
| 325 |
+
|
| 326 |
+
if evaluate == 'acc':
|
| 327 |
+
return evaluate_using_acc(model, test_gen)
|
| 328 |
+
elif evaluate == 'weighted_f1':
|
| 329 |
+
return evaluate_using_weighted_f1(model, test_gen, val_gen)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def evaluate_using_weighted_f1(model, test_gen, val_gen):
|
| 333 |
+
""" Evaluation function using macro weighted F1 score.
|
| 334 |
+
|
| 335 |
+
# Arguments:
|
| 336 |
+
model: Model to be evaluated.
|
| 337 |
+
X_test: Inputs of the testing set.
|
| 338 |
+
y_test: Outputs of the testing set.
|
| 339 |
+
X_val: Inputs of the validation set.
|
| 340 |
+
y_val: Outputs of the validation set.
|
| 341 |
+
batch_size: Batch size.
|
| 342 |
+
|
| 343 |
+
# Returns:
|
| 344 |
+
Weighted F1 score of the given model.
|
| 345 |
+
"""
|
| 346 |
+
# Evaluate on test and val data
|
| 347 |
+
f1_test, _ = find_f1_threshold(model, test_gen, val_gen, average='weighted_f1')
|
| 348 |
+
return f1_test
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def evaluate_using_acc(model, test_gen):
|
| 352 |
+
""" Evaluation function using accuracy.
|
| 353 |
+
|
| 354 |
+
# Arguments:
|
| 355 |
+
model: Model to be evaluated.
|
| 356 |
+
test_gen: Testing data iterator (DataLoader)
|
| 357 |
+
|
| 358 |
+
# Returns:
|
| 359 |
+
Accuracy of the given model.
|
| 360 |
+
"""
|
| 361 |
+
|
| 362 |
+
# Validate on test_data
|
| 363 |
+
model.eval()
|
| 364 |
+
accs = []
|
| 365 |
+
for i, data in enumerate(test_gen):
|
| 366 |
+
x, y = data
|
| 367 |
+
outs = model(x)
|
| 368 |
+
if model.nb_classes > 2:
|
| 369 |
+
pred = torch.max(outs, 1)[1]
|
| 370 |
+
acc = accuracy_score(y.squeeze().numpy(), pred.squeeze().numpy())
|
| 371 |
+
else:
|
| 372 |
+
pred = (outs >= 0).long()
|
| 373 |
+
acc = (pred == y).double().sum() / len(pred)
|
| 374 |
+
accs.append(acc)
|
| 375 |
+
return np.mean(accs)
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op,
|
| 379 |
+
patience=5, initial_lr=0.001, next_lr=0.0001, embed_l2=1E-6, evaluate='acc', verbose=1):
|
| 380 |
+
""" Finetunes given model using chain-thaw and evaluates using accuracy.
|
| 381 |
+
|
| 382 |
+
# Arguments:
|
| 383 |
+
model: Model to be finetuned.
|
| 384 |
+
train: Training data, given as a tuple of (inputs, outputs)
|
| 385 |
+
val: Validation data, given as a tuple of (inputs, outputs)
|
| 386 |
+
test: Testing data, given as a tuple of (inputs, outputs)
|
| 387 |
+
batch_size: Batch size.
|
| 388 |
+
loss: Loss function to be used during training.
|
| 389 |
+
epoch_size: Number of samples in an epoch.
|
| 390 |
+
nb_epochs: Number of epochs.
|
| 391 |
+
checkpoint_weight_path: Filepath where weights will be checkpointed to
|
| 392 |
+
during training. This file will be rewritten by the function.
|
| 393 |
+
initial_lr: Initial learning rate. Will only be used for the first
|
| 394 |
+
training step (i.e. the output_layer layer)
|
| 395 |
+
next_lr: Learning rate for every subsequent step.
|
| 396 |
+
seed: Random number generator seed.
|
| 397 |
+
verbose: Verbosity flag.
|
| 398 |
+
evaluate: Evaluation method to use. Can be 'acc' or 'weighted_f1'.
|
| 399 |
+
|
| 400 |
+
# Returns:
|
| 401 |
+
Accuracy of the finetuned model.
|
| 402 |
+
"""
|
| 403 |
+
if verbose:
|
| 404 |
+
print('Training..')
|
| 405 |
+
|
| 406 |
+
# Train using chain-thaw
|
| 407 |
+
train_by_chain_thaw(model, train_gen, val_gen, loss_op, patience, nb_epochs, checkpoint_path,
|
| 408 |
+
initial_lr, next_lr, embed_l2, verbose)
|
| 409 |
+
|
| 410 |
+
if evaluate == 'acc':
|
| 411 |
+
return evaluate_using_acc(model, test_gen)
|
| 412 |
+
elif evaluate == 'weighted_f1':
|
| 413 |
+
return evaluate_using_weighted_f1(model, test_gen, val_gen)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def train_by_chain_thaw(model, train_gen, val_gen, loss_op, patience, nb_epochs, checkpoint_path,
|
| 417 |
+
initial_lr=0.001, next_lr=0.0001, embed_l2=1E-6, verbose=1):
|
| 418 |
+
""" Finetunes model using the chain-thaw method.
|
| 419 |
+
|
| 420 |
+
This is done as follows:
|
| 421 |
+
1) Freeze every layer except the last (output_layer) layer and train it.
|
| 422 |
+
2) Freeze every layer except the first layer and train it.
|
| 423 |
+
3) Freeze every layer except the second etc., until the second last layer.
|
| 424 |
+
4) Unfreeze all layers and train entire model.
|
| 425 |
+
|
| 426 |
+
# Arguments:
|
| 427 |
+
model: Model to be trained.
|
| 428 |
+
train_gen: Training sample generator.
|
| 429 |
+
val_data: Validation data.
|
| 430 |
+
loss: Loss function to be used.
|
| 431 |
+
finetuning_args: Training early stopping and checkpoint saving parameters
|
| 432 |
+
epoch_size: Number of samples in an epoch.
|
| 433 |
+
nb_epochs: Number of epochs.
|
| 434 |
+
checkpoint_weight_path: Where weight checkpoints should be saved.
|
| 435 |
+
batch_size: Batch size.
|
| 436 |
+
initial_lr: Initial learning rate. Will only be used for the first
|
| 437 |
+
training step (i.e. the output_layer layer)
|
| 438 |
+
next_lr: Learning rate for every subsequent step.
|
| 439 |
+
verbose: Verbosity flag.
|
| 440 |
+
"""
|
| 441 |
+
# Get trainable layers
|
| 442 |
+
layers = [m for m in model.children() if len([id(p) for p in m.parameters()]) != 0]
|
| 443 |
+
|
| 444 |
+
# Bring last layer to front
|
| 445 |
+
layers.insert(0, layers.pop(len(layers) - 1))
|
| 446 |
+
|
| 447 |
+
# Add None to the end to signify finetuning all layers
|
| 448 |
+
layers.append(None)
|
| 449 |
+
|
| 450 |
+
lr = None
|
| 451 |
+
# Finetune each layer one by one and finetune all of them at once
|
| 452 |
+
# at the end
|
| 453 |
+
for layer in layers:
|
| 454 |
+
if lr is None:
|
| 455 |
+
lr = initial_lr
|
| 456 |
+
elif lr == initial_lr:
|
| 457 |
+
lr = next_lr
|
| 458 |
+
|
| 459 |
+
# Freeze all except current layer
|
| 460 |
+
for _layer in layers:
|
| 461 |
+
if _layer is not None:
|
| 462 |
+
trainable = _layer == layer or layer is None
|
| 463 |
+
change_trainable(_layer, trainable=trainable, verbose=False)
|
| 464 |
+
|
| 465 |
+
# Verify we froze the right layers
|
| 466 |
+
for _layer in model.children():
|
| 467 |
+
assert all(p.requires_grad == (_layer == layer) for p in _layer.parameters()) or layer is None
|
| 468 |
+
|
| 469 |
+
if verbose:
|
| 470 |
+
if layer is None:
|
| 471 |
+
print('Finetuning all layers')
|
| 472 |
+
else:
|
| 473 |
+
print('Finetuning {}'.format(layer))
|
| 474 |
+
|
| 475 |
+
special_params = [id(p) for p in model.embed.parameters()]
|
| 476 |
+
base_params = [p for p in model.parameters() if id(p) not in special_params and p.requires_grad]
|
| 477 |
+
embed_parameters = [p for p in model.parameters() if id(p) in special_params and p.requires_grad]
|
| 478 |
+
adam = optim.Adam([
|
| 479 |
+
{'params': base_params},
|
| 480 |
+
{'params': embed_parameters, 'weight_decay': embed_l2},
|
| 481 |
+
], lr=lr)
|
| 482 |
+
|
| 483 |
+
fit_model(model, loss_op, adam, train_gen, val_gen, nb_epochs,
|
| 484 |
+
checkpoint_path, patience)
|
| 485 |
+
|
| 486 |
+
# Reload the best weights found to avoid overfitting
|
| 487 |
+
# Wait a bit to allow proper closing of weights file
|
| 488 |
+
sleep(1)
|
| 489 |
+
model.load_state_dict(torch.load(checkpoint_path))
|
| 490 |
+
if verbose >= 2:
|
| 491 |
+
print("Loaded weights from {}".format(checkpoint_path))
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
def calc_loss(loss_op, pred, yv):
|
| 495 |
+
if type(loss_op) is nn.CrossEntropyLoss:
|
| 496 |
+
return loss_op(pred.squeeze(), yv.squeeze())
|
| 497 |
+
else:
|
| 498 |
+
return loss_op(pred.squeeze(), yv.squeeze().float())
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs,
|
| 502 |
+
checkpoint_path, patience):
|
| 503 |
+
""" Analog to Keras fit_generator function.
|
| 504 |
+
|
| 505 |
+
# Arguments:
|
| 506 |
+
model: Model to be finetuned.
|
| 507 |
+
loss_op: loss operation (BCEWithLogitsLoss or CrossEntropy for e.g.)
|
| 508 |
+
optim_op: optimization operation (Adam e.g.)
|
| 509 |
+
train_gen: Training data iterator (DataLoader)
|
| 510 |
+
val_gen: Validation data iterator (DataLoader)
|
| 511 |
+
epochs: Number of epochs.
|
| 512 |
+
checkpoint_path: Filepath where weights will be checkpointed to
|
| 513 |
+
during training. This file will be rewritten by the function.
|
| 514 |
+
patience: Patience for callback methods.
|
| 515 |
+
verbose: Verbosity flag.
|
| 516 |
+
|
| 517 |
+
# Returns:
|
| 518 |
+
Accuracy of the trained model, ONLY if 'evaluate' is set.
|
| 519 |
+
"""
|
| 520 |
+
# Save original checkpoint
|
| 521 |
+
torch.save(model.state_dict(), checkpoint_path)
|
| 522 |
+
|
| 523 |
+
model.eval()
|
| 524 |
+
best_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy() for xv, yv in val_gen])
|
| 525 |
+
print("original val loss", best_loss)
|
| 526 |
+
|
| 527 |
+
epoch_without_impr = 0
|
| 528 |
+
for epoch in range(epochs):
|
| 529 |
+
for i, data in enumerate(train_gen):
|
| 530 |
+
X_train, y_train = data
|
| 531 |
+
X_train = Variable(X_train, requires_grad=False)
|
| 532 |
+
y_train = Variable(y_train, requires_grad=False)
|
| 533 |
+
model.train()
|
| 534 |
+
optim_op.zero_grad()
|
| 535 |
+
output = model(X_train)
|
| 536 |
+
loss = calc_loss(loss_op, output, y_train)
|
| 537 |
+
loss.backward()
|
| 538 |
+
clip_grad_norm_(model.parameters(), 1)
|
| 539 |
+
optim_op.step()
|
| 540 |
+
|
| 541 |
+
acc = evaluate_using_acc(model, [(X_train.data, y_train.data)])
|
| 542 |
+
print("== Epoch", epoch, "step", i, "train loss", loss.data.cpu().numpy(), "train acc", acc)
|
| 543 |
+
|
| 544 |
+
model.eval()
|
| 545 |
+
acc = evaluate_using_acc(model, val_gen)
|
| 546 |
+
print("val acc", acc)
|
| 547 |
+
|
| 548 |
+
val_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy() for xv, yv in val_gen])
|
| 549 |
+
print("val loss", val_loss)
|
| 550 |
+
if best_loss is not None and val_loss >= best_loss:
|
| 551 |
+
epoch_without_impr += 1
|
| 552 |
+
print('No improvement over previous best loss: ', best_loss)
|
| 553 |
+
|
| 554 |
+
# Save checkpoint
|
| 555 |
+
if best_loss is None or val_loss < best_loss:
|
| 556 |
+
best_loss = val_loss
|
| 557 |
+
torch.save(model.state_dict(), checkpoint_path)
|
| 558 |
+
print('Saving model at', checkpoint_path)
|
| 559 |
+
|
| 560 |
+
# Early stopping
|
| 561 |
+
if epoch_without_impr >= patience:
|
| 562 |
+
break
|
| 563 |
+
|
| 564 |
+
def get_data_loader(X_in, y_in, batch_size, extended_batch_sampler=True, epoch_size=25000, upsample=False, seed=42):
|
| 565 |
+
""" Returns a dataloader that enables larger epochs on small datasets and
|
| 566 |
+
has upsampling functionality.
|
| 567 |
+
|
| 568 |
+
# Arguments:
|
| 569 |
+
X_in: Inputs of the given dataset.
|
| 570 |
+
y_in: Outputs of the given dataset.
|
| 571 |
+
batch_size: Batch size.
|
| 572 |
+
epoch_size: Number of samples in an epoch.
|
| 573 |
+
upsample: Whether upsampling should be done. This flag should only be
|
| 574 |
+
set on binary class problems.
|
| 575 |
+
|
| 576 |
+
# Returns:
|
| 577 |
+
DataLoader.
|
| 578 |
+
"""
|
| 579 |
+
dataset = DeepMojiDataset(X_in, y_in)
|
| 580 |
+
|
| 581 |
+
if extended_batch_sampler:
|
| 582 |
+
batch_sampler = DeepMojiBatchSampler(y_in, batch_size, epoch_size=epoch_size, upsample=upsample, seed=seed)
|
| 583 |
+
else:
|
| 584 |
+
batch_sampler = BatchSampler(SequentialSampler(y_in), batch_size, drop_last=False)
|
| 585 |
+
|
| 586 |
+
return DataLoader(dataset, batch_sampler=batch_sampler, num_workers=0)
|
| 587 |
+
|
| 588 |
+
class DeepMojiDataset(Dataset):
|
| 589 |
+
""" A simple Dataset class.
|
| 590 |
+
|
| 591 |
+
# Arguments:
|
| 592 |
+
X_in: Inputs of the given dataset.
|
| 593 |
+
y_in: Outputs of the given dataset.
|
| 594 |
+
|
| 595 |
+
# __getitem__ output:
|
| 596 |
+
(torch.LongTensor, torch.LongTensor)
|
| 597 |
+
"""
|
| 598 |
+
def __init__(self, X_in, y_in):
|
| 599 |
+
# Check if we have Torch.LongTensor inputs (assume Numpy array otherwise)
|
| 600 |
+
if not isinstance(X_in, torch.LongTensor):
|
| 601 |
+
X_in = torch.from_numpy(X_in.astype('int64')).long()
|
| 602 |
+
if not isinstance(y_in, torch.LongTensor):
|
| 603 |
+
y_in = torch.from_numpy(y_in.astype('int64')).long()
|
| 604 |
+
|
| 605 |
+
self.X_in = torch.split(X_in, 1, dim=0)
|
| 606 |
+
self.y_in = torch.split(y_in, 1, dim=0)
|
| 607 |
+
|
| 608 |
+
def __len__(self):
|
| 609 |
+
return len(self.X_in)
|
| 610 |
+
|
| 611 |
+
def __getitem__(self, idx):
|
| 612 |
+
return self.X_in[idx].squeeze(), self.y_in[idx].squeeze()
|
| 613 |
+
|
| 614 |
+
class DeepMojiBatchSampler(object):
|
| 615 |
+
"""A Batch sampler that enables larger epochs on small datasets and
|
| 616 |
+
has upsampling functionality.
|
| 617 |
+
|
| 618 |
+
# Arguments:
|
| 619 |
+
y_in: Labels of the dataset.
|
| 620 |
+
batch_size: Batch size.
|
| 621 |
+
epoch_size: Number of samples in an epoch.
|
| 622 |
+
upsample: Whether upsampling should be done. This flag should only be
|
| 623 |
+
set on binary class problems.
|
| 624 |
+
seed: Random number generator seed.
|
| 625 |
+
|
| 626 |
+
# __iter__ output:
|
| 627 |
+
iterator of lists (batches) of indices in the dataset
|
| 628 |
+
"""
|
| 629 |
+
|
| 630 |
+
def __init__(self, y_in, batch_size, epoch_size, upsample, seed):
|
| 631 |
+
self.batch_size = batch_size
|
| 632 |
+
self.epoch_size = epoch_size
|
| 633 |
+
self.upsample = upsample
|
| 634 |
+
|
| 635 |
+
np.random.seed(seed)
|
| 636 |
+
|
| 637 |
+
if upsample:
|
| 638 |
+
# Should only be used on binary class problems
|
| 639 |
+
assert len(y_in.shape) == 1
|
| 640 |
+
neg = np.where(y_in.numpy() == 0)[0]
|
| 641 |
+
pos = np.where(y_in.numpy() == 1)[0]
|
| 642 |
+
assert epoch_size % 2 == 0
|
| 643 |
+
samples_pr_class = int(epoch_size / 2)
|
| 644 |
+
else:
|
| 645 |
+
ind = range(len(y_in))
|
| 646 |
+
|
| 647 |
+
if not upsample:
|
| 648 |
+
# Randomly sample observations in a balanced way
|
| 649 |
+
self.sample_ind = np.random.choice(ind, epoch_size, replace=True)
|
| 650 |
+
else:
|
| 651 |
+
# Randomly sample observations in a balanced way
|
| 652 |
+
sample_neg = np.random.choice(neg, samples_pr_class, replace=True)
|
| 653 |
+
sample_pos = np.random.choice(pos, samples_pr_class, replace=True)
|
| 654 |
+
concat_ind = np.concatenate((sample_neg, sample_pos), axis=0)
|
| 655 |
+
|
| 656 |
+
# Shuffle to avoid labels being in specific order
|
| 657 |
+
# (all negative then positive)
|
| 658 |
+
p = np.random.permutation(len(concat_ind))
|
| 659 |
+
self.sample_ind = concat_ind[p]
|
| 660 |
+
|
| 661 |
+
label_dist = np.mean(y_in.numpy()[self.sample_ind])
|
| 662 |
+
assert(label_dist > 0.45)
|
| 663 |
+
assert(label_dist < 0.55)
|
| 664 |
+
|
| 665 |
+
def __iter__(self):
|
| 666 |
+
# Hand-off data using batch_size
|
| 667 |
+
for i in range(int(self.epoch_size/self.batch_size)):
|
| 668 |
+
start = i * self.batch_size
|
| 669 |
+
end = min(start + self.batch_size, self.epoch_size)
|
| 670 |
+
yield self.sample_ind[start:end]
|
| 671 |
+
|
| 672 |
+
def __len__(self):
|
| 673 |
+
# Take care of the last (maybe incomplete) batch
|
| 674 |
+
return (self.epoch_size + self.batch_size - 1) // self.batch_size
|
torchmoji/global_variables.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
""" Global variables.
|
| 3 |
+
"""
|
| 4 |
+
import tempfile
|
| 5 |
+
from os.path import abspath, dirname
|
| 6 |
+
|
| 7 |
+
# The ordering of these special tokens matter
|
| 8 |
+
# blank tokens can be used for new purposes
|
| 9 |
+
# Tokenizer should be updated if special token prefix is changed
|
| 10 |
+
SPECIAL_PREFIX = 'CUSTOM_'
|
| 11 |
+
SPECIAL_TOKENS = ['CUSTOM_MASK',
|
| 12 |
+
'CUSTOM_UNKNOWN',
|
| 13 |
+
'CUSTOM_AT',
|
| 14 |
+
'CUSTOM_URL',
|
| 15 |
+
'CUSTOM_NUMBER',
|
| 16 |
+
'CUSTOM_BREAK']
|
| 17 |
+
SPECIAL_TOKENS.extend(['{}BLANK_{}'.format(SPECIAL_PREFIX, i) for i in range(6, 10)])
|
| 18 |
+
|
| 19 |
+
ROOT_PATH = dirname(dirname(abspath(__file__)))
|
| 20 |
+
VOCAB_PATH = '{}/model/vocabulary.json'.format(ROOT_PATH)
|
| 21 |
+
PRETRAINED_PATH = '{}/model/pytorch_model.bin'.format(ROOT_PATH)
|
| 22 |
+
|
| 23 |
+
WEIGHTS_DIR = tempfile.mkdtemp()
|
| 24 |
+
|
| 25 |
+
NB_TOKENS = 50000
|
| 26 |
+
NB_EMOJI_CLASSES = 64
|
| 27 |
+
FINETUNING_METHODS = ['last', 'full', 'new', 'chain-thaw']
|
| 28 |
+
FINETUNING_METRICS = ['acc', 'weighted']
|
torchmoji/lstm.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
""" Implement a pyTorch LSTM with hard sigmoid reccurent activation functions.
|
| 3 |
+
Adapted from the non-cuda variant of pyTorch LSTM at
|
| 4 |
+
https://github.com/pytorch/pytorch/blob/master/torch/nn/_functions/rnn.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import print_function, division
|
| 8 |
+
import math
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from torch.nn import Module
|
| 12 |
+
from torch.nn.parameter import Parameter
|
| 13 |
+
from torch.nn.utils.rnn import PackedSequence
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
class LSTMHardSigmoid(Module):
|
| 17 |
+
|
| 18 |
+
def __init__(self, input_size, hidden_size,
|
| 19 |
+
num_layers=1, bias=True, batch_first=False,
|
| 20 |
+
dropout=0, bidirectional=False):
|
| 21 |
+
super(LSTMHardSigmoid, self).__init__()
|
| 22 |
+
self.input_size = input_size
|
| 23 |
+
self.hidden_size = hidden_size
|
| 24 |
+
self.num_layers = num_layers
|
| 25 |
+
self.bias = bias
|
| 26 |
+
self.batch_first = batch_first
|
| 27 |
+
self.dropout = dropout
|
| 28 |
+
self.dropout_state = {}
|
| 29 |
+
self.bidirectional = bidirectional
|
| 30 |
+
num_directions = 2 if bidirectional else 1
|
| 31 |
+
|
| 32 |
+
gate_size = 4 * hidden_size
|
| 33 |
+
|
| 34 |
+
self._all_weights = []
|
| 35 |
+
for layer in range(num_layers):
|
| 36 |
+
for direction in range(num_directions):
|
| 37 |
+
layer_input_size = input_size if layer == 0 else hidden_size * num_directions
|
| 38 |
+
|
| 39 |
+
w_ih = Parameter(torch.Tensor(gate_size, layer_input_size))
|
| 40 |
+
w_hh = Parameter(torch.Tensor(gate_size, hidden_size))
|
| 41 |
+
b_ih = Parameter(torch.Tensor(gate_size))
|
| 42 |
+
b_hh = Parameter(torch.Tensor(gate_size))
|
| 43 |
+
layer_params = (w_ih, w_hh, b_ih, b_hh)
|
| 44 |
+
|
| 45 |
+
suffix = '_reverse' if direction == 1 else ''
|
| 46 |
+
param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}']
|
| 47 |
+
if bias:
|
| 48 |
+
param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}']
|
| 49 |
+
param_names = [x.format(layer, suffix) for x in param_names]
|
| 50 |
+
|
| 51 |
+
for name, param in zip(param_names, layer_params):
|
| 52 |
+
setattr(self, name, param)
|
| 53 |
+
self._all_weights.append(param_names)
|
| 54 |
+
|
| 55 |
+
self.flatten_parameters()
|
| 56 |
+
self.reset_parameters()
|
| 57 |
+
|
| 58 |
+
def flatten_parameters(self):
|
| 59 |
+
"""Resets parameter data pointer so that they can use faster code paths.
|
| 60 |
+
|
| 61 |
+
Right now, this is a no-op wince we don't use CUDA acceleration.
|
| 62 |
+
"""
|
| 63 |
+
self._data_ptrs = []
|
| 64 |
+
|
| 65 |
+
def _apply(self, fn):
|
| 66 |
+
ret = super(LSTMHardSigmoid, self)._apply(fn)
|
| 67 |
+
self.flatten_parameters()
|
| 68 |
+
return ret
|
| 69 |
+
|
| 70 |
+
def reset_parameters(self):
|
| 71 |
+
stdv = 1.0 / math.sqrt(self.hidden_size)
|
| 72 |
+
for weight in self.parameters():
|
| 73 |
+
weight.data.uniform_(-stdv, stdv)
|
| 74 |
+
|
| 75 |
+
def forward(self, input, hx=None):
|
| 76 |
+
is_packed = isinstance(input, PackedSequence)
|
| 77 |
+
if is_packed:
|
| 78 |
+
batch_sizes = input.batch_sizes
|
| 79 |
+
input = input.data
|
| 80 |
+
max_batch_size = batch_sizes[0]
|
| 81 |
+
else:
|
| 82 |
+
batch_sizes = None
|
| 83 |
+
max_batch_size = input.size(0) if self.batch_first else input.size(1)
|
| 84 |
+
|
| 85 |
+
if hx is None:
|
| 86 |
+
num_directions = 2 if self.bidirectional else 1
|
| 87 |
+
hx = torch.autograd.Variable(input.data.new(self.num_layers *
|
| 88 |
+
num_directions,
|
| 89 |
+
max_batch_size,
|
| 90 |
+
self.hidden_size).zero_(), requires_grad=False)
|
| 91 |
+
hx = (hx, hx)
|
| 92 |
+
|
| 93 |
+
has_flat_weights = list(p.data.data_ptr() for p in self.parameters()) == self._data_ptrs
|
| 94 |
+
if has_flat_weights:
|
| 95 |
+
first_data = next(self.parameters()).data
|
| 96 |
+
assert first_data.storage().size() == self._param_buf_size
|
| 97 |
+
flat_weight = first_data.new().set_(first_data.storage(), 0, torch.Size([self._param_buf_size]))
|
| 98 |
+
else:
|
| 99 |
+
flat_weight = None
|
| 100 |
+
func = AutogradRNN(
|
| 101 |
+
self.input_size,
|
| 102 |
+
self.hidden_size,
|
| 103 |
+
num_layers=self.num_layers,
|
| 104 |
+
batch_first=self.batch_first,
|
| 105 |
+
dropout=self.dropout,
|
| 106 |
+
train=self.training,
|
| 107 |
+
bidirectional=self.bidirectional,
|
| 108 |
+
batch_sizes=batch_sizes,
|
| 109 |
+
dropout_state=self.dropout_state,
|
| 110 |
+
flat_weight=flat_weight
|
| 111 |
+
)
|
| 112 |
+
output, hidden = func(input, self.all_weights, hx)
|
| 113 |
+
if is_packed:
|
| 114 |
+
output = PackedSequence(output, batch_sizes)
|
| 115 |
+
return output, hidden
|
| 116 |
+
|
| 117 |
+
def __repr__(self):
|
| 118 |
+
s = '{name}({input_size}, {hidden_size}'
|
| 119 |
+
if self.num_layers != 1:
|
| 120 |
+
s += ', num_layers={num_layers}'
|
| 121 |
+
if self.bias is not True:
|
| 122 |
+
s += ', bias={bias}'
|
| 123 |
+
if self.batch_first is not False:
|
| 124 |
+
s += ', batch_first={batch_first}'
|
| 125 |
+
if self.dropout != 0:
|
| 126 |
+
s += ', dropout={dropout}'
|
| 127 |
+
if self.bidirectional is not False:
|
| 128 |
+
s += ', bidirectional={bidirectional}'
|
| 129 |
+
s += ')'
|
| 130 |
+
return s.format(name=self.__class__.__name__, **self.__dict__)
|
| 131 |
+
|
| 132 |
+
def __setstate__(self, d):
|
| 133 |
+
super(LSTMHardSigmoid, self).__setstate__(d)
|
| 134 |
+
self.__dict__.setdefault('_data_ptrs', [])
|
| 135 |
+
if 'all_weights' in d:
|
| 136 |
+
self._all_weights = d['all_weights']
|
| 137 |
+
if isinstance(self._all_weights[0][0], str):
|
| 138 |
+
return
|
| 139 |
+
num_layers = self.num_layers
|
| 140 |
+
num_directions = 2 if self.bidirectional else 1
|
| 141 |
+
self._all_weights = []
|
| 142 |
+
for layer in range(num_layers):
|
| 143 |
+
for direction in range(num_directions):
|
| 144 |
+
suffix = '_reverse' if direction == 1 else ''
|
| 145 |
+
weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}', 'bias_hh_l{}{}']
|
| 146 |
+
weights = [x.format(layer, suffix) for x in weights]
|
| 147 |
+
if self.bias:
|
| 148 |
+
self._all_weights += [weights]
|
| 149 |
+
else:
|
| 150 |
+
self._all_weights += [weights[:2]]
|
| 151 |
+
|
| 152 |
+
@property
|
| 153 |
+
def all_weights(self):
|
| 154 |
+
return [[getattr(self, weight) for weight in weights] for weights in self._all_weights]
|
| 155 |
+
|
| 156 |
+
def AutogradRNN(input_size, hidden_size, num_layers=1, batch_first=False,
|
| 157 |
+
dropout=0, train=True, bidirectional=False, batch_sizes=None,
|
| 158 |
+
dropout_state=None, flat_weight=None):
|
| 159 |
+
|
| 160 |
+
cell = LSTMCell
|
| 161 |
+
|
| 162 |
+
if batch_sizes is None:
|
| 163 |
+
rec_factory = Recurrent
|
| 164 |
+
else:
|
| 165 |
+
rec_factory = variable_recurrent_factory(batch_sizes)
|
| 166 |
+
|
| 167 |
+
if bidirectional:
|
| 168 |
+
layer = (rec_factory(cell), rec_factory(cell, reverse=True))
|
| 169 |
+
else:
|
| 170 |
+
layer = (rec_factory(cell),)
|
| 171 |
+
|
| 172 |
+
func = StackedRNN(layer,
|
| 173 |
+
num_layers,
|
| 174 |
+
True,
|
| 175 |
+
dropout=dropout,
|
| 176 |
+
train=train)
|
| 177 |
+
|
| 178 |
+
def forward(input, weight, hidden):
|
| 179 |
+
if batch_first and batch_sizes is None:
|
| 180 |
+
input = input.transpose(0, 1)
|
| 181 |
+
|
| 182 |
+
nexth, output = func(input, hidden, weight)
|
| 183 |
+
|
| 184 |
+
if batch_first and batch_sizes is None:
|
| 185 |
+
output = output.transpose(0, 1)
|
| 186 |
+
|
| 187 |
+
return output, nexth
|
| 188 |
+
|
| 189 |
+
return forward
|
| 190 |
+
|
| 191 |
+
def Recurrent(inner, reverse=False):
|
| 192 |
+
def forward(input, hidden, weight):
|
| 193 |
+
output = []
|
| 194 |
+
steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0))
|
| 195 |
+
for i in steps:
|
| 196 |
+
hidden = inner(input[i], hidden, *weight)
|
| 197 |
+
# hack to handle LSTM
|
| 198 |
+
output.append(hidden[0] if isinstance(hidden, tuple) else hidden)
|
| 199 |
+
|
| 200 |
+
if reverse:
|
| 201 |
+
output.reverse()
|
| 202 |
+
output = torch.cat(output, 0).view(input.size(0), *output[0].size())
|
| 203 |
+
|
| 204 |
+
return hidden, output
|
| 205 |
+
|
| 206 |
+
return forward
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def variable_recurrent_factory(batch_sizes):
|
| 210 |
+
def fac(inner, reverse=False):
|
| 211 |
+
if reverse:
|
| 212 |
+
return VariableRecurrentReverse(batch_sizes, inner)
|
| 213 |
+
else:
|
| 214 |
+
return VariableRecurrent(batch_sizes, inner)
|
| 215 |
+
return fac
|
| 216 |
+
|
| 217 |
+
def VariableRecurrent(batch_sizes, inner):
|
| 218 |
+
def forward(input, hidden, weight):
|
| 219 |
+
output = []
|
| 220 |
+
input_offset = 0
|
| 221 |
+
last_batch_size = batch_sizes[0]
|
| 222 |
+
hiddens = []
|
| 223 |
+
flat_hidden = not isinstance(hidden, tuple)
|
| 224 |
+
if flat_hidden:
|
| 225 |
+
hidden = (hidden,)
|
| 226 |
+
for batch_size in batch_sizes:
|
| 227 |
+
step_input = input[input_offset:input_offset + batch_size]
|
| 228 |
+
input_offset += batch_size
|
| 229 |
+
|
| 230 |
+
dec = last_batch_size - batch_size
|
| 231 |
+
if dec > 0:
|
| 232 |
+
hiddens.append(tuple(h[-dec:] for h in hidden))
|
| 233 |
+
hidden = tuple(h[:-dec] for h in hidden)
|
| 234 |
+
last_batch_size = batch_size
|
| 235 |
+
|
| 236 |
+
if flat_hidden:
|
| 237 |
+
hidden = (inner(step_input, hidden[0], *weight),)
|
| 238 |
+
else:
|
| 239 |
+
hidden = inner(step_input, hidden, *weight)
|
| 240 |
+
|
| 241 |
+
output.append(hidden[0])
|
| 242 |
+
hiddens.append(hidden)
|
| 243 |
+
hiddens.reverse()
|
| 244 |
+
|
| 245 |
+
hidden = tuple(torch.cat(h, 0) for h in zip(*hiddens))
|
| 246 |
+
assert hidden[0].size(0) == batch_sizes[0]
|
| 247 |
+
if flat_hidden:
|
| 248 |
+
hidden = hidden[0]
|
| 249 |
+
output = torch.cat(output, 0)
|
| 250 |
+
|
| 251 |
+
return hidden, output
|
| 252 |
+
|
| 253 |
+
return forward
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def VariableRecurrentReverse(batch_sizes, inner):
|
| 257 |
+
def forward(input, hidden, weight):
|
| 258 |
+
output = []
|
| 259 |
+
input_offset = input.size(0)
|
| 260 |
+
last_batch_size = batch_sizes[-1]
|
| 261 |
+
initial_hidden = hidden
|
| 262 |
+
flat_hidden = not isinstance(hidden, tuple)
|
| 263 |
+
if flat_hidden:
|
| 264 |
+
hidden = (hidden,)
|
| 265 |
+
initial_hidden = (initial_hidden,)
|
| 266 |
+
hidden = tuple(h[:batch_sizes[-1]] for h in hidden)
|
| 267 |
+
for batch_size in reversed(batch_sizes):
|
| 268 |
+
inc = batch_size - last_batch_size
|
| 269 |
+
if inc > 0:
|
| 270 |
+
hidden = tuple(torch.cat((h, ih[last_batch_size:batch_size]), 0)
|
| 271 |
+
for h, ih in zip(hidden, initial_hidden))
|
| 272 |
+
last_batch_size = batch_size
|
| 273 |
+
step_input = input[input_offset - batch_size:input_offset]
|
| 274 |
+
input_offset -= batch_size
|
| 275 |
+
|
| 276 |
+
if flat_hidden:
|
| 277 |
+
hidden = (inner(step_input, hidden[0], *weight),)
|
| 278 |
+
else:
|
| 279 |
+
hidden = inner(step_input, hidden, *weight)
|
| 280 |
+
output.append(hidden[0])
|
| 281 |
+
|
| 282 |
+
output.reverse()
|
| 283 |
+
output = torch.cat(output, 0)
|
| 284 |
+
if flat_hidden:
|
| 285 |
+
hidden = hidden[0]
|
| 286 |
+
return hidden, output
|
| 287 |
+
|
| 288 |
+
return forward
|
| 289 |
+
|
| 290 |
+
def StackedRNN(inners, num_layers, lstm=False, dropout=0, train=True):
|
| 291 |
+
|
| 292 |
+
num_directions = len(inners)
|
| 293 |
+
total_layers = num_layers * num_directions
|
| 294 |
+
|
| 295 |
+
def forward(input, hidden, weight):
|
| 296 |
+
assert(len(weight) == total_layers)
|
| 297 |
+
next_hidden = []
|
| 298 |
+
|
| 299 |
+
if lstm:
|
| 300 |
+
hidden = list(zip(*hidden))
|
| 301 |
+
|
| 302 |
+
for i in range(num_layers):
|
| 303 |
+
all_output = []
|
| 304 |
+
for j, inner in enumerate(inners):
|
| 305 |
+
l = i * num_directions + j
|
| 306 |
+
|
| 307 |
+
hy, output = inner(input, hidden[l], weight[l])
|
| 308 |
+
next_hidden.append(hy)
|
| 309 |
+
all_output.append(output)
|
| 310 |
+
|
| 311 |
+
input = torch.cat(all_output, input.dim() - 1)
|
| 312 |
+
|
| 313 |
+
if dropout != 0 and i < num_layers - 1:
|
| 314 |
+
input = F.dropout(input, p=dropout, training=train, inplace=False)
|
| 315 |
+
|
| 316 |
+
if lstm:
|
| 317 |
+
next_h, next_c = zip(*next_hidden)
|
| 318 |
+
next_hidden = (
|
| 319 |
+
torch.cat(next_h, 0).view(total_layers, *next_h[0].size()),
|
| 320 |
+
torch.cat(next_c, 0).view(total_layers, *next_c[0].size())
|
| 321 |
+
)
|
| 322 |
+
else:
|
| 323 |
+
next_hidden = torch.cat(next_hidden, 0).view(
|
| 324 |
+
total_layers, *next_hidden[0].size())
|
| 325 |
+
|
| 326 |
+
return next_hidden, input
|
| 327 |
+
|
| 328 |
+
return forward
|
| 329 |
+
|
| 330 |
+
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
|
| 331 |
+
"""
|
| 332 |
+
A modified LSTM cell with hard sigmoid activation on the input, forget and output gates.
|
| 333 |
+
"""
|
| 334 |
+
hx, cx = hidden
|
| 335 |
+
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
|
| 336 |
+
|
| 337 |
+
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
|
| 338 |
+
|
| 339 |
+
ingate = hard_sigmoid(ingate)
|
| 340 |
+
forgetgate = hard_sigmoid(forgetgate)
|
| 341 |
+
cellgate = torch.tanh(cellgate)
|
| 342 |
+
outgate = hard_sigmoid(outgate)
|
| 343 |
+
|
| 344 |
+
cy = (forgetgate * cx) + (ingate * cellgate)
|
| 345 |
+
hy = outgate * torch.tanh(cy)
|
| 346 |
+
|
| 347 |
+
return hy, cy
|
| 348 |
+
|
| 349 |
+
def hard_sigmoid(x):
|
| 350 |
+
"""
|
| 351 |
+
Computes element-wise hard sigmoid of x.
|
| 352 |
+
See e.g. https://github.com/Theano/Theano/blob/master/theano/tensor/nnet/sigm.py#L279
|
| 353 |
+
"""
|
| 354 |
+
x = (0.2 * x) + 0.5
|
| 355 |
+
x = F.threshold(-x, -1, -1)
|
| 356 |
+
x = F.threshold(-x, 0, 0)
|
| 357 |
+
return x
|
torchmoji/model_def.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
""" Model definition functions and weight loading.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import print_function, division, unicode_literals
|
| 6 |
+
|
| 7 |
+
from os.path import exists
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.autograd import Variable
|
| 12 |
+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, PackedSequence
|
| 13 |
+
|
| 14 |
+
from torchmoji.lstm import LSTMHardSigmoid
|
| 15 |
+
from torchmoji.attlayer import Attention
|
| 16 |
+
from torchmoji.global_variables import NB_TOKENS, NB_EMOJI_CLASSES
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def torchmoji_feature_encoding(weight_path, return_attention=False):
|
| 20 |
+
""" Loads the pretrained torchMoji model for extracting features
|
| 21 |
+
from the penultimate feature layer. In this way, it transforms
|
| 22 |
+
the text into its emotional encoding.
|
| 23 |
+
|
| 24 |
+
# Arguments:
|
| 25 |
+
weight_path: Path to model weights to be loaded.
|
| 26 |
+
return_attention: If true, output will include weight of each input token
|
| 27 |
+
used for the prediction
|
| 28 |
+
|
| 29 |
+
# Returns:
|
| 30 |
+
Pretrained model for encoding text into feature vectors.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
model = TorchMoji(nb_classes=None,
|
| 34 |
+
nb_tokens=NB_TOKENS,
|
| 35 |
+
feature_output=True,
|
| 36 |
+
return_attention=return_attention)
|
| 37 |
+
load_specific_weights(model, weight_path, exclude_names=['output_layer'])
|
| 38 |
+
return model
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def torchmoji_emojis(weight_path, return_attention=False):
|
| 42 |
+
""" Loads the pretrained torchMoji model for extracting features
|
| 43 |
+
from the penultimate feature layer. In this way, it transforms
|
| 44 |
+
the text into its emotional encoding.
|
| 45 |
+
|
| 46 |
+
# Arguments:
|
| 47 |
+
weight_path: Path to model weights to be loaded.
|
| 48 |
+
return_attention: If true, output will include weight of each input token
|
| 49 |
+
used for the prediction
|
| 50 |
+
|
| 51 |
+
# Returns:
|
| 52 |
+
Pretrained model for encoding text into feature vectors.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
model = TorchMoji(nb_classes=NB_EMOJI_CLASSES,
|
| 56 |
+
nb_tokens=NB_TOKENS,
|
| 57 |
+
return_attention=return_attention)
|
| 58 |
+
model.load_state_dict(torch.load(weight_path))
|
| 59 |
+
return model
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def torchmoji_transfer(nb_classes, weight_path=None, extend_embedding=0,
|
| 63 |
+
embed_dropout_rate=0.1, final_dropout_rate=0.5):
|
| 64 |
+
""" Loads the pretrained torchMoji model for finetuning/transfer learning.
|
| 65 |
+
Does not load weights for the softmax layer.
|
| 66 |
+
|
| 67 |
+
Note that if you are planning to use class average F1 for evaluation,
|
| 68 |
+
nb_classes should be set to 2 instead of the actual number of classes
|
| 69 |
+
in the dataset, since binary classification will be performed on each
|
| 70 |
+
class individually.
|
| 71 |
+
|
| 72 |
+
Note that for the 'new' method, weight_path should be left as None.
|
| 73 |
+
|
| 74 |
+
# Arguments:
|
| 75 |
+
nb_classes: Number of classes in the dataset.
|
| 76 |
+
weight_path: Path to model weights to be loaded.
|
| 77 |
+
extend_embedding: Number of tokens that have been added to the
|
| 78 |
+
vocabulary on top of NB_TOKENS. If this number is larger than 0,
|
| 79 |
+
the embedding layer's dimensions are adjusted accordingly, with the
|
| 80 |
+
additional weights being set to random values.
|
| 81 |
+
embed_dropout_rate: Dropout rate for the embedding layer.
|
| 82 |
+
final_dropout_rate: Dropout rate for the final Softmax layer.
|
| 83 |
+
|
| 84 |
+
# Returns:
|
| 85 |
+
Model with the given parameters.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
model = TorchMoji(nb_classes=nb_classes,
|
| 89 |
+
nb_tokens=NB_TOKENS + extend_embedding,
|
| 90 |
+
embed_dropout_rate=embed_dropout_rate,
|
| 91 |
+
final_dropout_rate=final_dropout_rate,
|
| 92 |
+
output_logits=True)
|
| 93 |
+
if weight_path is not None:
|
| 94 |
+
load_specific_weights(model, weight_path,
|
| 95 |
+
exclude_names=['output_layer'],
|
| 96 |
+
extend_embedding=extend_embedding)
|
| 97 |
+
return model
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class TorchMoji(nn.Module):
|
| 101 |
+
def __init__(self, nb_classes, nb_tokens, feature_output=False, output_logits=False,
|
| 102 |
+
embed_dropout_rate=0, final_dropout_rate=0, return_attention=False):
|
| 103 |
+
"""
|
| 104 |
+
torchMoji model.
|
| 105 |
+
IMPORTANT: The model is loaded in evaluation mode by default (self.eval())
|
| 106 |
+
|
| 107 |
+
# Arguments:
|
| 108 |
+
nb_classes: Number of classes in the dataset.
|
| 109 |
+
nb_tokens: Number of tokens in the dataset (i.e. vocabulary size).
|
| 110 |
+
feature_output: If True the model returns the penultimate
|
| 111 |
+
feature vector rather than Softmax probabilities
|
| 112 |
+
(defaults to False).
|
| 113 |
+
output_logits: If True the model returns logits rather than probabilities
|
| 114 |
+
(defaults to False).
|
| 115 |
+
embed_dropout_rate: Dropout rate for the embedding layer.
|
| 116 |
+
final_dropout_rate: Dropout rate for the final Softmax layer.
|
| 117 |
+
return_attention: If True the model also returns attention weights over the sentence
|
| 118 |
+
(defaults to False).
|
| 119 |
+
"""
|
| 120 |
+
super(TorchMoji, self).__init__()
|
| 121 |
+
|
| 122 |
+
embedding_dim = 256
|
| 123 |
+
hidden_size = 512
|
| 124 |
+
attention_size = 4 * hidden_size + embedding_dim
|
| 125 |
+
|
| 126 |
+
self.feature_output = feature_output
|
| 127 |
+
self.embed_dropout_rate = embed_dropout_rate
|
| 128 |
+
self.final_dropout_rate = final_dropout_rate
|
| 129 |
+
self.return_attention = return_attention
|
| 130 |
+
self.hidden_size = hidden_size
|
| 131 |
+
self.output_logits = output_logits
|
| 132 |
+
self.nb_classes = nb_classes
|
| 133 |
+
|
| 134 |
+
self.add_module('embed', nn.Embedding(nb_tokens, embedding_dim))
|
| 135 |
+
# dropout2D: embedding channels are dropped out instead of words
|
| 136 |
+
# many exampels in the datasets contain few words that losing one or more words can alter the emotions completely
|
| 137 |
+
self.add_module('embed_dropout', nn.Dropout2d(embed_dropout_rate))
|
| 138 |
+
self.add_module('lstm_0', LSTMHardSigmoid(embedding_dim, hidden_size, batch_first=True, bidirectional=True))
|
| 139 |
+
self.add_module('lstm_1', LSTMHardSigmoid(hidden_size*2, hidden_size, batch_first=True, bidirectional=True))
|
| 140 |
+
self.add_module('attention_layer', Attention(attention_size=attention_size, return_attention=return_attention))
|
| 141 |
+
if not feature_output:
|
| 142 |
+
self.add_module('final_dropout', nn.Dropout(final_dropout_rate))
|
| 143 |
+
if output_logits:
|
| 144 |
+
self.add_module('output_layer', nn.Sequential(nn.Linear(attention_size, nb_classes if self.nb_classes > 2 else 1)))
|
| 145 |
+
else:
|
| 146 |
+
self.add_module('output_layer', nn.Sequential(nn.Linear(attention_size, nb_classes if self.nb_classes > 2 else 1),
|
| 147 |
+
nn.Softmax(dim=1) if self.nb_classes > 2 else nn.Sigmoid()))
|
| 148 |
+
self.init_weights()
|
| 149 |
+
# Put model in evaluation mode by default
|
| 150 |
+
self.eval()
|
| 151 |
+
|
| 152 |
+
def init_weights(self):
|
| 153 |
+
"""
|
| 154 |
+
Here we reproduce Keras default initialization weights for consistency with Keras version
|
| 155 |
+
"""
|
| 156 |
+
ih = (param.data for name, param in self.named_parameters() if 'weight_ih' in name)
|
| 157 |
+
hh = (param.data for name, param in self.named_parameters() if 'weight_hh' in name)
|
| 158 |
+
b = (param.data for name, param in self.named_parameters() if 'bias' in name)
|
| 159 |
+
nn.init.uniform_(self.embed.weight.data, a=-0.5, b=0.5)
|
| 160 |
+
for t in ih:
|
| 161 |
+
nn.init.xavier_uniform_(t)
|
| 162 |
+
for t in hh:
|
| 163 |
+
nn.init.orthogonal_(t)
|
| 164 |
+
for t in b:
|
| 165 |
+
nn.init.constant_(t, 0)
|
| 166 |
+
if not self.feature_output:
|
| 167 |
+
nn.init.xavier_uniform_(self.output_layer[0].weight.data)
|
| 168 |
+
|
| 169 |
+
def forward(self, input_seqs):
|
| 170 |
+
""" Forward pass.
|
| 171 |
+
|
| 172 |
+
# Arguments:
|
| 173 |
+
input_seqs: Can be one of Numpy array, Torch.LongTensor, Torch.Variable, Torch.PackedSequence.
|
| 174 |
+
|
| 175 |
+
# Return:
|
| 176 |
+
Same format as input format (except for PackedSequence returned as Variable).
|
| 177 |
+
"""
|
| 178 |
+
# Check if we have Torch.LongTensor inputs or not Torch.Variable (assume Numpy array in this case), take note to return same format
|
| 179 |
+
return_numpy = False
|
| 180 |
+
if isinstance(input_seqs, (torch.LongTensor, torch.cuda.LongTensor)):
|
| 181 |
+
input_seqs = Variable(input_seqs)
|
| 182 |
+
elif not isinstance(input_seqs, Variable):
|
| 183 |
+
input_seqs = Variable(torch.from_numpy(input_seqs.astype('int64')).long())
|
| 184 |
+
return_numpy = True
|
| 185 |
+
|
| 186 |
+
# If we don't have a packed inputs, let's pack it
|
| 187 |
+
reorder_output = False
|
| 188 |
+
if not isinstance(input_seqs, PackedSequence):
|
| 189 |
+
ho = self.lstm_0.weight_hh_l0.data.new(2, input_seqs.size()[0], self.hidden_size).zero_()
|
| 190 |
+
co = self.lstm_0.weight_hh_l0.data.new(2, input_seqs.size()[0], self.hidden_size).zero_()
|
| 191 |
+
|
| 192 |
+
# Reorder batch by sequence length
|
| 193 |
+
input_lengths = torch.LongTensor([torch.max(input_seqs[i, :].data.nonzero()) + 1 for i in range(input_seqs.size()[0])])
|
| 194 |
+
input_lengths, perm_idx = input_lengths.sort(0, descending=True)
|
| 195 |
+
input_seqs = input_seqs[perm_idx][:, :input_lengths.max()]
|
| 196 |
+
|
| 197 |
+
# Pack sequence and work on data tensor to reduce embeddings/dropout computations
|
| 198 |
+
packed_input = pack_padded_sequence(input_seqs, input_lengths.cpu().numpy(), batch_first=True)
|
| 199 |
+
reorder_output = True
|
| 200 |
+
else:
|
| 201 |
+
ho = self.lstm_0.weight_hh_l0.data.data.new(2, input_seqs.size()[0], self.hidden_size).zero_()
|
| 202 |
+
co = self.lstm_0.weight_hh_l0.data.data.new(2, input_seqs.size()[0], self.hidden_size).zero_()
|
| 203 |
+
input_lengths = input_seqs.batch_sizes
|
| 204 |
+
packed_input = input_seqs
|
| 205 |
+
|
| 206 |
+
hidden = (Variable(ho, requires_grad=False), Variable(co, requires_grad=False))
|
| 207 |
+
|
| 208 |
+
# Embed with an activation function to bound the values of the embeddings
|
| 209 |
+
x = self.embed(packed_input.data)
|
| 210 |
+
x = nn.Tanh()(x)
|
| 211 |
+
|
| 212 |
+
# pyTorch 2D dropout2d operate on axis 1 which is fine for us
|
| 213 |
+
x = self.embed_dropout(x)
|
| 214 |
+
|
| 215 |
+
# Update packed sequence data for RNN
|
| 216 |
+
packed_input = PackedSequence(x, packed_input.batch_sizes)
|
| 217 |
+
|
| 218 |
+
# skip-connection from embedding to output eases gradient-flow and allows access to lower-level features
|
| 219 |
+
# ordering of the way the merge is done is important for consistency with the pretrained model
|
| 220 |
+
lstm_0_output, _ = self.lstm_0(packed_input, hidden)
|
| 221 |
+
lstm_1_output, _ = self.lstm_1(lstm_0_output, hidden)
|
| 222 |
+
|
| 223 |
+
# Update packed sequence data for attention layer
|
| 224 |
+
packed_input = PackedSequence(torch.cat((lstm_1_output.data,
|
| 225 |
+
lstm_0_output.data,
|
| 226 |
+
packed_input.data), dim=1),
|
| 227 |
+
packed_input.batch_sizes)
|
| 228 |
+
|
| 229 |
+
input_seqs, _ = pad_packed_sequence(packed_input, batch_first=True)
|
| 230 |
+
|
| 231 |
+
x, att_weights = self.attention_layer(input_seqs, input_lengths)
|
| 232 |
+
|
| 233 |
+
# output class probabilities or penultimate feature vector
|
| 234 |
+
if not self.feature_output:
|
| 235 |
+
x = self.final_dropout(x)
|
| 236 |
+
outputs = self.output_layer(x)
|
| 237 |
+
else:
|
| 238 |
+
outputs = x
|
| 239 |
+
|
| 240 |
+
# Reorder output if needed
|
| 241 |
+
if reorder_output:
|
| 242 |
+
reorered = Variable(outputs.data.new(outputs.size()))
|
| 243 |
+
reorered[perm_idx] = outputs
|
| 244 |
+
outputs = reorered
|
| 245 |
+
|
| 246 |
+
# Adapt return format if needed
|
| 247 |
+
if return_numpy:
|
| 248 |
+
outputs = outputs.data.numpy()
|
| 249 |
+
|
| 250 |
+
if self.return_attention:
|
| 251 |
+
return outputs, att_weights
|
| 252 |
+
else:
|
| 253 |
+
return outputs
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def load_specific_weights(model, weight_path, exclude_names=[], extend_embedding=0, verbose=True):
|
| 257 |
+
""" Loads model weights from the given file path, excluding any
|
| 258 |
+
given layers.
|
| 259 |
+
|
| 260 |
+
# Arguments:
|
| 261 |
+
model: Model whose weights should be loaded.
|
| 262 |
+
weight_path: Path to file containing model weights.
|
| 263 |
+
exclude_names: List of layer names whose weights should not be loaded.
|
| 264 |
+
extend_embedding: Number of new words being added to vocabulary.
|
| 265 |
+
verbose: Verbosity flag.
|
| 266 |
+
|
| 267 |
+
# Raises:
|
| 268 |
+
ValueError if the file at weight_path does not exist.
|
| 269 |
+
"""
|
| 270 |
+
if not exists(weight_path):
|
| 271 |
+
raise ValueError('ERROR (load_weights): The weights file at {} does '
|
| 272 |
+
'not exist. Refer to the README for instructions.'
|
| 273 |
+
.format(weight_path))
|
| 274 |
+
|
| 275 |
+
if extend_embedding and 'embed' in exclude_names:
|
| 276 |
+
raise ValueError('ERROR (load_weights): Cannot extend a vocabulary '
|
| 277 |
+
'without loading the embedding weights.')
|
| 278 |
+
|
| 279 |
+
# Copy only weights from the temporary model that are wanted
|
| 280 |
+
# for the specific task (e.g. the Softmax is often ignored)
|
| 281 |
+
weights = torch.load(weight_path)
|
| 282 |
+
for key, weight in weights.items():
|
| 283 |
+
if any(excluded in key for excluded in exclude_names):
|
| 284 |
+
if verbose:
|
| 285 |
+
print('Ignoring weights for {}'.format(key))
|
| 286 |
+
continue
|
| 287 |
+
|
| 288 |
+
try:
|
| 289 |
+
model_w = model.state_dict()[key]
|
| 290 |
+
except KeyError:
|
| 291 |
+
raise KeyError("Weights had parameters {},".format(key)
|
| 292 |
+
+ " but could not find this parameters in model.")
|
| 293 |
+
|
| 294 |
+
if verbose:
|
| 295 |
+
print('Loading weights for {}'.format(key))
|
| 296 |
+
|
| 297 |
+
# extend embedding layer to allow new randomly initialized words
|
| 298 |
+
# if requested. Otherwise, just load the weights for the layer.
|
| 299 |
+
if 'embed' in key and extend_embedding > 0:
|
| 300 |
+
weight = torch.cat((weight, model_w[NB_TOKENS:, :]), dim=0)
|
| 301 |
+
if verbose:
|
| 302 |
+
print('Extended vocabulary for embedding layer ' +
|
| 303 |
+
'from {} to {} tokens.'.format(
|
| 304 |
+
NB_TOKENS, NB_TOKENS + extend_embedding))
|
| 305 |
+
try:
|
| 306 |
+
model_w.copy_(weight)
|
| 307 |
+
except:
|
| 308 |
+
print('While copying the weigths named {}, whose dimensions in the model are'
|
| 309 |
+
' {} and whose dimensions in the saved file are {}, ...'.format(
|
| 310 |
+
key, model_w.size(), weight.size()))
|
| 311 |
+
raise
|
torchmoji/sentence_tokenizer.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
'''
|
| 3 |
+
Provides functionality for converting a given list of tokens (words) into
|
| 4 |
+
numbers, according to the given vocabulary.
|
| 5 |
+
'''
|
| 6 |
+
from __future__ import print_function, division, unicode_literals
|
| 7 |
+
|
| 8 |
+
import numbers
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from torchmoji.create_vocab import extend_vocab, VocabBuilder
|
| 12 |
+
from torchmoji.word_generator import WordGenerator
|
| 13 |
+
from torchmoji.global_variables import SPECIAL_TOKENS
|
| 14 |
+
|
| 15 |
+
# import torch
|
| 16 |
+
|
| 17 |
+
from sklearn.model_selection import train_test_split
|
| 18 |
+
|
| 19 |
+
from copy import deepcopy
|
| 20 |
+
|
| 21 |
+
class SentenceTokenizer():
|
| 22 |
+
""" Create numpy array of tokens corresponding to input sentences.
|
| 23 |
+
The vocabulary can include Unicode tokens.
|
| 24 |
+
"""
|
| 25 |
+
def __init__(self, vocabulary, fixed_length, custom_wordgen=None,
|
| 26 |
+
ignore_sentences_with_only_custom=False, masking_value=0,
|
| 27 |
+
unknown_value=1):
|
| 28 |
+
""" Needs a dictionary as input for the vocabulary.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
if len(vocabulary) > np.iinfo('uint16').max:
|
| 32 |
+
raise ValueError('Dictionary is too big ({} tokens) for the numpy '
|
| 33 |
+
'datatypes used (max limit={}). Reduce vocabulary'
|
| 34 |
+
' or adjust code accordingly!'
|
| 35 |
+
.format(len(vocabulary), np.iinfo('uint16').max))
|
| 36 |
+
|
| 37 |
+
# Shouldn't be able to modify the given vocabulary
|
| 38 |
+
self.vocabulary = deepcopy(vocabulary)
|
| 39 |
+
self.fixed_length = fixed_length
|
| 40 |
+
self.ignore_sentences_with_only_custom = ignore_sentences_with_only_custom
|
| 41 |
+
self.masking_value = masking_value
|
| 42 |
+
self.unknown_value = unknown_value
|
| 43 |
+
|
| 44 |
+
# Initialized with an empty stream of sentences that must then be fed
|
| 45 |
+
# to the generator at a later point for reusability.
|
| 46 |
+
# A custom word generator can be used for domain-specific filtering etc
|
| 47 |
+
if custom_wordgen is not None:
|
| 48 |
+
assert custom_wordgen.stream is None
|
| 49 |
+
self.wordgen = custom_wordgen
|
| 50 |
+
self.uses_custom_wordgen = True
|
| 51 |
+
else:
|
| 52 |
+
self.wordgen = WordGenerator(None, allow_unicode_text=True,
|
| 53 |
+
ignore_emojis=False,
|
| 54 |
+
remove_variation_selectors=True,
|
| 55 |
+
break_replacement=True)
|
| 56 |
+
self.uses_custom_wordgen = False
|
| 57 |
+
|
| 58 |
+
def tokenize_sentences(self, sentences, reset_stats=True, max_sentences=None):
|
| 59 |
+
""" Converts a given list of sentences into a numpy array according to
|
| 60 |
+
its vocabulary.
|
| 61 |
+
|
| 62 |
+
# Arguments:
|
| 63 |
+
sentences: List of sentences to be tokenized.
|
| 64 |
+
reset_stats: Whether the word generator's stats should be reset.
|
| 65 |
+
max_sentences: Maximum length of sentences. Must be set if the
|
| 66 |
+
length cannot be inferred from the input.
|
| 67 |
+
|
| 68 |
+
# Returns:
|
| 69 |
+
Numpy array of the tokenization sentences with masking,
|
| 70 |
+
infos,
|
| 71 |
+
stats
|
| 72 |
+
|
| 73 |
+
# Raises:
|
| 74 |
+
ValueError: When maximum length is not set and cannot be inferred.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
if max_sentences is None and not hasattr(sentences, '__len__'):
|
| 78 |
+
raise ValueError('Either you must provide an array with a length'
|
| 79 |
+
'attribute (e.g. a list) or specify the maximum '
|
| 80 |
+
'length yourself using `max_sentences`!')
|
| 81 |
+
n_sentences = (max_sentences if max_sentences is not None
|
| 82 |
+
else len(sentences))
|
| 83 |
+
|
| 84 |
+
if self.masking_value == 0:
|
| 85 |
+
tokens = np.zeros((n_sentences, self.fixed_length), dtype='uint16')
|
| 86 |
+
else:
|
| 87 |
+
tokens = (np.ones((n_sentences, self.fixed_length), dtype='uint16')
|
| 88 |
+
* self.masking_value)
|
| 89 |
+
|
| 90 |
+
if reset_stats:
|
| 91 |
+
self.wordgen.reset_stats()
|
| 92 |
+
|
| 93 |
+
# With a custom word generator info can be extracted from each
|
| 94 |
+
# sentence (e.g. labels)
|
| 95 |
+
infos = []
|
| 96 |
+
|
| 97 |
+
# Returns words as strings and then map them to vocabulary
|
| 98 |
+
self.wordgen.stream = sentences
|
| 99 |
+
next_insert = 0
|
| 100 |
+
n_ignored_unknowns = 0
|
| 101 |
+
for s_words, s_info in self.wordgen:
|
| 102 |
+
s_tokens = self.find_tokens(s_words)
|
| 103 |
+
|
| 104 |
+
if (self.ignore_sentences_with_only_custom and
|
| 105 |
+
np.all([True if t < len(SPECIAL_TOKENS)
|
| 106 |
+
else False for t in s_tokens])):
|
| 107 |
+
n_ignored_unknowns += 1
|
| 108 |
+
continue
|
| 109 |
+
if len(s_tokens) > self.fixed_length:
|
| 110 |
+
s_tokens = s_tokens[:self.fixed_length]
|
| 111 |
+
tokens[next_insert,:len(s_tokens)] = s_tokens
|
| 112 |
+
infos.append(s_info)
|
| 113 |
+
next_insert += 1
|
| 114 |
+
|
| 115 |
+
# For standard word generators all sentences should be tokenized
|
| 116 |
+
# this is not necessarily the case for custom wordgenerators as they
|
| 117 |
+
# may filter the sentences etc.
|
| 118 |
+
if not self.uses_custom_wordgen and not self.ignore_sentences_with_only_custom:
|
| 119 |
+
assert len(sentences) == next_insert
|
| 120 |
+
else:
|
| 121 |
+
# adjust based on actual tokens received
|
| 122 |
+
tokens = tokens[:next_insert]
|
| 123 |
+
infos = infos[:next_insert]
|
| 124 |
+
|
| 125 |
+
return tokens, infos, self.wordgen.stats
|
| 126 |
+
|
| 127 |
+
def find_tokens(self, words):
|
| 128 |
+
assert len(words) > 0
|
| 129 |
+
tokens = []
|
| 130 |
+
for w in words:
|
| 131 |
+
try:
|
| 132 |
+
tokens.append(self.vocabulary[w])
|
| 133 |
+
except KeyError:
|
| 134 |
+
tokens.append(self.unknown_value)
|
| 135 |
+
return tokens
|
| 136 |
+
|
| 137 |
+
def split_train_val_test(self, sentences, info_dicts,
|
| 138 |
+
split_parameter=[0.7, 0.1, 0.2], extend_with=0):
|
| 139 |
+
""" Splits given sentences into three different datasets: training,
|
| 140 |
+
validation and testing.
|
| 141 |
+
|
| 142 |
+
# Arguments:
|
| 143 |
+
sentences: The sentences to be tokenized.
|
| 144 |
+
info_dicts: A list of dicts that contain information about each
|
| 145 |
+
sentence (e.g. a label).
|
| 146 |
+
split_parameter: A parameter for deciding the splits between the
|
| 147 |
+
three different datasets. If instead of being passed three
|
| 148 |
+
values, three lists are passed, then these will be used to
|
| 149 |
+
specify which observation belong to which dataset.
|
| 150 |
+
extend_with: An optional parameter. If > 0 then this is the number
|
| 151 |
+
of tokens added to the vocabulary from this dataset. The
|
| 152 |
+
expanded vocab will be generated using only the training set,
|
| 153 |
+
but is applied to all three sets.
|
| 154 |
+
|
| 155 |
+
# Returns:
|
| 156 |
+
List of three lists of tokenized sentences,
|
| 157 |
+
|
| 158 |
+
List of three corresponding dictionaries with information,
|
| 159 |
+
|
| 160 |
+
How many tokens have been added to the vocab. Make sure to extend
|
| 161 |
+
the embedding layer of the model accordingly.
|
| 162 |
+
"""
|
| 163 |
+
|
| 164 |
+
# If passed three lists, use those directly
|
| 165 |
+
if isinstance(split_parameter, list) and \
|
| 166 |
+
all(isinstance(x, list) for x in split_parameter) and \
|
| 167 |
+
len(split_parameter) == 3:
|
| 168 |
+
|
| 169 |
+
# Helper function to verify provided indices are numbers in range
|
| 170 |
+
def verify_indices(inds):
|
| 171 |
+
return list(filter(lambda i: isinstance(i, numbers.Number)
|
| 172 |
+
and i < len(sentences), inds))
|
| 173 |
+
|
| 174 |
+
ind_train = verify_indices(split_parameter[0])
|
| 175 |
+
ind_val = verify_indices(split_parameter[1])
|
| 176 |
+
ind_test = verify_indices(split_parameter[2])
|
| 177 |
+
else:
|
| 178 |
+
# Split sentences and dicts
|
| 179 |
+
ind = list(range(len(sentences)))
|
| 180 |
+
ind_train, ind_test = train_test_split(ind, test_size=split_parameter[2])
|
| 181 |
+
ind_train, ind_val = train_test_split(ind_train, test_size=split_parameter[1])
|
| 182 |
+
|
| 183 |
+
# Map indices to data
|
| 184 |
+
train = np.array([sentences[x] for x in ind_train])
|
| 185 |
+
test = np.array([sentences[x] for x in ind_test])
|
| 186 |
+
val = np.array([sentences[x] for x in ind_val])
|
| 187 |
+
|
| 188 |
+
info_train = np.array([info_dicts[x] for x in ind_train])
|
| 189 |
+
info_test = np.array([info_dicts[x] for x in ind_test])
|
| 190 |
+
info_val = np.array([info_dicts[x] for x in ind_val])
|
| 191 |
+
|
| 192 |
+
added = 0
|
| 193 |
+
# Extend vocabulary with training set tokens
|
| 194 |
+
if extend_with > 0:
|
| 195 |
+
wg = WordGenerator(train)
|
| 196 |
+
vb = VocabBuilder(wg)
|
| 197 |
+
vb.count_all_words()
|
| 198 |
+
added = extend_vocab(self.vocabulary, vb, max_tokens=extend_with)
|
| 199 |
+
|
| 200 |
+
# Wrap results
|
| 201 |
+
result = [self.tokenize_sentences(s)[0] for s in [train, val, test]]
|
| 202 |
+
result_infos = [info_train, info_val, info_test]
|
| 203 |
+
# if type(result_infos[0][0]) in [np.double, np.float, np.int64, np.int32, np.uint8]:
|
| 204 |
+
# result_infos = [torch.from_numpy(label).long() for label in result_infos]
|
| 205 |
+
|
| 206 |
+
return result, result_infos, added
|
| 207 |
+
|
| 208 |
+
def to_sentence(self, sentence_idx):
|
| 209 |
+
""" Converts a tokenized sentence back to a list of words.
|
| 210 |
+
|
| 211 |
+
# Arguments:
|
| 212 |
+
sentence_idx: List of numbers, representing a tokenized sentence
|
| 213 |
+
given the current vocabulary.
|
| 214 |
+
|
| 215 |
+
# Returns:
|
| 216 |
+
String created by converting all numbers back to words and joined
|
| 217 |
+
together with spaces.
|
| 218 |
+
"""
|
| 219 |
+
# Have to recalculate the mappings in case the vocab was extended.
|
| 220 |
+
ind_to_word = {ind: word for word, ind in self.vocabulary.items()}
|
| 221 |
+
|
| 222 |
+
sentence_as_list = [ind_to_word[x] for x in sentence_idx]
|
| 223 |
+
cleaned_list = [x for x in sentence_as_list if x != 'CUSTOM_MASK']
|
| 224 |
+
return " ".join(cleaned_list)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def coverage(dataset, verbose=False):
|
| 228 |
+
""" Computes the percentage of words in a given dataset that are unknown.
|
| 229 |
+
|
| 230 |
+
# Arguments:
|
| 231 |
+
dataset: Tokenized dataset to be checked.
|
| 232 |
+
verbose: Verbosity flag.
|
| 233 |
+
|
| 234 |
+
# Returns:
|
| 235 |
+
Percentage of unknown tokens.
|
| 236 |
+
"""
|
| 237 |
+
n_total = np.count_nonzero(dataset)
|
| 238 |
+
n_unknown = np.sum(dataset == 1)
|
| 239 |
+
coverage = 1.0 - float(n_unknown) / n_total
|
| 240 |
+
|
| 241 |
+
if verbose:
|
| 242 |
+
print("Unknown words: {}".format(n_unknown))
|
| 243 |
+
print("Total words: {}".format(n_total))
|
| 244 |
+
print("Coverage: {}".format(coverage))
|
| 245 |
+
return coverage
|
torchmoji/tokenizer.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
'''
|
| 3 |
+
Splits up a Unicode string into a list of tokens.
|
| 4 |
+
Recognises:
|
| 5 |
+
- Abbreviations
|
| 6 |
+
- URLs
|
| 7 |
+
- Emails
|
| 8 |
+
- #hashtags
|
| 9 |
+
- @mentions
|
| 10 |
+
- emojis
|
| 11 |
+
- emoticons (limited support)
|
| 12 |
+
|
| 13 |
+
Multiple consecutive symbols are also treated as a single token.
|
| 14 |
+
'''
|
| 15 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
| 16 |
+
|
| 17 |
+
import re
|
| 18 |
+
|
| 19 |
+
# Basic patterns.
|
| 20 |
+
RE_NUM = r'[0-9]+'
|
| 21 |
+
RE_WORD = r'[a-zA-Z]+'
|
| 22 |
+
RE_WHITESPACE = r'\s+'
|
| 23 |
+
RE_ANY = r'.'
|
| 24 |
+
|
| 25 |
+
# Combined words such as 'red-haired' or 'CUSTOM_TOKEN'
|
| 26 |
+
RE_COMB = r'[a-zA-Z]+[-_][a-zA-Z]+'
|
| 27 |
+
|
| 28 |
+
# English-specific patterns
|
| 29 |
+
RE_CONTRACTIONS = RE_WORD + r'\'' + RE_WORD
|
| 30 |
+
|
| 31 |
+
TITLES = [
|
| 32 |
+
r'Mr\.',
|
| 33 |
+
r'Ms\.',
|
| 34 |
+
r'Mrs\.',
|
| 35 |
+
r'Dr\.',
|
| 36 |
+
r'Prof\.',
|
| 37 |
+
r'mr\.',
|
| 38 |
+
r'ms\.',
|
| 39 |
+
r'mrs\.',
|
| 40 |
+
r'dr\.',
|
| 41 |
+
r'prof\.',
|
| 42 |
+
]
|
| 43 |
+
# Ensure case insensitivity
|
| 44 |
+
RE_TITLES = r'|'.join([r'' + t for t in TITLES])
|
| 45 |
+
|
| 46 |
+
# Symbols have to be created as separate patterns in order to match consecutive
|
| 47 |
+
# identical symbols.
|
| 48 |
+
SYMBOLS = r'()<!?.,/\'\"-_=\\§|´ˇ°[]<>{}~$^&*;:%+\xa3€`'
|
| 49 |
+
RE_SYMBOL = r'|'.join([re.escape(s) + r'+' for s in SYMBOLS])
|
| 50 |
+
|
| 51 |
+
# Hash symbols and at symbols have to be defined separately in order to not
|
| 52 |
+
# clash with hashtags and mentions if there are multiple - i.e.
|
| 53 |
+
# ##hello -> ['#', '#hello'] instead of ['##', 'hello']
|
| 54 |
+
SPECIAL_SYMBOLS = r'|#+(?=#[a-zA-Z0-9_]+)|@+(?=@[a-zA-Z0-9_]+)|#+|@+'
|
| 55 |
+
RE_SYMBOL += SPECIAL_SYMBOLS
|
| 56 |
+
|
| 57 |
+
RE_ABBREVIATIONS = r'\b(?<!\.)(?:[A-Za-z]\.){2,}'
|
| 58 |
+
|
| 59 |
+
# Twitter-specific patterns
|
| 60 |
+
RE_HASHTAG = r'#[a-zA-Z0-9_]+'
|
| 61 |
+
RE_MENTION = r'@[a-zA-Z0-9_]+'
|
| 62 |
+
|
| 63 |
+
RE_URL = r'(?:https?://|www\.)(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
|
| 64 |
+
RE_EMAIL = r'\b[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+\b'
|
| 65 |
+
|
| 66 |
+
# Emoticons and emojis
|
| 67 |
+
RE_HEART = r'(?:<+/?3+)+'
|
| 68 |
+
EMOTICONS_START = [
|
| 69 |
+
r'>:',
|
| 70 |
+
r':',
|
| 71 |
+
r'=',
|
| 72 |
+
r';',
|
| 73 |
+
]
|
| 74 |
+
EMOTICONS_MID = [
|
| 75 |
+
r'-',
|
| 76 |
+
r',',
|
| 77 |
+
r'^',
|
| 78 |
+
'\'',
|
| 79 |
+
'\"',
|
| 80 |
+
]
|
| 81 |
+
EMOTICONS_END = [
|
| 82 |
+
r'D',
|
| 83 |
+
r'd',
|
| 84 |
+
r'p',
|
| 85 |
+
r'P',
|
| 86 |
+
r'v',
|
| 87 |
+
r')',
|
| 88 |
+
r'o',
|
| 89 |
+
r'O',
|
| 90 |
+
r'(',
|
| 91 |
+
r'3',
|
| 92 |
+
r'/',
|
| 93 |
+
r'|',
|
| 94 |
+
'\\',
|
| 95 |
+
]
|
| 96 |
+
EMOTICONS_EXTRA = [
|
| 97 |
+
r'-_-',
|
| 98 |
+
r'x_x',
|
| 99 |
+
r'^_^',
|
| 100 |
+
r'o.o',
|
| 101 |
+
r'o_o',
|
| 102 |
+
r'(:',
|
| 103 |
+
r'):',
|
| 104 |
+
r');',
|
| 105 |
+
r'(;',
|
| 106 |
+
]
|
| 107 |
+
|
| 108 |
+
RE_EMOTICON = r'|'.join([re.escape(s) for s in EMOTICONS_EXTRA])
|
| 109 |
+
for s in EMOTICONS_START:
|
| 110 |
+
for m in EMOTICONS_MID:
|
| 111 |
+
for e in EMOTICONS_END:
|
| 112 |
+
RE_EMOTICON += '|{0}{1}?{2}+'.format(re.escape(s), re.escape(m), re.escape(e))
|
| 113 |
+
|
| 114 |
+
# requires ucs4 in python2.7 or python3+
|
| 115 |
+
# RE_EMOJI = r"""[\U0001F300-\U0001F64F\U0001F680-\U0001F6FF\u2600-\u26FF\u2700-\u27BF]"""
|
| 116 |
+
# safe for all python
|
| 117 |
+
RE_EMOJI = r"""\ud83c[\udf00-\udfff]|\ud83d[\udc00-\ude4f\ude80-\udeff]|[\u2600-\u26FF\u2700-\u27BF]"""
|
| 118 |
+
|
| 119 |
+
# List of matched token patterns, ordered from most specific to least specific.
|
| 120 |
+
TOKENS = [
|
| 121 |
+
RE_URL,
|
| 122 |
+
RE_EMAIL,
|
| 123 |
+
RE_COMB,
|
| 124 |
+
RE_HASHTAG,
|
| 125 |
+
RE_MENTION,
|
| 126 |
+
RE_HEART,
|
| 127 |
+
RE_EMOTICON,
|
| 128 |
+
RE_CONTRACTIONS,
|
| 129 |
+
RE_TITLES,
|
| 130 |
+
RE_ABBREVIATIONS,
|
| 131 |
+
RE_NUM,
|
| 132 |
+
RE_WORD,
|
| 133 |
+
RE_SYMBOL,
|
| 134 |
+
RE_EMOJI,
|
| 135 |
+
RE_ANY
|
| 136 |
+
]
|
| 137 |
+
|
| 138 |
+
# List of ignored token patterns
|
| 139 |
+
IGNORED = [
|
| 140 |
+
RE_WHITESPACE
|
| 141 |
+
]
|
| 142 |
+
|
| 143 |
+
# Final pattern
|
| 144 |
+
|
| 145 |
+
RE_PATTERN = re.compile(r'|'.join(IGNORED) + r'|\(' + r'|'.join(TOKENS) + r'\)',
|
| 146 |
+
re.UNICODE)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def tokenize(text):
|
| 150 |
+
'''Splits given input string into a list of tokens.
|
| 151 |
+
|
| 152 |
+
# Arguments:
|
| 153 |
+
text: Input string to be tokenized.
|
| 154 |
+
|
| 155 |
+
# Returns:
|
| 156 |
+
List of strings (tokens).
|
| 157 |
+
'''
|
| 158 |
+
result = RE_PATTERN.findall(text)
|
| 159 |
+
|
| 160 |
+
# Remove empty strings
|
| 161 |
+
result = [t for t in result if t.strip()]
|
| 162 |
+
return result
|
torchmoji/word_generator.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
''' Extracts lists of words from a given input to be used for later vocabulary
|
| 3 |
+
generation or for creating tokenized datasets.
|
| 4 |
+
Supports functionality for handling different file types and
|
| 5 |
+
filtering/processing of this input.
|
| 6 |
+
'''
|
| 7 |
+
|
| 8 |
+
from __future__ import division, print_function, unicode_literals
|
| 9 |
+
|
| 10 |
+
import re
|
| 11 |
+
import unicodedata
|
| 12 |
+
import numpy as np
|
| 13 |
+
from text_unidecode import unidecode
|
| 14 |
+
|
| 15 |
+
from torchmoji.tokenizer import RE_MENTION, tokenize
|
| 16 |
+
from torchmoji.filter_utils import (convert_linebreaks,
|
| 17 |
+
convert_nonbreaking_space,
|
| 18 |
+
correct_length,
|
| 19 |
+
extract_emojis,
|
| 20 |
+
mostly_english,
|
| 21 |
+
non_english_user,
|
| 22 |
+
process_word,
|
| 23 |
+
punct_word,
|
| 24 |
+
remove_control_chars,
|
| 25 |
+
remove_variation_selectors,
|
| 26 |
+
separate_emojis_and_text)
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
unicode # Python 2
|
| 30 |
+
except NameError:
|
| 31 |
+
unicode = str # Python 3
|
| 32 |
+
|
| 33 |
+
# Only catch retweets in the beginning of the tweet as those are the
|
| 34 |
+
# automatically added ones.
|
| 35 |
+
# We do not want to remove tweets like "Omg.. please RT this!!"
|
| 36 |
+
RETWEETS_RE = re.compile(r'^[rR][tT]')
|
| 37 |
+
|
| 38 |
+
# Use fast and less precise regex for removing tweets with URLs
|
| 39 |
+
# It doesn't matter too much if a few tweets with URL's make it through
|
| 40 |
+
URLS_RE = re.compile(r'https?://|www\.')
|
| 41 |
+
|
| 42 |
+
MENTION_RE = re.compile(RE_MENTION)
|
| 43 |
+
ALLOWED_CONVERTED_UNICODE_PUNCTUATION = """!"#$'()+,-.:;<=>?@`~"""
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class WordGenerator():
|
| 47 |
+
''' Cleanses input and converts into words. Needs all sentences to be in
|
| 48 |
+
Unicode format. Has subclasses that read sentences differently based on
|
| 49 |
+
file type.
|
| 50 |
+
|
| 51 |
+
Takes a generator as input. This can be from e.g. a file.
|
| 52 |
+
unicode_handling in ['ignore_sentence', 'convert_punctuation', 'allow']
|
| 53 |
+
unicode_handling in ['ignore_emoji', 'ignore_sentence', 'allow']
|
| 54 |
+
'''
|
| 55 |
+
def __init__(self, stream, allow_unicode_text=False, ignore_emojis=True,
|
| 56 |
+
remove_variation_selectors=True, break_replacement=True):
|
| 57 |
+
self.stream = stream
|
| 58 |
+
self.allow_unicode_text = allow_unicode_text
|
| 59 |
+
self.remove_variation_selectors = remove_variation_selectors
|
| 60 |
+
self.ignore_emojis = ignore_emojis
|
| 61 |
+
self.break_replacement = break_replacement
|
| 62 |
+
self.reset_stats()
|
| 63 |
+
|
| 64 |
+
def get_words(self, sentence):
|
| 65 |
+
""" Tokenizes a sentence into individual words.
|
| 66 |
+
Converts Unicode punctuation into ASCII if that option is set.
|
| 67 |
+
Ignores sentences with Unicode if that option is set.
|
| 68 |
+
Returns an empty list of words if the sentence has Unicode and
|
| 69 |
+
that is not allowed.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
if not isinstance(sentence, unicode):
|
| 73 |
+
raise ValueError("All sentences should be Unicode-encoded!")
|
| 74 |
+
sentence = sentence.strip().lower()
|
| 75 |
+
|
| 76 |
+
if self.break_replacement:
|
| 77 |
+
sentence = convert_linebreaks(sentence)
|
| 78 |
+
|
| 79 |
+
if self.remove_variation_selectors:
|
| 80 |
+
sentence = remove_variation_selectors(sentence)
|
| 81 |
+
|
| 82 |
+
# Split into words using simple whitespace splitting and convert
|
| 83 |
+
# Unicode. This is done to prevent word splitting issues with
|
| 84 |
+
# twokenize and Unicode
|
| 85 |
+
words = sentence.split()
|
| 86 |
+
converted_words = []
|
| 87 |
+
for w in words:
|
| 88 |
+
accept_sentence, c_w = self.convert_unicode_word(w)
|
| 89 |
+
# Unicode word detected and not allowed
|
| 90 |
+
if not accept_sentence:
|
| 91 |
+
return []
|
| 92 |
+
else:
|
| 93 |
+
converted_words.append(c_w)
|
| 94 |
+
sentence = ' '.join(converted_words)
|
| 95 |
+
|
| 96 |
+
words = tokenize(sentence)
|
| 97 |
+
words = [process_word(w) for w in words]
|
| 98 |
+
return words
|
| 99 |
+
|
| 100 |
+
def check_ascii(self, word):
|
| 101 |
+
""" Returns whether a word is ASCII """
|
| 102 |
+
|
| 103 |
+
try:
|
| 104 |
+
word.decode('ascii')
|
| 105 |
+
return True
|
| 106 |
+
except (UnicodeDecodeError, UnicodeEncodeError, AttributeError):
|
| 107 |
+
return False
|
| 108 |
+
|
| 109 |
+
def convert_unicode_punctuation(self, word):
|
| 110 |
+
word_converted_punct = []
|
| 111 |
+
for c in word:
|
| 112 |
+
decoded_c = unidecode(c).lower()
|
| 113 |
+
if len(decoded_c) == 0:
|
| 114 |
+
# Cannot decode to anything reasonable
|
| 115 |
+
word_converted_punct.append(c)
|
| 116 |
+
else:
|
| 117 |
+
# Check if all punctuation and therefore fine
|
| 118 |
+
# to include unidecoded version
|
| 119 |
+
allowed_punct = punct_word(
|
| 120 |
+
decoded_c,
|
| 121 |
+
punctuation=ALLOWED_CONVERTED_UNICODE_PUNCTUATION)
|
| 122 |
+
|
| 123 |
+
if allowed_punct:
|
| 124 |
+
word_converted_punct.append(decoded_c)
|
| 125 |
+
else:
|
| 126 |
+
word_converted_punct.append(c)
|
| 127 |
+
return ''.join(word_converted_punct)
|
| 128 |
+
|
| 129 |
+
def convert_unicode_word(self, word):
|
| 130 |
+
""" Converts Unicode words to ASCII using unidecode. If Unicode is not
|
| 131 |
+
allowed (set as a variable during initialization), then only
|
| 132 |
+
punctuation that can be converted to ASCII will be allowed.
|
| 133 |
+
"""
|
| 134 |
+
if self.check_ascii(word):
|
| 135 |
+
return True, word
|
| 136 |
+
|
| 137 |
+
# First we ensure that the Unicode is normalized so it's
|
| 138 |
+
# always a single character.
|
| 139 |
+
word = unicodedata.normalize("NFKC", word)
|
| 140 |
+
|
| 141 |
+
# Convert Unicode punctuation to ASCII equivalent. We want
|
| 142 |
+
# e.g. "\u203c" (double exclamation mark) to be treated the same
|
| 143 |
+
# as "!!" no matter if we allow other Unicode characters or not.
|
| 144 |
+
word = self.convert_unicode_punctuation(word)
|
| 145 |
+
|
| 146 |
+
if self.ignore_emojis:
|
| 147 |
+
_, word = separate_emojis_and_text(word)
|
| 148 |
+
|
| 149 |
+
# If conversion of punctuation and removal of emojis took care
|
| 150 |
+
# of all the Unicode or if we allow Unicode then everything is fine
|
| 151 |
+
if self.check_ascii(word) or self.allow_unicode_text:
|
| 152 |
+
return True, word
|
| 153 |
+
else:
|
| 154 |
+
# Sometimes we might want to simply ignore Unicode sentences
|
| 155 |
+
# (e.g. for vocabulary creation). This is another way to prevent
|
| 156 |
+
# "polution" of strange Unicode tokens from low quality datasets
|
| 157 |
+
return False, ''
|
| 158 |
+
|
| 159 |
+
def data_preprocess_filtering(self, line, iter_i):
|
| 160 |
+
""" To be overridden with specific preprocessing/filtering behavior
|
| 161 |
+
if desired.
|
| 162 |
+
|
| 163 |
+
Returns a boolean of whether the line should be accepted and the
|
| 164 |
+
preprocessed text.
|
| 165 |
+
|
| 166 |
+
Runs prior to tokenization.
|
| 167 |
+
"""
|
| 168 |
+
return True, line, {}
|
| 169 |
+
|
| 170 |
+
def data_postprocess_filtering(self, words, iter_i):
|
| 171 |
+
""" To be overridden with specific postprocessing/filtering behavior
|
| 172 |
+
if desired.
|
| 173 |
+
|
| 174 |
+
Returns a boolean of whether the line should be accepted and the
|
| 175 |
+
postprocessed text.
|
| 176 |
+
|
| 177 |
+
Runs after tokenization.
|
| 178 |
+
"""
|
| 179 |
+
return True, words, {}
|
| 180 |
+
|
| 181 |
+
def extract_valid_sentence_words(self, line):
|
| 182 |
+
""" Line may either a string of a list of strings depending on how
|
| 183 |
+
the stream is being parsed.
|
| 184 |
+
Domain-specific processing and filtering can be done both prior to
|
| 185 |
+
and after tokenization.
|
| 186 |
+
Custom information about the line can be extracted during the
|
| 187 |
+
processing phases and returned as a dict.
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
info = {}
|
| 191 |
+
|
| 192 |
+
pre_valid, pre_line, pre_info = \
|
| 193 |
+
self.data_preprocess_filtering(line, self.stats['total'])
|
| 194 |
+
info.update(pre_info)
|
| 195 |
+
if not pre_valid:
|
| 196 |
+
self.stats['pretokenization_filtered'] += 1
|
| 197 |
+
return False, [], info
|
| 198 |
+
|
| 199 |
+
words = self.get_words(pre_line)
|
| 200 |
+
if len(words) == 0:
|
| 201 |
+
self.stats['unicode_filtered'] += 1
|
| 202 |
+
return False, [], info
|
| 203 |
+
|
| 204 |
+
post_valid, post_words, post_info = \
|
| 205 |
+
self.data_postprocess_filtering(words, self.stats['total'])
|
| 206 |
+
info.update(post_info)
|
| 207 |
+
if not post_valid:
|
| 208 |
+
self.stats['posttokenization_filtered'] += 1
|
| 209 |
+
return post_valid, post_words, info
|
| 210 |
+
|
| 211 |
+
def generate_array_from_input(self):
|
| 212 |
+
sentences = []
|
| 213 |
+
for words in self:
|
| 214 |
+
sentences.append(words)
|
| 215 |
+
return sentences
|
| 216 |
+
|
| 217 |
+
def reset_stats(self):
|
| 218 |
+
self.stats = {'pretokenization_filtered': 0,
|
| 219 |
+
'unicode_filtered': 0,
|
| 220 |
+
'posttokenization_filtered': 0,
|
| 221 |
+
'total': 0,
|
| 222 |
+
'valid': 0}
|
| 223 |
+
|
| 224 |
+
def __iter__(self):
|
| 225 |
+
if self.stream is None:
|
| 226 |
+
raise ValueError("Stream should be set before iterating over it!")
|
| 227 |
+
|
| 228 |
+
for line in self.stream:
|
| 229 |
+
valid, words, info = self.extract_valid_sentence_words(line)
|
| 230 |
+
|
| 231 |
+
# Words may be filtered away due to unidecode etc.
|
| 232 |
+
# In that case the words should not be passed on.
|
| 233 |
+
if valid and len(words):
|
| 234 |
+
self.stats['valid'] += 1
|
| 235 |
+
yield words, info
|
| 236 |
+
|
| 237 |
+
self.stats['total'] += 1
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class TweetWordGenerator(WordGenerator):
|
| 241 |
+
''' Returns np array or generator of ASCII sentences for given tweet input.
|
| 242 |
+
Any file opening/closing should be handled outside of this class.
|
| 243 |
+
'''
|
| 244 |
+
def __init__(self, stream, wanted_emojis=None, english_words=None,
|
| 245 |
+
non_english_user_set=None, allow_unicode_text=False,
|
| 246 |
+
ignore_retweets=True, ignore_url_tweets=True,
|
| 247 |
+
ignore_mention_tweets=False):
|
| 248 |
+
|
| 249 |
+
self.wanted_emojis = wanted_emojis
|
| 250 |
+
self.english_words = english_words
|
| 251 |
+
self.non_english_user_set = non_english_user_set
|
| 252 |
+
self.ignore_retweets = ignore_retweets
|
| 253 |
+
self.ignore_url_tweets = ignore_url_tweets
|
| 254 |
+
self.ignore_mention_tweets = ignore_mention_tweets
|
| 255 |
+
WordGenerator.__init__(self, stream,
|
| 256 |
+
allow_unicode_text=allow_unicode_text)
|
| 257 |
+
|
| 258 |
+
def validated_tweet(self, data):
|
| 259 |
+
''' A bunch of checks to determine whether the tweet is valid.
|
| 260 |
+
Also returns emojis contained by the tweet.
|
| 261 |
+
'''
|
| 262 |
+
|
| 263 |
+
# Ordering of validations is important for speed
|
| 264 |
+
# If it passes all checks, then the tweet is validated for usage
|
| 265 |
+
|
| 266 |
+
# Skips incomplete tweets
|
| 267 |
+
if len(data) <= 9:
|
| 268 |
+
return False, []
|
| 269 |
+
|
| 270 |
+
text = data[9]
|
| 271 |
+
|
| 272 |
+
if self.ignore_retweets and RETWEETS_RE.search(text):
|
| 273 |
+
return False, []
|
| 274 |
+
|
| 275 |
+
if self.ignore_url_tweets and URLS_RE.search(text):
|
| 276 |
+
return False, []
|
| 277 |
+
|
| 278 |
+
if self.ignore_mention_tweets and MENTION_RE.search(text):
|
| 279 |
+
return False, []
|
| 280 |
+
|
| 281 |
+
if self.wanted_emojis is not None:
|
| 282 |
+
uniq_emojis = np.unique(extract_emojis(text, self.wanted_emojis))
|
| 283 |
+
if len(uniq_emojis) == 0:
|
| 284 |
+
return False, []
|
| 285 |
+
else:
|
| 286 |
+
uniq_emojis = []
|
| 287 |
+
|
| 288 |
+
if self.non_english_user_set is not None and \
|
| 289 |
+
non_english_user(data[1], self.non_english_user_set):
|
| 290 |
+
return False, []
|
| 291 |
+
return True, uniq_emojis
|
| 292 |
+
|
| 293 |
+
def data_preprocess_filtering(self, line, iter_i):
|
| 294 |
+
fields = line.strip().split("\t")
|
| 295 |
+
valid, emojis = self.validated_tweet(fields)
|
| 296 |
+
text = fields[9].replace('\\n', '') \
|
| 297 |
+
.replace('\\r', '') \
|
| 298 |
+
.replace('&', '&') if valid else ''
|
| 299 |
+
return valid, text, {'emojis': emojis}
|
| 300 |
+
|
| 301 |
+
def data_postprocess_filtering(self, words, iter_i):
|
| 302 |
+
valid_length = correct_length(words, 1, None)
|
| 303 |
+
valid_english, n_words, n_english = mostly_english(words,
|
| 304 |
+
self.english_words)
|
| 305 |
+
if valid_length and valid_english:
|
| 306 |
+
return True, words, {'length': len(words),
|
| 307 |
+
'n_normal_words': n_words,
|
| 308 |
+
'n_english': n_english}
|
| 309 |
+
else:
|
| 310 |
+
return False, [], {'length': len(words),
|
| 311 |
+
'n_normal_words': n_words,
|
| 312 |
+
'n_english': n_english}
|