Spaces:
Runtime error
Runtime error
import sys | |
sys.path.extend(['.','..']) | |
import os | |
import re | |
import torch | |
import pandas as pd | |
import numpy as np | |
import ujson | |
from rich import progress | |
import pyarrow.parquet as pq | |
from model.infer import ChatBot | |
from logger import Logger | |
from config import PROJECT_ROOT, InferConfig | |
from utils.raw_data_process import delete_file | |
log = Logger('data_process', save2file=True, file_name=PROJECT_ROOT + '/logs/raw_data_process.log') | |
def process_alpaca_gpt4_data(max_len: int=512) -> None: | |
'''' | |
处理RM高质量回答部分 | |
数据集:https://huggingface.co/datasets/c-s-ale/alpaca-gpt4-data-zh | |
''' | |
read_file = PROJECT_ROOT + '/data/raw_data/alpaca_gpt4_data_zh.json' | |
save_file = PROJECT_ROOT + '/data/alpaca_gpt4_data_zh.json' | |
max_len += 8 | |
my_data = [] | |
with open(read_file, 'r', encoding='utf-8') as f: | |
data = ujson.load(f) | |
print('length of {} is {}'.format(read_file, len(data))) | |
for item in progress.track(data): | |
prompt = item['instruction'] | |
inputs = item['input'] | |
response = item['output'] | |
if len(response) > max_len: continue # 超长的不要 | |
if len(inputs.strip()) > 0: | |
prompt = f"{prompt},{inputs}" | |
if len(prompt) > max_len: continue | |
if len(prompt) == 0 or len(response) == 0: continue | |
my_data.append( | |
{ | |
'prompt': prompt, | |
'chosen': response | |
} | |
) | |
print('length of {} is {}'.format(save_file, len(my_data))) | |
with open(save_file, 'w', encoding='utf-8') as f: | |
ujson.dump(my_data, f, indent=4, ensure_ascii=False) | |
def generate_alpaca_gpt4_reject_response(groups_cnt: int=50000, max_len: int=320, batch_size: int=32) -> None: | |
'''生成不是很满意的回答回答 | |
''' | |
print('load model...') | |
# load config | |
infer_config = InferConfig() | |
chatbot = ChatBot(infer_config) | |
model = chatbot.model | |
tokenizer = chatbot.tokenizer | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
finetune_file = PROJECT_ROOT + '/data/alpaca_gpt4_data_zh.json' | |
save_rw_json_file = PROJECT_ROOT + '/data/my_dpo_alpaca_gpt4_data_zh.json' | |
# save_rw_parquet_file = PROJECT_ROOT + '/data/my_rlhf_dataset.parquet' | |
data = [] | |
with open(finetune_file, 'r', encoding='utf-8') as f: | |
data = ujson.load(f) | |
log.info('length of {} is {}'.format(save_rw_json_file, len(data)), save_to_file=True) | |
model_outs = [] | |
batch_prompt = [] | |
process_item = [] | |
for i, item in progress.track(enumerate(data), total=len(data)): | |
# 模型生成的答案为拒绝答案 | |
batch_prompt.append(f"{item['prompt']}[EOS]") | |
process_item.append(item) | |
if i % 500 == 0: | |
print('process {} items.'.format(i)) | |
if len(batch_prompt) >= batch_size or i == len(data) - 1: | |
encoded = tokenizer.batch_encode_plus(batch_prompt, truncation=False, padding=True) | |
with torch.no_grad(): | |
input_ids = torch.LongTensor(encoded.input_ids).to(device) | |
attention_mask = torch.LongTensor(encoded.attention_mask).to(device) | |
outputs = model.my_generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
max_seq_len=infer_config.max_seq_len, | |
search_type='greedy', | |
) | |
outputs = tokenizer.batch_decode(outputs.cpu().numpy(), clean_up_tokenization_spaces=True, skip_special_tokens=True) | |
model_outs.extend(outputs) | |
batch_prompt = [] | |
if len(model_outs) % 2000 == 0: | |
for i in range(len(model_outs)): | |
process_item[i]['reject'] = model_outs[i] | |
try: | |
with open(PROJECT_ROOT + '/data/outs.ckp.json', 'w', encoding='utf-8') as f: | |
ujson.dump(process_item, f, indent=4, ensure_ascii=False) | |
except Exception as e: | |
print(e) | |
for i in range(len(model_outs)): | |
process_item[i]['reject'] = model_outs[i] | |
with open(save_rw_json_file, 'w', encoding='utf-8') as f: | |
ujson.dump(process_item, f, indent=4, ensure_ascii=False) | |
# df = pd.DataFrame(data) | |
# write_single_parquet_file(save_rw_parquet_file, df) | |
def replace_line(s: str) -> str: | |
'''将双斜杠替换为单斜杠,既是 \\n 替换为 \n | |
''' | |
return re.sub('\\\\n', '\n', s) | |
def merge_rlhf_data(max_len: int=512) -> None: | |
'''' | |
处理RM高质量回答部分 | |
数据集:https://huggingface.co/datasets/Skepsun/huozi_rlhf_data_json | |
https://huggingface.co/datasets/beyond/rlhf-reward-single-round-trans_chinese | |
''' | |
my_data = [] | |
read_files = [ | |
PROJECT_ROOT + '/data/raw_data/huozi_rlhf_data.json', | |
PROJECT_ROOT + '/data/my_dpo_alpaca_gpt4_data_zh.json', | |
] | |
save_file = PROJECT_ROOT + '/data/my_dpo_data.json' | |
if os.path.exists(save_file): | |
assert delete_file(save_file) | |
max_len += 8 # for eos token | |
for read_file in read_files: | |
items = [] | |
with open(read_file, 'r', encoding='utf-8') as f: | |
items = ujson.load(f) | |
for item in progress.track(items): | |
prompt, chosen, reject = item['prompt'], item['chosen'], item['reject'] | |
if len(prompt) > max_len or len(chosen) > max_len or len(reject) > max_len: | |
continue | |
# reject.strip() == chosen.strip(),这两个相同的也不要 | |
if len(prompt) == 0 or len(chosen) == 0 or len(reject) == 0 or reject.strip() == chosen.strip(): | |
continue | |
my_data.append({ | |
'prompt': replace_line(prompt), | |
'chosen': replace_line(chosen), | |
'rejected': replace_line(reject), | |
}) | |
read_files = [ | |
PROJECT_ROOT + '/data/raw_data/train-00000-of-00001-789dc5dece0f1fc1.parquet', | |
PROJECT_ROOT + '/data/raw_data/test-00000-of-00001-8ecd46436fadcf7f.parquet', | |
] | |
for read_file in read_files: | |
pf = pq.read_table(read_file) | |
for prompt, chosen, rejected in progress.track(zip(pf['prompt'], pf['chosen'], pf['rejected']), total=pf.num_rows): | |
prompt, chosen, rejected = prompt.as_py(), chosen.as_py(), rejected.as_py() | |
if len(prompt) > max_len or len(chosen) > max_len or len(rejected) > max_len: | |
continue | |
if len(prompt) == 0 or len(chosen) == 0 or len(rejected) == 0 or rejected.strip() == chosen.strip(): | |
continue | |
my_data.append({ | |
'prompt': replace_line(prompt), | |
'chosen': replace_line(chosen), | |
'rejected': replace_line(rejected), | |
}) | |
print('length of {} is {}'.format(save_file, len(my_data))) | |
with open(save_file, 'w', encoding='utf-8') as f: | |
ujson.dump(my_data, f, indent=4, ensure_ascii=False) | |
def split_train_eval_dataset() -> None: | |
'''划分数据集 | |
''' | |
rw_json_file = PROJECT_ROOT + '/data/my_dpo_data.json' | |
train_file = PROJECT_ROOT + '/data/my_dpo_train.json' | |
eval_file = PROJECT_ROOT + '/data/my_dpo_eval.json' | |
data = [] | |
with open(rw_json_file, 'r', encoding='utf-8') as f: | |
data = ujson.load(f) | |
np.random.shuffle(data) | |
split_idx = int(len(data) * 0.99) | |
train_data = data[0: split_idx] | |
eval_data = data[split_idx: ] | |
log.info('train size: {}, eval size:{}'.format(len(train_data), len(eval_data)), save_to_file=True) | |
with open(train_file, 'w', encoding='utf-8') as f: | |
ujson.dump(train_data, f, indent=4, ensure_ascii=False) | |
with open(eval_file, 'w', encoding='utf-8') as f: | |
ujson.dump(eval_data, f, indent=4, ensure_ascii=False) | |
if __name__ == '__main__': | |
# 1. 处理chosen文本 | |
# process_alpaca_gpt4_data() | |
# 2. 生成rejected文本 | |
# generate_alpaca_gpt4_reject_response() | |
# 合并数据集 | |
merge_rlhf_data() | |
# 3. split train and eval dataset | |
# split_train_eval_dataset() |