|
""" |
|
This script translates the queries in the MS MARCO dataset to the defined target languages. |
|
|
|
For machine translation, we use EasyNMT: https://github.com/UKPLab/EasyNMT |
|
You can install it via: pip install easynmt |
|
|
|
Usage: |
|
python translate_queries [target_language] |
|
""" |
|
import os |
|
from sentence_transformers import LoggingHandler, util |
|
import logging |
|
import tarfile |
|
from easynmt import EasyNMT |
|
import sys |
|
|
|
|
|
logging.basicConfig(format='%(asctime)s - %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S', |
|
level=logging.INFO, |
|
handlers=[LoggingHandler()]) |
|
|
|
|
|
target_lang = sys.argv[1] |
|
output_folder = 'multilingual-data' |
|
data_folder = '../msmarco-data' |
|
|
|
output_filename = os.path.join(output_folder, 'train_queries.en-{}.tsv'.format(target_lang)) |
|
os.makedirs(output_folder, exist_ok=True) |
|
|
|
|
|
|
|
translated_qids = set() |
|
if os.path.exists(output_filename): |
|
with open(output_filename, 'r', encoding='utf8') as fIn: |
|
for line in fIn: |
|
splits = line.strip().split("\t") |
|
translated_qids.add(splits[0]) |
|
|
|
|
|
os.makedirs(data_folder, exist_ok=True) |
|
|
|
|
|
train_queries = {} |
|
qrels_train = os.path.join(data_folder, 'qrels.train.tsv') |
|
if not os.path.exists(qrels_train): |
|
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/qrels.train.tsv', qrels_train) |
|
|
|
with open(qrels_train) as fIn: |
|
for line in fIn: |
|
qid, _, pid, _ = line.strip().split() |
|
if qid not in translated_qids: |
|
train_queries[qid] = None |
|
|
|
|
|
queries_filepath = os.path.join(data_folder, 'queries.train.tsv') |
|
if not os.path.exists(queries_filepath): |
|
tar_filepath = os.path.join(data_folder, 'queries.tar.gz') |
|
if not os.path.exists(tar_filepath): |
|
logging.info("Download queries.tar.gz") |
|
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz', tar_filepath) |
|
|
|
with tarfile.open(tar_filepath, "r:gz") as tar: |
|
tar.extractall(path=data_folder) |
|
|
|
|
|
with open(queries_filepath, 'r', encoding='utf8') as fIn: |
|
for line in fIn: |
|
qid, query = line.strip().split("\t") |
|
if qid in train_queries: |
|
train_queries[qid] = query.strip() |
|
|
|
|
|
qids = [qid for qid in train_queries if train_queries[qid] is not None] |
|
queries = [train_queries[qid] for qid in qids] |
|
|
|
|
|
translation_model = EasyNMT('opus-mt') |
|
|
|
print("Start translation of {} queries.".format(len(queries))) |
|
print("This can take a while. But you can stop this script at any point") |
|
|
|
|
|
with open(output_filename, 'a' if os.path.exists(output_filename) else 'w', encoding='utf8') as fOut: |
|
for qid, query, translated_query in zip(qids, queries, translation_model.translate_stream(queries, source_lang='en', target_lang=target_lang, beam_size=2, perform_sentence_splitting=False, chunk_size=256, batch_size=64)): |
|
fOut.write("{}\t{}\t{}\n".format(qid, translated_query.replace("\t", " "))) |
|
fOut.flush() |
|
|