{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "ed9bad4c-b546-43cd-b11d-39da03e3b2fc", "metadata": { "execution": { "iopub.execute_input": "2023-11-25T03:08:25.222203Z", "iopub.status.busy": "2023-11-25T03:08:25.221934Z", "iopub.status.idle": "2023-11-25T03:09:12.123983Z", "shell.execute_reply": "2023-11-25T03:09:12.123211Z", "shell.execute_reply.started": "2023-11-25T03:08:25.222184Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cpu:\n", "Collecting pandas\n", " Downloading pandas-2.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.4 MB)\n", " |████████████████████████████████| 12.4 MB 9.3 MB/s \n", "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /opt/pytorch/lib/python3.8/site-packages (from pandas) (2.8.2)\n", "Requirement already satisfied: numpy>=1.20.3 in /opt/pytorch/lib/python3.8/site-packages (from pandas) (1.21.6)\n", "Requirement already satisfied: pytz>=2020.1 in /opt/pytorch/lib/python3.8/site-packages (from pandas) (2023.3)\n", "Collecting tzdata>=2022.1\n", " Downloading tzdata-2023.3-py2.py3-none-any.whl (341 kB)\n", " |████████████████████████████████| 341 kB 89.1 MB/s \n", "\u001b[?25hRequirement already satisfied: six>=1.5 in /opt/pytorch/lib/python3.8/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n", "Installing collected packages: tzdata, pandas\n", "Successfully installed pandas-2.0.3 tzdata-2023.3\n", "Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cpu:\n", "Collecting scikit-learn\n", " Downloading scikit_learn-1.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.1 MB)\n", " |████████████████████████████████| 11.1 MB 9.1 MB/s \n", "\u001b[?25hCollecting threadpoolctl>=2.0.0\n", " Downloading threadpoolctl-3.2.0-py3-none-any.whl (15 kB)\n", "Collecting joblib>=1.1.1\n", " Downloading joblib-1.3.2-py3-none-any.whl (302 kB)\n", " |████████████████████████████████| 302 kB 71.4 MB/s \n", "\u001b[?25hCollecting scipy>=1.5.0\n", " Downloading scipy-1.10.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (34.5 MB)\n", " |████████████████████████████████| 34.5 MB 70.3 MB/s \n", "\u001b[?25hRequirement already satisfied: numpy<2.0,>=1.17.3 in /opt/pytorch/lib/python3.8/site-packages (from scikit-learn) (1.21.6)\n", "Installing collected packages: threadpoolctl, scipy, joblib, scikit-learn\n", "Successfully installed joblib-1.3.2 scikit-learn-1.3.2 scipy-1.10.1 threadpoolctl-3.2.0\n", "Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cpu:\n", "Collecting datasets\n", " Downloading datasets-2.15.0-py3-none-any.whl (521 kB)\n", " |████████████████████████████████| 521 kB 8.7 MB/s \n", "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /opt/pytorch/lib/python3.8/site-packages (from datasets) (1.21.6)\n", "Collecting fsspec[http]<=2023.10.0,>=2023.1.0\n", " Downloading fsspec-2023.10.0-py3-none-any.whl (166 kB)\n", " |████████████████████████████████| 166 kB 31.8 MB/s \n", "\u001b[?25hCollecting pyarrow-hotfix\n", " Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)\n", "Collecting dill<0.3.8,>=0.3.0\n", " Downloading dill-0.3.7-py3-none-any.whl (115 kB)\n", " |████████████████████████████████| 115 kB 29.8 MB/s \n", "\u001b[?25hRequirement already satisfied: requests>=2.19.0 in /opt/pytorch/lib/python3.8/site-packages (from datasets) (2.31.0)\n", "Requirement already satisfied: tqdm>=4.62.1 in /opt/pytorch/lib/python3.8/site-packages (from datasets) (4.65.0)\n", "Requirement already satisfied: pandas in /opt/pytorch/lib/python3.8/site-packages (from datasets) (2.0.3)\n", "Collecting huggingface-hub>=0.18.0\n", " Downloading huggingface_hub-0.19.4-py3-none-any.whl (311 kB)\n", " |████████████████████████████████| 311 kB 35.9 MB/s \n", "\u001b[?25hRequirement already satisfied: packaging in /opt/pytorch/lib/python3.8/site-packages (from datasets) (23.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /opt/pytorch/lib/python3.8/site-packages (from datasets) (5.4.1)\n", "Collecting xxhash\n", " Downloading xxhash-3.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n", " |████████████████████████████████| 194 kB 51.7 MB/s \n", "\u001b[?25hCollecting pyarrow>=8.0.0\n", " Downloading pyarrow-14.0.1-cp38-cp38-manylinux_2_28_x86_64.whl (38.1 MB)\n", " |████████████████████████████████| 38.1 MB 88.8 MB/s \n", "\u001b[?25hCollecting multiprocess\n", " Downloading multiprocess-0.70.15-py38-none-any.whl (132 kB)\n", " |████████████████████████████████| 132 kB 63.6 MB/s \n", "\u001b[?25hCollecting aiohttp\n", " Downloading aiohttp-3.9.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n", " |████████████████████████████████| 1.3 MB 29.4 MB/s \n", "\u001b[?25hCollecting async-timeout<5.0,>=4.0\n", " Downloading async_timeout-4.0.3-py3-none-any.whl (5.7 kB)\n", "Collecting multidict<7.0,>=4.5\n", " Downloading multidict-6.0.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (121 kB)\n", " |████████████████████████████████| 121 kB 68.5 MB/s \n", "\u001b[?25hCollecting aiosignal>=1.1.2\n", " Downloading aiosignal-1.3.1-py3-none-any.whl (7.6 kB)\n", "Collecting frozenlist>=1.1.1\n", " Downloading frozenlist-1.4.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (220 kB)\n", " |████████████████████████████████| 220 kB 73.2 MB/s \n", "\u001b[?25hRequirement already satisfied: attrs>=17.3.0 in /opt/pytorch/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\n", "Collecting yarl<2.0,>=1.0\n", " Downloading yarl-1.9.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (307 kB)\n", " |████████████████████████████████| 307 kB 17.4 MB/s \n", "\u001b[?25hRequirement already satisfied: filelock in /opt/pytorch/lib/python3.8/site-packages (from huggingface-hub>=0.18.0->datasets) (3.12.2)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/pytorch/lib/python3.8/site-packages (from huggingface-hub>=0.18.0->datasets) (4.7.1)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/pytorch/lib/python3.8/site-packages (from requests>=2.19.0->datasets) (1.26.16)\n", "Requirement already satisfied: idna<4,>=2.5 in /opt/pytorch/lib/python3.8/site-packages (from requests>=2.19.0->datasets) (3.4)\n", "Requirement already satisfied: certifi>=2017.4.17 in /opt/pytorch/lib/python3.8/site-packages (from requests>=2.19.0->datasets) (2023.5.7)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/pytorch/lib/python3.8/site-packages (from requests>=2.19.0->datasets) (3.1.0)\n", "Requirement already satisfied: pytz>=2020.1 in /opt/pytorch/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /opt/pytorch/lib/python3.8/site-packages (from pandas->datasets) (2.8.2)\n", "Requirement already satisfied: tzdata>=2022.1 in /opt/pytorch/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n", "Requirement already satisfied: six>=1.5 in /opt/pytorch/lib/python3.8/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n", "Installing collected packages: multidict, frozenlist, yarl, async-timeout, aiosignal, fsspec, dill, aiohttp, xxhash, pyarrow-hotfix, pyarrow, multiprocess, huggingface-hub, datasets\n", "Successfully installed aiohttp-3.9.0 aiosignal-1.3.1 async-timeout-4.0.3 datasets-2.15.0 dill-0.3.7 frozenlist-1.4.0 fsspec-2023.10.0 huggingface-hub-0.19.4 multidict-6.0.4 multiprocess-0.70.15 pyarrow-14.0.1 pyarrow-hotfix-0.6 xxhash-3.4.1 yarl-1.9.3\n", "Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cpu:\n", "Collecting transformers\n", " Downloading transformers-4.35.2-py3-none-any.whl (7.9 MB)\n", " |████████████████████████████████| 7.9 MB 12.9 MB/s \n", "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /opt/pytorch/lib/python3.8/site-packages (from transformers) (1.21.6)\n", "Requirement already satisfied: pyyaml>=5.1 in /opt/pytorch/lib/python3.8/site-packages (from transformers) (5.4.1)\n", "Collecting regex!=2019.12.17\n", " Downloading regex-2023.10.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (776 kB)\n", " |████████████████████████████████| 776 kB 48.5 MB/s \n", "\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /opt/pytorch/lib/python3.8/site-packages (from transformers) (4.65.0)\n", "Collecting safetensors>=0.3.1\n", " Downloading safetensors-0.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n", " |████████████████████████████████| 1.3 MB 87.5 MB/s \n", "\u001b[?25hRequirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /opt/pytorch/lib/python3.8/site-packages (from transformers) (0.19.4)\n", "Requirement already satisfied: packaging>=20.0 in /opt/pytorch/lib/python3.8/site-packages (from transformers) (23.1)\n", "Collecting tokenizers<0.19,>=0.14\n", " Downloading tokenizers-0.15.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.8 MB)\n", " |████████████████████████████████| 3.8 MB 123.4 MB/s \n", "\u001b[?25hRequirement already satisfied: requests in /opt/pytorch/lib/python3.8/site-packages (from transformers) (2.31.0)\n", "Requirement already satisfied: filelock in /opt/pytorch/lib/python3.8/site-packages (from transformers) (3.12.2)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/pytorch/lib/python3.8/site-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (4.7.1)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /opt/pytorch/lib/python3.8/site-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (2023.10.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /opt/pytorch/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/pytorch/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/pytorch/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\n", "Requirement already satisfied: idna<4,>=2.5 in /opt/pytorch/lib/python3.8/site-packages (from requests->transformers) (3.4)\n", "Installing collected packages: tokenizers, safetensors, regex, transformers\n", "Successfully installed regex-2023.10.3 safetensors-0.4.0 tokenizers-0.15.0 transformers-4.35.2\n", "Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cpu:\n", "Requirement already satisfied: transformers[torch] in /opt/pytorch/lib/python3.8/site-packages (4.35.2)\n", "Requirement already satisfied: tokenizers<0.19,>=0.14 in /opt/pytorch/lib/python3.8/site-packages (from transformers[torch]) (0.15.0)\n", "Requirement already satisfied: numpy>=1.17 in /opt/pytorch/lib/python3.8/site-packages (from transformers[torch]) (1.21.6)\n", "Requirement already satisfied: requests in /opt/pytorch/lib/python3.8/site-packages (from transformers[torch]) (2.31.0)\n", "Requirement already satisfied: filelock in /opt/pytorch/lib/python3.8/site-packages (from transformers[torch]) (3.12.2)\n", "Requirement already satisfied: tqdm>=4.27 in /opt/pytorch/lib/python3.8/site-packages (from transformers[torch]) (4.65.0)\n", "Requirement already satisfied: regex!=2019.12.17 in /opt/pytorch/lib/python3.8/site-packages (from transformers[torch]) (2023.10.3)\n", "Requirement already satisfied: packaging>=20.0 in /opt/pytorch/lib/python3.8/site-packages (from transformers[torch]) (23.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /opt/pytorch/lib/python3.8/site-packages (from transformers[torch]) (5.4.1)\n", "Requirement already satisfied: safetensors>=0.3.1 in /opt/pytorch/lib/python3.8/site-packages (from transformers[torch]) (0.4.0)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /opt/pytorch/lib/python3.8/site-packages (from transformers[torch]) (0.19.4)\n", "Collecting accelerate>=0.20.3\n", " Downloading accelerate-0.24.1-py3-none-any.whl (261 kB)\n", " |████████████████████████████████| 261 kB 9.6 MB/s \n", "\u001b[?25hRequirement already satisfied: torch!=1.12.0,>=1.10 in /opt/pytorch/lib/python3.8/site-packages (from transformers[torch]) (2.0.1+cpu)\n", "Requirement already satisfied: psutil in /opt/pytorch/lib/python3.8/site-packages (from accelerate>=0.20.3->transformers[torch]) (5.9.5)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /opt/pytorch/lib/python3.8/site-packages (from huggingface-hub<1.0,>=0.16.4->transformers[torch]) (2023.10.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/pytorch/lib/python3.8/site-packages (from huggingface-hub<1.0,>=0.16.4->transformers[torch]) (4.7.1)\n", "Requirement already satisfied: sympy in /opt/pytorch/lib/python3.8/site-packages (from torch!=1.12.0,>=1.10->transformers[torch]) (1.12)\n", "Requirement already satisfied: networkx in /opt/pytorch/lib/python3.8/site-packages (from torch!=1.12.0,>=1.10->transformers[torch]) (3.1)\n", "Requirement already satisfied: jinja2 in /opt/pytorch/lib/python3.8/site-packages (from torch!=1.12.0,>=1.10->transformers[torch]) (3.1.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/pytorch/lib/python3.8/site-packages (from requests->transformers[torch]) (3.1.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /opt/pytorch/lib/python3.8/site-packages (from requests->transformers[torch]) (2023.5.7)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/pytorch/lib/python3.8/site-packages (from requests->transformers[torch]) (1.26.16)\n", "Requirement already satisfied: idna<4,>=2.5 in /opt/pytorch/lib/python3.8/site-packages (from requests->transformers[torch]) (3.4)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /opt/pytorch/lib/python3.8/site-packages (from jinja2->torch!=1.12.0,>=1.10->transformers[torch]) (2.1.3)\n", "Requirement already satisfied: mpmath>=0.19 in /opt/pytorch/lib/python3.8/site-packages (from sympy->torch!=1.12.0,>=1.10->transformers[torch]) (1.3.0)\n", "Installing collected packages: accelerate\n", "Successfully installed accelerate-0.24.1\n", "Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cpu:\n", "Requirement already satisfied: accelerate in /opt/pytorch/lib/python3.8/site-packages (0.24.1)\n", "Requirement already satisfied: psutil in /opt/pytorch/lib/python3.8/site-packages (from accelerate) (5.9.5)\n", "Requirement already satisfied: huggingface-hub in /opt/pytorch/lib/python3.8/site-packages (from accelerate) (0.19.4)\n", "Requirement already satisfied: torch>=1.10.0 in /opt/pytorch/lib/python3.8/site-packages (from accelerate) (2.0.1+cpu)\n", "Requirement already satisfied: packaging>=20.0 in /opt/pytorch/lib/python3.8/site-packages (from accelerate) (23.1)\n", "Requirement already satisfied: pyyaml in /opt/pytorch/lib/python3.8/site-packages (from accelerate) (5.4.1)\n", "Requirement already satisfied: numpy>=1.17 in /opt/pytorch/lib/python3.8/site-packages (from accelerate) (1.21.6)\n", "Requirement already satisfied: sympy in /opt/pytorch/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (1.12)\n", "Requirement already satisfied: networkx in /opt/pytorch/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (3.1)\n", "Requirement already satisfied: jinja2 in /opt/pytorch/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (3.1.2)\n", "Requirement already satisfied: filelock in /opt/pytorch/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (3.12.2)\n", "Requirement already satisfied: typing-extensions in /opt/pytorch/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.7.1)\n", "Requirement already satisfied: requests in /opt/pytorch/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /opt/pytorch/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\n", "Requirement already satisfied: tqdm>=4.42.1 in /opt/pytorch/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /opt/pytorch/lib/python3.8/site-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.3)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/pytorch/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\n", "Requirement already satisfied: idna<4,>=2.5 in /opt/pytorch/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/pytorch/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /opt/pytorch/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\n", "Requirement already satisfied: mpmath>=0.19 in /opt/pytorch/lib/python3.8/site-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n" ] } ], "source": [ "! pip install pandas\n", "! pip install scikit-learn\n", "! pip install datasets\n", "! pip install transformers\n", "! pip install transformers[torch]\n", "! pip install accelerate -U" ] }, { "cell_type": "code", "execution_count": 2, "id": "fed20656-1f48-40d6-93e2-53aec7de522e", "metadata": { "execution": { "iopub.execute_input": "2023-11-25T03:09:12.125910Z", "iopub.status.busy": "2023-11-25T03:09:12.125577Z", "iopub.status.idle": "2023-11-25T03:09:18.835267Z", "shell.execute_reply": "2023-11-25T03:09:18.834607Z", "shell.execute_reply.started": "2023-11-25T03:09:12.125891Z" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "bcd589e60cf34ea9a3336f439162493e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "config.json: 0%| | 0.00/768 [00:00 token (equiv. to [CLS])\n", " x = self.dropout(x)\n", " x = self.dense(x)\n", " x = torch.nn.functional.relu(x)\n", " x = self.dropout(x)\n", " x = self.out_proj(x)\n", " return x\n", "\n", "def preprocess_data(df):\n", " ## rename columns\n", " df = df.rename(columns={'Comment': 'text', 'Emotion': 'label'})\n", "\n", " ## remove rows with missing values\n", " df = df.dropna()\n", " df['text'] = df['text'].str.replace('\\t', ' ') # Remove extra spaces - this line replaces any occurrence of two or more spaces with a single spac\n", " df['text'] = df['text'].str.replace(' +', ' ', regex=True) # Remove extra spaces - this line replaces any occurrence of two or more spaces with a single space\n", " df['text'] = df['text'].str.strip() # Remove extra spaces - this line replaces any occurrence of two or more spaces with a single space\n", "\n", " df['label'] = df['label'].str.replace('\\t', ' ') # Remove extra spaces - this line replaces any occurrence of two or more spaces with a single spac\n", " df['label'] = df['label'].str.replace(' +', ' ', regex=True) # Remove extra spaces - this line replaces any occurrence of two or more spaces with a single space\n", " df['label'] = df['label'].str.strip() # Remove extra spaces - this line replaces any occurrence of two or more spaces with a single space \n", "\n", " return df\n", "\n", "def encode_label(df):\n", " le = LabelEncoder()\n", " df['label'] = le.fit_transform(df['label'])\n", " label_mapping = {label: index for index, label in enumerate(le.classes_)}\n", " df['label'].map(label_mapping)\n", " return df\n", "\n", "def generate_dataset(df, test_size=0.2):\n", " \"\"\"\n", " Convert to transformers dataset and split into train and test\n", " \"\"\"\n", " dataset = Dataset.from_pandas(df)\n", " ds = dataset.train_test_split(test_size=test_size)\n", " return ds\n", "\n", "def tokenize(batch):\n", " return tokenizer(batch['text'], padding='max_length', truncation=True)\n", "\n", "\n", "def compute_metrics(pred):\n", " from sklearn.metrics import accuracy_score, precision_recall_fscore_support\n", " labels = pred.label_ids\n", " preds = pred.predictions.argmax(-1)\n", " precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')\n", " acc = accuracy_score(labels, preds)\n", " return {\n", " 'accuracy': acc,\n", " 'f1': f1,\n", " 'precision': precision,\n", " 'recall': recall\n", " }\n", "\n", "# Define model and training arguments\n", "model_name = \"cardiffnlp/twitter-roberta-base-emotion\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "config = RobertaConfig.from_pretrained(model_name, num_labels=3) # Set the number of labels to 3\n", "model = RobertaForSequenceClassification.from_pretrained(model_name, config=config, ignore_mismatched_sizes=True)\n", "model.classifier = NewClassificationHead(config)\n", "\n", "df = pd.read_csv('Emotion_classify_Data.csv')\n", "df = preprocess_data(df)\n", "df = encode_label(df)\n", "ds = generate_dataset(df)\n", "ds = ds.map(tokenize, batched=True)\n", "\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "f3dd5334-f8b4-4f0d-b696-939f2d5174ba", "metadata": { "execution": { "iopub.execute_input": "2023-11-25T03:09:18.836520Z", "iopub.status.busy": "2023-11-25T03:09:18.836241Z", "iopub.status.idle": "2023-11-25T03:09:18.845692Z", "shell.execute_reply": "2023-11-25T03:09:18.844909Z", "shell.execute_reply.started": "2023-11-25T03:09:18.836502Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/pytorch/lib/python3.8/site-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n" ] } ], "source": [ "# Freeze all layers first\n", "for param in model.parameters():\n", " param.requires_grad = False\n", "\n", "# Unfreeze the classifier layer\n", "for param in model.classifier.parameters():\n", " param.requires_grad = True\n", "\n", "\n", "# Define different learning rates\n", "head_lr = 3e-4 # Higher learning rate for the head\n", "base_lr = head_lr/5 # Lower learning rate for the base layers\n", "\n", "# Group parameters and set learning rates\n", "optimizer_grouped_parameters = [\n", " {'params': model.classifier.parameters(), 'lr': head_lr},\n", " {'params': [p for n, p in model.named_parameters() if 'classifier' not in n], 'lr': base_lr}\n", "]\n", "\n", "optimizer = AdamW(optimizer_grouped_parameters)" ] }, { "cell_type": "code", "execution_count": 4, "id": "882c5342-a82a-4e5a-b0ad-eaaa4978831f", "metadata": { "execution": { "iopub.execute_input": "2023-11-25T03:09:18.847637Z", "iopub.status.busy": "2023-11-25T03:09:18.847285Z", "iopub.status.idle": "2023-11-25T03:09:18.862687Z", "shell.execute_reply": "2023-11-25T03:09:18.862118Z", "shell.execute_reply.started": "2023-11-25T03:09:18.847619Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" ] } ], "source": [ "training_args = TrainingArguments(\n", " output_dir='./results', \n", " num_train_epochs=10, \n", " per_device_train_batch_size=16, \n", " per_device_eval_batch_size=64, \n", " warmup_steps=500, \n", " weight_decay=0.01, \n", " logging_dir='./logs',\n", " save_strategy=\"no\",\n", ")\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=ds['train'],\n", " eval_dataset=ds['test'],\n", " tokenizer=tokenizer,\n", " optimizers=(optimizer, None), # No need to pass a learning rate scheduler if you're managing learning rates manually,\n", " compute_metrics=compute_metrics\n", ")" ] }, { "cell_type": "code", "execution_count": 5, "id": "19f8b2f1-d03b-42c2-a0a1-2475f2dfde37", "metadata": { "execution": { "iopub.execute_input": "2023-11-25T03:09:18.864992Z", "iopub.status.busy": "2023-11-25T03:09:18.864819Z", "iopub.status.idle": "2023-11-25T03:17:56.086914Z", "shell.execute_reply": "2023-11-25T03:17:56.085959Z", "shell.execute_reply.started": "2023-11-25T03:09:18.864977Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [2970/2970 08:36, Epoch 10/10]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
5000.678100
10000.537700
15000.514900
20000.474500
25000.450500

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=2970, training_loss=0.516797270598235, metrics={'train_runtime': 517.0884, 'train_samples_per_second': 91.841, 'train_steps_per_second': 5.744, 'total_flos': 1128914327325078.0, 'train_loss': 0.516797270598235, 'epoch': 10.0})" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ " trainer.train()" ] }, { "cell_type": "code", "execution_count": 6, "id": "208a5c13-31c7-4a03-b9a8-18146a265f73", "metadata": { "execution": { "iopub.execute_input": "2023-11-25T03:17:56.091526Z", "iopub.status.busy": "2023-11-25T03:17:56.091204Z", "iopub.status.idle": "2023-11-25T03:18:09.179984Z", "shell.execute_reply": "2023-11-25T03:18:09.179279Z", "shell.execute_reply.started": "2023-11-25T03:17:56.091497Z" } }, "outputs": [ { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [19/19 18:00]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "{'eval_loss': 0.4612630307674408,\n", " 'eval_accuracy': 0.8181818181818182,\n", " 'eval_f1': 0.8180812962482343,\n", " 'eval_precision': 0.8186808374254468,\n", " 'eval_recall': 0.8181818181818182,\n", " 'eval_runtime': 13.0807,\n", " 'eval_samples_per_second': 90.821,\n", " 'eval_steps_per_second': 1.453,\n", " 'epoch': 10.0}" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.evaluate()" ] }, { "cell_type": "code", "execution_count": 7, "id": "d90661bf-e22b-4dbf-980b-1c8ff69f625c", "metadata": { "execution": { "iopub.execute_input": "2023-11-25T03:18:09.184420Z", "iopub.status.busy": "2023-11-25T03:18:09.184230Z", "iopub.status.idle": "2023-11-25T03:35:26.344692Z", "shell.execute_reply": "2023-11-25T03:35:26.344122Z", "shell.execute_reply.started": "2023-11-25T03:18:09.184402Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/pytorch/lib/python3.8/site-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [1485/1485 17:16, Epoch 5/5]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
5000.253200
10000.105000

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=1485, training_loss=0.13263646263867515, metrics={'train_runtime': 1037.0165, 'train_samples_per_second': 22.897, 'train_steps_per_second': 1.432, 'total_flos': 563885457261714.0, 'train_loss': 0.13263646263867515, 'epoch': 5.0})" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "for param in model.parameters():\n", " param.requires_grad = True\n", "\n", " \n", "head_lr = 1e-4 # Slightly lower learning rate for the head\n", "base_lr = 5e-6 # Much lower learning rate for the base layers\n", "\n", "optimizer_grouped_parameters = [\n", " {'params': model.classifier.parameters(), 'lr': head_lr},\n", " {'params': [p for n, p in model.named_parameters() if 'classifier' not in n], 'lr': base_lr}\n", "]\n", "\n", "optimizer = AdamW(optimizer_grouped_parameters)\n", "\n", "training_args.num_train_epochs = 5 # Set the number of additional epochs\n", "trainer.train()" ] }, { "cell_type": "code", "execution_count": 10, "id": "e4502600-7091-4a8a-83b6-5af5e249b7ca", "metadata": { "execution": { "iopub.execute_input": "2023-11-25T03:35:44.942721Z", "iopub.status.busy": "2023-11-25T03:35:44.942333Z", "iopub.status.idle": "2023-11-25T03:35:57.188045Z", "shell.execute_reply": "2023-11-25T03:35:57.187245Z", "shell.execute_reply.started": "2023-11-25T03:35:44.942703Z" } }, "outputs": [ { "data": { "text/plain": [ "{'eval_loss': 0.2423660308122635,\n", " 'eval_accuracy': 0.9671717171717171,\n", " 'eval_f1': 0.9671861840444216,\n", " 'eval_precision': 0.9672086987568536,\n", " 'eval_recall': 0.9671717171717171,\n", " 'eval_runtime': 12.2384,\n", " 'eval_samples_per_second': 97.071,\n", " 'eval_steps_per_second': 1.552,\n", " 'epoch': 5.0}" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.evaluate()" ] }, { "cell_type": "code", "execution_count": 13, "id": "190ff835-a7a2-465f-994d-73adb75950a3", "metadata": { "execution": { "iopub.execute_input": "2023-11-25T03:39:58.975250Z", "iopub.status.busy": "2023-11-25T03:39:58.974521Z", "iopub.status.idle": "2023-11-25T03:39:59.367917Z", "shell.execute_reply": "2023-11-25T03:39:59.367402Z", "shell.execute_reply.started": "2023-11-25T03:39:58.975230Z" } }, "outputs": [ { "data": { "text/plain": [ "('transferLearningResults/tokenizer_config.json',\n", " 'transferLearningResults/special_tokens_map.json',\n", " 'transferLearningResults/vocab.json',\n", " 'transferLearningResults/merges.txt',\n", " 'transferLearningResults/added_tokens.json',\n", " 'transferLearningResults/tokenizer.json')" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.save_pretrained('transferLearningResults')\n", "tokenizer.save_pretrained('transferLearningResults')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }