Nonnormalizable commited on
Commit
ab18efc
·
1 Parent(s): 3ec6adb

Point submission at my first bert model in HF.

Browse files
Files changed (2) hide show
  1. Finetune BERT.ipynb +402 -83
  2. tasks/text.py +41 -3
Finetune BERT.ipynb CHANGED
@@ -6,11 +6,11 @@
6
  "id": "73e72549-69f2-46b5-b0f5-655777139972",
7
  "metadata": {
8
  "execution": {
9
- "iopub.execute_input": "2025-01-17T04:45:37.715126Z",
10
- "iopub.status.busy": "2025-01-17T04:45:37.714808Z",
11
- "iopub.status.idle": "2025-01-17T04:45:41.232154Z",
12
- "shell.execute_reply": "2025-01-17T04:45:41.231851Z",
13
- "shell.execute_reply.started": "2025-01-17T04:45:37.715090Z"
14
  }
15
  },
16
  "outputs": [],
@@ -20,6 +20,7 @@
20
  "import torch\n",
21
  "from torch import nn\n",
22
  "from transformers import BertTokenizer, BertModel\n",
 
23
  "from torch.utils.data import Dataset, DataLoader\n",
24
  "from datasets import load_dataset"
25
  ]
@@ -27,14 +28,32 @@
27
  {
28
  "cell_type": "code",
29
  "execution_count": 2,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  "id": "d4b79fb9-5e70-4600-8885-94bc0a6e917c",
31
  "metadata": {
32
  "execution": {
33
- "iopub.execute_input": "2025-01-17T04:45:41.232694Z",
34
- "iopub.status.busy": "2025-01-17T04:45:41.232554Z",
35
- "iopub.status.idle": "2025-01-17T04:45:41.236434Z",
36
- "shell.execute_reply": "2025-01-17T04:45:41.236218Z",
37
- "shell.execute_reply.started": "2025-01-17T04:45:41.232685Z"
38
  }
39
  },
40
  "outputs": [],
@@ -43,12 +62,12 @@
43
  " time_str = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n",
44
  " print(time_str, x)\n",
45
  "\n",
46
- "class BertClassifier(nn.Module):\n",
47
- " def __init__(self, num_classes: int = 8, bert_variety='bert-base-uncased'):\n",
48
  " super().__init__()\n",
49
  " self.bert = BertModel.from_pretrained(bert_variety)\n",
50
  " self.dropout = nn.Dropout(0.05)\n",
51
- " self.classifier = nn.Linear(self.bert.pooler.dense.out_features, num_classes)\n",
52
  "\n",
53
  " def forward(self, input_ids, attention_mask):\n",
54
  " outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n",
@@ -58,7 +77,7 @@
58
  " return logits\n",
59
  "\n",
60
  "class TextDataset(Dataset):\n",
61
- " def __init__(self, texts, labels, tokenizer, max_length=200):\n",
62
  " self.encodings = tokenizer(\n",
63
  " texts,\n",
64
  " truncation=True,\n",
@@ -104,15 +123,15 @@
104
  },
105
  {
106
  "cell_type": "code",
107
- "execution_count": 3,
108
  "id": "07131bce-23ad-4787-8622-cce401f3e5ce",
109
  "metadata": {
110
  "execution": {
111
- "iopub.execute_input": "2025-01-17T04:45:41.237451Z",
112
- "iopub.status.busy": "2025-01-17T04:45:41.237358Z",
113
- "iopub.status.idle": "2025-01-17T04:45:41.252075Z",
114
- "shell.execute_reply": "2025-01-17T04:45:41.251851Z",
115
- "shell.execute_reply.started": "2025-01-17T04:45:41.237443Z"
116
  }
117
  },
118
  "outputs": [],
@@ -128,15 +147,15 @@
128
  },
129
  {
130
  "cell_type": "code",
131
- "execution_count": 4,
132
  "id": "695bc080-bbd7-4937-af5b-50db1c936500",
133
  "metadata": {
134
  "execution": {
135
- "iopub.execute_input": "2025-01-17T04:45:41.252581Z",
136
- "iopub.status.busy": "2025-01-17T04:45:41.252476Z",
137
- "iopub.status.idle": "2025-01-17T04:45:41.255279Z",
138
- "shell.execute_reply": "2025-01-17T04:45:41.255045Z",
139
- "shell.execute_reply.started": "2025-01-17T04:45:41.252572Z"
140
  }
141
  },
142
  "outputs": [],
@@ -179,15 +198,15 @@
179
  },
180
  {
181
  "cell_type": "code",
182
- "execution_count": 5,
183
  "id": "792fd13f-e7cc-4d90-832d-c0da15e193cd",
184
  "metadata": {
185
  "execution": {
186
- "iopub.execute_input": "2025-01-17T04:45:41.255750Z",
187
- "iopub.status.busy": "2025-01-17T04:45:41.255661Z",
188
- "iopub.status.idle": "2025-01-17T04:47:17.151654Z",
189
- "shell.execute_reply": "2025-01-17T04:47:17.149076Z",
190
- "shell.execute_reply.started": "2025-01-17T04:45:41.255742Z"
191
  }
192
  },
193
  "outputs": [
@@ -195,18 +214,18 @@
195
  "name": "stdout",
196
  "output_type": "stream",
197
  "text": [
198
- "2025-01-16 20:45:45 Starting epoch 1.\n",
199
- "2025-01-16 20:46:15 Epoch 1/3 done, Average Loss: 1.9223\n",
200
- "2025-01-16 20:46:46 Epoch 2/3 done, Average Loss: 1.6052\n",
201
- "2025-01-16 20:47:17 Epoch 3/3 done, Average Loss: 1.2876\n"
202
  ]
203
  }
204
  ],
