Spaces:
Runtime error
Runtime error
import ujson | |
import codecs | |
import re | |
from rich import progress | |
import numpy as np | |
def process_all_50_schemas(raw_schemas_file: str='./data/all_50_schemas', save_schemas_file: str=None) -> list[str]: | |
''' | |
获取prompt的关系列表 | |
''' | |
lines = [] | |
with codecs.open(raw_schemas_file, 'r', encoding='utf-8') as f: | |
lines = f.readlines() | |
scheme_list = [] | |
for line in lines: | |
item = ujson.loads(line) | |
scheme_list.append( | |
item['predicate'] | |
) | |
scheme_list = list(set(scheme_list)) | |
if save_schemas_file: | |
with codecs.open(save_schemas_file, 'w', encoding='utf-8') as f: | |
ujson.dump(f"{scheme_list}", f, indent=4, ensure_ascii=False) | |
return scheme_list | |
def process_spo_list(text: str, spo_list: list, repair_song: bool=False): | |
''' | |
处理spo_list,处理成{subject: 'subject', subject_start: 0, subject_end:3, predicate: 'predicate', object: 'object', object_start: 5, object_end = 7} | |
''' | |
new_spo_list = [] | |
# 找出所有用书名号隔开的名字 | |
some_name = re.findall('《([^《》]*?)》', text) | |
some_name = [n.strip() for n in some_name] | |
# 歌曲和专辑 | |
song = [] | |
album = [] | |
for spo in spo_list: | |
# 修正so的错误,删除前后的书名号 | |
s = spo['subject'].strip('《》').strip().lower() | |
o = spo['object'].strip('《》').strip().lower() | |
p = spo['predicate'] | |
# 如果s在找到的名字中,以正则找到的s为准,用in判等, | |
# 如text: '《造梦者---dreamer》',但是标注的s是'造梦者' | |
for name in some_name: | |
if s in name and text.count(s) == 1: | |
s = name | |
if repair_song: | |
if p == '所属专辑': | |
song.append(s) | |
album.append(o) | |
temp = dict() | |
temp['s'] = s | |
temp['p'] = spo['predicate'] | |
temp['o'] = o | |
# 在text中找不到subject 或者 object,不要这条数据了 | |
if text.find(s) == -1 or text.find(o) == -1: | |
continue | |
new_spo_list.append(temp) | |
if repair_song: | |
ret_spo_list = [] | |
ps = ['歌手', '作词', '作曲'] | |
for spo in new_spo_list: | |
s, p, o = spo['s'], spo['p'], spo['o'] | |
if p in ps and s in album and s not in song: | |
continue | |
ret_spo_list.append(spo) | |
return ret_spo_list | |
return new_spo_list | |
def process_data(raw_data_file: str, train_file_name: str, dev_file_name: str, keep_max_length: int=512, repair_song: bool=True, dev_size: int=1000) -> None: | |
''' | |
将原始的格式处理为prompt:resopnse的格式 | |
''' | |
lines = [] | |
with codecs.open(raw_data_file, 'r', encoding='utf-8') as f: | |
lines = f.readlines() | |
my_raw_data = [] | |
schemas = process_all_50_schemas('./data/all_50_schemas') | |
schemas = f"[{','.join(schemas)}]" | |
for i, line in progress.track(enumerate(lines), total=len(lines)): | |
tmp = ujson.decode(line) | |
text = f"请抽取出给定句子中的所有三元组。给定句子:{tmp['text'].lower()}" | |
spo_list = process_spo_list(tmp['text'].lower(), tmp['spo_list'], repair_song=repair_song) | |
spo = f"{[(item['s'], item['p'], item['o']) for item in spo_list]}" | |
# 删除长度过长、没有找到实体信息的句子 | |
if len(text) > keep_max_length or len(spo) > keep_max_length or len(spo_list) == 0: | |
continue | |
my_raw_data.append({ | |
'prompt': text, | |
'response':spo.replace('\'','').replace(' ', ''), | |
}) | |
dev_date = [] | |
if dev_file_name is not None: | |
dev_index = np.random.choice(range(0, len(my_raw_data)), size=dev_size, replace=False) | |
dev_index = set(dev_index) | |
assert len(dev_index) == dev_size | |
train_data = [x for i, x in enumerate(my_raw_data) if i not in dev_index] | |
dev_date = [x for i, x in enumerate(my_raw_data) if i in dev_index] | |
with codecs.open(dev_file_name, 'w', encoding='utf-8') as f: | |
ujson.dump(dev_date, f, indent=4, ensure_ascii=False) | |
my_raw_data = train_data | |
print(f'length of train data {len(my_raw_data)}, length of eval data {len(dev_date)}') | |
with codecs.open(train_file_name, 'w', encoding='utf-8') as f: | |
ujson.dump(my_raw_data, f, indent=4, ensure_ascii=False) | |
if __name__ == '__main__': | |
raw_data_file = './data/train_data.json' | |
train_file = './data/my_train.json' | |
dev_file = './data/my_eval.json' | |
process_all_50_schemas('./data/all_50_schemas', './data/my_schemas.txt') | |
process_data(raw_data_file, train_file, dev_file, keep_max_length=512, dev_size=1000) | |
# 使用该数据集公开的dev_data作为测试集 | |
process_data('./data/dev_data.json', train_file_name='./data/test.json', dev_file_name=None, keep_max_length=512, dev_size=1000) | |