{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "959444f6-bb64-49f4-b537-d764680219ca", "metadata": {}, "outputs": [], "source": [ "#!pip install -U bitsandbytes\n", "#!pip install -U transformers\n", "#!pip install -U accelerate\n", "#!pip install -U peft\n", "#!pip install -U trl" ] }, { "cell_type": "code", "execution_count": 2, "id": "74ca2d22-ee78-4f31-8769-efd3dae9c46c", "metadata": {}, "outputs": [], "source": [ "#!huggingface-cli whoami" ] }, { "cell_type": "code", "execution_count": 3, "id": "682a7a96-c5b8-4595-b495-99e13b88a844", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import os\n", "from tqdm import tqdm\n", "import bitsandbytes as bnb\n", "import torch\n", "import torch.nn as nn\n", "import transformers\n", "from datasets import Dataset\n", "from peft import LoraConfig, PeftConfig\n", "from trl import SFTTrainer\n", "from trl import setup_chat_format\n", "from transformers import (AutoModelForCausalLM, \n", " AutoTokenizer, \n", " BitsAndBytesConfig, \n", " TrainingArguments, \n", " pipeline, \n", " logging)\n", "from sklearn.metrics import (accuracy_score, \n", " classification_report, \n", " confusion_matrix)\n", "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "code", "execution_count": 4, "id": "5cbaba37-4deb-4a60-b4a7-dacd3c75c62b", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
counthate_speech_countoffensive_language_countneither_countclasstweet
030032!!! RT @mayasolovely: As a woman you shouldn't...
130301!!!!! RT @mleew17: boy dats cold...tyga dwn ba...
230301!!!!!!! RT @UrKindOfBrand Dawg!!!! RT @80sbaby...
330211!!!!!!!!! RT @C_G_Anderson: @viva_based she lo...
460601!!!!!!!!!!!!! RT @ShenikaRoberts: The shit you...
\n", "
" ], "text/plain": [ " count hate_speech_count offensive_language_count neither_count class \\\n", "0 3 0 0 3 2 \n", "1 3 0 3 0 1 \n", "2 3 0 3 0 1 \n", "3 3 0 2 1 1 \n", "4 6 0 6 0 1 \n", "\n", " tweet \n", "0 !!! RT @mayasolovely: As a woman you shouldn't... \n", "1 !!!!! RT @mleew17: boy dats cold...tyga dwn ba... \n", "2 !!!!!!! RT @UrKindOfBrand Dawg!!!! RT @80sbaby... \n", "3 !!!!!!!!! RT @C_G_Anderson: @viva_based she lo... \n", "4 !!!!!!!!!!!!! RT @ShenikaRoberts: The shit you... " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from datasets import load_dataset\n", "\n", "df = pd.read_parquet(\"hf://datasets/tdavidson/hate_speech_offensive/data/train-00000-of-00001.parquet\")\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 5, "id": "90b89c75-ab77-42ee-b50c-466d6cc9de96", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_11705/3716630465.py:2: FutureWarning: Setting an item of incompatible dtype is deprecated and will raise in a future error of pandas. Value '[1 1 1 ... 1 1 2]' has dtype incompatible with int64, please explicitly cast to a compatible dtype first.\n", " df.loc[:,'label'] = df.loc[:,'label'].replace(0,'Hate')\n" ] } ], "source": [ "df = df.rename(columns={\"class\": \"label\",\"tweet\": \"text\"}).sample(frac=1, random_state=85).reset_index(drop=True).head(3000)\n", "df.loc[:,'label'] = df.loc[:,'label'].replace(0,'Hate')\n", "df.loc[:,'label'] = df.loc[:,'label'].replace(1,'Offensive')\n", "df.loc[:,'label'] = df.loc[:,'label'].replace(2,'Normal')\n", "# Split the DataFrame\n", "train_size = 0.8\n", "eval_size = 0.1\n", "\n", "# Calculate sizes\n", "train_end = int(train_size * len(df))\n", "eval_end = train_end + int(eval_size * len(df))\n", "\n", "# Split the data\n", "X_train = df[:train_end]\n", "X_eval = df[train_end:eval_end]\n", "X_test = df[eval_end:]\n", "# Define the prompt generation functions\n", "def generate_prompt(data_point):\n", " return f\"\"\"\n", " Classify the text into Hatespeech, Offensive, Normal and return the answer as the corresponding label.\n", "text: {data_point[\"text\"]}\n", "label: {data_point[\"label\"]}\"\"\".strip()\n", "\n", "def generate_test_prompt(data_point):\n", " return f\"\"\"\n", " Classify the text into Hatespeech, Offensive, Normal and return the answer as the corresponding label.\n", " text: {data_point[\"text\"]}\n", " label: \"\"\".strip()\n", "\n", "# Generate prompts for training and evaluation data\n", "X_train.loc[:,'text'] = X_train.apply(generate_prompt, axis=1)\n", "X_eval.loc[:,'text'] = X_eval.apply(generate_prompt, axis=1)\n", "\n", "# Generate test prompts and extract true labels\n", "y_true = X_test.loc[:,'label']\n", "X_test = pd.DataFrame(X_test.apply(generate_test_prompt, axis=1), columns=[\"text\"])" ] }, { "cell_type": "code", "execution_count": 6, "id": "25859362-af09-43fd-93ef-b981159a7dba", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "label\n", "Offensive 1877\n", "Normal 391\n", "Hate 132\n", "Name: count, dtype: int64" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train.label.value_counts()" ] }, { "cell_type": "code", "execution_count": 7, "id": "12010dd7-0885-4296-9d8e-dae9d2addc30", "metadata": {}, "outputs": [], "source": [ "train_data = Dataset.from_pandas(X_train[[\"text\"]])\n", "eval_data = Dataset.from_pandas(X_eval[[\"text\"]])" ] }, { "cell_type": "code", "execution_count": 8, "id": "e8d7feef-d881-4e90-86d6-a4b30ac040f8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'Classify the text into Hatespeech, Offensive, Normal and return the answer as the corresponding label.\\ntext: @kieffer_jason bitch u a thot oh fake ass nigga box up hoe u not bout nothing\\nlabel: Offensive'" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_data['text'][2000]" ] }, { "cell_type": "code", "execution_count": 9, "id": "6e8f818f-2c29-4637-93d0-750144b3f599", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b36e9ec950eb41909ae7159673c5a774", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /home/marco/wandb/run-20250112_154819-ixq2tt7q" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run /home/marco/llama-3.2-3B-instruct-offensive-classification-2 to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/marcoor-universit-t-klagenfurt/huggingface" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/marcoor-universit-t-klagenfurt/huggingface/runs/ixq2tt7q" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/home/marco/.config/jupyterlab-desktop/jlab_server/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:632: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n", " return fn(*args, **kwargs)\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [300/300 09:53, Epoch 1/1]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining LossValidation Loss
602.0116002.002830
1201.8539001.961909
1802.0888001.939240
2401.9231001.927367
3002.0890001.924164

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/home/marco/.config/jupyterlab-desktop/jlab_server/lib/python3.12/site-packages/peft/utils/other.py:716: UserWarning: Unable to fetch remote file due to the following error (ProtocolError('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')), '(Request ID: ebbfa21b-df77-43ef-bddd-b95ec07a63f8)') - silently ignoring the lookup for the file config.json in meta-llama/Llama-3.2-3B-Instruct.\n", " warnings.warn(\n", "/home/marco/.config/jupyterlab-desktop/jlab_server/lib/python3.12/site-packages/peft/utils/save_and_load.py:246: UserWarning: Could not find a config file in meta-llama/Llama-3.2-3B-Instruct - will assume that the vocabulary was not modified.\n", " warnings.warn(\n" ] }, { "data": { "text/plain": [ "TrainOutput(global_step=300, training_loss=2.0926066251595814, metrics={'train_runtime': 596.9063, 'train_samples_per_second': 4.021, 'train_steps_per_second': 0.503, 'total_flos': 2216727844706304.0, 'train_loss': 2.0926066251595814, 'epoch': 1.0})" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": 16, "id": "d50bf803-dedd-44bf-8016-6d80b5726803", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('/home/marco/llama-3.2-3B-instruct-offensive-classification-2/tokenizer_config.json',\n", " '/home/marco/llama-3.2-3B-instruct-offensive-classification-2/special_tokens_map.json',\n", " '/home/marco/llama-3.2-3B-instruct-offensive-classification-2/tokenizer.json')" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.save_model(output_dir)\n", "tokenizer.save_pretrained(output_dir)" ] }, { "cell_type": "code", "execution_count": 17, "id": "37e30c0e-dd90-4400-aaab-b3526ac4a7ab", "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/300 [00:00