205
  "source": [
206
  "model, tokenizer = run_training(\n",
207
- " max_dataset_size=16 * 50,\n",
208
  " bert_variety='bert-base-uncased',\n",
209
- " max_length=200,\n",
210
  " num_epochs=3,\n",
211
  " batch_size=32,\n",
212
  ")"
@@ -214,15 +233,15 @@
214
  },
215
  {
216
  "cell_type": "code",
217
- "execution_count": 6,
218
  "id": "0aedfcca-843e-4f4c-8062-3e4625161bcc",
219
  "metadata": {
220
  "execution": {
221
- "iopub.execute_input": "2025-01-17T04:47:17.158101Z",
222
- "iopub.status.busy": "2025-01-17T04:47:17.157305Z",
223
- "iopub.status.idle": "2025-01-17T04:47:17.333568Z",
224
- "shell.execute_reply": "2025-01-17T04:47:17.333317Z",
225
- "shell.execute_reply.started": "2025-01-17T04:47:17.157437Z"
226
  }
227
  },
228
  "outputs": [
@@ -230,7 +249,7 @@
230
  "name": "stdout",
231
  "output_type": "stream",
232
  "text": [
233
- "2025-01-16 20:47:17 Predictions: tensor([6, 1, 1, 6, 1, 6, 6], device='mps:0')\n"
234
  ]
235
  }
236
  ],
@@ -367,15 +386,15 @@
367
  },
368
  {
369
  "cell_type": "code",
370
- "execution_count": 10,
371
  "id": "28354e8c-886a-4523-8968-8c688c13f6a3",
372
  "metadata": {
373
  "execution": {
374
- "iopub.execute_input": "2025-01-17T05:05:36.905668Z",
375
- "iopub.status.busy": "2025-01-17T05:05:36.905353Z",
376
- "iopub.status.idle": "2025-01-17T05:21:10.045463Z",
377
- "shell.execute_reply": "2025-01-17T05:21:10.044788Z",
378
- "shell.execute_reply.started": "2025-01-17T05:05:36.905630Z"
379
  }
380
  },
381
  "outputs": [
@@ -383,10 +402,10 @@
383
  "name": "stdout",
384
  "output_type": "stream",
385
  "text": [
386
- "2025-01-16 21:05:43 Starting epoch 1.\n",
387
- "2025-01-16 21:10:53 Epoch 1/3 done, Average Loss: 1.3415\n",
388
- "2025-01-16 21:16:02 Epoch 2/3 done, Average Loss: 0.7216\n",
389
- "2025-01-16 21:21:10 Epoch 3/3 done, Average Loss: 0.3978\n"
390
  ]
391
  }
392
  ],
@@ -400,17 +419,60 @@
400
  ")"
401
  ]
402
  },
 
 
 
 
 
 
 
 
403
  {
404
  "cell_type": "code",
405
- "execution_count": 11,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  "id": "e3b099c6-6b98-473b-8797-5032213b9fcb",
407
  "metadata": {
408
  "execution": {
409
- "iopub.execute_input": "2025-01-17T05:21:10.059844Z",
410
- "iopub.status.busy": "2025-01-17T05:21:10.058980Z",
411
- "iopub.status.idle": "2025-01-17T05:21:10.164116Z",
412
- "shell.execute_reply": "2025-01-17T05:21:10.163826Z",
413
- "shell.execute_reply.started": "2025-01-17T05:21:10.059552Z"
414
  }
415
  },
416
  "outputs": [
@@ -418,12 +480,12 @@
418
  "name": "stdout",
419
  "output_type": "stream",
420
  "text": [
421
- "2025-01-16 21:21:10 Predictions: tensor([0, 0, 3, 6, 2, 4, 6], device='mps:0')\n"
422
  ]
423
  }
424
  ],
425
  "source": [
426
- "model.eval()\n",
427
  "test_text = [\n",
428
  " 'This was a great experience!', # 0_not_relevant\n",
429
  " 'My favorite hike is Laguna de los Tres.', # 0_not_relevant\n",
@@ -433,7 +495,7 @@
433
  " 'Solar panels emit bad vibes.', # 4_solutions_harmful_unnecessary\n",
434
  " 'All those so-called scientists are Democrats.', # 6_proponents_biased\n",
435
  "]\n",
436
- "test_encoding = tokenizer(\n",
437
  " test_text,\n",
438
  " truncation=True,\n",
439
  " padding=True,\n",
@@ -443,46 +505,101 @@
443
  "with torch.no_grad():\n",
444
  " test_input_ids = test_encoding['input_ids'].to(device)\n",
445
  " test_attention_mask = test_encoding['attention_mask'].to(device)\n",
446
- " outputs = model(test_input_ids, test_attention_mask)\n",
447
  " predictions = torch.argmax(outputs, dim=1)\n",
448
  " my_print(f'Predictions: {predictions}')"
449
  ]
450
  },
