File size: 9,154 Bytes
2359bda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
"""
This file contains an example how to make a SentenceTransformer model faster and lighter.
This is achieved by using Knowledge Distillation: We use a well working teacher model to train
a fast and light student model. The student model learns to imitate the produced
sentence embeddings from the teacher. We train this on a diverse set of sentences we got
from SNLI + Multi+NLI + Wikipedia.
After the distillation is finished, the student model produce nearly the same embeddings as the
teacher, however, it will be much faster.
The script implements to options two options to initialize the student:
Option 1: Train a light transformer model like TinyBERT to imitate the teacher
Option 2: We take the teacher model and keep only certain layers, for example, only 4 layers.
Option 2) works usually better, as we keep most of the weights from the teacher. In Option 1, we have to tune all
weights in the student from scratch.
There is a performance - speed trade-off. However, we found that a student with 4 instead of 12 layers keeps about 99.4%
of the teacher performance, while being 2.3 times faster.
"""
from torch.utils.data import DataLoader
from sentence_transformers import models, losses, evaluation
from sentence_transformers import LoggingHandler, SentenceTransformer, util, InputExample
from sentence_transformers.datasets import ParallelSentencesDataset
import logging
from datetime import datetime
import os
import gzip
import csv
import random
from sklearn.decomposition import PCA
import torch
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
#### /print debug information to stdout
# Teacher Model: Model we want to distill to a smaller model
teacher_model_name = 'stsb-roberta-base-v2'
teacher_model = SentenceTransformer(teacher_model_name)
output_path = "output/model-distillation-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
use_layer_reduction = True
#There are two options to create a light and fast student model:
if use_layer_reduction:
# 1) Create a smaller student model by using only some of the teacher layers
student_model = SentenceTransformer(teacher_model_name)
# Get the transformer model
auto_model = student_model._first_module().auto_model
# Which layers to keep from the teacher model. We equally spread the layers to keep over the original teacher
#layers_to_keep = [5]
#layers_to_keep = [3, 7]
#layers_to_keep = [3, 7, 11]
layers_to_keep = [1, 4, 7, 10] #Keep 4 layers from the teacher
#layers_to_keep = [0, 2, 4, 6, 8, 10]
#layers_to_keep = [0, 1, 3, 4, 6, 7, 9, 10]
logging.info("Remove layers from student. Only keep these layers: {}".format(layers_to_keep))
new_layers = torch.nn.ModuleList([layer_module for i, layer_module in enumerate(auto_model.encoder.layer) if i in layers_to_keep])
auto_model.encoder.layer = new_layers
auto_model.config.num_hidden_layers = len(layers_to_keep)
else:
# 2) The other option is to train a small model like TinyBERT to imitate the teacher.
# You can find some small BERT models here: https://huggingface.co/nreimers
word_embedding_model = models.Transformer('nreimers/TinyBERT_L-4_H-312_v2')
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
student_model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
inference_batch_size = 64
train_batch_size = 64
#We use AllNLI as a source of sentences for the distillation
nli_dataset_path = 'datasets/AllNLI.tsv.gz'
#Further, we use sentences extracted from the English Wikipedia to train the distillation
wikipedia_dataset_path = 'datasets/wikipedia-en-sentences.txt.gz'
#We use the STS benchmark dataset to see how much performance we loose
sts_dataset_path = 'datasets/stsbenchmark.tsv.gz'
#Download datasets if needed
if not os.path.exists(nli_dataset_path):
util.http_get('https://sbert.net/datasets/AllNLI.tsv.gz', nli_dataset_path)
if not os.path.exists(wikipedia_dataset_path):
util.http_get('https://sbert.net/datasets/wikipedia-en-sentences.txt.gz', wikipedia_dataset_path)
if not os.path.exists(sts_dataset_path):
util.http_get('https://sbert.net/datasets/stsbenchmark.tsv.gz', sts_dataset_path)
#We need sentences to train our distillation. Here, we use sentences from AllNLI and from WikiPedia
train_sentences_nli = set()
dev_sentences_nli = set()
train_sentences_wikipedia = []
dev_sentences_wikipedia = []
# Read ALLNLI
with gzip.open(nli_dataset_path, 'rt', encoding='utf8') as fIn:
reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE)
for row in reader:
if row['split'] == 'dev':
dev_sentences_nli.add(row['sentence1'])
dev_sentences_nli.add(row['sentence2'])
else:
train_sentences_nli.add(row['sentence1'])
train_sentences_nli.add(row['sentence2'])
train_sentences_nli = list(train_sentences_nli)
random.shuffle(train_sentences_nli)
dev_sentences_nli = list(dev_sentences_nli)
random.shuffle(dev_sentences_nli)
dev_sentences_nli = dev_sentences_nli[0:5000] #Limit dev sentences to 5k
# Read Wikipedia sentences file
with gzip.open(wikipedia_dataset_path, 'rt', encoding='utf8') as fIn:
wikipeda_sentences = [line.strip() for line in fIn]
dev_sentences_wikipedia = wikipeda_sentences[0:5000] #Use the first 5k sentences from the wikipedia file for development
train_sentences_wikipedia = wikipeda_sentences[5000:]
# We use the STS benchmark dataset to measure the performance of student model im comparison to the teacher model
logging.info("Read STSbenchmark dev dataset")
dev_samples = []
with gzip.open(sts_dataset_path, 'rt', encoding='utf8') as fIn:
reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE)
for row in reader:
if row['split'] == 'dev':
score = float(row['score']) / 5.0 #Normalize score to range 0 ... 1
dev_samples.append(InputExample(texts=[row['sentence1'], row['sentence2']], label=score))
dev_evaluator_sts = evaluation.EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='sts-dev')
logging.info("Teacher Performance:")
dev_evaluator_sts(teacher_model)
# Student model has fewer dimensions. Compute PCA for the teacher to reduce the dimensions
if student_model.get_sentence_embedding_dimension() < teacher_model.get_sentence_embedding_dimension():
logging.info("Student model has fewer dimensions than the teacher. Compute PCA for down projection")
pca_sentences = train_sentences_nli[0:20000] + train_sentences_wikipedia[0:20000]
pca_embeddings = teacher_model.encode(pca_sentences, convert_to_numpy=True)
pca = PCA(n_components=student_model.get_sentence_embedding_dimension())
pca.fit(pca_embeddings)
#Add Dense layer to teacher that projects the embeddings down to the student embedding size
dense = models.Dense(in_features=teacher_model.get_sentence_embedding_dimension(), out_features=student_model.get_sentence_embedding_dimension(), bias=False, activation_function=torch.nn.Identity())
dense.linear.weight = torch.nn.Parameter(torch.tensor(pca.components_))
teacher_model.add_module('dense', dense)
logging.info("Teacher Performance with {} dimensions:".format(teacher_model.get_sentence_embedding_dimension()))
dev_evaluator_sts(teacher_model)
# We train the student_model such that it creates sentence embeddings similar to the embeddings from the teacher_model
# For this, we need a large set of sentences. These sentences are embedded using the teacher model,
# and the student tries to mimic these embeddings. It is the same approach as used in: https://arxiv.org/abs/2004.09813
train_data = ParallelSentencesDataset(student_model=student_model, teacher_model=teacher_model, batch_size=inference_batch_size, use_embedding_cache=False)
train_data.add_dataset([[sent] for sent in train_sentences_nli], max_sentence_length=256)
train_data.add_dataset([[sent] for sent in train_sentences_wikipedia], max_sentence_length=256)
train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size)
train_loss = losses.MSELoss(model=student_model)
# We create an evaluator, that measure the Mean Squared Error (MSE) between the teacher and the student embeddings
dev_sentences = dev_sentences_nli + dev_sentences_wikipedia
dev_evaluator_mse = evaluation.MSEEvaluator(dev_sentences, dev_sentences, teacher_model=teacher_model)
# Train the student model to imitate the teacher
student_model.fit(train_objectives=[(train_dataloader, train_loss)],
evaluator=evaluation.SequentialEvaluator([dev_evaluator_sts, dev_evaluator_mse]),
epochs=1,
warmup_steps=1000,
evaluation_steps=5000,
output_path=output_path,
save_best_model=True,
optimizer_params={'lr': 1e-4, 'eps': 1e-6, 'correct_bias': False},
use_amp=True)
|