fangshengren's picture
Upload 59 files
f4fac26 verified
raw
history blame
5.04 kB
import ujson
import codecs
import re
from rich import progress
import numpy as np
def process_all_50_schemas(raw_schemas_file: str='./data/all_50_schemas', save_schemas_file: str=None) -> list[str]:
'''
获取prompt的关系列表
'''
lines = []
with codecs.open(raw_schemas_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
scheme_list = []
for line in lines:
item = ujson.loads(line)
scheme_list.append(
item['predicate']
)
scheme_list = list(set(scheme_list))
if save_schemas_file:
with codecs.open(save_schemas_file, 'w', encoding='utf-8') as f:
ujson.dump(f"{scheme_list}", f, indent=4, ensure_ascii=False)
return scheme_list
def process_spo_list(text: str, spo_list: list, repair_song: bool=False):
'''
处理spo_list,处理成{subject: 'subject', subject_start: 0, subject_end:3, predicate: 'predicate', object: 'object', object_start: 5, object_end = 7}
'''
new_spo_list = []
# 找出所有用书名号隔开的名字
some_name = re.findall('《([^《》]*?)》', text)
some_name = [n.strip() for n in some_name]
# 歌曲和专辑
song = []
album = []
for spo in spo_list:
# 修正so的错误,删除前后的书名号
s = spo['subject'].strip('《》').strip().lower()
o = spo['object'].strip('《》').strip().lower()
p = spo['predicate']
# 如果s在找到的名字中,以正则找到的s为准,用in判等,
# 如text: '《造梦者---dreamer》',但是标注的s是'造梦者'
for name in some_name:
if s in name and text.count(s) == 1:
s = name
if repair_song:
if p == '所属专辑':
song.append(s)
album.append(o)
temp = dict()
temp['s'] = s
temp['p'] = spo['predicate']
temp['o'] = o
# 在text中找不到subject 或者 object,不要这条数据了
if text.find(s) == -1 or text.find(o) == -1:
continue
new_spo_list.append(temp)
if repair_song:
ret_spo_list = []
ps = ['歌手', '作词', '作曲']
for spo in new_spo_list:
s, p, o = spo['s'], spo['p'], spo['o']
if p in ps and s in album and s not in song:
continue
ret_spo_list.append(spo)
return ret_spo_list
return new_spo_list
def process_data(raw_data_file: str, train_file_name: str, dev_file_name: str, keep_max_length: int=512, repair_song: bool=True, dev_size: int=1000) -> None:
'''
将原始的格式处理为prompt:resopnse的格式
'''
lines = []
with codecs.open(raw_data_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
my_raw_data = []
schemas = process_all_50_schemas('./data/all_50_schemas')
schemas = f"[{','.join(schemas)}]"
for i, line in progress.track(enumerate(lines), total=len(lines)):
tmp = ujson.decode(line)
text = f"请抽取出给定句子中的所有三元组。给定句子:{tmp['text'].lower()}"
spo_list = process_spo_list(tmp['text'].lower(), tmp['spo_list'], repair_song=repair_song)
spo = f"{[(item['s'], item['p'], item['o']) for item in spo_list]}"
# 删除长度过长、没有找到实体信息的句子
if len(text) > keep_max_length or len(spo) > keep_max_length or len(spo_list) == 0:
continue
my_raw_data.append({
'prompt': text,
'response':spo.replace('\'','').replace(' ', ''),
})
dev_date = []
if dev_file_name is not None:
dev_index = np.random.choice(range(0, len(my_raw_data)), size=dev_size, replace=False)
dev_index = set(dev_index)
assert len(dev_index) == dev_size
train_data = [x for i, x in enumerate(my_raw_data) if i not in dev_index]
dev_date = [x for i, x in enumerate(my_raw_data) if i in dev_index]
with codecs.open(dev_file_name, 'w', encoding='utf-8') as f:
ujson.dump(dev_date, f, indent=4, ensure_ascii=False)
my_raw_data = train_data
print(f'length of train data {len(my_raw_data)}, length of eval data {len(dev_date)}')
with codecs.open(train_file_name, 'w', encoding='utf-8') as f:
ujson.dump(my_raw_data, f, indent=4, ensure_ascii=False)
if __name__ == '__main__':
raw_data_file = './data/train_data.json'
train_file = './data/my_train.json'
dev_file = './data/my_eval.json'
process_all_50_schemas('./data/all_50_schemas', './data/my_schemas.txt')
process_data(raw_data_file, train_file, dev_file, keep_max_length=512, dev_size=1000)
# 使用该数据集公开的dev_data作为测试集
process_data('./data/dev_data.json', train_file_name='./data/test.json', dev_file_name=None, keep_max_length=512, dev_size=1000)