451
  {
452
  "cell_type": "code",
453
- "execution_count": 12,
454
  "id": "befb94b5-88bf-40fc-8b26-cf373d1256e0",
455
  "metadata": {
456
  "execution": {
457
- "iopub.execute_input": "2025-01-17T05:27:58.042752Z",
458
- "iopub.status.busy": "2025-01-17T05:27:58.042151Z",
459
- "iopub.status.idle": "2025-01-17T05:27:58.454054Z",
460
- "shell.execute_reply": "2025-01-17T05:27:58.453644Z",
461
- "shell.execute_reply.started": "2025-01-17T05:27:58.042662Z"
462
  }
463
  },
464
  "outputs": [
465
  {
466
- "ename": "AttributeError",
467
- "evalue": "'BertClassifier' object has no attribute 'push_to_hub'",
468
- "output_type": "error",
469
- "traceback": [
470
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
471
- "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
472
- "Cell \u001b[0;32mIn[12], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpush_to_hub\u001b[49m()\n",
473
- "File \u001b[0;32m~/miniconda3/envs/py313/lib/python3.13/site-packages/torch/nn/modules/module.py:1931\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1929\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m 1930\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1931\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\n\u001b[1;32m 1932\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1933\u001b[0m )\n",
474
- "\u001b[0;31mAttributeError\u001b[0m: 'BertClassifier' object has no attribute 'push_to_hub'"
475
- ]
 
 
 
 
 
 
 
 
 
 
 
 
476
  }
477
  ],
478
  "source": [
479
- "model.push_to_hub()"
480
  ]
481
  },
482
  {
483
  "cell_type": "code",
484
- "execution_count": null,
485
  "id": "251ef9ee-8ba3-495f-8fe6-a93aa63168ce",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
  "metadata": {},
487
  "outputs": [],
488
  "source": []
@@ -505,6 +622,208 @@
505
  "nbconvert_exporter": "python",
506
  "pygments_lexer": "ipython3",
507
  "version": "3.13.1"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  }
509
  },
510
  "nbformat": 4,
 
6
  "id": "73e72549-69f2-46b5-b0f5-655777139972",
7
  "metadata": {
8
  "execution": {
9
+ "iopub.execute_input": "2025-01-17T18:17:50.964659Z",
10
+ "iopub.status.busy": "2025-01-17T18:17:50.964450Z",
11
+ "iopub.status.idle": "2025-01-17T18:17:53.646932Z",
12
+ "shell.execute_reply": "2025-01-17T18:17:53.646697Z",
13
+ "shell.execute_reply.started": "2025-01-17T18:17:50.964637Z"
14
  }
15
  },
16
  "outputs": [],
 
20
  "import torch\n",
21
  "from torch import nn\n",
22
  "from transformers import BertTokenizer, BertModel\n",
23
+ "from huggingface_hub import PyTorchModelHubMixin, notebook_login\n",
24
  "from torch.utils.data import Dataset, DataLoader\n",
25
  "from datasets import load_dataset"
26
  ]
 
