File size: 9,154 Bytes
2359bda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
"""
This file contains an example how to make a SentenceTransformer model faster and lighter.

This is achieved by using Knowledge Distillation: We use a well working teacher model to train
a fast and light student model. The student model learns to imitate the produced
sentence embeddings from the teacher. We train this on a diverse set of sentences we got
from SNLI + Multi+NLI + Wikipedia.

After the distillation is finished, the student model produce nearly the same embeddings as the
teacher, however, it will be much faster.

The script implements to options two options to initialize the student:
Option 1: Train a light transformer model like TinyBERT to imitate the teacher
Option 2: We take the teacher model and keep only certain layers, for example, only 4 layers.

Option 2) works usually better, as we keep most of the weights from the teacher. In Option 1, we have to tune all
weights in the student from scratch.

There is a performance - speed trade-off. However, we found that a student with 4 instead of 12 layers keeps about 99.4%
of the teacher performance, while being 2.3 times faster.
"""
from torch.utils.data import DataLoader
from sentence_transformers import models, losses, evaluation
from sentence_transformers import LoggingHandler, SentenceTransformer, util, InputExample
from sentence_transformers.datasets import ParallelSentencesDataset
import logging
from datetime import datetime
import os
import gzip
import csv
import random
from sklearn.decomposition import PCA
import torch


#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
#### /print debug information to stdout


# Teacher Model: Model we want to distill to a smaller model
teacher_model_name = 'stsb-roberta-base-v2'
teacher_model = SentenceTransformer(teacher_model_name)

output_path = "output/model-distillation-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")


use_layer_reduction = True

#There are two options to create a light and fast student model:
if use_layer_reduction:
    # 1) Create a smaller student model by using only some of the teacher layers
    student_model = SentenceTransformer(teacher_model_name)

    # Get the transformer model
    auto_model = student_model._first_module().auto_model

    # Which layers to keep from the teacher model. We equally spread the layers to keep over the original teacher
    #layers_to_keep = [5]
    #layers_to_keep = [3, 7]
    #layers_to_keep = [3, 7, 11]
    layers_to_keep = [1, 4, 7, 10]          #Keep 4 layers from the teacher
    #layers_to_keep = [0, 2, 4, 6, 8, 10]
    #layers_to_keep = [0, 1, 3, 4, 6, 7, 9, 10]

    logging.info("Remove layers from student. Only keep these layers: {}".format(layers_to_keep))
    new_layers = torch.nn.ModuleList([layer_module for i, layer_module in enumerate(auto_model.encoder.layer) if i in layers_to_keep])
    auto_model.encoder.layer = new_layers
    auto_model.config.num_hidden_layers = len(layers_to_keep)
else:
    # 2) The other option is to train a small model like TinyBERT to imitate the teacher.
    # You can find some small BERT models here: https://huggingface.co/nreimers
    word_embedding_model = models.Transformer('nreimers/TinyBERT_L-4_H-312_v2')
    pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
    student_model = SentenceTransformer(modules=[word_embedding_model, pooling_model])



inference_batch_size = 64
train_batch_size = 64



#We use AllNLI as a source of sentences for the distillation
nli_dataset_path = 'datasets/AllNLI.tsv.gz'

#Further, we use sentences extracted from the English Wikipedia to train the distillation
wikipedia_dataset_path = 'datasets/wikipedia-en-sentences.txt.gz'

#We use the STS benchmark dataset to see how much performance we loose
sts_dataset_path = 'datasets/stsbenchmark.tsv.gz'


#Download datasets if needed
if not os.path.exists(nli_dataset_path):
    util.http_get('https://sbert.net/datasets/AllNLI.tsv.gz', nli_dataset_path)

if not os.path.exists(wikipedia_dataset_path):
    util.http_get('https://sbert.net/datasets/wikipedia-en-sentences.txt.gz', wikipedia_dataset_path)

if not os.path.exists(sts_dataset_path):
    util.http_get('https://sbert.net/datasets/stsbenchmark.tsv.gz', sts_dataset_path)

#We need sentences to train our distillation. Here, we use sentences from AllNLI and from WikiPedia
train_sentences_nli = set()
dev_sentences_nli = set()

train_sentences_wikipedia = []
dev_sentences_wikipedia = []

