File size: 3,222 Bytes
2359bda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
This script translates the queries in the MS MARCO dataset to the defined target languages.

For machine translation, we use EasyNMT: https://github.com/UKPLab/EasyNMT
You can install it via: pip install easynmt

Usage:
python translate_queries [target_language]
"""
import os
from sentence_transformers import LoggingHandler, util
import logging
import tarfile
from easynmt import EasyNMT
import sys

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
#### /print debug information to stdout

target_lang = sys.argv[1]
output_folder = 'multilingual-data'
data_folder = '../msmarco-data'

output_filename = os.path.join(output_folder, 'train_queries.en-{}.tsv'.format(target_lang))
os.makedirs(output_folder, exist_ok=True)


## Does the output file exists? If yes, read it so we can continue the translation
translated_qids = set()
if os.path.exists(output_filename):
    with open(output_filename, 'r', encoding='utf8') as fIn:
        for line in fIn:
            splits = line.strip().split("\t")
            translated_qids.add(splits[0])

### Now we read the MS Marco dataset
os.makedirs(data_folder, exist_ok=True)

# Read qrels file for relevant positives per query
train_queries = {}
qrels_train = os.path.join(data_folder, 'qrels.train.tsv')
if not os.path.exists(qrels_train):
    util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/qrels.train.tsv', qrels_train)

with open(qrels_train) as fIn:
    for line in fIn:
        qid, _, pid, _ = line.strip().split()
        if qid not in translated_qids:
            train_queries[qid] = None

# Read all queries
queries_filepath = os.path.join(data_folder, 'queries.train.tsv')
if not os.path.exists(queries_filepath):
    tar_filepath = os.path.join(data_folder, 'queries.tar.gz')
    if not os.path.exists(tar_filepath):
        logging.info("Download queries.tar.gz")
        util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz', tar_filepath)

    with tarfile.open(tar_filepath, "r:gz") as tar:
        tar.extractall(path=data_folder)


with open(queries_filepath, 'r', encoding='utf8') as fIn:
    for line in fIn:
        qid, query = line.strip().split("\t")
        if qid in train_queries:
            train_queries[qid] = query.strip()


qids = [qid for qid in train_queries if train_queries[qid] is not None]
queries = [train_queries[qid] for qid in qids]

#Define our translation model
translation_model = EasyNMT('opus-mt')

print("Start translation of {} queries.".format(len(queries)))
print("This can take a while. But you can stop this script at any point")


with open(output_filename, 'a' if os.path.exists(output_filename) else 'w', encoding='utf8') as fOut:
    for qid, query, translated_query in zip(qids, queries, translation_model.translate_stream(queries, source_lang='en', target_lang=target_lang, beam_size=2, perform_sentence_splitting=False, chunk_size=256, batch_size=64)):
        fOut.write("{}\t{}\t{}\n".format(qid, translated_query.replace("\t", " ")))
        fOut.flush()