28
  {
29
  "cell_type": "code",
30
  "execution_count": 2,
31
+ "id": "07e0787e-c72b-41f3-baba-43cef3f8d6f8",
32
+ "metadata": {
33
+ "execution": {
34
+ "iopub.execute_input": "2025-01-17T18:17:53.648499Z",
35
+ "iopub.status.busy": "2025-01-17T18:17:53.648417Z",
36
+ "iopub.status.idle": "2025-01-17T18:17:53.650284Z",
37
+ "shell.execute_reply": "2025-01-17T18:17:53.650113Z",
38
+ "shell.execute_reply.started": "2025-01-17T18:17:53.648489Z"
39
+ }
40
+ },
41
+ "outputs": [],
42
+ "source": [
43
+ "notebook_login(new_session=False)"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": 11,
49
  "id": "d4b79fb9-5e70-4600-8885-94bc0a6e917c",
50
  "metadata": {
51
  "execution": {
52
+ "iopub.execute_input": "2025-01-17T18:35:15.421761Z",
53
+ "iopub.status.busy": "2025-01-17T18:35:15.421353Z",
54
+ "iopub.status.idle": "2025-01-17T18:35:15.433782Z",
55
+ "shell.execute_reply": "2025-01-17T18:35:15.433001Z",
56
+ "shell.execute_reply.started": "2025-01-17T18:35:15.421734Z"
57
  }
58
  },
59
  "outputs": [],
 
62
  " time_str = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n",
63
  " print(time_str, x)\n",
64
  "\n",
65
+ "class BertClassifier(nn.Module, PyTorchModelHubMixin):\n",
66
+ " def __init__(self, num_labels=8, bert_variety='bert-base-uncased'):\n",
67
  " super().__init__()\n",
68
  " self.bert = BertModel.from_pretrained(bert_variety)\n",
69
  " self.dropout = nn.Dropout(0.05)\n",
70
+ " self.classifier = nn.Linear(self.bert.pooler.dense.out_features, num_labels)\n",
71
  "\n",
72
  " def forward(self, input_ids, attention_mask):\n",
73
  " outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n",
 
77
  " return logits\n",
78
  "\n",
79
  "class TextDataset(Dataset):\n",
80
+ " def __init__(self, texts, labels, tokenizer, max_length=512):\n",
81
  " self.encodings = tokenizer(\n",
82
  " texts,\n",
83
  " truncation=True,\n",
 
123
  },
124
  {
125
  "cell_type": "code",
126
+ "execution_count": 4,
127
  "id": "07131bce-23ad-4787-8622-cce401f3e5ce",
128
  "metadata": {
129
  "execution": {
130
+ "iopub.execute_input": "2025-01-17T18:17:57.885732Z",
131
+ "iopub.status.busy": "2025-01-17T18:17:57.884455Z",
132
+ "iopub.status.idle": "2025-01-17T18:17:57.919509Z",
133
+ "shell.execute_reply": "2025-01-17T18:17:57.919081Z",
134
+ "shell.execute_reply.started": "2025-01-17T18:17:57.885667Z"
135
  }
136
  },
137
  "outputs": [],
 
147
  },
148
  {
149
  "cell_type": "code",
150
+ "execution_count": 5,
151
  "id": "695bc080-bbd7-4937-af5b-50db1c936500",
152
  "metadata": {
153
  "execution": {
154
+ "iopub.execute_input": "2025-01-17T18:17:58.556031Z",
155
+ "iopub.status.busy": "2025-01-17T18:17:58.555349Z",
156
+ "iopub.status.idle": "2025-01-17T18:17:58.564519Z",
157
+ "shell.execute_reply": "2025-01-17T18:17:58.563640Z",
158
+ "shell.execute_reply.started": "2025-01-17T18:17:58.555979Z"
159
  }
160
  },
161
  "outputs": [],
 
198
  },
199
  {
200
  "cell_type": "code",
201
+ "execution_count": 19,
202
  "id": "792fd13f-e7cc-4d90-832d-c0da15e193cd",
203
  "metadata": {
204
  "execution": {
205
+ "iopub.execute_input": "2025-01-17T15:22:41.286449Z",
206
+ "iopub.status.busy": "2025-01-17T15:22:41.285811Z",
207
+ "iopub.status.idle": "2025-01-17T15:24:35.507909Z",
208
+ "shell.execute_reply": "2025-01-17T15:24:35.506587Z",
209
+ "shell.execute_reply.started": "2025-01-17T15:22:41.286404Z"
210
  }
211
  },
212
  "outputs": [
 
214
  "name": "stdout",
215
  "output_type": "stream",
216
  "text": [
217
+ "2025-01-17 07:22:44 Starting epoch 1.\n",
218
+ "2025-01-17 07:23:21 Epoch 1/3 done, Average Loss: 1.8129\n",
219
+ "2025-01-17 07:23:58 Epoch 2/3 done, Average Loss: 1.3089\n",
220
+ "2025-01-17 07:24:35 Epoch 3/3 done, Average Loss: 0.8916\n"
221
  ]
222
  }
223
  ],
224
  "source": [
225
  "model, tokenizer = run_training(\n",
226
+ " max_dataset_size=16 * 100,\n",
227
  " bert_variety='bert-base-uncased',\n",
228
+ " max_length=128,\n",
229
  " num_epochs=3,\n",
230
  " batch_size=32,\n",
231
  ")"
 
233
  },
234
  {
235
  "cell_type": "code",
236
+ "execution_count": 21,
237
  "id": "0aedfcca-843e-4f4c-8062-3e4625161bcc",
238
  "metadata": {
239
  "execution": {
240
+ "iopub.execute_input": "2025-01-17T15:24:46.754460Z",
241
+ "iopub.status.busy": "2025-01-17T15:24:46.753753Z",
242
+ "iopub.status.idle": "2025-01-17T15:24:47.249458Z",
243
+ "shell.execute_reply": "2025-01-17T15:24:47.249207Z",
244
+ "shell.execute_reply.started": "2025-01-17T15:24:46.754391Z"
245
  }
246
  },
247
  "outputs": [
 
249
  "name": "stdout",
250
  "output_type": "stream",
251
  "text": [
252
+ "2025-01-17 07:24:47 Predictions: tensor([0, 1, 3, 6, 2, 3, 6], device='mps:0')\n"
253
  ]
254
  }
255
  ],
 
386
  },
387
  {
388
  "cell_type": "code",
389
+ "execution_count": 12,
390
  "id": "28354e8c-886a-4523-8968-8c688c13f6a3",
391
  "metadata": {
392
  "execution": {
393
+ "iopub.execute_input": "2025-01-17T18:35:15.434902Z",
394
+ "iopub.status.busy": "2025-01-17T18:35:15.434668Z",
395
+ "iopub.status.idle": "2025-01-17T18:50:43.167167Z",
396
+ "shell.execute_reply": "2025-01-17T18:50:43.166720Z",
397
+ "shell.execute_reply.started": "2025-01-17T18:35:15.434880Z"
398
  }
399
  },
400
  "outputs": [
 
402
  "name": "stdout",
403
  "output_type": "stream",
404
  "text": [
405
+ "2025-01-17 10:35:20 Starting epoch 1.\n",
406
+ "2025-01-17 10:40:29 Epoch 1/3 done, Average Loss: 1.2876\n",
407
+ "2025-01-17 10:45:37 Epoch 2/3 done, Average Loss: 0.7289\n",
408
+ "2025-01-17 10:50:43 Epoch 3/3 done, Average Loss: 0.3990\n"
409
  ]
410
  }
411
  ],
 
419
  ")"
420
  ]
421
  },
