{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Read before you start:\n", "\n", "This notebook gives a test demo for all the LLMs we trained during phase2: Multi-Task Instruction Tuning.\n", "\n", "- LLMs: Llama2, Falcon, BLOOM, ChatGLM2, Qwen, MPT\n", "- Tasks: Sentiment Analysis, Headline Classification, Named Entity Extraction, Relation Extraction\n", "\n", "All the models & instruction data samples used are also available in our huggingface organization. [https://huggingface.co/FinGPT]\n", "\n", "Models trained in phase1&3 are not provided, as MT-models cover most of their capacity. Feel free to train your own models based on the tasks you want.\n", "\n", "Due to the limited diversity of the financial tasks and datasets we used, models might not response correctly to out-of-scope instructions. We'll delve into the generalization ability more in our future works." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# First choose to load data/model from huggingface or local space\n", "\n", "FROM_REMOTE = False" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2023-10-15 20:44:54,084] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" ] } ], "source": [ "from transformers import AutoTokenizer, AutoModelForCausalLM\n", "from peft import PeftModel\n", "from utils import *" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import logging\n", "# Suppress Warnings during inference\n", "logging.getLogger(\"transformers\").setLevel(logging.ERROR)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "demo_tasks = [\n", " 'Financial Sentiment Analysis',\n", " 'Financial Relation Extraction',\n", " 'Financial Headline Classification',\n", " 'Financial Named Entity Recognition',\n", "]\n", "demo_inputs = [\n", " \"Glaxo's ViiV Healthcare Signs China Manufacturing Deal With Desano\",\n", " \"Wednesday, July 8, 2015 10:30AM IST (5:00AM GMT) Rimini Street Comment on Oracle Litigation Las Vegas, United States Rimini Street, Inc., the leading independent provider of enterprise software support for SAP AG’s (NYSE:SAP) Business Suite and BusinessObjects software and Oracle Corporation’s (NYSE:ORCL) Siebel , PeopleSoft , JD Edwards , E-Business Suite , Oracle Database , Hyperion and Oracle Retail software, today issued a statement on the Oracle litigation.\",\n", " 'april gold down 20 cents to settle at $1,116.10/oz',\n", " 'Subject to the terms and conditions of this Agreement , Bank agrees to lend to Borrower , from time to time prior to the Commitment Termination Date , equipment advances ( each an \" Equipment Advance \" and collectively the \" Equipment Advances \").',\n", "]\n", "demo_instructions = [\n", " 'What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.',\n", " 'Given phrases that describe the relationship between two words/phrases as options, extract the word/phrase pair and the corresponding lexical relationship between them from the input text. The output format should be \"relation1: word1, word2; relation2: word3, word4\". Options: product/material produced, manufacturer, distributed by, industry, position held, original broadcaster, owned by, founded by, distribution format, headquarters location, stock exchange, currency, parent organization, chief executive officer, director/manager, owner of, operator, member of, employer, chairperson, platform, subsidiary, legal form, publisher, developer, brand, business division, location of formation, creator.',\n", " 'Does the news headline talk about price in the past? Please choose an answer from {Yes/No}.',\n", " 'Please extract entities and their types from the input sentence, entity types should be chosen from {person/organization/location}.',\n", "]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def load_model(base_model, peft_model, from_remote=False):\n", " \n", " model_name = parse_model_name(base_model, from_remote)\n", "\n", " model = AutoModelForCausalLM.from_pretrained(\n", " model_name, trust_remote_code=True, \n", " device_map=\"auto\",\n", " )\n", " model.model_parallel = True\n", "\n", " tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n", " \n", " tokenizer.padding_side = \"left\"\n", " if base_model == 'qwen':\n", " tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids('<|endoftext|>')\n", " tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids('<|extra_0|>')\n", " if not tokenizer.pad_token or tokenizer.pad_token_id == tokenizer.eos_token_id:\n", " tokenizer.add_special_tokens({'pad_token': '[PAD]'})\n", " model.resize_token_embeddings(len(tokenizer))\n", " \n", " model = PeftModel.from_pretrained(model, peft_model)\n", " model = model.eval()\n", " return model, tokenizer\n", "\n", "\n", "def test_demo(model, tokenizer):\n", "\n", " for task_name, input, instruction in zip(demo_tasks, demo_inputs, demo_instructions):\n", " prompt = 'Instruction: {instruction}\\nInput: {input}\\nAnswer: '.format(\n", " input=input, \n", " instruction=instruction\n", " )\n", " inputs = tokenizer(\n", " prompt, return_tensors='pt',\n", " padding=True, max_length=512,\n", " return_token_type_ids=False\n", " )\n", " inputs = {key: value.to(model.device) for key, value in inputs.items()}\n", " res = model.generate(\n", " **inputs, max_length=512, do_sample=False,\n", " eos_token_id=tokenizer.eos_token_id\n", " )\n", " output = tokenizer.decode(res[0], skip_special_tokens=True)\n", " print(f\"\\n==== {task_name} ====\\n\")\n", " print(output)\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Llama2-7B" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "application/json": { "ascii": false, "bar_format": null, "colour": null, "elapsed": 0.006228446960449219, "initial": 0, "n": 0, "ncols": null, "nrows": null, "postfix": null, "prefix": "Loading checkpoint shards", "rate": null, "total": 2, "unit": "it", "unit_divisor": 1000, "unit_scale": false }, "application/vnd.jupyter.widget-view+json": { "model_id": "0d58aff745fb486780792c86384fe702", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00