|
import sys
|
|
from tqdm import tqdm
|
|
from typing import List, Tuple
|
|
|
|
|
|
def remove_large_sentences(src_path: str, tgt_path: str) -> Tuple[int, List[str], List[str]]:
|
|
"""
|
|
Removes large sentences from a parallel dataset of source and target data.
|
|
|
|
Args:
|
|
src_path (str): path to the file containing the source language data.
|
|
tgt_path (str): path to the file containing the target language data.
|
|
|
|
Returns:
|
|
Tuple[int, List[str], List[str]]: a tuple of
|
|
- an integer representing the number of sentences removed
|
|
- a list of strings containing the source language data after removing large sentences
|
|
- a list of strings containing the target language data after removing large sentences
|
|
"""
|
|
count = 0
|
|
new_src_lines, new_tgt_lines = [], []
|
|
|
|
src_num_lines = sum(1 for line in open(src_path, "r", encoding="utf-8"))
|
|
tgt_num_lines = sum(1 for line in open(tgt_path, "r", encoding="utf-8"))
|
|
assert src_num_lines == tgt_num_lines
|
|
|
|
with open(src_path, encoding="utf-8") as f1, open(tgt_path, encoding="utf-8") as f2:
|
|
for src_line, tgt_line in tqdm(zip(f1, f2), total=src_num_lines):
|
|
src_tokens = src_line.strip().split(" ")
|
|
tgt_tokens = tgt_line.strip().split(" ")
|
|
|
|
if len(src_tokens) > 200 or len(tgt_tokens) > 200:
|
|
count += 1
|
|
continue
|
|
|
|
new_src_lines.append(src_line)
|
|
new_tgt_lines.append(tgt_line)
|
|
|
|
return count, new_src_lines, new_tgt_lines
|
|
|
|
|
|
def create_txt(out_file: str, lines: List[str]):
|
|
"""
|
|
Creates a text file and writes the given list of lines to file.
|
|
|
|
Args:
|
|
out_file (str): path to the output file to be created.
|
|
lines (List[str]): a list of strings to be written to the output file.
|
|
"""
|
|
add_newline = not "\n" in lines[0]
|
|
outfile = open("{}".format(out_file), "w", encoding="utf-8")
|
|
for line in lines:
|
|
if add_newline:
|
|
outfile.write(line + "\n")
|
|
else:
|
|
outfile.write(line)
|
|
outfile.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
src_path = sys.argv[1]
|
|
tgt_path = sys.argv[2]
|
|
new_src_path = sys.argv[3]
|
|
new_tgt_path = sys.argv[4]
|
|
|
|
count, new_src_lines, new_tgt_lines = remove_large_sentences(src_path, tgt_path)
|
|
print(f"{count} lines removed due to seq_len > 200")
|
|
create_txt(new_src_path, new_src_lines)
|
|
create_txt(new_tgt_path, new_tgt_lines)
|
|
|