422
+ {
423
+ "cell_type": "markdown",
424
+ "id": "982ba556-c589-4cbb-b392-614942a64ab3",
425
+ "metadata": {},
426
+ "source": [
427
+ "# Model to upload"
428
+ ]
429
+ },
430
  {
431
  "cell_type": "code",
432
+ "execution_count": 6,
433
+ "id": "ac5f412c-a745-4327-9303-acf4c5b1efcd",
434
+ "metadata": {
435
+ "execution": {
436
+ "iopub.execute_input": "2025-01-17T18:19:11.590514Z",
437
+ "iopub.status.busy": "2025-01-17T18:19:11.589753Z",
438
+ "iopub.status.idle": "2025-01-17T18:26:45.645104Z",
439
+ "shell.execute_reply": "2025-01-17T18:26:45.644631Z",
440
+ "shell.execute_reply.started": "2025-01-17T18:19:11.590428Z"
441
+ }
442
+ },
443
+ "outputs": [
444
+ {
445
+ "name": "stdout",
446
+ "output_type": "stream",
447
+ "text": [
448
+ "2025-01-17 10:19:17 Starting epoch 1.\n",
449
+ "2025-01-17 10:21:47 Epoch 1/3 done, Average Loss: 1.2608\n",
450
+ "2025-01-17 10:24:16 Epoch 2/3 done, Average Loss: 0.7134\n",
451
+ "2025-01-17 10:26:45 Epoch 3/3 done, Average Loss: 0.3931\n"
452
+ ]
453
+ }
454
+ ],
455
+ "source": [
456
+ "model_final, tokenizer_final = run_training(\n",
457
+ " max_dataset_size='full',\n",
458
+ " bert_variety='bert-base-uncased',\n",
459
+ " max_length=128,\n",
460
+ " num_epochs=3,\n",
461
+ " batch_size=16,\n",
462
+ ")"
463
+ ]
464
+ },
465
+ {
466
+ "cell_type": "code",
467
+ "execution_count": 7,
468
  "id": "e3b099c6-6b98-473b-8797-5032213b9fcb",
469
  "metadata": {
470
  "execution": {
471
+ "iopub.execute_input": "2025-01-17T18:26:45.646178Z",
472
+ "iopub.status.busy": "2025-01-17T18:26:45.646081Z",
473
+ "iopub.status.idle": "2025-01-17T18:26:45.722052Z",
474
+ "shell.execute_reply": "2025-01-17T18:26:45.721803Z",
475
+ "shell.execute_reply.started": "2025-01-17T18:26:45.646168Z"
476
  }
477
  },
478
  "outputs": [
 
480
  "name": "stdout",
481
  "output_type": "stream",
482
  "text": [
483
+ "2025-01-17 10:26:45 Predictions: tensor([0, 0, 3, 1, 2, 4, 6], device='mps:0')\n"
484
  ]
485
  }
486
  ],
487
  "source": [
488
+ "model_final.eval()\n",
489
  "test_text = [\n",
490
  " 'This was a great experience!', # 0_not_relevant\n",
491
  " 'My favorite hike is Laguna de los Tres.', # 0_not_relevant\n",
 
495
  " 'Solar panels emit bad vibes.', # 4_solutions_harmful_unnecessary\n",
496
  " 'All those so-called scientists are Democrats.', # 6_proponents_biased\n",
497
  "]\n",
498
+ "test_encoding = tokenizer_final(\n",
499
  " test_text,\n",
500
  " truncation=True,\n",
501
  " padding=True,\n",
 
505
  "with torch.no_grad():\n",
506
  " test_input_ids = test_encoding['input_ids'].to(device)\n",
507
  " test_attention_mask = test_encoding['attention_mask'].to(device)\n",
508
+ " outputs = model_final(test_input_ids, test_attention_mask)\n",
509
  " predictions = torch.argmax(outputs, dim=1)\n",
510
  " my_print(f'Predictions: {predictions}')"
511
  ]
512
  },
513
  {
514
  "cell_type": "code",
515
+ "execution_count": 10,
516
  "id": "befb94b5-88bf-40fc-8b26-cf373d1256e0",
517
  "metadata": {
518
  "execution": {
519
+ "iopub.execute_input": "2025-01-17T18:32:40.094019Z",
520
+ "iopub.status.busy": "2025-01-17T18:32:40.093429Z",
521
+ "iopub.status.idle": "2025-01-17T18:35:15.419578Z",
522
+ "shell.execute_reply": "2025-01-17T18:35:15.418848Z",
523
+ "shell.execute_reply.started": "2025-01-17T18:32:40.093970Z"
524
  }
525
  },
526
  "outputs": [
527
  {
528
+ "data": {
529
+ "application/vnd.jupyter.widget-view+json": {
530
+ "model_id": "7dd2d0eb08624920b345ca85712f0169",
531
+ "version_major": 2,
532
+ "version_minor": 0
533
+ },
534
+ "text/plain": [
535
+ "model.safetensors: 0%| | 0.00/438M [00:00<?, ?B/s]"
536
+ ]
537
+ },
538
+ "metadata": {},
539
+ "output_type": "display_data"
540
+ },
541
+ {
542
+ "data": {
543
+ "text/plain": [
544
+ "CommitInfo(commit_url='https://huggingface.co/Nonnormalizable/frugal-ai-text-bert-base/commit/bd94aa1344798fcf671ddd5f8a7bd4f4dc0b20c4', commit_message='Push model using huggingface_hub.', commit_description='', oid='bd94aa1344798fcf671ddd5f8a7bd4f4dc0b20c4', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Nonnormalizable/frugal-ai-text-bert-base', endpoint='https://huggingface.co', repo_type='model', repo_id='Nonnormalizable/frugal-ai-text-bert-base'), pr_revision=None, pr_num=None)"
545
+ ]
546
+ },
547
+ "execution_count": 10,
548
+ "metadata": {},
549
+ "output_type": "execute_result"
550
  }
551
  ],
552
  "source": [
553
+ "model_final.push_to_hub('frugal-ai-text-bert-base')"
554
  ]
555
  },
