lengocduc195's picture
pushNe
2359bda
"""
This script translates the queries in the MS MARCO dataset to the defined target languages.
For machine translation, we use EasyNMT: https://github.com/UKPLab/EasyNMT
You can install it via: pip install easynmt
Usage:
python translate_queries [target_language]
"""
import os
from sentence_transformers import LoggingHandler, util
import logging
import tarfile
from easynmt import EasyNMT
import sys
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
#### /print debug information to stdout
target_lang = sys.argv[1]
output_folder = 'multilingual-data'
data_folder = '../msmarco-data'
output_filename = os.path.join(output_folder, 'train_queries.en-{}.tsv'.format(target_lang))
os.makedirs(output_folder, exist_ok=True)
## Does the output file exists? If yes, read it so we can continue the translation
translated_qids = set()
if os.path.exists(output_filename):
with open(output_filename, 'r', encoding='utf8') as fIn:
for line in fIn:
splits = line.strip().split("\t")
translated_qids.add(splits[0])
### Now we read the MS Marco dataset
os.makedirs(data_folder, exist_ok=True)
# Read qrels file for relevant positives per query
train_queries = {}
qrels_train = os.path.join(data_folder, 'qrels.train.tsv')
if not os.path.exists(qrels_train):
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/qrels.train.tsv', qrels_train)
with open(qrels_train) as fIn:
for line in fIn:
qid, _, pid, _ = line.strip().split()
if qid not in translated_qids:
train_queries[qid] = None
# Read all queries
queries_filepath = os.path.join(data_folder, 'queries.train.tsv')
if not os.path.exists(queries_filepath):
tar_filepath = os.path.join(data_folder, 'queries.tar.gz')
if not os.path.exists(tar_filepath):
logging.info("Download queries.tar.gz")
util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz', tar_filepath)
with tarfile.open(tar_filepath, "r:gz") as tar:
tar.extractall(path=data_folder)
with open(queries_filepath, 'r', encoding='utf8') as fIn:
for line in fIn:
qid, query = line.strip().split("\t")
if qid in train_queries:
train_queries[qid] = query.strip()
qids = [qid for qid in train_queries if train_queries[qid] is not None]
queries = [train_queries[qid] for qid in qids]
#Define our translation model
translation_model = EasyNMT('opus-mt')
print("Start translation of {} queries.".format(len(queries)))
print("This can take a while. But you can stop this script at any point")
with open(output_filename, 'a' if os.path.exists(output_filename) else 'w', encoding='utf8') as fOut:
for qid, query, translated_query in zip(qids, queries, translation_model.translate_stream(queries, source_lang='en', target_lang=target_lang, beam_size=2, perform_sentence_splitting=False, chunk_size=256, batch_size=64)):
fOut.write("{}\t{}\t{}\n".format(qid, translated_query.replace("\t", " ")))
fOut.flush()