{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 下载c-eavl数据集\n",
"\n",
"```bash\n",
"mkdir ceval-data\n",
"cd ceval-data\n",
"wget https://huggingface.co/datasets/ceval/ceval-exam/resolve/main/ceval-exam.zip \n",
"unzip ceval-exam.zip -d ceval-exam\n",
"wget https://github.com/hkust-nlp/ceval/blob/main/subject_mapping.json\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dev\n",
"subject_mapping.json\n",
"test\n",
"val\n"
]
}
],
"source": [
"! ls ceval-exam"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import os, re\n",
"import ujson\n",
"import torch\n",
"import pandas as pd\n",
"from tqdm import tqdm\n",
"from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n",
"from transformers.generation.configuration_utils import GenerationConfig\n",
"from transformers.generation.utils import LogitsProcessorList, InfNanRemoveLogitsProcessor"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"ceval_dir = './ceval-exam'\n",
"result_save_dir = './result'\n",
"model_dir = '../model_save/dpo' # 模型文件在上一层目录,使用dpo后的模型\n",
"\n",
"if not os.path.exists(result_save_dir):\n",
" os.mkdir(result_save_dir)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"subject_files = os.listdir(f\"{ceval_dir}/val\")\n",
"subjects = [subjetc.replace('_val.csv', '') for subjetc in subject_files]\n",
"\n",
"subject_mapping = {}\n",
"with open('./ceval-exam/subject_mapping.json', 'r', encoding='utf-8') as f:\n",
" subject_mapping = ujson.load(f)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"由于本项目的模型在sft阶段删除了很多带input的数据,且没有针对问题回答做微调,直接输入问题会解释问题中提到的关键词。所以c-eval测试使用预测 'A'、'B'、'C'、'D' token的方式。\n",
"> 然而有时候,特别是零样本测试和面对没有做过指令微调的模型时,模型可能无法很好的理解指令,甚至有时不会回答问题。这种情况下我们推荐直接计算下一个预测token等于\"A\", \"B\", \"C\", \"D\"的概率,然后以概率最大的选项作为答案 \n",
"> -- 这是一种受限解码生成的方法,MMLU的官方测试代码中是使用了这种方法进行测试。注意这种概率方法对思维链的测试不适用。\n",
"\n",
"见: [如何在C-Eval上测试](https://github.com/hkust-nlp/ceval/blob/main/README_zh.md#如何在C-Eval上测试)\n",
"\n",
"评测模式:zero-shot模式(chatbot/对话机器人模式) \n",
"dev数据集用来做few-shot,暂时不用"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def format_prompt(df: pd.Series) -> str:\n",
" '''\n",
" 将df中的 'question', 'A', 'B', 'C', 'D',格式化为问题\n",
" '''\n",
" prompt = f\"请回答单选题,回答字母A、B、C、D即可。问题:\\n{df['question']}\\n答案选项:\\n\"\n",
" for col in ['A', 'B', 'C', 'D']:\n",
" prompt += f\"{col}:{df[col]}\\n\"\n",
" \n",
" return prompt"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['Accountant', '注册会计师', 'Other']"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"subject_mapping['accountant']"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 52/52 [00:00<00:00, 617.74it/s]\n"
]
}
],
"source": [
"do_test = False\n",
"all_eval_items = []\n",
"for i, subject_name in tqdm(enumerate(subjects), total=len(subjects)):\n",
" val_file = f\"{ceval_dir}/val/{subject_name}_val.csv\"\n",
" test_file = f\"{ceval_dir}/test/{subject_name}_test.csv\"\n",
"\n",
" val_df = pd.read_csv(test_file) if do_test else pd.read_csv(val_file)\n",
" \n",
" for idx, row in val_df.iterrows():\n",
" quesuton = format_prompt(row)\n",
" answer = row['answer'] if 'answer' in val_df.columns else '' \n",
"\n",
" item = {\n",
" 'subject_en': subject_mapping[subject_name][0],\n",
" 'subject_zh': subject_mapping[subject_name][1],\n",
" 'category': subject_mapping[subject_name][2], # 类别(STEM,Social Science,Humanities,Other四选一)\n",
" 'question': quesuton,\n",
" 'answer':answer,\n",
" }\n",
" \n",
" all_eval_items.append(item)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" subject_en | \n",
" subject_zh | \n",
" category | \n",
" question | \n",
" answer | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" Accountant | \n",
" 注册会计师 | \n",
" Other | \n",
" 请回答单选题,回答字母A、B、C、D即可。问题:\\n下列关于税法基本原则的表述中,不正确的是... | \n",
" D | \n",
"
\n",
" \n",
" 1 | \n",
" Accountant | \n",
" 注册会计师 | \n",
" Other | \n",
" 请回答单选题,回答字母A、B、C、D即可。问题:\\n甲公司是国内一家领先的新媒体、通信及移动... | \n",
" C | \n",
"
\n",
" \n",
" 2 | \n",
" Accountant | \n",
" 注册会计师 | \n",
" Other | \n",
" 请回答单选题,回答字母A、B、C、D即可。问题:\\n根据我国《印花税暂行条例》的规定,下列各... | \n",
" D | \n",
"
\n",
" \n",
" 3 | \n",
" Accountant | \n",
" 注册会计师 | \n",
" Other | \n",
" 请回答单选题,回答字母A、B、C、D即可。问题:\\n税务行政复议的申请人可以在得知税务机关作... | \n",
" A | \n",
"
\n",
" \n",
" 4 | \n",
" Accountant | \n",
" 注册会计师 | \n",
" Other | \n",
" 请回答单选题,回答字母A、B、C、D即可。问题:\\n关于战略管理表述错误的是____。\\n答... | \n",
" C | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" subject_en subject_zh category \\\n",
"0 Accountant 注册会计师 Other \n",
"1 Accountant 注册会计师 Other \n",
"2 Accountant 注册会计师 Other \n",
"3 Accountant 注册会计师 Other \n",
"4 Accountant 注册会计师 Other \n",
"\n",
" question answer \n",
"0 请回答单选题,回答字母A、B、C、D即可。问题:\\n下列关于税法基本原则的表述中,不正确的是... D \n",
"1 请回答单选题,回答字母A、B、C、D即可。问题:\\n甲公司是国内一家领先的新媒体、通信及移动... C \n",
"2 请回答单选题,回答字母A、B、C、D即可。问题:\\n根据我国《印花税暂行条例》的规定,下列各... D \n",
"3 请回答单选题,回答字母A、B、C、D即可。问题:\\n税务行政复议的申请人可以在得知税务机关作... A \n",
"4 请回答单选题,回答字母A、B、C、D即可。问题:\\n关于战略管理表述错误的是____。\\n答... C "
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"eval_df = pd.DataFrame(all_eval_items)\n",
"eval_df.head(5)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[872, 873, 884, 886]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 加载模型\n",
"tokenizer = AutoTokenizer.from_pretrained(model_dir)\n",
"model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)\n",
"\n",
"generation_config = GenerationConfig()\n",
"generation_config.remove_invalid_values = True # 自动添加InfNanRemoveLogitsProcessor\n",
"generation_config.eos_token_id = tokenizer.eos_token_id\n",
"generation_config.pad_token_id = tokenizer.pad_token_id\n",
"# for t5, set decoder_start_token_id = pad_token_id\n",
"generation_config.decoder_start_token_id = tokenizer.pad_token_id \n",
"generation_config.max_new_tokens = 16\n",
"generation_config.num_beams = 1\n",
"generation_config.do_sample = False # greedy search\n",
"\n",
"choices = ['A', 'B', 'C', 'D']\n",
"choices_ids = [tokenizer.convert_tokens_to_ids(c) for c in choices]\n",
"choices_ids"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1346/1346 [00:20<00:00, 64.11it/s]\n"
]
}
],
"source": [
"batch_size = 32\n",
"batch_data, batch_answers = [], []\n",
"n = len(eval_df)\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"model.to(device)\n",
"model.eval()\n",
"\n",
"for idx, row in tqdm(eval_df.iterrows(), total=n):\n",
" batch_data.append(row['question'])\n",
" \n",
" if len(batch_data) == batch_size or idx == n - 1:\n",
" torch.cuda.empty_cache()\n",
" \n",
" encode_ids = tokenizer(batch_data, padding=True)\n",
" input_ids, attention_mask = torch.LongTensor(encode_ids['input_ids']), torch.LongTensor(encode_ids['attention_mask'])\n",
" \n",
" outputs = model.generate(\n",
" input_ids=input_ids.to(device),\n",
" attention_mask=attention_mask.to(device),\n",
" generation_config=generation_config,\n",
" return_dict_in_generate=True,\n",
" output_scores=True,\n",
" )\n",
"\n",
" scores = torch.stack(outputs['scores'], dim=1)\n",
" scores = torch.softmax(scores, dim=2)\n",
" scores = scores[..., 0, choices_ids] #取第一个字符的ABCD概率\n",
" choices_index = torch.argmax(scores, dim=1)\n",
" \n",
" for i in choices_index:\n",
" batch_answers.append(choices[i])\n",
"\n",
" batch_data = []"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"eval_df.insert(loc=5, column='model_predict', value=batch_answers)\n",
"val_df = eval_df.copy(deep=True)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"val_df['is_correct'] = val_df['model_predict'] == val_df['answer']\n",
"val_df['is_correct'] = val_df['is_correct'].astype(pd.Int16Dtype())"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" subject_en | \n",
" subject_zh | \n",
" category | \n",
" question | \n",
" answer | \n",
" model_predict | \n",
" is_correct | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" Accountant | \n",
" 注册会计师 | \n",
" Other | \n",
" 请回答单选题,回答字母A、B、C、D即可。问题:\\n下列关于税法基本原则的表述中,不正确的是... | \n",
" D | \n",
" A | \n",
" 0 | \n",
"
\n",
" \n",
" 1 | \n",
" Accountant | \n",
" 注册会计师 | \n",
" Other | \n",
" 请回答单选题,回答字母A、B、C、D即可。问题:\\n甲公司是国内一家领先的新媒体、通信及移动... | \n",
" C | \n",
" A | \n",
" 0 | \n",
"
\n",
" \n",
" 2 | \n",
" Accountant | \n",
" 注册会计师 | \n",
" Other | \n",
" 请回答单选题,回答字母A、B、C、D即可。问题:\\n根据我国《印花税暂行条例》的规定,下列各... | \n",
" D | \n",
" A | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" subject_en subject_zh category \\\n",
"0 Accountant 注册会计师 Other \n",
"1 Accountant 注册会计师 Other \n",
"2 Accountant 注册会计师 Other \n",
"\n",
" question answer model_predict \\\n",
"0 请回答单选题,回答字母A、B、C、D即可。问题:\\n下列关于税法基本原则的表述中,不正确的是... D A \n",
"1 请回答单选题,回答字母A、B、C、D即可。问题:\\n甲公司是国内一家领先的新媒体、通信及移动... C A \n",
"2 请回答单选题,回答字母A、B、C、D即可。问题:\\n根据我国《印花税暂行条例》的规定,下列各... D A \n",
"\n",
" is_correct \n",
"0 0 \n",
"1 0 \n",
"2 0 "
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"val_df.head(3)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" is_correct | \n",
"
\n",
" \n",
" category | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" Humanities | \n",
" 63 | \n",
"
\n",
" \n",
" Other | \n",
" 89 | \n",
"
\n",
" \n",
" STEM | \n",
" 89 | \n",
"
\n",
" \n",
" Social Science | \n",
" 72 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" is_correct\n",
"category \n",
"Humanities 63\n",
"Other 89\n",
"STEM 89\n",
"Social Science 72"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"final_df = val_df.groupby('category').sum('is_correct')\n",
"final_df"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" is_correct | \n",
" question_count | \n",
" accuracy | \n",
"
\n",
" \n",
" category | \n",
" | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" Humanities | \n",
" 63 | \n",
" 257 | \n",
" 24.51% | \n",
"
\n",
" \n",
" Other | \n",
" 89 | \n",
" 384 | \n",
" 23.18% | \n",
"
\n",
" \n",
" STEM | \n",
" 89 | \n",
" 430 | \n",
" 20.70% | \n",
"
\n",
" \n",
" Social Science | \n",
" 72 | \n",
" 275 | \n",
" 26.18% | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" is_correct question_count accuracy\n",
"category \n",
"Humanities 63 257 24.51%\n",
"Other 89 384 23.18%\n",
"STEM 89 430 20.70%\n",
"Social Science 72 275 26.18%"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"final_df['question_count'] = val_df.groupby('category').count()['question']\n",
"final_df['accuracy'] = final_df['is_correct'] / final_df['question_count']\n",
"final_df['accuracy'] = final_df['accuracy'] .apply(lambda x: format(x, '.2%'))\n",
"final_df"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "py310",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}