556
  {
557
  "cell_type": "code",
558
+ "execution_count": 9,
559
  "id": "251ef9ee-8ba3-495f-8fe6-a93aa63168ce",
560
+ "metadata": {
561
+ "execution": {
562
+ "iopub.execute_input": "2025-01-17T18:31:37.682978Z",
563
+ "iopub.status.busy": "2025-01-17T18:31:37.682009Z",
564
+ "iopub.status.idle": "2025-01-17T18:31:39.578706Z",
565
+ "shell.execute_reply": "2025-01-17T18:31:39.577664Z",
566
+ "shell.execute_reply.started": "2025-01-17T18:31:37.682910Z"
567
+ }
568
+ },
569
+ "outputs": [
570
+ {
571
+ "data": {
572
+ "application/vnd.jupyter.widget-view+json": {
573
+ "model_id": "b62ae26d30534f8fa6057824124e9c95",
574
+ "version_major": 2,
575
+ "version_minor": 0
576
+ },
577
+ "text/plain": [
578
+ "README.md: 0%| | 0.00/320 [00:00<?, ?B/s]"
579
+ ]
580
+ },
581
+ "metadata": {},
582
+ "output_type": "display_data"
583
+ },
584
+ {
585
+ "data": {
586
+ "text/plain": [
587
+ "CommitInfo(commit_url='https://huggingface.co/Nonnormalizable/frugal-ai-text-bert-base/commit/9814436ad5f77cd8c607aa5dba9b67e7983e8ca7', commit_message='Upload tokenizer', commit_description='', oid='9814436ad5f77cd8c607aa5dba9b67e7983e8ca7', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Nonnormalizable/frugal-ai-text-bert-base', endpoint='https://huggingface.co', repo_type='model', repo_id='Nonnormalizable/frugal-ai-text-bert-base'), pr_revision=None, pr_num=None)"
588
+ ]
589
+ },
590
+ "execution_count": 9,
591
+ "metadata": {},
592
+ "output_type": "execute_result"
593
+ }
594
+ ],
595
+ "source": [
596
+ "tokenizer_final.push_to_hub('frugal-ai-text-bert-base')"
597
+ ]
598
+ },
599
+ {
600
+ "cell_type": "code",
601
+ "execution_count": null,
602
+ "id": "863d3553-89a6-4188-a8d0-eaa0b6bccb6c",
603
  "metadata": {},
604
  "outputs": [],
605
  "source": []
 
622
  "nbconvert_exporter": "python",
623
  "pygments_lexer": "ipython3",
624
  "version": "3.13.1"
625
+ },
626
+ "widgets": {
627
+ "application/vnd.jupyter.widget-state+json": {
628
+ "state": {
629
+ "25776d7aede3476da6f33fc15fe300c8": {
630
+ "model_module": "@jupyter-widgets/controls",
631
+ "model_module_version": "2.0.0",
632
+ "model_name": "ProgressStyleModel",
633
+ "state": {
634
+ "description_width": ""
635
+ }
636
+ },
637
+ "3a03347251c644bd9b5f58bac49ba2b7": {
638
+ "model_module": "@jupyter-widgets/base",
639
+ "model_module_version": "2.0.0",
640
+ "model_name": "LayoutModel",
641
+ "state": {}
642
+ },
643
+ "3f7dd449d7f84420a836adb899c3b374": {
644
+ "model_module": "@jupyter-widgets/controls",
645
+ "model_module_version": "2.0.0",
646
+ "model_name": "HTMLStyleModel",
647
+ "state": {
648
+ "description_width": "",
649
+ "font_size": null,
650
+ "text_color": null
651
+ }
652
+ },
653
+ "47f3b8da36704934acf81f357a9da6c3": {
654
+ "model_module": "@jupyter-widgets/controls",
655
+ "model_module_version": "2.0.0",
656
+ "model_name": "FloatProgressModel",
657
+ "state": {
658
+ "bar_style": "success",
659
+ "layout": "IPY_MODEL_ae0e1835546645cd85915a133bd0b578",
660
+ "max": 437977072,
661
+ "style": "IPY_MODEL_25776d7aede3476da6f33fc15fe300c8",
662
+ "value": 437977072
663
+ }
664
+ },
665
+ "4eff913c8c554820b957c2192d04a8cd": {
666
+ "model_module": "@jupyter-widgets/controls",
667
+ "model_module_version": "2.0.0",
668
+ "model_name": "HTMLModel",
669
+ "state": {
670
+ "layout": "IPY_MODEL_54b8a0d455794f8881e6d9ceddcac787",
671
+ "style": "IPY_MODEL_3f7dd449d7f84420a836adb899c3b374",
672
+ "value": " 438M/438M [02:32&lt;00:00, 3.02MB/s]"
673
+ }
674
+ },
675
+ "54b8a0d455794f8881e6d9ceddcac787": {
676
+ "model_module": "@jupyter-widgets/base",
677
+ "model_module_version": "2.0.0",
678
+ "model_name": "LayoutModel",
679
+ "state": {}
680
+ },
681
+ "5c96c3617819467d9fb70aa3b716106e": {
682
+ "model_module": "@jupyter-widgets/base",
683
+ "model_module_version": "2.0.0",
684
+ "model_name": "LayoutModel",
685
+ "state": {}
686
+ },
687
+ "62f9a837c04142b5a2fd66097be6fb6e": {
688
+ "model_module": "@jupyter-widgets/base",
689
+ "model_module_version": "2.0.0",
690
+ "model_name": "LayoutModel",
691
+ "state": {}
692
+ },
693
+ "68c0e93ffde14a40b3599dff15512174": {
694
+ "model_module": "@jupyter-widgets/controls",
695
+ "model_module_version": "2.0.0",
696
+ "model_name": "HTMLStyleModel",
697
+ "state": {
698
+ "description_width": "",
699
+ "font_size": null,
700
+ "text_color": null
701
+ }
702
+ },
703
+ "6f679b19e9824e1cac8545d7244ec83a": {
704
+ "model_module": "@jupyter-widgets/controls",
705
+ "model_module_version": "2.0.0",
706
+ "model_name": "FloatProgressModel",
707
+ "state": {
708
+ "bar_style": "success",
709
+ "layout": "IPY_MODEL_9785d5bb51544986b4c51b63a39d46cf",
710
+ "max": 320,
711
+ "style": "IPY_MODEL_88bc5db626a242af8879201d263d9eef",
712
+ "value": 320
713
+ }
714
+ },
715
+ "7dd2d0eb08624920b345ca85712f0169": {
716
+ "model_module": "@jupyter-widgets/controls",
717
+ "model_module_version": "2.0.0",
718
+ "model_name": "HBoxModel",
719
+ "state": {
720
+ "children": [
721
+ "IPY_MODEL_bdca6adbcf2347729287c1d2dc44fa2e",
722
+ "IPY_MODEL_47f3b8da36704934acf81f357a9da6c3",
723
+ "IPY_MODEL_4eff913c8c554820b957c2192d04a8cd"
724
+ ],
725
+ "layout": "IPY_MODEL_3a03347251c644bd9b5f58bac49ba2b7"
726
+ }
727
+ },
728
+ "88bc5db626a242af8879201d263d9eef": {
729
+ "model_module": "@jupyter-widgets/controls",
730
+ "model_module_version": "2.0.0",
731
+ "model_name": "ProgressStyleModel",
732
+ "state": {
733
+ "description_width": ""
734
+ }
735
+ },
736
+ "9396575ac43b4832bb12e246801a2316": {
737
+ "model_module": "@jupyter-widgets/controls",
738
+ "model_module_version": "2.0.0",
739
+ "model_name": "HTMLModel",
740
+ "state": {
741
+ "layout": "IPY_MODEL_c16752a4cf734193accaae9835d55aab",
742
+ "style": "IPY_MODEL_c1b70a1ce9d149cf87169838a18f2e58",
743
+ "value": "README.md: 100%"
744
+ }
745
+ },
746
+ "9785d5bb51544986b4c51b63a39d46cf": {
747
+ "model_module": "@jupyter-widgets/base",
748
+ "model_module_version": "2.0.0",
749
+ "model_name": "LayoutModel",
750
+ "state": {}
751
+ },
752
+ "ae0e1835546645cd85915a133bd0b578": {
753
+ "model_module": "@jupyter-widgets/base",
754
+ "model_module_version": "2.0.0",
755
+ "model_name": "LayoutModel",
756
+ "state": {}
757
+ },
758
+ "b62ae26d30534f8fa6057824124e9c95": {
759
+ "model_module": "@jupyter-widgets/controls",
760
+ "model_module_version": "2.0.0",
761
+ "model_name": "HBoxModel",
762
+ "state": {
763
+ "children": [
764
+ "IPY_MODEL_9396575ac43b4832bb12e246801a2316",
765
+ "IPY_MODEL_6f679b19e9824e1cac8545d7244ec83a",
766
+ "IPY_MODEL_ce85ada4df3c41e9a9b35b7401cd1883"
767
+ ],
768
+ "layout": "IPY_MODEL_62f9a837c04142b5a2fd66097be6fb6e"
769
+ }
770
+ },
771
+ "bdca6adbcf2347729287c1d2dc44fa2e": {
772
+ "model_module": "@jupyter-widgets/controls",
773
+ "model_module_version": "2.0.0",
774
+ "model_name": "HTMLModel",
775
+ "state": {
776
+ "layout": "IPY_MODEL_5c96c3617819467d9fb70aa3b716106e",
777
+ "style": "IPY_MODEL_c18dc3ed330d4d97a0c9d7dba32a9217",
778
+ "value": "model.safetensors: 100%"
779
+ }
780
+ },
781
+ "c16752a4cf734193accaae9835d55aab": {
782
+ "model_module": "@jupyter-widgets/base",
783
+ "model_module_version": "2.0.0",
784
+ "model_name": "LayoutModel",
785
+ "state": {}
786
+ },
787
+ "c18dc3ed330d4d97a0c9d7dba32a9217": {
788
+ "model_module": "@jupyter-widgets/controls",
789
+ "model_module_version": "2.0.0",
790
+ "model_name": "HTMLStyleModel",
791
+ "state": {
792
+ "description_width": "",
793
+ "font_size": null,
794
+ "text_color": null
795
+ }
796
+ },
797
+ "c1b70a1ce9d149cf87169838a18f2e58": {
798
+ "model_module": "@jupyter-widgets/controls",
799
+ "model_module_version": "2.0.0",
800
+ "model_name": "HTMLStyleModel",
801
+ "state": {
802
+ "description_width": "",
803
+ "font_size": null,
804
+ "text_color": null
805
+ }
806
+ },
807
+ "ce85ada4df3c41e9a9b35b7401cd1883": {
808
+ "model_module": "@jupyter-widgets/controls",
809
+ "model_module_version": "2.0.0",
810
+ "model_name": "HTMLModel",
811
+ "state": {
812
+ "layout": "IPY_MODEL_dae692ab00184ab190368530f21dcad9",
813
+ "style": "IPY_MODEL_68c0e93ffde14a40b3599dff15512174",
814
+ "value": " 320/320 [00:00&lt;00:00, 21.4kB/s]"
815
+ }
816
+ },
817
+ "dae692ab00184ab190368530f21dcad9": {
818
+ "model_module": "@jupyter-widgets/base",
819
+ "model_module_version": "2.0.0",
820
+ "model_name": "LayoutModel",
821
+ "state": {}
822
+ }
823
+ },
824
+ "version_major": 2,
825
+ "version_minor": 0
826
+ }
827
  }
