{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# coding=utf-8\n", "from typing import Dict\n", "import time \n", "import pandas as pd \n", "\n", "import torch\n", "from datasets import Dataset, load_dataset\n", "from transformers import PreTrainedTokenizerFast, Seq2SeqTrainer, DataCollatorForSeq2Seq,Seq2SeqTrainingArguments\n", "from transformers.generation.configuration_utils import GenerationConfig" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import sys, os\n", "root = os.path.realpath('.').replace('\\\\','/').split('/')[0: -2]\n", "root = '/'.join(root)\n", "if root not in sys.path:\n", " sys.path.append(root)\n", "\n", "from model.chat_model import TextToTextModel\n", "from config import SFTconfig, InferConfig, T5ModelConfig\n", "from utils.functions import get_T5_config\n", "\n", "os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def get_dataset(file: str, split: str, encode_fn: callable, encode_args: dict, cache_dir: str='.cache') -> Dataset:\n", " \"\"\"\n", " Load a dataset\n", " \"\"\"\n", " dataset = load_dataset('json', data_files=file, split=split, cache_dir=cache_dir)\n", "\n", " def merge_prompt_and_responses(sample: dict) -> Dict[str, str]:\n", " # add an eos token note that end of sentence, using in generate.\n", " prompt = encode_fn(f\"{sample['prompt']}[EOS]\", **encode_args)\n", " response = encode_fn(f\"{sample['response']}[EOS]\", **encode_args)\n", " return {\n", " 'input_ids': prompt.input_ids,\n", " 'labels': response.input_ids,\n", " }\n", "\n", " dataset = dataset.map(merge_prompt_and_responses)\n", " return dataset" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def sft_train(config: SFTconfig) -> None:\n", "\n", " # step 1. 加载tokenizer\n", " tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_dir)\n", " \n", " # step 2. 加载预训练模型\n", " model = None\n", " if os.path.isdir(config.finetune_from_ckp_file):\n", " # 传入文件夹则 from_pretrained\n", " model = TextToTextModel.from_pretrained(config.finetune_from_ckp_file)\n", " else:\n", " # load_state_dict\n", " t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)\n", " model = TextToTextModel(t5_config)\n", " model.load_state_dict(torch.load(config.finetune_from_ckp_file, map_location='cpu')) # set cpu for no exception\n", " \n", " # Step 4: Load the dataset\n", " encode_args = {\n", " 'truncation': False,\n", " 'padding': 'max_length',\n", " }\n", "\n", " dataset = get_dataset(file=config.sft_train_file, encode_fn=tokenizer.encode_plus, encode_args=encode_args, split=\"train\")\n", "\n", " # Step 5: Define the training arguments\n", " # T5属于sequence to sequence模型,故要使用Seq2SeqTrainingArguments、DataCollatorForSeq2Seq、Seq2SeqTrainer\n", " # huggingface官网的sft工具适用于language model/LM模型\n", " generation_config = GenerationConfig()\n", " generation_config.remove_invalid_values = True\n", " generation_config.eos_token_id = tokenizer.eos_token_id\n", " generation_config.pad_token_id = tokenizer.pad_token_id\n", " generation_config.decoder_start_token_id = tokenizer.pad_token_id\n", " generation_config.max_new_tokens = 320\n", " generation_config.repetition_penalty = 1.5\n", " generation_config.num_beams = 1 # greedy search\n", " generation_config.do_sample = False # greedy search\n", "\n", " training_args = Seq2SeqTrainingArguments(\n", " output_dir=config.output_dir,\n", " per_device_train_batch_size=config.batch_size,\n", " auto_find_batch_size=True, # 防止OOM\n", " gradient_accumulation_steps=config.gradient_accumulation_steps,\n", " learning_rate=config.learning_rate,\n", " logging_steps=config.logging_steps,\n", " num_train_epochs=config.num_train_epochs,\n", " optim=\"adafactor\",\n", " report_to='tensorboard',\n", " log_level='info',\n", " save_steps=config.save_steps,\n", " save_total_limit=3,\n", " fp16=config.fp16,\n", " logging_first_step=config.logging_first_step,\n", " warmup_steps=config.warmup_steps,\n", " seed=config.seed,\n", " generation_config=generation_config,\n", " )\n", "\n", " # step 6: init a collator\n", " collator = DataCollatorForSeq2Seq(tokenizer, max_length=config.max_seq_len)\n", " \n", " # Step 7: Define the Trainer\n", " trainer = Seq2SeqTrainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=dataset,\n", " eval_dataset=dataset,\n", " tokenizer=tokenizer,\n", " data_collator=collator,\n", " )\n", "\n", " # step 8: train\n", " trainer.train(\n", " # resume_from_checkpoint=True\n", " )\n", "\n", " loss_log = pd.DataFrame(trainer.state.log_history)\n", " log_dir = './logs'\n", " if not os.path.exists(log_dir):\n", " os.mkdir(log_dir)\n", " loss_log.to_csv(f\"{log_dir}/ie_task_finetune_log_{time.strftime('%Y%m%d-%H%M')}.csv\")\n", "\n", " # Step 9: Save the model\n", " trainer.save_model(config.output_dir)\n", "\n", " return trainer\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "config = SFTconfig()\n", "config.finetune_from_ckp_file = InferConfig().model_dir\n", "config.sft_train_file = './data/my_train.json'\n", "config.output_dir = './model_save/ie_task'\n", "config.max_seq_len = 512\n", "config.batch_size = 16\n", "config.gradient_accumulation_steps = 4\n", "config.logging_steps = 20\n", "config.learning_rate = 5e-5\n", "config.num_train_epochs = 6\n", "config.save_steps = 3000\n", "config.warmup_steps = 1000\n", "print(config)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trainer = sft_train(config)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import sys, os\n", "root = os.path.realpath('.').replace('\\\\','/').split('/')[0: -2]\n", "root = '/'.join(root)\n", "if root not in sys.path:\n", " sys.path.append(root)\n", "import ujson, torch\n", "from rich import progress\n", "\n", "from model.infer import ChatBot\n", "from config import InferConfig\n", "from utils.functions import f1_p_r_compute\n", "inf_conf = InferConfig()\n", "inf_conf.model_dir = './model_save/ie_task/'\n", "bot = ChatBot(infer_config=inf_conf)\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[(傅淑云,民族,汉族),(傅淑云,出生地,上海),(傅淑云,出生日期,1915年)]\n" ] } ], "source": [ "ret = bot.chat('请抽取出给定句子中的所有三元组。给定句子:傅淑云,女,汉族,1915年出生,上海人')\n", "print(ret)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[('傅淑云', '民族', '汉族'), ('傅淑云', '出生地', '上海'), ('傅淑云', '出生日期', '1915年')]\n" ] } ], "source": [ "def text_to_spo_list(sentence: str) -> str:\n", " '''\n", " 将输出转换为SPO列表,时间复杂度: O(n)\n", " '''\n", " spo_list = []\n", " sentence = sentence.replace(',',',').replace('(','(').replace(')', ')') # 符号标准化\n", "\n", " cur_txt, cur_spo, started = '', [], False\n", " for i, char in enumerate(sentence):\n", " if char not in '[](),':\n", " cur_txt += char\n", " elif char == '(':\n", " started = True\n", " cur_txt, cur_spo = '' , []\n", " elif char == ',' and started and len(cur_txt) > 0 and len(cur_spo) < 3:\n", " cur_spo.append(cur_txt)\n", " cur_txt = ''\n", " elif char == ')' and started and len(cur_txt) > 0 and len(cur_spo) == 2:\n", " cur_spo.append(cur_txt)\n", " spo_list.append(tuple(cur_spo))\n", " cur_spo = []\n", " cur_txt = ''\n", " started = False\n", " return spo_list\n", "print(text_to_spo_list(ret))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "test_data = []\n", "with open('./data/test.json', 'r', encoding='utf-8') as f:\n", " test_data = ujson.load(f)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'prompt': '请抽取出给定句子中的所有三元组。给定句子:查尔斯·阿兰基斯(charles aránguiz),1989年4月17日出生于智利圣地亚哥,智利职业足球运动员,司职中场,效力于德国足球甲级联赛勒沃库森足球俱乐部',\n", " 'response': '[(查尔斯·阿兰基斯,出生地,圣地亚哥),(查尔斯·阿兰基斯,出生日期,1989年4月17日)]'},\n", " {'prompt': '请抽取出给定句子中的所有三元组。给定句子:《离开》是由张宇谱曲,演唱',\n", " 'response': '[(离开,歌手,张宇),(离开,作曲,张宇)]'}]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_data[0:2]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "bca40f71fcc34dda95eb97a6f48fea0c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "prompt_buffer, batch_size, n = [], 32, len(test_data)\n", "traget_spo_list, predict_spo_list = [], []\n", "for i, item in progress.track(enumerate(test_data), total=n):\n", " prompt_buffer.append(item['prompt'])\n", " traget_spo_list.append(\n", " text_to_spo_list(item['response'])\n", " )\n", "\n", " if len(prompt_buffer) == batch_size or i == n - 1:\n", " torch.cuda.empty_cache()\n", " model_pred = bot.chat(prompt_buffer)\n", " model_pred = [text_to_spo_list(item) for item in model_pred]\n", " predict_spo_list.extend(model_pred)\n", " prompt_buffer = []" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[('查尔斯·阿兰基斯', '出生地', '圣地亚哥'), ('查尔斯·阿兰基斯', '出生日期', '1989年4月17日')], [('离开', '歌手', '张宇'), ('离开', '作曲', '张宇')]] \n", "\n", "\n", " [[('查尔斯·阿兰基斯', '国籍', '智利'), ('查尔斯·阿兰基斯', '出生地', '智利圣地亚哥'), ('查尔斯·阿兰基斯', '出生日期', '1989年4月17日')], [('离开', '歌手', '张宇'), ('离开', '作曲', '张宇')]]\n" ] } ], "source": [ "print(traget_spo_list[0:2], '\\n\\n\\n',predict_spo_list[0:2])" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "21636 21636\n" ] } ], "source": [ "print(len(predict_spo_list), len(traget_spo_list))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "f1: 0.74, precision: 0.75, recall: 0.73\n" ] } ], "source": [ "f1, p, r = f1_p_r_compute(predict_spo_list, traget_spo_list)\n", "print(f\"f1: {f1:.2f}, precision: {p:.2f}, recall: {r:.2f}\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['你好,有什么我可以帮你的吗?',\n", " '[(江苏省赣榆海洋经济开发区,成立日期,2003年1月28日)]',\n", " '南方地区气候干燥,气候寒冷,冬季寒冷,夏季炎热,冬季寒冷的原因很多,可能是由于全球气候变暖导致的。\\n南方气候的变化可以引起天气的变化,例如气温下降、降雨增多、冷空气南下等。南方气候的变化可以促进气候的稳定,有利于经济发展和经济繁荣。\\n此外,南方地区的气候也可能受到自然灾害的影响,例如台风、台风、暴雨等,这些自然灾害会对南方气候产生影响。\\n总之,南方气候的变化是一个复杂的过程,需要综合考虑多方面因素,才能应对。']" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 测试一下对话能力\n", "bot.chat(['你好', '请抽取出给定句子中的所有三元组。给定句子:江苏省赣榆海洋经济开发区位于赣榆区青口镇临海而建,2003年1月28日,经江苏省人民政府《关于同意设立赣榆海洋经济开发区的批复》(苏政复〔2003〕14号)文件批准为全省首家省级海洋经济开发区,','如何看待最近南方天气突然变冷?'])" ] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }