File size: 5,039 Bytes
f4fac26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import ujson
import codecs
import re
from rich import progress
import numpy as np 


def process_all_50_schemas(raw_schemas_file: str='./data/all_50_schemas', save_schemas_file: str=None) -> list[str]:
    '''
    获取prompt的关系列表
    '''
    lines = []
    with codecs.open(raw_schemas_file, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    
    scheme_list = []
    for line in lines:
        item = ujson.loads(line)
        scheme_list.append(
            item['predicate']
        ) 
    
    scheme_list = list(set(scheme_list))
    
    if save_schemas_file:
        with codecs.open(save_schemas_file, 'w', encoding='utf-8') as f:
            ujson.dump(f"{scheme_list}", f, indent=4, ensure_ascii=False)

    return scheme_list

def process_spo_list(text: str, spo_list: list, repair_song: bool=False):
    '''
    处理spo_list,处理成{subject: 'subject', subject_start: 0, subject_end:3, predicate: 'predicate', object: 'object', object_start: 5, object_end = 7}
    '''
    new_spo_list = []

    # 找出所有用书名号隔开的名字
    some_name = re.findall('《([^《》]*?)》', text)
    some_name = [n.strip() for n in some_name]
    
    # 歌曲和专辑
    song = []
    album = []
    for spo in spo_list:

        # 修正so的错误,删除前后的书名号
        s = spo['subject'].strip('《》').strip().lower()
        o = spo['object'].strip('《》').strip().lower()
        p = spo['predicate']
        
        # 如果s在找到的名字中,以正则找到的s为准,用in判等,
        # 如text: '《造梦者---dreamer》',但是标注的s是'造梦者'
        for name in some_name:
            if s in name and text.count(s) == 1:
                s = name
        
        if repair_song:
            if p == '所属专辑':
                song.append(s)
                album.append(o)

        temp = dict()
        temp['s'] = s
        temp['p'] = spo['predicate']
        temp['o'] = o


        # 在text中找不到subject 或者 object,不要这条数据了
        if text.find(s) == -1 or text.find(o) == -1:
            continue

        new_spo_list.append(temp)
    
    if repair_song:
        ret_spo_list = []
        ps = ['歌手', '作词', '作曲']
        
        for spo in new_spo_list:
            s, p, o = spo['s'], spo['p'], spo['o']
            if p in ps and s in album and s not in song:
                continue
            ret_spo_list.append(spo)

        return ret_spo_list

    return new_spo_list


def process_data(raw_data_file: str, train_file_name: str, dev_file_name: str, keep_max_length: int=512, repair_song: bool=True, dev_size: int=1000) -> None:
    '''
    将原始的格式处理为prompt:resopnse的格式
    '''
    lines = []
    with codecs.open(raw_data_file, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    my_raw_data = []

    schemas = process_all_50_schemas('./data/all_50_schemas')
    schemas = f"[{','.join(schemas)}]"
    for i, line in progress.track(enumerate(lines), total=len(lines)):
        
        tmp = ujson.decode(line)
        text = f"请抽取出给定句子中的所有三元组。给定句子:{tmp['text'].lower()}"
            
        spo_list = process_spo_list(tmp['text'].lower(), tmp['spo_list'], repair_song=repair_song)
        spo = f"{[(item['s'], item['p'], item['o']) for item in spo_list]}"
        # 删除长度过长、没有找到实体信息的句子
        if len(text) > keep_max_length or len(spo) > keep_max_length or len(spo_list) == 0:
            continue

        my_raw_data.append({
                'prompt': text, 
                'response':spo.replace('\'','').replace(' ', ''),
            })


    dev_date = []
    if dev_file_name is not None:
        dev_index = np.random.choice(range(0, len(my_raw_data)), size=dev_size, replace=False)
        dev_index = set(dev_index)
        assert len(dev_index) == dev_size
        
        train_data = [x for i, x in enumerate(my_raw_data) if i not in dev_index]
        dev_date = [x for i, x in enumerate(my_raw_data) if i in dev_index]
        
        with codecs.open(dev_file_name, 'w', encoding='utf-8') as f:
            ujson.dump(dev_date, f, indent=4, ensure_ascii=False)
        
        my_raw_data = train_data

    print(f'length of train data {len(my_raw_data)}, length of eval data {len(dev_date)}')
    
    with codecs.open(train_file_name, 'w', encoding='utf-8') as f:
        ujson.dump(my_raw_data, f, indent=4, ensure_ascii=False)

if __name__ == '__main__':
    raw_data_file = './data/train_data.json'
    train_file = './data/my_train.json'
    dev_file = './data/my_eval.json'

    process_all_50_schemas('./data/all_50_schemas', './data/my_schemas.txt')

    process_data(raw_data_file, train_file, dev_file, keep_max_length=512, dev_size=1000)

    # 使用该数据集公开的dev_data作为测试集
    process_data('./data/dev_data.json', train_file_name='./data/test.json', dev_file_name=None, keep_max_length=512, dev_size=1000)