File size: 4,225 Bytes
861bb01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import spacy
import zstandard as zstd
import json
import typing
import os
from tqdm import tqdm
import multiprocessing
import random
from langdetect import detect
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--input_dir', type=str, help='Path to the input file')
args = parser.parse_args()
input_dir = args.input_dir


def is_english(text):
    try:
        lang = detect(text)
        return lang == 'en'
    except:
        return False
    
def process_text(texts, model, out_f, lock):
    for text in texts:
        doc = model(text)
        freq_cnt = {}
        for e in doc.ents:
            if e not in freq_cnt:
                freq_cnt[e] = 0
            freq_cnt[e] += 1
        if len(freq_cnt) == 0:
            continue
        sorted_freq = sorted(freq_cnt.items(), key = lambda x:[1])
        most_freq = sorted_freq[-1][0]
        data = {'text':text, 'main_entity':most_freq.text, 'label': most_freq.label_, 'id': most_freq.kb_id_}
        json_data = json.dumps(data)
        with lock:
            out_f.write(json_data + '\n')
            out_f.flush()
            
def run_ner_linking(texts: typing.List[str], ner_model_path: str):
    nlp = spacy.load(ner_model_path)
    out_f = open('result/temp_store_data.json', 'w', encoding='utf-8')
    lock = multiprocessing.Lock()
    processes = []

    for i in tqdm(range(0, len(texts), 1000)):
        p = multiprocessing.Process(target=process_text, args=(texts[i:i+1000], nlp, out_f, lock))
        processes.append(p)
        p.start()

    for p in processes:
        p.join()
        
    out_f.close()
    return 

wikipedia_out_path='result/wikipedia.json'
subdirectories = [f.path for f in os.scandir(input_dir) if f.is_dir()]
wikipedia_data = []
for sub_dir in subdirectories:
    chunk_dir = sub_dir+'/'
    zst_files = [f for f in os.listdir(chunk_dir) if f.endswith('.zst')]
    for file in tqdm(zst_files):
        with open(chunk_dir+file, 'rb') as compressed_file:
            decompressor = zstd.ZstdDecompressor()
            with decompressor.stream_reader(compressed_file) as reader:
                decompressed_data = reader.read()
        for line in decompressed_data.splitlines():
            data = json.loads(line)
            # print(data)
            if data['meta']['redpajama_set_name']=='RedPajamaWikipedia':
                if is_english(data['text']):
                    wikipedia_data.append(data)
                    
with open(wikipedia_out_path, 'w', encoding='utf-8') as f:
    for data in wikipedia_data:
        json_data = json.dumps(data)
        f.write(json_data+'\n')
        
wikipedia_data = []
ner_model_path = 'kc-ner-model'
with open(wikipedia_out_path, 'r', encoding='utf-8') as f:
    for line in tqdm(f):
        data = json.loads(line)
        wikipedia_data.append(data['text'])
run_ner_linking(wikipedia_data, ner_model_path)

entity_info_path = 'result/entity_info.json'
with open(entity_info_path, 'r', encoding='utf-8') as f:
    entity_info = json.load(f)
all_original_data = []

category = {}
all_data = []
with open('result/temp_store_data.json', 'r', encoding='utf-8') as f:
    for line in f:
        data = json.loads(line)
        all_data.append(data)
        if data['label'] not in category:
            category[data['label']] = []
        category[data['label']].append(data['main_entity'])
        
with open('result/processed_data.json', 'w', encoding='utf-8') as f:
    for data in tqdm(all_data):
        text = data['text']
        main_entity = [data['main_entity']]
        if data['id'] in entity_info:
            main_entity.extend(entity_info[data['id']]['aliases'])
        if len(category[data['label']]) == 1:
            continue
        replaced_eneity = random.sample(category[data['label']], 1)
        while replaced_eneity[0] in main_entity:
            replaced_eneity = random.sample(category[data['label']], 1)
        for entity in main_entity:
            text = text.replace(entity, replaced_eneity[0])
        data = {
            'text':text,
            'original_main_entity':main_entity,
            'replaced_entity':replaced_eneity[0]
        }
        json_data = json.dumps(data)
        f.write(json_data+'\n')