File size: 10,540 Bytes
2359bda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
"""
This script contains an example how to extend an existent sentence embedding model to new languages.
Given a (monolingual) teacher model you would like to extend to new languages, which is specified in the teacher_model_name
variable. We train a multilingual student model to imitate the teacher model (variable student_model_name)
on multiple languages.
For training, you need parallel sentence data (machine translation training data). You need tab-seperated files (.tsv)
with the first column a sentence in a language understood by the teacher model, e.g. English,
and the further columns contain the according translations for languages you want to extend to.
This scripts downloads automatically the TED2020 corpus: https://github.com/UKPLab/sentence-transformers/blob/master/docs/datasets/TED2020.md
This corpus contains transcripts from
TED and TEDx talks, translated to 100+ languages. For other parallel data, see get_parallel_data_[].py scripts
Further information can be found in our paper:
Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation
https://arxiv.org/abs/2004.09813
"""
from sentence_transformers import SentenceTransformer, LoggingHandler, models, evaluation, losses
from torch.utils.data import DataLoader
from sentence_transformers.datasets import ParallelSentencesDataset
from datetime import datetime
import os
import logging
import sentence_transformers.util
import csv
import gzip
from tqdm.autonotebook import tqdm
import numpy as np
import zipfile
import io
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)
teacher_model_name = 'paraphrase-distilroberta-base-v2' #Our monolingual teacher model, we want to convert to multiple languages
student_model_name = 'xlm-roberta-base' #Multilingual base model we use to imitate the teacher model
max_seq_length = 128 #Student model max. lengths for inputs (number of word pieces)
train_batch_size = 64 #Batch size for training
inference_batch_size = 64 #Batch size at inference
max_sentences_per_language = 500000 #Maximum number of parallel sentences for training
train_max_sentence_length = 250 #Maximum length (characters) for parallel training sentences
num_epochs = 5 #Train for x epochs
num_warmup_steps = 10000 #Warumup steps
num_evaluation_steps = 1000 #Evaluate performance after every xxxx steps
dev_sentences = 1000 #Number of parallel sentences to be used for development
# Define the language codes you would like to extend the model to
source_languages = set(['en']) # Our teacher model accepts English (en) sentences
target_languages = set(['de', 'es', 'it', 'fr', 'ar', 'tr']) # We want to extend the model to these new languages. For language codes, see the header of the train file
output_path = "output/make-multilingual-"+"-".join(sorted(list(source_languages))+sorted(list(target_languages)))+"-"+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# This function downloads a corpus if it does not exist
def download_corpora(filepaths):
if not isinstance(filepaths, list):
filepaths = [filepaths]
for filepath in filepaths:
if not os.path.exists(filepath):
print(filepath, "does not exists. Try to download from server")
filename = os.path.basename(filepath)
url = "https://sbert.net/datasets/" + filename
sentence_transformers.util.http_get(url, filepath)
# Here we define train train and dev corpora
train_corpus = "datasets/ted2020.tsv.gz" # Transcripts of TED talks, crawled 2020
sts_corpus = "datasets/STS2017-extended.zip" # Extended STS2017 dataset for more languages
parallel_sentences_folder = "parallel-sentences/"
# Check if the file exists. If not, they are downloaded
download_corpora([train_corpus, sts_corpus])
# Create parallel files for the selected language combinations
os.makedirs(parallel_sentences_folder, exist_ok=True)
train_files = []
dev_files = []
files_to_create = []
for source_lang in source_languages:
for target_lang in target_languages:
output_filename_train = os.path.join(parallel_sentences_folder, "TED2020-{}-{}-train.tsv.gz".format(source_lang, target_lang))
output_filename_dev = os.path.join(parallel_sentences_folder, "TED2020-{}-{}-dev.tsv.gz".format(source_lang, target_lang))
train_files.append(output_filename_train)
dev_files.append(output_filename_dev)
if not os.path.exists(output_filename_train) or not os.path.exists(output_filename_dev):
files_to_create.append({'src_lang': source_lang, 'trg_lang': target_lang,
'fTrain': gzip.open(output_filename_train, 'wt', encoding='utf8'),
'fDev': gzip.open(output_filename_dev, 'wt', encoding='utf8'),
'devCount': 0
})
if len(files_to_create) > 0:
print("Parallel sentences files {} do not exist. Create these files now".format(", ".join(map(lambda x: x['src_lang']+"-"+x['trg_lang'], files_to_create))))
with gzip.open(train_corpus, 'rt', encoding='utf8') as fIn:
reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE)
for line in tqdm(reader, desc="Sentences"):
for outfile in files_to_create:
src_text = line[outfile['src_lang']].strip()
trg_text = line[outfile['trg_lang']].strip()
if src_text != "" and trg_text != "":
if outfile['devCount'] < dev_sentences:
outfile['devCount'] += 1
fOut = outfile['fDev']
else:
fOut = outfile['fTrain']
fOut.write("{}\t{}\n".format(src_text, trg_text))
for outfile in files_to_create:
outfile['fTrain'].close()
outfile['fDev'].close()
######## Start the extension of the teacher model to multiple languages ########
logger.info("Load teacher model")
teacher_model = SentenceTransformer(teacher_model_name)
logger.info("Create student model from scratch")
word_embedding_model = models.Transformer(student_model_name, max_seq_length=max_seq_length)
# Apply mean pooling to get one fixed sized sentence vector
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
student_model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
###### Read Parallel Sentences Dataset ######
train_data = ParallelSentencesDataset(student_model=student_model, teacher_model=teacher_model, batch_size=inference_batch_size, use_embedding_cache=True)
for train_file in train_files:
train_data.load_data(train_file, max_sentences=max_sentences_per_language, max_sentence_length=train_max_sentence_length)
train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size)
train_loss = losses.MSELoss(model=student_model)
#### Evaluate cross-lingual performance on different tasks #####
evaluators = [] #evaluators has a list of different evaluator classes we call periodically
for dev_file in dev_files:
logger.info("Create evaluator for " + dev_file)
src_sentences = []
trg_sentences = []
with gzip.open(dev_file, 'rt', encoding='utf8') as fIn:
for line in fIn:
splits = line.strip().split('\t')
if splits[0] != "" and splits[1] != "":
src_sentences.append(splits[0])
trg_sentences.append(splits[1])
#Mean Squared Error (MSE) measures the (euclidean) distance between teacher and student embeddings
dev_mse = evaluation.MSEEvaluator(src_sentences, trg_sentences, name=os.path.basename(dev_file), teacher_model=teacher_model, batch_size=inference_batch_size)
evaluators.append(dev_mse)
# TranslationEvaluator computes the embeddings for all parallel sentences. It then check if the embedding of source[i] is the closest to target[i] out of all available target sentences
dev_trans_acc = evaluation.TranslationEvaluator(src_sentences, trg_sentences, name=os.path.basename(dev_file),batch_size=inference_batch_size)
evaluators.append(dev_trans_acc)
##### Read cross-lingual Semantic Textual Similarity (STS) data ####
all_languages = list(set(list(source_languages)+list(target_languages)))
sts_data = {}
#Open the ZIP File of STS2017-extended.zip and check for which language combinations we have STS data
with zipfile.ZipFile(sts_corpus) as zip:
filelist = zip.namelist()
sts_files = []
for i in range(len(all_languages)):
for j in range(i, len(all_languages)):
lang1 = all_languages[i]
lang2 = all_languages[j]
filepath = 'STS2017-extended/STS.{}-{}.txt'.format(lang1, lang2)
if filepath not in filelist:
lang1, lang2 = lang2, lang1
filepath = 'STS2017-extended/STS.{}-{}.txt'.format(lang1, lang2)
if filepath in filelist:
filename = os.path.basename(filepath)
sts_data[filename] = {'sentences1': [], 'sentences2': [], 'scores': []}
fIn = zip.open(filepath)
for line in io.TextIOWrapper(fIn, 'utf8'):
sent1, sent2, score = line.strip().split("\t")
score = float(score)
sts_data[filename]['sentences1'].append(sent1)
sts_data[filename]['sentences2'].append(sent2)
sts_data[filename]['scores'].append(score)
for filename, data in sts_data.items():
test_evaluator = evaluation.EmbeddingSimilarityEvaluator(data['sentences1'], data['sentences2'], data['scores'], batch_size=inference_batch_size, name=filename, show_progress_bar=False)
evaluators.append(test_evaluator)
# Train the model
student_model.fit(train_objectives=[(train_dataloader, train_loss)],
evaluator=evaluation.SequentialEvaluator(evaluators, main_score_function=lambda scores: np.mean(scores)),
epochs=num_epochs,
warmup_steps=num_warmup_steps,
evaluation_steps=num_evaluation_steps,
output_path=output_path,
save_best_model=True,
optimizer_params= {'lr': 2e-5, 'eps': 1e-6, 'correct_bias': False}
)
|