|
""" |
|
This script downloads the WikiMatrix corpus (https://github.com/facebookresearch/LASER/tree/master/tasks/WikiMatrix) |
|
and create parallel sentences tsv files that can be used to extend existent sentence embedding models to new languages. |
|
|
|
The WikiMatrix mined parallel sentences from Wikipedia in various languages. |
|
|
|
Further information can be found in our paper: |
|
Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation |
|
https://arxiv.org/abs/2004.09813 |
|
""" |
|
import os |
|
import sentence_transformers.util |
|
import gzip |
|
import csv |
|
from tqdm.autonotebook import tqdm |
|
|
|
|
|
|
|
source_languages = set(['en']) |
|
target_languages = set(['de', 'es', 'it', 'fr', 'ar', 'tr']) |
|
|
|
|
|
num_dev_sentences = 1000 |
|
threshold = 1.075 |
|
|
|
download_url = "https://dl.fbaipublicfiles.com/laser/WikiMatrix/v1/" |
|
download_folder = "../datasets/WikiMatrix/" |
|
parallel_sentences_folder = "parallel-sentences/" |
|
|
|
|
|
os.makedirs(os.path.dirname(download_folder), exist_ok=True) |
|
os.makedirs(parallel_sentences_folder, exist_ok=True) |
|
|
|
|
|
for source_lang in source_languages: |
|
for target_lang in target_languages: |
|
filename_train = os.path.join(parallel_sentences_folder, "WikiMatrix-{}-{}-train.tsv.gz".format(source_lang, target_lang)) |
|
filename_dev = os.path.join(parallel_sentences_folder, "WikiMatrix-{}-{}-dev.tsv.gz".format(source_lang, target_lang)) |
|
|
|
if not os.path.exists(filename_train) and not os.path.exists(filename_dev): |
|
langs_ordered = sorted([source_lang, target_lang]) |
|
wikimatrix_filename = "WikiMatrix.{}-{}.tsv.gz".format(*langs_ordered) |
|
wikimatrix_filepath = os.path.join(download_folder, wikimatrix_filename) |
|
|
|
if not os.path.exists(wikimatrix_filepath): |
|
print("Download", download_url+wikimatrix_filename) |
|
try: |
|
sentence_transformers.util.http_get(download_url+wikimatrix_filename, wikimatrix_filepath) |
|
except: |
|
print("Was not able to download", download_url+wikimatrix_filename) |
|
continue |
|
|
|
if not os.path.exists(wikimatrix_filepath): |
|
continue |
|
|
|
train_sentences = [] |
|
dev_sentences = [] |
|
dev_sentences_set = set() |
|
extract_dev_sentences = True |
|
|
|
with gzip.open(wikimatrix_filepath, 'rt', encoding='utf8') as fIn: |
|
for line in fIn: |
|
score, sent1, sent2 = line.strip().split('\t') |
|
sent1 = sent1.strip() |
|
sent2 = sent2.strip() |
|
score = float(score) |
|
|
|
if score < threshold: |
|
break |
|
|
|
if sent1 == sent2: |
|
continue |
|
|
|
if langs_ordered.index(source_lang) == 1: |
|
sent1, sent2 = sent2, sent1 |
|
|
|
|
|
if sent1 in dev_sentences_set or sent2 in dev_sentences_set: |
|
continue |
|
|
|
if extract_dev_sentences: |
|
dev_sentences.append([sent1, sent2]) |
|
dev_sentences_set.add(sent1) |
|
dev_sentences_set.add(sent2) |
|
|
|
if len(dev_sentences) >= num_dev_sentences: |
|
extract_dev_sentences = False |
|
else: |
|
train_sentences.append([sent1, sent2]) |
|
|
|
print("Write", len(dev_sentences), "dev sentences", filename_dev) |
|
with gzip.open(filename_dev, 'wt', encoding='utf8') as fOut: |
|
for sents in dev_sentences: |
|
fOut.write("\t".join(sents)) |
|
fOut.write("\n") |
|
|
|
print("Write", len(train_sentences), "train sentences", filename_train) |
|
with gzip.open(filename_train, 'wt', encoding='utf8') as fOut: |
|
for sents in train_sentences: |
|
fOut.write("\t".join(sents)) |
|
fOut.write("\n") |
|
|
|
|
|
print("---DONE---") |