828
  },
829
  "nbformat": 4,
tasks/text.py CHANGED
@@ -3,15 +3,18 @@ from datetime import datetime
3
  from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
  import random
 
 
6
 
7
  from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
  router = APIRouter()
11
 
12
- DESCRIPTION = "Most common class baseline"
13
  ROUTE = "/text"
14
 
 
15
  def baseline_model(dataset_length: int):
16
  # Make random predictions (placeholder for actual model inference)
17
  #predictions = [random.randint(0, 7) for _ in range(dataset_length)]
@@ -22,6 +25,40 @@ def baseline_model(dataset_length: int):
22
  return predictions
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  @router.post(ROUTE, tags=["Text Task"],
26
  description=DESCRIPTION)
27
  async def evaluate_text(request: TextEvaluationRequest):
@@ -67,8 +104,9 @@ async def evaluate_text(request: TextEvaluationRequest):
67
  #--------------------------------------------------------------------------------------------
68
 
69
  true_labels = test_dataset["label"]
70
- predictions = baseline_model(len(true_labels))
71
-
 
72
  #--------------------------------------------------------------------------------------------
73
  # YOUR MODEL INFERENCE STOPS HERE
74
  #--------------------------------------------------------------------------------------------
 
3
  from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
  import random
6
+ import torch
7
+ from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
8
 
