File size: 2,172 Bytes
2359bda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""
Given a tab seperated file (.tsv) with parallel sentences, where the second column is the translation of the sentence in the first column, for example, in the format:
src1    trg1
src2    trg2
...

where trg_i is the translation of src_i.

Given src_i, the TranslationEvaluator checks which trg_j has the highest similarity using cosine similarity. If i == j, we assume
a match, i.e., the correct translation has been found for src_i out of all possible target sentences.

It then computes an accuracy over all possible source sentences src_i. Equivalently, it computes also the accuracy for the other direction.

A high accuracy score indicates that the model is able to find the correct translation out of a large pool with sentences.

Usage:
python [model_name_or_path] [parallel-file1] [parallel-file2] ...

For example:
python distiluse-base-multilingual-cased  TED2020-en-de.tsv.gz

See the training_multilingual/get_parallel_data_...py scripts for getting parallel sentence data from different sources
"""

from sentence_transformers import SentenceTransformer, evaluation, LoggingHandler
import sys
import gzip
import os
import logging


logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])

logger = logging.getLogger(__name__)

model_name = sys.argv[1]
filepaths = sys.argv[2:]
inference_batch_size = 32

model = SentenceTransformer(model_name)


for filepath in filepaths:
    src_sentences = []
    trg_sentences = []
    with gzip.open(filepath, 'rt', encoding='utf8') if filepath.endswith('.gz') else open(filepath, 'r', encoding='utf8') as fIn:
        for line in fIn:
            splits = line.strip().split('\t')
            if len(splits) >= 2:
                src_sentences.append(splits[0])
                trg_sentences.append(splits[1])

    logger.info(os.path.basename(filepath)+": "+str(len(src_sentences))+" sentence pairs")
    dev_trans_acc = evaluation.TranslationEvaluator(src_sentences, trg_sentences, name=os.path.basename(filepath), batch_size=inference_batch_size)
    dev_trans_acc(model)