|
|
|
|
|
|
|
|
|
|
|
|
|
import itertools |
|
import os |
|
import csv |
|
from collections import defaultdict |
|
from six.moves import zip |
|
import io |
|
import wget |
|
import sys |
|
|
|
from subprocess import check_call, check_output |
|
|
|
|
|
CWD = os.getcwd() |
|
UTILS = f"{CWD}/utils" |
|
|
|
MOSES = f"{UTILS}/mosesdecoder" |
|
|
|
WORKDIR_ROOT = os.environ.get('WORKDIR_ROOT', None) |
|
|
|
if WORKDIR_ROOT is None or not WORKDIR_ROOT.strip(): |
|
print('please specify your working directory root in OS environment variable WORKDIR_ROOT. Exitting..."') |
|
sys.exit(-1) |
|
|
|
|
|
|
|
detok_cmd = f'{MOSES}/scripts/tokenizer/detokenizer.perl' |
|
|
|
|
|
def call(cmd): |
|
print(f"Executing: {cmd}") |
|
check_call(cmd, shell=True) |
|
|
|
class MultiLingualAlignedCorpusReader(object): |
|
"""A class to read TED talk dataset |
|
""" |
|
|
|
def __init__(self, corpus_path, delimiter='\t', |
|
target_token=True, bilingual=True, corpus_type='file', |
|
lang_dict={'source': ['fr'], 'target': ['en']}, |
|
eval_lang_dict=None, zero_shot=False, |
|
detok=True, |
|
): |
|
|
|
self.empty_line_flag = 'NULL' |
|
self.corpus_path = corpus_path |
|
self.delimiter = delimiter |
|
self.bilingual = bilingual |
|
self.lang_dict = lang_dict |
|
self.lang_set = set() |
|
self.target_token = target_token |
|
self.zero_shot = zero_shot |
|
self.eval_lang_dict = eval_lang_dict |
|
self.corpus_type = corpus_type |
|
self.detok = detok |
|
|
|
for list_ in self.lang_dict.values(): |
|
for lang in list_: |
|
self.lang_set.add(lang) |
|
|
|
self.data = dict() |
|
self.data['train'] = self.read_aligned_corpus(split_type='train') |
|
self.data['test'] = self.read_aligned_corpus(split_type='test') |
|
self.data['dev'] = self.read_aligned_corpus(split_type='dev') |
|
|
|
def read_data(self, file_loc_): |
|
data_list = list() |
|
with io.open(file_loc_, 'r', encoding='utf8') as fp: |
|
for line in fp: |
|
try: |
|
text = line.strip() |
|
except IndexError: |
|
text = self.empty_line_flag |
|
data_list.append(text) |
|
return data_list |
|
|
|
def filter_text(self, dict_): |
|
if self.target_token: |
|
field_index = 1 |
|
else: |
|
field_index = 0 |
|
data_dict = defaultdict(list) |
|
list1 = dict_['source'] |
|
list2 = dict_['target'] |
|
for sent1, sent2 in zip(list1, list2): |
|
try: |
|
src_sent = ' '.join(sent1.split()[field_index: ]) |
|
except IndexError: |
|
src_sent = 'NULL' |
|
|
|
if src_sent.find(self.empty_line_flag) != -1 or len(src_sent) == 0: |
|
continue |
|
|
|
elif sent2.find(self.empty_line_flag) != -1 or len(sent2) == 0: |
|
continue |
|
|
|
else: |
|
data_dict['source'].append(sent1) |
|
data_dict['target'].append(sent2) |
|
return data_dict |
|
|
|
def read_file(self, split_type, data_type): |
|
return self.data[split_type][data_type] |
|
|
|
def save_file(self, path_, split_type, data_type, lang): |
|
tok_file = tok_file_name(path_, lang) |
|
with io.open(tok_file, 'w', encoding='utf8') as fp: |
|
for line in self.data[split_type][data_type]: |
|
fp.write(line + '\n') |
|
if self.detok: |
|
de_tok(tok_file, lang) |
|
|
|
def add_target_token(self, list_, lang_id): |
|
new_list = list() |
|
token = '__' + lang_id + '__' |
|
for sent in list_: |
|
new_list.append(token + ' ' + sent) |
|
return new_list |
|
|
|
def read_from_single_file(self, path_, s_lang, t_lang): |
|
data_dict = defaultdict(list) |
|
with io.open(path_, 'r', encoding='utf8') as fp: |
|
reader = csv.DictReader(fp, delimiter='\t', quoting=csv.QUOTE_NONE) |
|
for row in reader: |
|
data_dict['source'].append(row[s_lang]) |
|
data_dict['target'].append(row[t_lang]) |
|
|
|
if self.target_token: |
|
text = self.add_target_token(data_dict['source'], t_lang) |
|
data_dict['source'] = text |
|
|
|
return data_dict['source'], data_dict['target'] |
|
|
|
def read_aligned_corpus(self, split_type='train'): |
|
data_dict = defaultdict(list) |
|
iterable = [] |
|
s_list = [] |
|
t_list = [] |
|
|
|
if self.zero_shot: |
|
if split_type == "train": |
|
iterable = zip(self.lang_dict['source'], self.lang_dict['target']) |
|
else: |
|
iterable = zip(self.eval_lang_dict['source'], self.eval_lang_dict['target']) |
|
|
|
elif self.bilingual: |
|
iterable = itertools.product(self.lang_dict['source'], self.lang_dict['target']) |
|
|
|
for s_lang, t_lang in iterable: |
|
if s_lang == t_lang: |
|
continue |
|
if self.corpus_type == 'file': |
|
split_type_file_path = os.path.join(self.corpus_path, |
|
"all_talks_{}.tsv".format(split_type)) |
|
s_list, t_list = self.read_from_single_file(split_type_file_path, |
|
s_lang=s_lang, |
|
t_lang=t_lang) |
|
data_dict['source'] += s_list |
|
data_dict['target'] += t_list |
|
new_data_dict = self.filter_text(data_dict) |
|
return new_data_dict |
|
|
|
|
|
def read_langs(corpus_path): |
|
split_type_file_path = os.path.join(corpus_path, 'extracted', |
|
"all_talks_dev.tsv") |
|
with io.open(split_type_file_path, 'r', encoding='utf8') as fp: |
|
reader = csv.DictReader(fp, delimiter='\t', quoting=csv.QUOTE_NONE) |
|
header = next(reader) |
|
return [k for k in header.keys() if k != 'talk_name'] |
|
|
|
def extra_english(corpus_path, split): |
|
split_type_file_path = os.path.join(corpus_path, |
|
f"all_talks_{split}.tsv") |
|
output_split_type_file_path = os.path.join(corpus_path, |
|
f"all_talks_{split}.en") |
|
with io.open(split_type_file_path, 'r', encoding='utf8') as fp, io.open(output_split_type_file_path, 'w', encoding='utf8') as fw: |
|
reader = csv.DictReader(fp, delimiter='\t', quoting=csv.QUOTE_NONE) |
|
for row in reader: |
|
line = row['en'] |
|
fw.write(line + '\n') |
|
de_tok(output_split_type_file_path, 'en') |
|
|
|
|
|
|
|
def tok_file_name(filename, lang): |
|
seps = filename.split('.') |
|
seps.insert(-1, 'tok') |
|
tok_file = '.'.join(seps) |
|
return tok_file |
|
|
|
def de_tok(tok_file, lang): |
|
|
|
|
|
|
|
de_tok_file = tok_file.replace('.tok.', '.') |
|
cmd = 'perl {detok_cmd} -l {lang} < {tok_file} > {de_tok_file}'.format( |
|
detok_cmd=detok_cmd, tok_file=tok_file, |
|
de_tok_file=de_tok_file, lang=lang[:2]) |
|
call(cmd) |
|
|
|
def extra_bitex( |
|
ted_data_path, |
|
lsrc_lang, |
|
ltrg_lang, |
|
target_token, |
|
output_data_path, |
|
): |
|
def get_ted_lang(lang): |
|
long_langs = ['pt-br', 'zh-cn', 'zh-tw', 'fr-ca'] |
|
if lang[:5] in long_langs: |
|
return lang[:5] |
|
elif lang[:4] =='calv': |
|
return lang[:5] |
|
elif lang in ['pt_BR', 'zh_CN', 'zh_TW', 'fr_CA']: |
|
return lang.lower().replace('_', '-') |
|
return lang[:2] |
|
src_lang = get_ted_lang(lsrc_lang) |
|
trg_lang = get_ted_lang(ltrg_lang) |
|
train_lang_dict={'source': [src_lang], 'target': [trg_lang]} |
|
eval_lang_dict = {'source': [src_lang], 'target': [trg_lang]} |
|
|
|
obj = MultiLingualAlignedCorpusReader(corpus_path=ted_data_path, |
|
lang_dict=train_lang_dict, |
|
target_token=target_token, |
|
corpus_type='file', |
|
eval_lang_dict=eval_lang_dict, |
|
zero_shot=False, |
|
bilingual=True) |
|
|
|
os.makedirs(output_data_path, exist_ok=True) |
|
lsrc_lang = lsrc_lang.replace('-', '_') |
|
ltrg_lang = ltrg_lang.replace('-', '_') |
|
obj.save_file(output_data_path + f"/train.{lsrc_lang}-{ltrg_lang}.{lsrc_lang}", |
|
split_type='train', data_type='source', lang=src_lang) |
|
obj.save_file(output_data_path + f"/train.{lsrc_lang}-{ltrg_lang}.{ltrg_lang}", |
|
split_type='train', data_type='target', lang=trg_lang) |
|
|
|
obj.save_file(output_data_path + f"/test.{lsrc_lang}-{ltrg_lang}.{lsrc_lang}", |
|
split_type='test', data_type='source', lang=src_lang) |
|
obj.save_file(output_data_path + f"/test.{lsrc_lang}-{ltrg_lang}.{ltrg_lang}", |
|
split_type='test', data_type='target', lang=trg_lang) |
|
|
|
obj.save_file(output_data_path + f"/valid.{lsrc_lang}-{ltrg_lang}.{lsrc_lang}", |
|
split_type='dev', data_type='source', lang=src_lang) |
|
obj.save_file(output_data_path + f"/valid.{lsrc_lang}-{ltrg_lang}.{ltrg_lang}", |
|
split_type='dev', data_type='target', lang=trg_lang) |
|
|
|
|
|
def bar_custom(current, total, width=80): |
|
print("Downloading: %d%% [%d / %d] Ks" % (current / total * 100, current / 1000, total / 1000), end='\r') |
|
|
|
|
|
def download_and_extract(download_to, extract_to): |
|
url = 'http://phontron.com/data/ted_talks.tar.gz' |
|
filename = f"{download_to}/ted_talks.tar.gz" |
|
if os.path.exists(filename): |
|
print(f'{filename} has already been downloaded so skip') |
|
else: |
|
filename = wget.download(url, filename, bar=bar_custom) |
|
if os.path.exists(f'{extract_to}/all_talks_train.tsv'): |
|
print(f'Already extracted so skip') |
|
else: |
|
extract_cmd = f'tar xzfv "{filename}" -C "{extract_to}"' |
|
call(extract_cmd) |
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--ted_data_path', type=str, default=WORKDIR_ROOT, required=False) |
|
parser.add_argument( |
|
'--direction-list', |
|
type=str, |
|
|
|
|
|
default=( |
|
"bn_IN-en_XX,he_IL-en_XX,fa_IR-en_XX,id_ID-en_XX,sv_SE-en_XX,pt_XX-en_XX,ka_GE-en_XX,ka_GE-en_XX,th_TH-en_XX," |
|
"mr_IN-en_XX,hr_HR-en_XX,uk_UA-en_XX,az_AZ-en_XX,mk_MK-en_XX,gl_ES-en_XX,sl_SI-en_XX,mn_MN-en_XX," |
|
|
|
|
|
|
|
), |
|
required=False) |
|
parser.add_argument('--target-token', action='store_true', default=False) |
|
parser.add_argument('--extract-all-english', action='store_true', default=False) |
|
|
|
args = parser.parse_args() |
|
|
|
import sys |
|
import json |
|
|
|
|
|
ted_data_path = args.ted_data_path |
|
|
|
download_to = f'{ted_data_path}/downloads' |
|
extract_to = f'{ted_data_path}/extracted' |
|
|
|
|
|
output_path = f'{ted_data_path}/ML50/raw' |
|
os.makedirs(download_to, exist_ok=True) |
|
os.makedirs(extract_to, exist_ok=True) |
|
os.makedirs(output_path, exist_ok=True) |
|
download_and_extract(download_to, extract_to) |
|
|
|
|
|
if args.extract_all_english: |
|
for split in ['train', 'dev', 'test']: |
|
extra_english(ted_data_path, split) |
|
exit(0) |
|
if args.direction_list is not None: |
|
directions = args.direction_list.strip().split(',') |
|
directions = [tuple(d.strip().split('-', 1)) for d in directions if d] |
|
else: |
|
langs = read_langs(ted_data_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
directions = [('en', tgt) for tgt in langs if tgt != 'en'] |
|
print(f'num directions={len(directions)}: {directions}') |
|
|
|
for src_lang, trg_lang in directions: |
|
print('--working on {}-{}'.format(src_lang, trg_lang)) |
|
extra_bitex( |
|
extract_to, |
|
src_lang, |
|
trg_lang, |
|
target_token=args.target_token, |
|
output_data_path=output_path |
|
) |
|
|