{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## 下载c-eavl数据集\n", "\n", "```bash\n", "mkdir ceval-data\n", "cd ceval-data\n", "wget https://huggingface.co/datasets/ceval/ceval-exam/resolve/main/ceval-exam.zip \n", "unzip ceval-exam.zip -d ceval-exam\n", "wget https://github.com/hkust-nlp/ceval/blob/main/subject_mapping.json\n", "```" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dev\n", "subject_mapping.json\n", "test\n", "val\n" ] } ], "source": [ "! ls ceval-exam" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import os, re\n", "import ujson\n", "import torch\n", "import pandas as pd\n", "from tqdm import tqdm\n", "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n", "from transformers.generation.configuration_utils import GenerationConfig\n", "from transformers.generation.utils import LogitsProcessorList, InfNanRemoveLogitsProcessor" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "ceval_dir = './ceval-exam'\n", "result_save_dir = './result'\n", "model_dir = '../model_save/dpo' # 模型文件在上一层目录,使用dpo后的模型\n", "\n", "if not os.path.exists(result_save_dir):\n", " os.mkdir(result_save_dir)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "subject_files = os.listdir(f\"{ceval_dir}/val\")\n", "subjects = [subjetc.replace('_val.csv', '') for subjetc in subject_files]\n", "\n", "subject_mapping = {}\n", "with open('./ceval-exam/subject_mapping.json', 'r', encoding='utf-8') as f:\n", " subject_mapping = ujson.load(f)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "由于本项目的模型在sft阶段删除了很多带input的数据,且没有针对问题回答做微调,直接输入问题会解释问题中提到的关键词。所以c-eval测试使用预测 'A'、'B'、'C'、'D' token的方式。\n", "> 然而有时候,特别是零样本测试和面对没有做过指令微调的模型时,模型可能无法很好的理解指令,甚至有时不会回答问题。这种情况下我们推荐直接计算下一个预测token等于\"A\", \"B\", \"C\", \"D\"的概率,然后以概率最大的选项作为答案 \n", "> -- 这是一种受限解码生成的方法,MMLU的官方测试代码中是使用了这种方法进行测试。注意这种概率方法对思维链的测试不适用。\n", "\n", "见: [如何在C-Eval上测试](https://github.com/hkust-nlp/ceval/blob/main/README_zh.md#如何在C-Eval上测试)\n", "\n", "评测模式:zero-shot模式(chatbot/对话机器人模式) \n", "dev数据集用来做few-shot,暂时不用" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def format_prompt(df: pd.Series) -> str:\n", " '''\n", " 将df中的 'question', 'A', 'B', 'C', 'D',格式化为问题\n", " '''\n", " prompt = f\"请回答单选题,回答字母A、B、C、D即可。问题:\\n{df['question']}\\n答案选项:\\n\"\n", " for col in ['A', 'B', 'C', 'D']:\n", " prompt += f\"{col}:{df[col]}\\n\"\n", " \n", " return prompt" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['Accountant', '注册会计师', 'Other']" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "subject_mapping['accountant']" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 52/52 [00:00<00:00, 617.74it/s]\n" ] } ], "source": [ "do_test = False\n", "all_eval_items = []\n", "for i, subject_name in tqdm(enumerate(subjects), total=len(subjects)):\n", " val_file = f\"{ceval_dir}/val/{subject_name}_val.csv\"\n", " test_file = f\"{ceval_dir}/test/{subject_name}_test.csv\"\n", "\n", " val_df = pd.read_csv(test_file) if do_test else pd.read_csv(val_file)\n", " \n", " for idx, row in val_df.iterrows():\n", " quesuton = format_prompt(row)\n", " answer = row['answer'] if 'answer' in val_df.columns else '' \n", "\n", " item = {\n", " 'subject_en': subject_mapping[subject_name][0],\n", " 'subject_zh': subject_mapping[subject_name][1],\n", " 'category': subject_mapping[subject_name][2], # 类别(STEM,Social Science,Humanities,Other四选一)\n", " 'question': quesuton,\n", " 'answer':answer,\n", " }\n", " \n", " all_eval_items.append(item)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
subject_ensubject_zhcategoryquestionanswer
0Accountant注册会计师Other请回答单选题,回答字母A、B、C、D即可。问题:\\n下列关于税法基本原则的表述中,不正确的是...D
1Accountant注册会计师Other请回答单选题,回答字母A、B、C、D即可。问题:\\n甲公司是国内一家领先的新媒体、通信及移动...C
2Accountant注册会计师Other请回答单选题,回答字母A、B、C、D即可。问题:\\n根据我国《印花税暂行条例》的规定,下列各...D
3Accountant注册会计师Other请回答单选题,回答字母A、B、C、D即可。问题:\\n税务行政复议的申请人可以在得知税务机关作...A
4Accountant注册会计师Other请回答单选题,回答字母A、B、C、D即可。问题:\\n关于战略管理表述错误的是____。\\n答...C
\n", "
" ], "text/plain": [ " subject_en subject_zh category \\\n", "0 Accountant 注册会计师 Other \n", "1 Accountant 注册会计师 Other \n", "2 Accountant 注册会计师 Other \n", "3 Accountant 注册会计师 Other \n", "4 Accountant 注册会计师 Other \n", "\n", " question answer \n", "0 请回答单选题,回答字母A、B、C、D即可。问题:\\n下列关于税法基本原则的表述中,不正确的是... D \n", "1 请回答单选题,回答字母A、B、C、D即可。问题:\\n甲公司是国内一家领先的新媒体、通信及移动... C \n", "2 请回答单选题,回答字母A、B、C、D即可。问题:\\n根据我国《印花税暂行条例》的规定,下列各... D \n", "3 请回答单选题,回答字母A、B、C、D即可。问题:\\n税务行政复议的申请人可以在得知税务机关作... A \n", "4 请回答单选题,回答字母A、B、C、D即可。问题:\\n关于战略管理表述错误的是____。\\n答... C " ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "eval_df = pd.DataFrame(all_eval_items)\n", "eval_df.head(5)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[872, 873, 884, 886]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 加载模型\n", "tokenizer = AutoTokenizer.from_pretrained(model_dir)\n", "model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)\n", "\n", "generation_config = GenerationConfig()\n", "generation_config.remove_invalid_values = True # 自动添加InfNanRemoveLogitsProcessor\n", "generation_config.eos_token_id = tokenizer.eos_token_id\n", "generation_config.pad_token_id = tokenizer.pad_token_id\n", "# for t5, set decoder_start_token_id = pad_token_id\n", "generation_config.decoder_start_token_id = tokenizer.pad_token_id \n", "generation_config.max_new_tokens = 16\n", "generation_config.num_beams = 1\n", "generation_config.do_sample = False # greedy search\n", "\n", "choices = ['A', 'B', 'C', 'D']\n", "choices_ids = [tokenizer.convert_tokens_to_ids(c) for c in choices]\n", "choices_ids" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1346/1346 [00:20<00:00, 64.11it/s]\n" ] } ], "source": [ "batch_size = 32\n", "batch_data, batch_answers = [], []\n", "n = len(eval_df)\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "model.to(device)\n", "model.eval()\n", "\n", "for idx, row in tqdm(eval_df.iterrows(), total=n):\n", " batch_data.append(row['question'])\n", " \n", " if len(batch_data) == batch_size or idx == n - 1:\n", " torch.cuda.empty_cache()\n", " \n", " encode_ids = tokenizer(batch_data, padding=True)\n", " input_ids, attention_mask = torch.LongTensor(encode_ids['input_ids']), torch.LongTensor(encode_ids['attention_mask'])\n", " \n", " outputs = model.generate(\n", " input_ids=input_ids.to(device),\n", " attention_mask=attention_mask.to(device),\n", " generation_config=generation_config,\n", " return_dict_in_generate=True,\n", " output_scores=True,\n", " )\n", "\n", " scores = torch.stack(outputs['scores'], dim=1)\n", " scores = torch.softmax(scores, dim=2)\n", " scores = scores[..., 0, choices_ids] #取第一个字符的ABCD概率\n", " choices_index = torch.argmax(scores, dim=1)\n", " \n", " for i in choices_index:\n", " batch_answers.append(choices[i])\n", "\n", " batch_data = []" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "eval_df.insert(loc=5, column='model_predict', value=batch_answers)\n", "val_df = eval_df.copy(deep=True)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "val_df['is_correct'] = val_df['model_predict'] == val_df['answer']\n", "val_df['is_correct'] = val_df['is_correct'].astype(pd.Int16Dtype())" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
subject_ensubject_zhcategoryquestionanswermodel_predictis_correct
0Accountant注册会计师Other请回答单选题,回答字母A、B、C、D即可。问题:\\n下列关于税法基本原则的表述中,不正确的是...DA0
1Accountant注册会计师Other请回答单选题,回答字母A、B、C、D即可。问题:\\n甲公司是国内一家领先的新媒体、通信及移动...CA0
2Accountant注册会计师Other请回答单选题,回答字母A、B、C、D即可。问题:\\n根据我国《印花税暂行条例》的规定,下列各...DA0
\n", "
" ], "text/plain": [ " subject_en subject_zh category \\\n", "0 Accountant 注册会计师 Other \n", "1 Accountant 注册会计师 Other \n", "2 Accountant 注册会计师 Other \n", "\n", " question answer model_predict \\\n", "0 请回答单选题,回答字母A、B、C、D即可。问题:\\n下列关于税法基本原则的表述中,不正确的是... D A \n", "1 请回答单选题,回答字母A、B、C、D即可。问题:\\n甲公司是国内一家领先的新媒体、通信及移动... C A \n", "2 请回答单选题,回答字母A、B、C、D即可。问题:\\n根据我国《印花税暂行条例》的规定,下列各... D A \n", "\n", " is_correct \n", "0 0 \n", "1 0 \n", "2 0 " ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val_df.head(3)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
is_correct
category
Humanities63
Other89
STEM89
Social Science72
\n", "
" ], "text/plain": [ " is_correct\n", "category \n", "Humanities 63\n", "Other 89\n", "STEM 89\n", "Social Science 72" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "final_df = val_df.groupby('category').sum('is_correct')\n", "final_df" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
is_correctquestion_countaccuracy
category
Humanities6325724.51%
Other8938423.18%
STEM8943020.70%
Social Science7227526.18%
\n", "
" ], "text/plain": [ " is_correct question_count accuracy\n", "category \n", "Humanities 63 257 24.51%\n", "Other 89 384 23.18%\n", "STEM 89 430 20.70%\n", "Social Science 72 275 26.18%" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "final_df['question_count'] = val_df.groupby('category').count()['question']\n", "final_df['accuracy'] = final_df['is_correct'] / final_df['question_count']\n", "final_df['accuracy'] = final_df['accuracy'] .apply(lambda x: format(x, '.2%'))\n", "final_df" ] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }