|
import spacy |
|
import zstandard as zstd |
|
import json |
|
import typing |
|
import os |
|
from tqdm import tqdm |
|
import multiprocessing |
|
import random |
|
from langdetect import detect |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--input_dir', type=str, help='Path to the input file') |
|
args = parser.parse_args() |
|
input_dir = args.input_dir |
|
|
|
|
|
def is_english(text): |
|
try: |
|
lang = detect(text) |
|
return lang == 'en' |
|
except: |
|
return False |
|
|
|
def process_text(texts, model, out_f, lock): |
|
for text in texts: |
|
doc = model(text) |
|
freq_cnt = {} |
|
for e in doc.ents: |
|
if e not in freq_cnt: |
|
freq_cnt[e] = 0 |
|
freq_cnt[e] += 1 |
|
if len(freq_cnt) == 0: |
|
continue |
|
sorted_freq = sorted(freq_cnt.items(), key = lambda x:[1]) |
|
most_freq = sorted_freq[-1][0] |
|
data = {'text':text, 'main_entity':most_freq.text, 'label': most_freq.label_, 'id': most_freq.kb_id_} |
|
json_data = json.dumps(data) |
|
with lock: |
|
out_f.write(json_data + '\n') |
|
out_f.flush() |
|
|
|
def run_ner_linking(texts: typing.List[str], ner_model_path: str): |
|
nlp = spacy.load(ner_model_path) |
|
out_f = open('result/temp_store_data.json', 'w', encoding='utf-8') |
|
lock = multiprocessing.Lock() |
|
processes = [] |
|
|
|
for i in tqdm(range(0, len(texts), 1000)): |
|
p = multiprocessing.Process(target=process_text, args=(texts[i:i+1000], nlp, out_f, lock)) |
|
processes.append(p) |
|
p.start() |
|
|
|
for p in processes: |
|
p.join() |
|
|
|
out_f.close() |
|
return |
|
|
|
wikipedia_out_path='result/wikipedia.json' |
|
subdirectories = [f.path for f in os.scandir(input_dir) if f.is_dir()] |
|
wikipedia_data = [] |
|
for sub_dir in subdirectories: |
|
chunk_dir = sub_dir+'/' |
|
zst_files = [f for f in os.listdir(chunk_dir) if f.endswith('.zst')] |
|
for file in tqdm(zst_files): |
|
with open(chunk_dir+file, 'rb') as compressed_file: |
|
decompressor = zstd.ZstdDecompressor() |
|
with decompressor.stream_reader(compressed_file) as reader: |
|
decompressed_data = reader.read() |
|
for line in decompressed_data.splitlines(): |
|
data = json.loads(line) |
|
|
|
if data['meta']['redpajama_set_name']=='RedPajamaWikipedia': |
|
if is_english(data['text']): |
|
wikipedia_data.append(data) |
|
|
|
with open(wikipedia_out_path, 'w', encoding='utf-8') as f: |
|
for data in wikipedia_data: |
|
json_data = json.dumps(data) |
|
f.write(json_data+'\n') |
|
|
|
wikipedia_data = [] |
|
ner_model_path = 'kc-ner-model' |
|
with open(wikipedia_out_path, 'r', encoding='utf-8') as f: |
|
for line in tqdm(f): |
|
data = json.loads(line) |
|
wikipedia_data.append(data['text']) |
|
run_ner_linking(wikipedia_data, ner_model_path) |
|
|
|
entity_info_path = 'result/entity_info.json' |
|
with open(entity_info_path, 'r', encoding='utf-8') as f: |
|
entity_info = json.load(f) |
|
all_original_data = [] |
|
|
|
category = {} |
|
all_data = [] |
|
with open('result/temp_store_data.json', 'r', encoding='utf-8') as f: |
|
for line in f: |
|
data = json.loads(line) |
|
all_data.append(data) |
|
if data['label'] not in category: |
|
category[data['label']] = [] |
|
category[data['label']].append(data['main_entity']) |
|
|
|
with open('result/processed_data.json', 'w', encoding='utf-8') as f: |
|
for data in tqdm(all_data): |
|
text = data['text'] |
|
main_entity = [data['main_entity']] |
|
if data['id'] in entity_info: |
|
main_entity.extend(entity_info[data['id']]['aliases']) |
|
if len(category[data['label']]) == 1: |
|
continue |
|
replaced_eneity = random.sample(category[data['label']], 1) |
|
while replaced_eneity[0] in main_entity: |
|
replaced_eneity = random.sample(category[data['label']], 1) |
|
for entity in main_entity: |
|
text = text.replace(entity, replaced_eneity[0]) |
|
data = { |
|
'text':text, |
|
'original_main_entity':main_entity, |
|
'replaced_entity':replaced_eneity[0] |
|
} |
|
json_data = json.dumps(data) |
|
f.write(json_data+'\n') |
|
|
|
|