9
  from .utils.evaluation import TextEvaluationRequest
10
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
11
 
12
  router = APIRouter()
13
 
14
+ DESCRIPTION = "bert base finetuned"
15
  ROUTE = "/text"
16
 
17
+
18
  def baseline_model(dataset_length: int):
19
  # Make random predictions (placeholder for actual model inference)
20
  #predictions = [random.randint(0, 7) for _ in range(dataset_length)]
 
25
  return predictions
26
 
27
 
28
+ def bert_model(test_dataset):
29
+ print('Starting my code block.')
30
+ texts = test_dataset["quote"]
31
+
32
+ model_repo = 'Nonnormalizable/frugal-ai-text-bert-base'
33
+ config = AutoConfig.from_pretrained(model_repo)
34
+ model = AutoModelForSequenceClassification.from_pretrained(model_repo)
35
+ tokenizer = AutoTokenizer.from_pretrained(model_repo)
36
+
37
+ if torch.cuda.is_available():
38
+ device = torch.device('cuda')
39
+ else:
40
+ device = torch.device('cpu')
41
+ print('device:', device)
42
+ test_encoding = tokenizer(
43
+ texts,
44
+ truncation=True,
45
+ padding=True,
46
+ return_tensors='pt',
47
+ )
48
+
49
+ model.eval()
50
+ with torch.no_grad():
51
+ test_input_ids = test_encoding['input_ids'].to(device)
52
+ test_attention_mask = test_encoding['attention_mask'].to(device)
53
+ print('Starting model run.')
54
+ outputs = model(test_input_ids, test_attention_mask)
55
+ print('End of model run.')
56
+ predictions = torch.argmax(outputs.logits, dim=1)
57
+
58
+ print('End of my code block.')
59
+ return predictions
60
+
61
+
62
  @router.post(ROUTE, tags=["Text Task"],
63
  description=DESCRIPTION)
64
  async def evaluate_text(request: TextEvaluationRequest):
 
104
  #--------------------------------------------------------------------------------------------
105
 
106
  true_labels = test_dataset["label"]
107
+ #predictions = baseline_model(len(true_labels))
108
+ predictions = bert_model(test_dataset)
109
+
110
  #--------------------------------------------------------------------------------------------
111
  # YOUR MODEL INFERENCE STOPS HERE
112
  #--------------------------------------------------------------------------------------------