|
import regex as re
|
|
import os
|
|
import sys
|
|
from collections import defaultdict
|
|
from tqdm import tqdm
|
|
|
|
def remove_overlaps(in_data_dir: str, out_data_dir: str, benchmark_dir: str):
|
|
"""
|
|
Removes overlapping sentences between train dataset and dev/test dataset from the
|
|
input directory and writes de-duplicated train data to the specified output directory.
|
|
|
|
Args:
|
|
in_data_dir (str): path of the directory containing train data for each language pair.
|
|
out_data_dir (str): path of the directory where the de-duplicated train data will be written for each language pair.
|
|
benchmark_dir (str): path of the directory containing the language-wise monolingual side of dev/test set.
|
|
"""
|
|
|
|
devtest_normalized = defaultdict(set)
|
|
for lang in os.listdir(benchmark_dir):
|
|
fname = os.path.join(benchmark_dir, lang)
|
|
|
|
with open(fname, "r") as f:
|
|
sents = [sent for sent in f.read().split("\n") if sent.strip()]
|
|
sents = [re.sub(" +", " ", sent).replace("\n", "").strip() for sent in sents]
|
|
sents = [re.sub(" +", " ", re.sub(r"[^\w\s]", "", x)).lower() for x in sents]
|
|
devtest_normalized[lang] = set(sents)
|
|
|
|
|
|
pairs = sorted(os.listdir(in_data_dir))
|
|
for pair in pairs:
|
|
print(pair)
|
|
src_lang, tgt_lang = pair.split("-")
|
|
|
|
src_infname = os.path.join(in_data_dir, pair, f"train.{src_lang}")
|
|
tgt_infname = os.path.join(in_data_dir, pair, f"train.{tgt_lang}")
|
|
|
|
src_outfname = os.path.join(out_data_dir, pair, f"train.{src_lang}")
|
|
tgt_outfname = os.path.join(out_data_dir, pair, f"train.{tgt_lang}")
|
|
|
|
os.makedirs(os.path.join(out_data_dir, pair), exist_ok=True)
|
|
|
|
|
|
with open(src_infname, 'r', encoding='utf-8') as src_infile, \
|
|
open(tgt_infname, 'r', encoding='utf-8') as tgt_infile, \
|
|
open(src_outfname, 'w', encoding='utf-8') as src_outfile, \
|
|
open(tgt_outfname, 'w', encoding='utf-8') as tgt_outfile:
|
|
|
|
for src_line, tgt_line in tqdm(zip(src_infile, tgt_infile)):
|
|
src_line = re.sub(" +", " ", src_line).replace("\n", "").strip()
|
|
tgt_line = re.sub(" +", " ", tgt_line).replace("\n", "").strip()
|
|
|
|
src_line_normalized = re.sub(" +", " ", re.sub(r"[^\w\s]", "", src_line)).lower()
|
|
tgt_line_normalized = re.sub(" +", " ", re.sub(r"[^\w\s]", "", tgt_line)).lower()
|
|
if src_line_normalized in devtest_normalized[src_lang] or tgt_line_normalized in devtest_normalized[tgt_lang]:
|
|
continue
|
|
|
|
src_outfile.write(src_line + "\n")
|
|
tgt_outfile.write(tgt_line + "\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
in_data_dir = sys.argv[1]
|
|
out_data_dir = sys.argv[2]
|
|
benchmark_dir = sys.argv[3]
|
|
|
|
os.makedirs(out_data_dir, exist_ok=True)
|
|
|
|
remove_overlaps(in_data_dir, out_data_dir, benchmark_dir)
|
|
|