# Read ALLNLI
with gzip.open(nli_dataset_path, 'rt', encoding='utf8') as fIn:
    reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE)
    for row in reader:
        if row['split'] == 'dev':
            dev_sentences_nli.add(row['sentence1'])
            dev_sentences_nli.add(row['sentence2'])
        else:
            train_sentences_nli.add(row['sentence1'])
            train_sentences_nli.add(row['sentence2'])

train_sentences_nli = list(train_sentences_nli)
random.shuffle(train_sentences_nli)

dev_sentences_nli = list(dev_sentences_nli)
random.shuffle(dev_sentences_nli)
dev_sentences_nli = dev_sentences_nli[0:5000] #Limit dev sentences to 5k

# Read Wikipedia sentences file
with gzip.open(wikipedia_dataset_path, 'rt', encoding='utf8') as fIn:
    wikipeda_sentences = [line.strip() for line in fIn]

dev_sentences_wikipedia = wikipeda_sentences[0:5000] #Use the first 5k sentences from the wikipedia file for development
train_sentences_wikipedia = wikipeda_sentences[5000:]


# We use the STS benchmark dataset to measure the performance of student model im comparison to the teacher model
logging.info("Read STSbenchmark dev dataset")
dev_samples = []
with gzip.open(sts_dataset_path, 'rt', encoding='utf8') as fIn:
    reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE)
    for row in reader:
        if row['split'] == 'dev':
            score = float(row['score']) / 5.0 #Normalize score to range 0 ... 1
            dev_samples.append(InputExample(texts=[row['sentence1'], row['sentence2']], label=score))

dev_evaluator_sts = evaluation.EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='sts-dev')


logging.info("Teacher Performance:")
dev_evaluator_sts(teacher_model)

# Student model has fewer dimensions. Compute PCA for the teacher to reduce the dimensions
if student_model.get_sentence_embedding_dimension() < teacher_model.get_sentence_embedding_dimension():
    logging.info("Student model has fewer dimensions than the teacher. Compute PCA for down projection")
    pca_sentences = train_sentences_nli[0:20000] + train_sentences_wikipedia[0:20000]
    pca_embeddings = teacher_model.encode(pca_sentences, convert_to_numpy=True)
    pca = PCA(n_components=student_model.get_sentence_embedding_dimension())
    pca.fit(pca_embeddings)

    #Add Dense layer to teacher that projects the embeddings down to the student embedding size
    dense = models.Dense(in_features=teacher_model.get_sentence_embedding_dimension(), out_features=student_model.get_sentence_embedding_dimension(), bias=False, activation_function=torch.nn.Identity())
    dense.linear.weight = torch.nn.Parameter(torch.tensor(pca.components_))
    teacher_model.add_module('dense', dense)

    logging.info("Teacher Performance with {} dimensions:".format(teacher_model.get_sentence_embedding_dimension()))
    dev_evaluator_sts(teacher_model)



# We train the student_model such that it creates sentence embeddings similar to the embeddings from the teacher_model
# For this, we need a large set of sentences. These sentences are embedded using the teacher model,
# and the student tries to mimic these embeddings. It is the same approach as used in: https://arxiv.org/abs/2004.09813
train_data = ParallelSentencesDataset(student_model=student_model, teacher_model=teacher_model, batch_size=inference_batch_size, use_embedding_cache=False)
train_data.add_dataset([[sent] for sent in train_sentences_nli], max_sentence_length=256)
train_data.add_dataset([[sent] for sent in train_sentences_wikipedia], max_sentence_length=256)

train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size)
train_loss = losses.MSELoss(model=student_model)

# We create an evaluator, that measure the Mean Squared Error (MSE) between the teacher and the student embeddings
dev_sentences = dev_sentences_nli + dev_sentences_wikipedia
dev_evaluator_mse = evaluation.MSEEvaluator(dev_sentences, dev_sentences, teacher_model=teacher_model)

# Train the student model to imitate the teacher
student_model.fit(train_objectives=[(train_dataloader, train_loss)],
                  evaluator=evaluation.SequentialEvaluator([dev_evaluator_sts, dev_evaluator_mse]),
                  epochs=1,
                  warmup_steps=1000,
                  evaluation_steps=5000,
                  output_path=output_path,
                  save_best_model=True,
                  optimizer_params={'lr': 1e-4, 'eps': 1e-6, 'correct_bias': False},
                  use_amp=True)