LasRuinasCirculares's picture
Upload 7 files
861bb01 verified
import spacy
import zstandard as zstd
import json
import typing
import os
from tqdm import tqdm
import multiprocessing
import random
from langdetect import detect
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--input_dir', type=str, help='Path to the input file')
args = parser.parse_args()
input_dir = args.input_dir
def is_english(text):
try:
lang = detect(text)
return lang == 'en'
except:
return False
def process_text(texts, model, out_f, lock):
for text in texts:
doc = model(text)
freq_cnt = {}
for e in doc.ents:
if e not in freq_cnt:
freq_cnt[e] = 0
freq_cnt[e] += 1
if len(freq_cnt) == 0:
continue
sorted_freq = sorted(freq_cnt.items(), key = lambda x:[1])
most_freq = sorted_freq[-1][0]
data = {'text':text, 'main_entity':most_freq.text, 'label': most_freq.label_, 'id': most_freq.kb_id_}
json_data = json.dumps(data)
with lock:
out_f.write(json_data + '\n')
out_f.flush()
def run_ner_linking(texts: typing.List[str], ner_model_path: str):
nlp = spacy.load(ner_model_path)
out_f = open('result/temp_store_data.json', 'w', encoding='utf-8')
lock = multiprocessing.Lock()
processes = []
for i in tqdm(range(0, len(texts), 1000)):
p = multiprocessing.Process(target=process_text, args=(texts[i:i+1000], nlp, out_f, lock))
processes.append(p)
p.start()
for p in processes:
p.join()
out_f.close()
return
wikipedia_out_path='result/wikipedia.json'
subdirectories = [f.path for f in os.scandir(input_dir) if f.is_dir()]
wikipedia_data = []
for sub_dir in subdirectories:
chunk_dir = sub_dir+'/'
zst_files = [f for f in os.listdir(chunk_dir) if f.endswith('.zst')]
for file in tqdm(zst_files):
with open(chunk_dir+file, 'rb') as compressed_file:
decompressor = zstd.ZstdDecompressor()
with decompressor.stream_reader(compressed_file) as reader:
decompressed_data = reader.read()
for line in decompressed_data.splitlines():
data = json.loads(line)
# print(data)
if data['meta']['redpajama_set_name']=='RedPajamaWikipedia':
if is_english(data['text']):
wikipedia_data.append(data)
with open(wikipedia_out_path, 'w', encoding='utf-8') as f:
for data in wikipedia_data:
json_data = json.dumps(data)
f.write(json_data+'\n')
wikipedia_data = []
ner_model_path = 'kc-ner-model'
with open(wikipedia_out_path, 'r', encoding='utf-8') as f:
for line in tqdm(f):
data = json.loads(line)
wikipedia_data.append(data['text'])
run_ner_linking(wikipedia_data, ner_model_path)
entity_info_path = 'result/entity_info.json'
with open(entity_info_path, 'r', encoding='utf-8') as f:
entity_info = json.load(f)
all_original_data = []
category = {}
all_data = []
with open('result/temp_store_data.json', 'r', encoding='utf-8') as f:
for line in f:
data = json.loads(line)
all_data.append(data)
if data['label'] not in category:
category[data['label']] = []
category[data['label']].append(data['main_entity'])
with open('result/processed_data.json', 'w', encoding='utf-8') as f:
for data in tqdm(all_data):
text = data['text']
main_entity = [data['main_entity']]
if data['id'] in entity_info:
main_entity.extend(entity_info[data['id']]['aliases'])
if len(category[data['label']]) == 1:
continue
replaced_eneity = random.sample(category[data['label']], 1)
while replaced_eneity[0] in main_entity:
replaced_eneity = random.sample(category[data['label']], 1)
for entity in main_entity:
text = text.replace(entity, replaced_eneity[0])
data = {
'text':text,
'original_main_entity':main_entity,
'replaced_entity':replaced_eneity[0]
}
json_data = json.dumps(data)
f.write(json_data+'\n')