{ "cells": [ { "cell_type": "markdown", "id": "33faae25-af36-4781-bf8f-2084ddc96a52", "metadata": {}, "source": [ "# Setup" ] }, { "cell_type": "code", "execution_count": 1, "id": "73e72549-69f2-46b5-b0f5-655777139972", "metadata": { "execution": { "iopub.execute_input": "2025-01-20T20:17:03.803583Z", "iopub.status.busy": "2025-01-20T20:17:03.803051Z", "iopub.status.idle": "2025-01-20T20:17:06.786959Z", "shell.execute_reply": "2025-01-20T20:17:06.786718Z", "shell.execute_reply.started": "2025-01-20T20:17:03.803542Z" } }, "outputs": [], "source": [ "from datetime import datetime\n", "import numpy as np\n", "import torch\n", "from torch import nn\n", "from transformers import BertTokenizer, BertModel\n", "from huggingface_hub import (\n", " PyTorchModelHubMixin,\n", " notebook_login,\n", " ModelCard,\n", " ModelCardData,\n", " EvalResult,\n", ")\n", "from datasets import DatasetDict, load_dataset\n", "from torch.utils.data import Dataset, DataLoader" ] }, { "cell_type": "code", "execution_count": 2, "id": "07e0787e-c72b-41f3-baba-43cef3f8d6f8", "metadata": { "execution": { "iopub.execute_input": "2025-01-20T20:17:06.787691Z", "iopub.status.busy": "2025-01-20T20:17:06.787547Z", "iopub.status.idle": "2025-01-20T20:17:06.789420Z", "shell.execute_reply": "2025-01-20T20:17:06.789211Z", "shell.execute_reply.started": "2025-01-20T20:17:06.787682Z" } }, "outputs": [], "source": [ "notebook_login(new_session=False)" ] }, { "cell_type": "markdown", "id": "a919d72c-8d10-4275-a2ca-4ead295f41a8", "metadata": {}, "source": [ "# Functions" ] }, { "cell_type": "code", "execution_count": 3, "id": "d4b79fb9-5e70-4600-8885-94bc0a6e917c", "metadata": { "execution": { "iopub.execute_input": "2025-01-20T20:17:06.789829Z", "iopub.status.busy": "2025-01-20T20:17:06.789761Z", "iopub.status.idle": "2025-01-20T20:17:06.794443Z", "shell.execute_reply": "2025-01-20T20:17:06.794260Z", "shell.execute_reply.started": "2025-01-20T20:17:06.789822Z" } }, "outputs": [], "source": [ "def my_print(x):\n", " time_str = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", " print(time_str, x)\n", "\n", "\n", "def model_metrics(model, dataloader):\n", " criterion = nn.CrossEntropyLoss()\n", " model.eval()\n", " with torch.no_grad():\n", " total_loss = 0\n", " total_correct = 0\n", " total_length = 0\n", " for batch in dataloader:\n", " input_ids = batch[\"input_ids\"].to(device)\n", " attention_mask = batch[\"attention_mask\"].to(device)\n", " labels = batch[\"labels\"].to(device)\n", "\n", " outputs = model(input_ids, attention_mask)\n", " loss = criterion(outputs, labels)\n", " predictions_cpu = torch.argmax(outputs, dim=1).cpu().numpy()\n", " labels_cpu = labels.cpu().numpy()\n", " correct_count = (predictions_cpu == labels_cpu).sum()\n", "\n", " total_loss += loss.item()\n", " total_correct += correct_count\n", " total_length += len(labels_cpu)\n", " avg_loss = total_loss / len(dataloader)\n", " avg_acc = total_correct / total_length\n", " model.train()\n", " return avg_loss, avg_acc\n", "\n", "\n", "def print_model_status(epoch, num_epochs, model, train_dataloader, test_dataloader):\n", " train_loss, train_acc = model_metrics(model, train_dataloader)\n", " test_loss, test_acc = model_metrics(model, test_dataloader)\n", " loss_str = f\"Loss: Train {train_loss:0.3f}, Test {test_loss:0.3f}\"\n", " acc_str = f\"Acc: Train {train_acc:0.3f}, Test {test_acc:0.3f}\"\n", " my_print(f\"Epoch {epoch+1}/{num_epochs} done. {loss_str}; and {acc_str}\")\n", "\n", "\n", "class BertClassifier(nn.Module, PyTorchModelHubMixin):\n", " def __init__(self, num_labels=8, bert_variety=\"bert-base-uncased\"):\n", " super().__init__()\n", " self.bert = BertModel.from_pretrained(bert_variety)\n", " self.dropout = nn.Dropout(0.05)\n", " self.classifier = nn.Linear(self.bert.pooler.dense.out_features, num_labels)\n", "\n", " def forward(self, input_ids, attention_mask):\n", " outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n", " pooled_output = outputs.pooler_output\n", " pooled_output = self.dropout(pooled_output)\n", " logits = self.classifier(pooled_output)\n", " return logits\n", "\n", "\n", "class TextDataset(Dataset):\n", " def __init__(self, texts, labels, tokenizer, max_length=512):\n", " self.encodings = tokenizer(\n", " texts,\n", " truncation=True,\n", " padding=True,\n", " max_length=max_length,\n", " return_tensors=\"pt\",\n", " )\n", " self.labels = torch.tensor([int(l[0]) for l in labels])\n", "\n", " def __getitem__(self, idx):\n", " item = {key: val[idx] for key, val in self.encodings.items()}\n", " item[\"labels\"] = self.labels[idx]\n", " return item\n", "\n", " def __len__(self) -> int:\n", " return len(self.labels)\n", "\n", "\n", "def train_model(model, train_dataloader, test_dataloader, device, num_epochs):\n", " optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)\n", " criterion = nn.CrossEntropyLoss()\n", " model.train()\n", "\n", " print_model_status(-1, num_epochs, model, train_dataloader, test_dataloader)\n", " for epoch in range(num_epochs):\n", " total_loss = 0\n", " for batch in train_dataloader:\n", " optimizer.zero_grad()\n", "\n", " input_ids = batch[\"input_ids\"].to(device)\n", " attention_mask = batch[\"attention_mask\"].to(device)\n", " labels = batch[\"labels\"].to(device)\n", "\n", " outputs = model(input_ids, attention_mask)\n", " loss = criterion(outputs, labels)\n", "\n", " loss.backward()\n", " optimizer.step()\n", "\n", " total_loss += loss.item()\n", " avg_loss = total_loss / len(train_dataloader)\n", " print_model_status(epoch, num_epochs, model, train_dataloader, test_dataloader)" ] }, { "cell_type": "code", "execution_count": 4, "id": "07131bce-23ad-4787-8622-cce401f3e5ce", "metadata": { "execution": { "iopub.execute_input": "2025-01-20T20:17:06.795335Z", "iopub.status.busy": "2025-01-20T20:17:06.795239Z", "iopub.status.idle": "2025-01-20T20:17:06.821293Z", "shell.execute_reply": "2025-01-20T20:17:06.821061Z", "shell.execute_reply.started": "2025-01-20T20:17:06.795328Z" } }, "outputs": [], "source": [ "if torch.backends.mps.is_available():\n", " device = torch.device(\"mps\")\n", " torch.mps.empty_cache()\n", "elif torch.cuda.is_available():\n", " device = torch.device(\"cuda\")\n", "else:\n", " device = torch.device(\"cpu\")" ] }, { "cell_type": "code", "execution_count": 5, "id": "695bc080-bbd7-4937-af5b-50db1c936500", "metadata": { "execution": { "iopub.execute_input": "2025-01-20T20:17:06.821637Z", "iopub.status.busy": "2025-01-20T20:17:06.821569Z", "iopub.status.idle": "2025-01-20T20:17:06.824265Z", "shell.execute_reply": "2025-01-20T20:17:06.824082Z", "shell.execute_reply.started": "2025-01-20T20:17:06.821630Z" } }, "outputs": [], "source": [ "def run_training(\n", " max_dataset_size=16 * 200,\n", " bert_variety=\"bert-base-uncased\",\n", " max_length=200,\n", " num_epochs=3,\n", " batch_size=32,\n", "):\n", " hf_dataset = load_dataset(\"quotaclimat/frugalaichallenge-text-train\")\n", " test_size = 0.2\n", " test_seed = 42\n", " train_test = hf_dataset[\"train\"].train_test_split(\n", " test_size=test_size, seed=test_seed\n", " )\n", " train_dataset = train_test[\"train\"]\n", " test_dataset = train_test[\"test\"]\n", " if not max_dataset_size == \"full\" and max_dataset_size < len(hf_dataset[\"train\"]):\n", " train_dataset = train_dataset[:max_dataset_size]\n", " test_dataset = test_dataset[:max_dataset_size]\n", " else:\n", " train_dataset = train_dataset\n", " test_dataset = test_dataset\n", "\n", " tokenizer = BertTokenizer.from_pretrained(bert_variety, max_length=max_length)\n", " model = BertClassifier(bert_variety=bert_variety)\n", " if torch.backends.mps.is_available():\n", " device = torch.device(\"mps\")\n", " torch.mps.empty_cache()\n", " elif torch.cuda.is_available():\n", " device = torch.device(\"cuda\")\n", " else:\n", " device = torch.device(\"cpu\")\n", " model.to(device)\n", "\n", " text_dataset_train = TextDataset(\n", " train_dataset[\"quote\"],\n", " train_dataset[\"label\"],\n", " tokenizer=tokenizer,\n", " max_length=max_length,\n", " )\n", " text_dataset_test = TextDataset(\n", " test_dataset[\"quote\"],\n", " test_dataset[\"label\"],\n", " tokenizer=tokenizer,\n", " max_length=max_length,\n", " )\n", " dataloader_train = DataLoader(\n", " text_dataset_train, batch_size=batch_size, shuffle=True\n", " )\n", " dataloader_test = DataLoader(\n", " text_dataset_test, batch_size=batch_size, shuffle=False\n", " )\n", "\n", " train_model(model, dataloader_train, dataloader_test, device, num_epochs=num_epochs)\n", " return model, tokenizer" ] }, { "cell_type": "markdown", "id": "5af751f3-1fc4-4540-ae25-638db9d33c67", "metadata": {}, "source": [ "# Exploration" ] }, { "cell_type": "markdown", "id": "a847135f-ce86-46a1-9c61-3459a847cb29", "metadata": { "execution": { "iopub.execute_input": "2025-01-20T19:13:05.482383Z", "iopub.status.busy": "2025-01-20T19:13:05.481449Z", "iopub.status.idle": "2025-01-20T19:13:05.487546Z", "shell.execute_reply": "2025-01-20T19:13:05.486557Z", "shell.execute_reply.started": "2025-01-20T19:13:05.482339Z" } }, "source": [ "## Check if runs" ] }, { "cell_type": "code", "execution_count": 6, "id": "792fd13f-e7cc-4d90-832d-c0da15e193cd", "metadata": { "execution": { "iopub.execute_input": "2025-01-20T20:17:06.824513Z", "iopub.status.busy": "2025-01-20T20:17:06.824457Z", "iopub.status.idle": "2025-01-20T20:17:14.130284Z", "shell.execute_reply": "2025-01-20T20:17:14.129964Z", "shell.execute_reply.started": "2025-01-20T20:17:06.824506Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-01-20 12:17:10 Epoch 0/3 done. Loss: Train 2.111, Test 2.247; and Acc: Train 0.281, Test 0.156\n", "2025-01-20 12:17:11 Epoch 1/3 done. Loss: Train 2.026, Test 2.222; and Acc: Train 0.344, Test 0.156\n", "2025-01-20 12:17:12 Epoch 2/3 done. Loss: Train 1.943, Test 2.194; and Acc: Train 0.312, Test 0.156\n", "2025-01-20 12:17:14 Epoch 3/3 done. Loss: Train 1.859, Test 2.159; and Acc: Train 0.344, Test 0.156\n" ] } ], "source": [ "model, tokenizer = run_training(\n", " max_dataset_size=16 * 2,\n", " bert_variety=\"bert-base-uncased\",\n", " max_length=128,\n", " num_epochs=3,\n", " batch_size=32,\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "id": "0aedfcca-843e-4f4c-8062-3e4625161bcc", "metadata": { "execution": { "iopub.execute_input": "2025-01-20T20:17:14.130879Z", "iopub.status.busy": "2025-01-20T20:17:14.130792Z", "iopub.status.idle": "2025-01-20T20:17:14.193695Z", "shell.execute_reply": "2025-01-20T20:17:14.193466Z", "shell.execute_reply.started": "2025-01-20T20:17:14.130869Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-01-20 12:17:14 Predictions: tensor([4, 1, 1, 1, 3, 1, 1], device='mps:0')\n" ] } ], "source": [ "model.eval()\n", "test_text = [\n", " \"This was a great experience!\", # 0_not_relevant\n", " \"My favorite hike is Laguna de los Tres.\", # 0_not_relevant\n", " \"Crops will grow great in Finland if it's warmer there.\", # 3_not_bad\n", " \"Climate change is fake.\", # 1_not_happening\n", " \"The apparent warming is caused by solar cycles.\", # 2_not_human\n", " \"Solar panels emit bad vibes.\", # 4_solutions_harmful_unnecessary\n", " \"All those so-called scientists are Democrats.\", # 6_proponents_biased\n", "]\n", "test_encoding = tokenizer(\n", " test_text,\n", " truncation=True,\n", " padding=True,\n", " return_tensors=\"pt\",\n", ")\n", "\n", "with torch.no_grad():\n", " test_input_ids = test_encoding[\"input_ids\"].to(device)\n", " test_attention_mask = test_encoding[\"attention_mask\"].to(device)\n", " outputs = model(test_input_ids, test_attention_mask)\n", " predictions = torch.argmax(outputs, dim=1)\n", " my_print(f\"Predictions: {predictions}\")" ] }, { "cell_type": "markdown", "id": "0c3ea938-dd87-4673-b1d6-f06c70b19455", "metadata": {}, "source": [ "## Hyperparameters" ] }, { "cell_type": "code", "execution_count": 8, "id": "1d29336e-7f88-4127-afdf-2fe043e310e1", "metadata": { "execution": { "iopub.execute_input": "2025-01-20T20:17:14.194160Z", "iopub.status.busy": "2025-01-20T20:17:14.194076Z", "iopub.status.idle": "2025-01-20T20:25:46.660251Z", "shell.execute_reply": "2025-01-20T20:25:46.659652Z", "shell.execute_reply.started": "2025-01-20T20:17:14.194152Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-01-20 12:18:02 Epoch 0/3 done. Loss: Train 2.106, Test 2.091; and Acc: Train 0.118, Test 0.135\n", "2025-01-20 12:20:37 Epoch 1/3 done. Loss: Train 0.989, Test 1.114; and Acc: Train 0.647, Test 0.603\n", "2025-01-20 12:23:12 Epoch 2/3 done. Loss: Train 0.584, Test 0.928; and Acc: Train 0.825, Test 0.669\n", "2025-01-20 12:25:46 Epoch 3/3 done. Loss: Train 0.313, Test 0.950; and Acc: Train 0.913, Test 0.683\n" ] } ], "source": [ "model, tokenizer = run_training(\n", " max_dataset_size=\"full\",\n", " bert_variety=\"bert-base-uncased\",\n", " max_length=128,\n", " num_epochs=3,\n", " batch_size=32,\n", ")" ] }, { "cell_type": "code", "execution_count": 9, "id": "461b8f57-0c52-403a-bb69-3bc192b323bf", "metadata": { "execution": { "iopub.execute_input": "2025-01-20T20:25:46.661264Z", "iopub.status.busy": "2025-01-20T20:25:46.661132Z", "iopub.status.idle": "2025-01-20T20:34:54.221239Z", "shell.execute_reply": "2025-01-20T20:34:54.220590Z", "shell.execute_reply.started": "2025-01-20T20:25:46.661249Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-01-20 12:26:34 Epoch 0/3 done. Loss: Train 2.174, Test 2.168; and Acc: Train 0.096, Test 0.094\n", "2025-01-20 12:29:21 Epoch 1/3 done. Loss: Train 0.878, Test 1.033; and Acc: Train 0.712, Test 0.653\n", "2025-01-20 12:32:07 Epoch 2/3 done. Loss: Train 0.458, Test 0.906; and Acc: Train 0.869, Test 0.678\n", "2025-01-20 12:34:54 Epoch 3/3 done. Loss: Train 0.218, Test 0.959; and Acc: Train 0.944, Test 0.695\n" ] } ], "source": [ "model, tokenizer = run_training(\n", " max_dataset_size=\"full\",\n", " bert_variety=\"bert-base-uncased\",\n", " max_length=128,\n", " num_epochs=3,\n", " batch_size=16,\n", ")" ] }, { "cell_type": "code", "execution_count": 10, "id": "28354e8c-886a-4523-8968-8c688c13f6a3", "metadata": { "execution": { "iopub.execute_input": "2025-01-20T20:34:54.224989Z", "iopub.status.busy": "2025-01-20T20:34:54.224772Z", "iopub.status.idle": "2025-01-20T20:54:07.531338Z", "shell.execute_reply": "2025-01-20T20:54:07.530559Z", "shell.execute_reply.started": "2025-01-20T20:34:54.224968Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-01-20 12:36:37 Epoch 0/3 done. Loss: Train 2.122, Test 2.127; and Acc: Train 0.122, Test 0.118\n", "2025-01-20 12:42:26 Epoch 1/3 done. Loss: Train 0.779, Test 0.978; and Acc: Train 0.748, Test 0.652\n", "2025-01-20 12:48:16 Epoch 2/3 done. Loss: Train 0.391, Test 0.884; and Acc: Train 0.897, Test 0.696\n", "2025-01-20 12:54:07 Epoch 3/3 done. Loss: Train 0.154, Test 0.978; and Acc: Train 0.959, Test 0.705\n" ] } ], "source": [ "model, tokenizer = run_training(\n", " max_dataset_size=\"full\",\n", " bert_variety=\"bert-base-uncased\",\n", " max_length=256,\n", " num_epochs=3,\n", " batch_size=16,\n", ")" ] }, { "cell_type": "markdown", "id": "982ba556-c589-4cbb-b392-614942a64ab3", "metadata": {}, "source": [ "# Model to upload" ] }, { "cell_type": "code", "execution_count": 14, "id": "ec2516f9-79f2-4ae1-ab9a-9a51a7a50587", "metadata": { "execution": { "iopub.execute_input": "2025-01-20T22:10:34.055595Z", "iopub.status.busy": "2025-01-20T22:10:34.054690Z", "iopub.status.idle": "2025-01-20T22:10:34.083784Z", "shell.execute_reply": "2025-01-20T22:10:34.083448Z", "shell.execute_reply.started": "2025-01-20T22:10:34.055529Z" }, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "---\n", "base_model: google-bert/bert-base-uncased\n", "datasets:\n", "- QuotaClimat/frugalaichallenge-text-train\n", "language:\n", "- en\n", "license: apache-2.0\n", "model_name: frugal-ai-text-bert-base\n", "pipeline_tag: text-classification\n", "tags:\n", "- model_hub_mixin\n", "- pytorch_model_hub_mixin\n", "- climate\n", "---\n", "\n", "# Model Card for Model ID\n", "\n", "\n", "\n", "Classify text into 8 categories of climate misinformation.\n", "\n", "## Model Details\n", "\n", "### Model Description\n", "\n", "\n", "\n", "Fine trained BERT for classifying climate information as part of the Frugal AI Challenge, for submission to https://huggingface.co/frugal-ai-challenge and scoring on accuracy and efficiency. Trainied on only the non-evaluation 80% of the data, so it's (non-cheating) score will be lower.\n", "\n", "- **Developed by:** Andre Bach\n", "- **Funded by [optional]:** N/A\n", "- **Shared by [optional]:** Andre Bach\n", "- **Model type:** Text classification\n", "- **Language(s) (NLP):** ['en']\n", "- **License:** apache-2.0\n", "- **Finetuned from model [optional]:** google-bert/bert-base-uncased\n", "\n", "### Model Sources [optional]\n", "\n", "\n", "\n", "- **Repository:** frugal-ai-text-bert-base\n", "- **Paper [optional]:** [More Information Needed]\n", "- **Demo [optional]:** [More Information Needed]\n", "\n", "## Uses\n", "\n", "\n", "\n", "### Direct Use\n", "\n", "\n", "\n", "[More Information Needed]\n", "\n", "### Downstream Use [optional]\n", "\n", "\n", "\n", "[More Information Needed]\n", "\n", "### Out-of-Scope Use\n", "\n", "\n", "\n", "[More Information Needed]\n", "\n", "## Bias, Risks, and Limitations\n", "\n", "\n", "\n", "[More Information Needed]\n", "\n", "### Recommendations\n", "\n", "\n", "\n", "Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.\n", "\n", "## How to Get Started with the Model\n", "\n", "Use the code below to get started with the model.\n", "\n", "[More Information Needed]\n", "\n", "## Training Details\n", "\n", "### Training Data\n", "\n", "\n", "\n", "[More Information Needed]\n", "\n", "### Training Procedure\n", "\n", "\n", "\n", "#### Preprocessing [optional]\n", "\n", "[More Information Needed]\n", "\n", "\n", "#### Training Hyperparameters\n", "\n", "- **Training regime:** {'max_dataset_size': 'full', 'bert_variety': 'bert-base-uncased', 'max_length': 256, 'num_epochs': 3, 'batch_size': 16} \n", "\n", "#### Speeds, Sizes, Times [optional]\n", "\n", "\n", "\n", "[More Information Needed]\n", "\n", "## Evaluation\n", "\n", "\n", "\n", "### Testing Data, Factors & Metrics\n", "\n", "#### Testing Data\n", "\n", "\n", "\n", "[More Information Needed]\n", "\n", "#### Factors\n", "\n", "\n", "\n", "[More Information Needed]\n", "\n", "#### Metrics\n", "\n", "\n", "\n", "{'loss_train': 0.154, 'loss_test': 0.978, 'acc_train': 0.959, 'acc_test': 0.705}\n", "\n", "### Results\n", "\n", "[More Information Needed]\n", "\n", "#### Summary\n", "\n", "\n", "\n", "## Model Examination [optional]\n", "\n", "\n", "\n", "[More Information Needed]\n", "\n", "## Environmental Impact\n", "\n", "\n", "\n", "Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).\n", "\n", "- **Hardware Type:** [More Information Needed]\n", "- **Hours used:** [More Information Needed]\n", "- **Cloud Provider:** [More Information Needed]\n", "- **Compute Region:** [More Information Needed]\n", "- **Carbon Emitted:** [More Information Needed]\n", "\n", "## Technical Specifications [optional]\n", "\n", "### Model Architecture and Objective\n", "\n", "[More Information Needed]\n", "\n", "### Compute Infrastructure\n", "\n", "[More Information Needed]\n", "\n", "#### Hardware\n", "\n", "[More Information Needed]\n", "\n", "#### Software\n", "\n", "[More Information Needed]\n", "\n", "## Citation [optional]\n", "\n", "\n", "\n", "**BibTeX:**\n", "\n", "[More Information Needed]\n", "\n", "**APA:**\n", "\n", "[More Information Needed]\n", "\n", "## Glossary [optional]\n", "\n", "\n", "\n", "[More Information Needed]\n", "\n", "## More Information [optional]\n", "\n", "[More Information Needed]\n", "\n", "## Model Card Authors [optional]\n", "\n", "[More Information Needed]\n", "\n", "## Model Card Contact\n", "\n", "[More Information Needed]\n" ] } ], "source": [ "model_and_repo_name = \"frugal-ai-text-bert-base\"\n", "card_data = ModelCardData(\n", " model_name=model_and_repo_name,\n", " base_model=\"google-bert/bert-base-uncased\",\n", " license=\"apache-2.0\",\n", " language=[\"en\"],\n", " datasets=[\"QuotaClimat/frugalaichallenge-text-train\"],\n", " tags=[\"model_hub_mixin\", \"pytorch_model_hub_mixin\", \"climate\"],\n", " pipeline_tag=\"text-classification\",\n", ")\n", "card = ModelCard.from_template(\n", " card_data,\n", " model_summary=\"Classify text into 8 categories of climate misinformation.\",\n", " model_description=\"Fine trained BERT for classifying climate information as part of the Frugal AI Challenge, for submission to https://huggingface.co/frugal-ai-challenge and scoring on accuracy and efficiency. Trainied on only the non-evaluation 80% of the data, so it's (non-cheating) score will be lower.\",\n", " developers=\"Andre Bach\",\n", " funded_by=\"N/A\",\n", " shared_by=\"Andre Bach\",\n", " model_type=\"Text classification\",\n", " repo=model_and_repo_name,\n", " training_regime=dict(\n", " max_dataset_size=\"full\",\n", " bert_variety=\"bert-base-uncased\",\n", " max_length=256,\n", " num_epochs=3,\n", " batch_size=16,\n", " ),\n", " testing_metrics=dict(\n", " loss_train=0.154, loss_test=0.978, acc_train=0.959, acc_test=0.705\n", " ),\n", ")\n", "# print(card_data.to_yaml())\n", "print(card)" ] }, { "cell_type": "code", "execution_count": 17, "id": "29d3bbf9-ab2a-48e2-a550-e16da5025720", "metadata": { "execution": { "iopub.execute_input": "2025-01-20T22:11:59.827681Z", "iopub.status.busy": "2025-01-20T22:11:59.827001Z", "iopub.status.idle": "2025-01-20T22:11:59.831852Z", "shell.execute_reply": "2025-01-20T22:11:59.831047Z", "shell.execute_reply.started": "2025-01-20T22:11:59.827635Z" } }, "outputs": [], "source": [ "model_final = model\n", "tokenizer_final = tokenizer" ] }, { "cell_type": "code", "execution_count": 18, "id": "e3b099c6-6b98-473b-8797-5032213b9fcb", "metadata": { "execution": { "iopub.execute_input": "2025-01-20T22:12:00.576369Z", "iopub.status.busy": "2025-01-20T22:12:00.575421Z", "iopub.status.idle": "2025-01-20T22:12:01.065512Z", "shell.execute_reply": "2025-01-20T22:12:01.065237Z", "shell.execute_reply.started": "2025-01-20T22:12:00.576294Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-01-20 14:12:01 Predictions: tensor([0, 0, 3, 6, 2, 4, 6], device='mps:0')\n" ] } ], "source": [ "model_final.eval()\n", "test_text = [\n", " \"This was a great experience!\", # 0_not_relevant\n", " \"My favorite hike is Laguna de los Tres.\", # 0_not_relevant\n", " \"Crops will grow great in Finland if it's warmer there.\", # 3_not_bad\n", " \"Climate change is fake.\", # 1_not_happening\n", " \"The apparent warming is caused by solar cycles.\", # 2_not_human\n", " \"Solar panels emit bad vibes.\", # 4_solutions_harmful_unnecessary\n", " \"All those so-called scientists are Democrats.\", # 6_proponents_biased\n", "]\n", "test_encoding = tokenizer_final(\n", " test_text,\n", " truncation=True,\n", " padding=True,\n", " return_tensors=\"pt\",\n", ")\n", "\n", "with torch.no_grad():\n", " test_input_ids = test_encoding[\"input_ids\"].to(device)\n", " test_attention_mask = test_encoding[\"attention_mask\"].to(device)\n", " outputs = model_final(test_input_ids, test_attention_mask)\n", " predictions = torch.argmax(outputs, dim=1)\n", " my_print(f\"Predictions: {predictions}\")" ] }, { "cell_type": "code", "execution_count": 19, "id": "befb94b5-88bf-40fc-8b26-cf373d1256e0", "metadata": { "execution": { "iopub.execute_input": "2025-01-20T22:12:15.099356Z", "iopub.status.busy": "2025-01-20T22:12:15.098818Z", "iopub.status.idle": "2025-01-20T22:12:33.175760Z", "shell.execute_reply": "2025-01-20T22:12:33.174719Z", "shell.execute_reply.started": "2025-01-20T22:12:15.099315Z" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fbc09ae2c5614831a2fb02fa48a44fd1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "model.safetensors: 0%| | 0.00/438M [00:00