Nonnormalizable commited on
Commit
250d2de
·
1 Parent(s): 9f48354

Train on just the training set. Automatic model card.

Browse files
Files changed (2) hide show
  1. Finetune BERT.ipynb +543 -301
  2. tasks/text.py +3 -10
Finetune BERT.ipynb CHANGED
@@ -1,16 +1,24 @@
1
  {
2
  "cells": [
 
 
 
 
 
 
 
 
3
  {
4
  "cell_type": "code",
5
  "execution_count": 1,
6
  "id": "73e72549-69f2-46b5-b0f5-655777139972",
7
  "metadata": {
8
  "execution": {
9
- "iopub.execute_input": "2025-01-17T18:17:50.964659Z",
10
- "iopub.status.busy": "2025-01-17T18:17:50.964450Z",
11
- "iopub.status.idle": "2025-01-17T18:17:53.646932Z",
12
- "shell.execute_reply": "2025-01-17T18:17:53.646697Z",
13
- "shell.execute_reply.started": "2025-01-17T18:17:50.964637Z"
14
  }
15
  },
16
  "outputs": [],
@@ -20,9 +28,15 @@
20
  "import torch\n",
21
  "from torch import nn\n",
22
  "from transformers import BertTokenizer, BertModel\n",
23
- "from huggingface_hub import PyTorchModelHubMixin, notebook_login\n",
24
- "from torch.utils.data import Dataset, DataLoader\n",
25
- "from datasets import load_dataset"
 
 
 
 
 
 
26
  ]
27
  },
28
  {
@@ -31,11 +45,11 @@
31
  "id": "07e0787e-c72b-41f3-baba-43cef3f8d6f8",
32
  "metadata": {
33
  "execution": {
34
- "iopub.execute_input": "2025-01-17T18:17:53.648499Z",
35
- "iopub.status.busy": "2025-01-17T18:17:53.648417Z",
36
- "iopub.status.idle": "2025-01-17T18:17:53.650284Z",
37
- "shell.execute_reply": "2025-01-17T18:17:53.650113Z",
38
- "shell.execute_reply.started": "2025-01-17T18:17:53.648489Z"
39
  }
40
  },
41
  "outputs": [],
@@ -43,17 +57,25 @@
43
  "notebook_login(new_session=False)"
44
  ]
45
  },
 
 
 
 
 
 
 
 
46
  {
47
  "cell_type": "code",
48
- "execution_count": 11,
49
  "id": "d4b79fb9-5e70-4600-8885-94bc0a6e917c",
50
  "metadata": {
51
  "execution": {
52
- "iopub.execute_input": "2025-01-17T18:35:15.421761Z",
53
- "iopub.status.busy": "2025-01-17T18:35:15.421353Z",
54
- "iopub.status.idle": "2025-01-17T18:35:15.433782Z",
55
- "shell.execute_reply": "2025-01-17T18:35:15.433001Z",
56
- "shell.execute_reply.started": "2025-01-17T18:35:15.421734Z"
57
  }
58
  },
59
  "outputs": [],
@@ -63,6 +85,41 @@
63
  " print(time_str, x)\n",
64
  "\n",
65
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  "class BertClassifier(nn.Module, PyTorchModelHubMixin):\n",
67
  " def __init__(self, num_labels=8, bert_variety=\"bert-base-uncased\"):\n",
68
  " super().__init__()\n",
@@ -98,12 +155,12 @@
98
  " return len(self.labels)\n",
99
  "\n",
100
  "\n",
101
- "def train_model(model, train_dataloader, device, num_epochs):\n",
102
  " optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)\n",
103
  " criterion = nn.CrossEntropyLoss()\n",
104
  " model.train()\n",
105
  "\n",
106
- " my_print(\"Starting epoch 1.\")\n",
107
  " for epoch in range(num_epochs):\n",
108
  " total_loss = 0\n",
109
  " for batch in train_dataloader:\n",
@@ -121,7 +178,7 @@
121
  "\n",
122
  " total_loss += loss.item()\n",
123
  " avg_loss = total_loss / len(train_dataloader)\n",
124
- " my_print(f\"Epoch {epoch+1}/{num_epochs} done, Average Loss: {avg_loss:0.4f}\")"
125
  ]
126
  },
127
  {
@@ -130,11 +187,11 @@
130
  "id": "07131bce-23ad-4787-8622-cce401f3e5ce",
131
  "metadata": {
132
  "execution": {
133
- "iopub.execute_input": "2025-01-17T18:17:57.885732Z",
134
- "iopub.status.busy": "2025-01-17T18:17:57.884455Z",
135
- "iopub.status.idle": "2025-01-17T18:17:57.919509Z",
136
- "shell.execute_reply": "2025-01-17T18:17:57.919081Z",
137
- "shell.execute_reply.started": "2025-01-17T18:17:57.885667Z"
138
  }
139
  },
140
  "outputs": [],
@@ -154,11 +211,11 @@
154
  "id": "695bc080-bbd7-4937-af5b-50db1c936500",
155
  "metadata": {
156
  "execution": {
157
- "iopub.execute_input": "2025-01-17T18:17:58.556031Z",
158
- "iopub.status.busy": "2025-01-17T18:17:58.555349Z",
159
- "iopub.status.idle": "2025-01-17T18:17:58.564519Z",
160
- "shell.execute_reply": "2025-01-17T18:17:58.563640Z",
161
- "shell.execute_reply.started": "2025-01-17T18:17:58.555979Z"
162
  }
163
  },
164
  "outputs": [],
@@ -171,10 +228,19 @@
171
  " batch_size=32,\n",
172
  "):\n",
173
  " hf_dataset = load_dataset(\"quotaclimat/frugalaichallenge-text-train\")\n",
 
 
 
 
 
 
 
174
  " if not max_dataset_size == \"full\" and max_dataset_size < len(hf_dataset[\"train\"]):\n",
175
- " train_dataset = hf_dataset[\"train\"][:max_dataset_size]\n",
 
176
  " else:\n",
177
- " train_dataset = hf_dataset[\"train\"]\n",
 
178
  "\n",
179
  " tokenizer = BertTokenizer.from_pretrained(bert_variety, max_length=max_length)\n",
180
  " model = BertClassifier(bert_variety=bert_variety)\n",
@@ -187,29 +253,64 @@
187
  " device = torch.device(\"cpu\")\n",
188
  " model.to(device)\n",
189
  "\n",
190
- " dataset = TextDataset(\n",
191
  " train_dataset[\"quote\"],\n",
192
  " train_dataset[\"label\"],\n",
193
  " tokenizer=tokenizer,\n",
194
  " max_length=max_length,\n",
195
  " )\n",
196
- " dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
 
 
 
 
 
 
 
 
 
 
 
197
  "\n",
198
- " train_model(model, dataloader, device, num_epochs=num_epochs)\n",
199
  " return model, tokenizer"
200
  ]
201
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  {
203
  "cell_type": "code",
204
- "execution_count": 19,
205
  "id": "792fd13f-e7cc-4d90-832d-c0da15e193cd",
206
  "metadata": {
207
  "execution": {
208
- "iopub.execute_input": "2025-01-17T15:22:41.286449Z",
209
- "iopub.status.busy": "2025-01-17T15:22:41.285811Z",
210
- "iopub.status.idle": "2025-01-17T15:24:35.507909Z",
211
- "shell.execute_reply": "2025-01-17T15:24:35.506587Z",
212
- "shell.execute_reply.started": "2025-01-17T15:22:41.286404Z"
213
  }
214
  },
215
  "outputs": [
@@ -217,16 +318,16 @@
217
  "name": "stdout",
218
  "output_type": "stream",
219
  "text": [
220
- "2025-01-17 07:22:44 Starting epoch 1.\n",
221
- "2025-01-17 07:23:21 Epoch 1/3 done, Average Loss: 1.8129\n",
222
- "2025-01-17 07:23:58 Epoch 2/3 done, Average Loss: 1.3089\n",
223
- "2025-01-17 07:24:35 Epoch 3/3 done, Average Loss: 0.8916\n"
224
  ]
225
  }
226
  ],
227
  "source": [
228
  "model, tokenizer = run_training(\n",
229
- " max_dataset_size=16 * 100,\n",
230
  " bert_variety=\"bert-base-uncased\",\n",
231
  " max_length=128,\n",
232
  " num_epochs=3,\n",
@@ -236,15 +337,15 @@
236
  },
237
  {
238
  "cell_type": "code",
239
- "execution_count": 21,
240
  "id": "0aedfcca-843e-4f4c-8062-3e4625161bcc",
241
  "metadata": {
242
  "execution": {
243
- "iopub.execute_input": "2025-01-17T15:24:46.754460Z",
244
- "iopub.status.busy": "2025-01-17T15:24:46.753753Z",
245
- "iopub.status.idle": "2025-01-17T15:24:47.249458Z",
246
- "shell.execute_reply": "2025-01-17T15:24:47.249207Z",
247
- "shell.execute_reply.started": "2025-01-17T15:24:46.754391Z"
248
  }
249
  },
250
  "outputs": [
@@ -252,7 +353,7 @@
252
  "name": "stdout",
253
  "output_type": "stream",
254
  "text": [
255
- "2025-01-17 07:24:47 Predictions: tensor([0, 1, 3, 6, 2, 3, 6], device='mps:0')\n"
256
  ]
257
  }
258
  ],
@@ -283,38 +384,11 @@
283
  ]
284
  },
285
  {
286
- "cell_type": "code",
287
- "execution_count": 7,
288
- "id": "881b738e-2392-4b7e-a0de-a0bad572ddfa",
289
- "metadata": {
290
- "execution": {
291
- "iopub.execute_input": "2025-01-17T04:47:17.334399Z",
292
- "iopub.status.busy": "2025-01-17T04:47:17.334287Z",
293
- "iopub.status.idle": "2025-01-17T04:50:59.116389Z",
294
- "shell.execute_reply": "2025-01-17T04:50:59.115528Z",
295
- "shell.execute_reply.started": "2025-01-17T04:47:17.334390Z"
296
- }
297
- },
298
- "outputs": [
299
- {
300
- "name": "stdout",
301
- "output_type": "stream",
302
- "text": [
303
- "2025-01-16 20:47:23 Starting epoch 1.\n",
304
- "2025-01-16 20:48:35 Epoch 1/3 done, Average Loss: 1.4272\n",
305
- "2025-01-16 20:49:46 Epoch 2/3 done, Average Loss: 0.8694\n",
306
- "2025-01-16 20:50:59 Epoch 3/3 done, Average Loss: 0.5774\n"
307
- ]
308
- }
309
- ],
310
  "source": [
311
- "model, tokenizer = run_training(\n",
312
- " max_dataset_size=\"full\",\n",
313
- " bert_variety=\"bert-base-uncased\",\n",
314
- " max_length=64,\n",
315
- " num_epochs=3,\n",
316
- " batch_size=32,\n",
317
- ")"
318
  ]
319
  },
320
  {
@@ -323,11 +397,11 @@
323
  "id": "1d29336e-7f88-4127-afdf-2fe043e310e1",
324
  "metadata": {
325
  "execution": {
326
- "iopub.execute_input": "2025-01-17T04:50:59.118025Z",
327
- "iopub.status.busy": "2025-01-17T04:50:59.117838Z",
328
- "iopub.status.idle": "2025-01-17T04:58:02.423121Z",
329
- "shell.execute_reply": "2025-01-17T04:58:02.421532Z",
330
- "shell.execute_reply.started": "2025-01-17T04:50:59.118005Z"
331
  }
332
  },
333
  "outputs": [
@@ -335,10 +409,10 @@
335
  "name": "stdout",
336
  "output_type": "stream",
337
  "text": [
338
- "2025-01-16 20:51:04 Starting epoch 1.\n",
339
- "2025-01-16 20:53:20 Epoch 1/3 done, Average Loss: 1.4107\n",
340
- "2025-01-16 20:55:41 Epoch 2/3 done, Average Loss: 0.8491\n",
341
- "2025-01-16 20:58:02 Epoch 3/3 done, Average Loss: 0.5359\n"
342
  ]
343
  }
344
  ],
@@ -358,11 +432,11 @@
358
  "id": "461b8f57-0c52-403a-bb69-3bc192b323bf",
359
  "metadata": {
360
  "execution": {
361
- "iopub.execute_input": "2025-01-17T04:58:02.426159Z",
362
- "iopub.status.busy": "2025-01-17T04:58:02.425896Z",
363
- "iopub.status.idle": "2025-01-17T05:05:36.903446Z",
364
- "shell.execute_reply": "2025-01-17T05:05:36.901961Z",
365
- "shell.execute_reply.started": "2025-01-17T04:58:02.426132Z"
366
  }
367
  },
368
  "outputs": [
@@ -370,10 +444,10 @@
370
  "name": "stdout",
371
  "output_type": "stream",
372
  "text": [
373
- "2025-01-16 20:58:08 Starting epoch 1.\n",
374
- "2025-01-16 21:00:38 Epoch 1/3 done, Average Loss: 1.2946\n",
375
- "2025-01-16 21:03:07 Epoch 2/3 done, Average Loss: 0.7425\n",
376
- "2025-01-16 21:05:36 Epoch 3/3 done, Average Loss: 0.4126\n"
377
  ]
378
  }
379
  ],
@@ -389,15 +463,15 @@
389
  },
390
  {
391
  "cell_type": "code",
392
- "execution_count": 12,
393
  "id": "28354e8c-886a-4523-8968-8c688c13f6a3",
394
  "metadata": {
395
  "execution": {
396
- "iopub.execute_input": "2025-01-17T18:35:15.434902Z",
397
- "iopub.status.busy": "2025-01-17T18:35:15.434668Z",
398
- "iopub.status.idle": "2025-01-17T18:50:43.167167Z",
399
- "shell.execute_reply": "2025-01-17T18:50:43.166720Z",
400
- "shell.execute_reply.started": "2025-01-17T18:35:15.434880Z"
401
  }
402
  },
403
  "outputs": [
@@ -405,10 +479,10 @@
405
  "name": "stdout",
406
  "output_type": "stream",
407
  "text": [
408
- "2025-01-17 10:35:20 Starting epoch 1.\n",
409
- "2025-01-17 10:40:29 Epoch 1/3 done, Average Loss: 1.2876\n",
410
- "2025-01-17 10:45:37 Epoch 2/3 done, Average Loss: 0.7289\n",
411
- "2025-01-17 10:50:43 Epoch 3/3 done, Average Loss: 0.3990\n"
412
  ]
413
  }
414
  ],
@@ -432,50 +506,300 @@
432
  },
433
  {
434
  "cell_type": "code",
435
- "execution_count": 6,
436
- "id": "ac5f412c-a745-4327-9303-acf4c5b1efcd",
437
  "metadata": {
438
  "execution": {
439
- "iopub.execute_input": "2025-01-17T18:19:11.590514Z",
440
- "iopub.status.busy": "2025-01-17T18:19:11.589753Z",
441
- "iopub.status.idle": "2025-01-17T18:26:45.645104Z",
442
- "shell.execute_reply": "2025-01-17T18:26:45.644631Z",
443
- "shell.execute_reply.started": "2025-01-17T18:19:11.590428Z"
444
- }
 
445
  },
446
  "outputs": [
447
  {
448
  "name": "stdout",
449
  "output_type": "stream",
450
  "text": [
451
- "2025-01-17 10:19:17 Starting epoch 1.\n",
452
- "2025-01-17 10:21:47 Epoch 1/3 done, Average Loss: 1.2608\n",
453
- "2025-01-17 10:24:16 Epoch 2/3 done, Average Loss: 0.7134\n",
454
- "2025-01-17 10:26:45 Epoch 3/3 done, Average Loss: 0.3931\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
  ]
456
  }
457
  ],
458
  "source": [
459
- "model_final, tokenizer_final = run_training(\n",
460
- " max_dataset_size=\"full\",\n",
461
- " bert_variety=\"bert-base-uncased\",\n",
462
- " max_length=128,\n",
463
- " num_epochs=3,\n",
464
- " batch_size=16,\n",
465
- ")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  ]
467
  },
468
  {
469
  "cell_type": "code",
470
- "execution_count": 7,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  "id": "e3b099c6-6b98-473b-8797-5032213b9fcb",
472
  "metadata": {
473
  "execution": {
474
- "iopub.execute_input": "2025-01-17T18:26:45.646178Z",
475
- "iopub.status.busy": "2025-01-17T18:26:45.646081Z",
476
- "iopub.status.idle": "2025-01-17T18:26:45.722052Z",
477
- "shell.execute_reply": "2025-01-17T18:26:45.721803Z",
478
- "shell.execute_reply.started": "2025-01-17T18:26:45.646168Z"
479
  }
480
  },
481
  "outputs": [
@@ -483,7 +807,7 @@
483
  "name": "stdout",
484
  "output_type": "stream",
485
  "text": [
486
- "2025-01-17 10:26:45 Predictions: tensor([0, 0, 3, 1, 2, 4, 6], device='mps:0')\n"
487
  ]
488
  }
489
  ],
@@ -515,22 +839,22 @@
515
  },
516
  {
517
  "cell_type": "code",
518
- "execution_count": 10,
519
  "id": "befb94b5-88bf-40fc-8b26-cf373d1256e0",
520
  "metadata": {
521
  "execution": {
522
- "iopub.execute_input": "2025-01-17T18:32:40.094019Z",
523
- "iopub.status.busy": "2025-01-17T18:32:40.093429Z",
524
- "iopub.status.idle": "2025-01-17T18:35:15.419578Z",
525
- "shell.execute_reply": "2025-01-17T18:35:15.418848Z",
526
- "shell.execute_reply.started": "2025-01-17T18:32:40.093970Z"
527
  }
528
  },
529
  "outputs": [
530
  {
531
  "data": {
532
  "application/vnd.jupyter.widget-view+json": {
533
- "model_id": "7dd2d0eb08624920b345ca85712f0169",
534
  "version_major": 2,
535
  "version_minor": 0
536
  },
@@ -544,10 +868,10 @@
544
  {
545
  "data": {
546
  "text/plain": [
547
- "CommitInfo(commit_url='https://huggingface.co/Nonnormalizable/frugal-ai-text-bert-base/commit/bd94aa1344798fcf671ddd5f8a7bd4f4dc0b20c4', commit_message='Push model using huggingface_hub.', commit_description='', oid='bd94aa1344798fcf671ddd5f8a7bd4f4dc0b20c4', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Nonnormalizable/frugal-ai-text-bert-base', endpoint='https://huggingface.co', repo_type='model', repo_id='Nonnormalizable/frugal-ai-text-bert-base'), pr_revision=None, pr_num=None)"
548
  ]
549
  },
550
- "execution_count": 10,
551
  "metadata": {},
552
  "output_type": "execute_result"
553
  }
@@ -558,51 +882,66 @@
558
  },
559
  {
560
  "cell_type": "code",
561
- "execution_count": 9,
562
  "id": "251ef9ee-8ba3-495f-8fe6-a93aa63168ce",
563
  "metadata": {
564
  "execution": {
565
- "iopub.execute_input": "2025-01-17T18:31:37.682978Z",
566
- "iopub.status.busy": "2025-01-17T18:31:37.682009Z",
567
- "iopub.status.idle": "2025-01-17T18:31:39.578706Z",
568
- "shell.execute_reply": "2025-01-17T18:31:39.577664Z",
569
- "shell.execute_reply.started": "2025-01-17T18:31:37.682910Z"
570
  }
571
  },
572
  "outputs": [
573
  {
574
  "data": {
575
- "application/vnd.jupyter.widget-view+json": {
576
- "model_id": "b62ae26d30534f8fa6057824124e9c95",
577
- "version_major": 2,
578
- "version_minor": 0
579
- },
580
  "text/plain": [
581
- "README.md: 0%| | 0.00/320 [00:00<?, ?B/s]"
582
  ]
583
  },
 
584
  "metadata": {},
585
- "output_type": "display_data"
586
- },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
587
  {
588
  "data": {
589
  "text/plain": [
590
- "CommitInfo(commit_url='https://huggingface.co/Nonnormalizable/frugal-ai-text-bert-base/commit/9814436ad5f77cd8c607aa5dba9b67e7983e8ca7', commit_message='Upload tokenizer', commit_description='', oid='9814436ad5f77cd8c607aa5dba9b67e7983e8ca7', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Nonnormalizable/frugal-ai-text-bert-base', endpoint='https://huggingface.co', repo_type='model', repo_id='Nonnormalizable/frugal-ai-text-bert-base'), pr_revision=None, pr_num=None)"
591
  ]
592
  },
593
- "execution_count": 9,
594
  "metadata": {},
595
  "output_type": "execute_result"
596
  }
597
  ],
598
  "source": [
599
- "tokenizer_final.push_to_hub(\"frugal-ai-text-bert-base\")"
600
  ]
601
  },
602
  {
603
  "cell_type": "code",
604
  "execution_count": null,
605
- "id": "863d3553-89a6-4188-a8d0-eaa0b6bccb6c",
606
  "metadata": {},
607
  "outputs": [],
608
  "source": []
@@ -629,71 +968,41 @@
629
  "widgets": {
630
  "application/vnd.jupyter.widget-state+json": {
631
  "state": {
632
- "25776d7aede3476da6f33fc15fe300c8": {
633
- "model_module": "@jupyter-widgets/controls",
634
- "model_module_version": "2.0.0",
635
- "model_name": "ProgressStyleModel",
636
- "state": {
637
- "description_width": ""
638
- }
639
- },
640
- "3a03347251c644bd9b5f58bac49ba2b7": {
641
  "model_module": "@jupyter-widgets/base",
642
  "model_module_version": "2.0.0",
643
  "model_name": "LayoutModel",
644
  "state": {}
645
  },
646
- "3f7dd449d7f84420a836adb899c3b374": {
647
- "model_module": "@jupyter-widgets/controls",
648
- "model_module_version": "2.0.0",
649
- "model_name": "HTMLStyleModel",
650
- "state": {
651
- "description_width": "",
652
- "font_size": null,
653
- "text_color": null
654
- }
655
- },
656
- "47f3b8da36704934acf81f357a9da6c3": {
657
- "model_module": "@jupyter-widgets/controls",
658
- "model_module_version": "2.0.0",
659
- "model_name": "FloatProgressModel",
660
- "state": {
661
- "bar_style": "success",
662
- "layout": "IPY_MODEL_ae0e1835546645cd85915a133bd0b578",
663
- "max": 437977072,
664
- "style": "IPY_MODEL_25776d7aede3476da6f33fc15fe300c8",
665
- "value": 437977072
666
- }
667
- },
668
- "4eff913c8c554820b957c2192d04a8cd": {
669
- "model_module": "@jupyter-widgets/controls",
670
- "model_module_version": "2.0.0",
671
- "model_name": "HTMLModel",
672
- "state": {
673
- "layout": "IPY_MODEL_54b8a0d455794f8881e6d9ceddcac787",
674
- "style": "IPY_MODEL_3f7dd449d7f84420a836adb899c3b374",
675
- "value": " 438M/438M [02:32&lt;00:00, 3.02MB/s]"
676
- }
677
- },
678
- "54b8a0d455794f8881e6d9ceddcac787": {
679
  "model_module": "@jupyter-widgets/base",
680
  "model_module_version": "2.0.0",
681
  "model_name": "LayoutModel",
682
  "state": {}
683
  },
684
- "5c96c3617819467d9fb70aa3b716106e": {
685
  "model_module": "@jupyter-widgets/base",
686
  "model_module_version": "2.0.0",
687
  "model_name": "LayoutModel",
688
  "state": {}
689
  },
690
- "62f9a837c04142b5a2fd66097be6fb6e": {
 
 
 
 
 
 
 
 
 
 
691
  "model_module": "@jupyter-widgets/base",
692
  "model_module_version": "2.0.0",
693
  "model_name": "LayoutModel",
694
  "state": {}
695
  },
696
- "68c0e93ffde14a40b3599dff15512174": {
697
  "model_module": "@jupyter-widgets/controls",
698
  "model_module_version": "2.0.0",
699
  "model_name": "HTMLStyleModel",
@@ -703,32 +1012,17 @@
703
  "text_color": null
704
  }
705
  },
706
- "6f679b19e9824e1cac8545d7244ec83a": {
707
- "model_module": "@jupyter-widgets/controls",
708
- "model_module_version": "2.0.0",
709
- "model_name": "FloatProgressModel",
710
- "state": {
711
- "bar_style": "success",
712
- "layout": "IPY_MODEL_9785d5bb51544986b4c51b63a39d46cf",
713
- "max": 320,
714
- "style": "IPY_MODEL_88bc5db626a242af8879201d263d9eef",
715
- "value": 320
716
- }
717
- },
718
- "7dd2d0eb08624920b345ca85712f0169": {
719
  "model_module": "@jupyter-widgets/controls",
720
  "model_module_version": "2.0.0",
721
- "model_name": "HBoxModel",
722
  "state": {
723
- "children": [
724
- "IPY_MODEL_bdca6adbcf2347729287c1d2dc44fa2e",
725
- "IPY_MODEL_47f3b8da36704934acf81f357a9da6c3",
726
- "IPY_MODEL_4eff913c8c554820b957c2192d04a8cd"
727
- ],
728
- "layout": "IPY_MODEL_3a03347251c644bd9b5f58bac49ba2b7"
729
  }
730
  },
731
- "88bc5db626a242af8879201d263d9eef": {
732
  "model_module": "@jupyter-widgets/controls",
733
  "model_module_version": "2.0.0",
734
  "model_name": "ProgressStyleModel",
@@ -736,58 +1030,7 @@
736
  "description_width": ""
737
  }
738
  },
739
- "9396575ac43b4832bb12e246801a2316": {
740
- "model_module": "@jupyter-widgets/controls",
741
- "model_module_version": "2.0.0",
742
- "model_name": "HTMLModel",
743
- "state": {
744
- "layout": "IPY_MODEL_c16752a4cf734193accaae9835d55aab",
745
- "style": "IPY_MODEL_c1b70a1ce9d149cf87169838a18f2e58",
746
- "value": "README.md: 100%"
747
- }
748
- },
749
- "9785d5bb51544986b4c51b63a39d46cf": {
750
- "model_module": "@jupyter-widgets/base",
751
- "model_module_version": "2.0.0",
752
- "model_name": "LayoutModel",
753
- "state": {}
754
- },
755
- "ae0e1835546645cd85915a133bd0b578": {
756
- "model_module": "@jupyter-widgets/base",
757
- "model_module_version": "2.0.0",
758
- "model_name": "LayoutModel",
759
- "state": {}
760
- },
761
- "b62ae26d30534f8fa6057824124e9c95": {
762
- "model_module": "@jupyter-widgets/controls",
763
- "model_module_version": "2.0.0",
764
- "model_name": "HBoxModel",
765
- "state": {
766
- "children": [
767
- "IPY_MODEL_9396575ac43b4832bb12e246801a2316",
768
- "IPY_MODEL_6f679b19e9824e1cac8545d7244ec83a",
769
- "IPY_MODEL_ce85ada4df3c41e9a9b35b7401cd1883"
770
- ],
771
- "layout": "IPY_MODEL_62f9a837c04142b5a2fd66097be6fb6e"
772
- }
773
- },
774
- "bdca6adbcf2347729287c1d2dc44fa2e": {
775
- "model_module": "@jupyter-widgets/controls",
776
- "model_module_version": "2.0.0",
777
- "model_name": "HTMLModel",
778
- "state": {
779
- "layout": "IPY_MODEL_5c96c3617819467d9fb70aa3b716106e",
780
- "style": "IPY_MODEL_c18dc3ed330d4d97a0c9d7dba32a9217",
781
- "value": "model.safetensors: 100%"
782
- }
783
- },
784
- "c16752a4cf734193accaae9835d55aab": {
785
- "model_module": "@jupyter-widgets/base",
786
- "model_module_version": "2.0.0",
787
- "model_name": "LayoutModel",
788
- "state": {}
789
- },
790
- "c18dc3ed330d4d97a0c9d7dba32a9217": {
791
  "model_module": "@jupyter-widgets/controls",
792
  "model_module_version": "2.0.0",
793
  "model_name": "HTMLStyleModel",
@@ -797,31 +1040,30 @@
797
  "text_color": null
798
  }
799
  },
800
- "c1b70a1ce9d149cf87169838a18f2e58": {
801
  "model_module": "@jupyter-widgets/controls",
802
  "model_module_version": "2.0.0",
803
- "model_name": "HTMLStyleModel",
804
  "state": {
805
- "description_width": "",
806
- "font_size": null,
807
- "text_color": null
 
 
808
  }
809
  },
810
- "ce85ada4df3c41e9a9b35b7401cd1883": {
811
  "model_module": "@jupyter-widgets/controls",
812
  "model_module_version": "2.0.0",
813
- "model_name": "HTMLModel",
814
  "state": {
815
- "layout": "IPY_MODEL_dae692ab00184ab190368530f21dcad9",
816
- "style": "IPY_MODEL_68c0e93ffde14a40b3599dff15512174",
817
- "value": " 320/320 [00:00&lt;00:00, 21.4kB/s]"
 
 
 
818
  }
819
- },
820
- "dae692ab00184ab190368530f21dcad9": {
821
- "model_module": "@jupyter-widgets/base",
822
- "model_module_version": "2.0.0",
823
- "model_name": "LayoutModel",
824
- "state": {}
825
  }
826
  },
827
  "version_major": 2,
 
1
  {
2
  "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "33faae25-af36-4781-bf8f-2084ddc96a52",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Setup"
9
+ ]
10
+ },
11
  {
12
  "cell_type": "code",
13
  "execution_count": 1,
14
  "id": "73e72549-69f2-46b5-b0f5-655777139972",
15
  "metadata": {
16
  "execution": {
17
+ "iopub.execute_input": "2025-01-20T20:17:03.803583Z",
18
+ "iopub.status.busy": "2025-01-20T20:17:03.803051Z",
19
+ "iopub.status.idle": "2025-01-20T20:17:06.786959Z",
20
+ "shell.execute_reply": "2025-01-20T20:17:06.786718Z",
21
+ "shell.execute_reply.started": "2025-01-20T20:17:03.803542Z"
22
  }
23
  },
24
  "outputs": [],
 
28
  "import torch\n",
29
  "from torch import nn\n",
30
  "from transformers import BertTokenizer, BertModel\n",
31
+ "from huggingface_hub import (\n",
32
+ " PyTorchModelHubMixin,\n",
33
+ " notebook_login,\n",
34
+ " ModelCard,\n",
35
+ " ModelCardData,\n",
36
+ " EvalResult,\n",
37
+ ")\n",
38
+ "from datasets import DatasetDict, load_dataset\n",
39
+ "from torch.utils.data import Dataset, DataLoader"
40
  ]
41
  },
42
  {
 
45
  "id": "07e0787e-c72b-41f3-baba-43cef3f8d6f8",
46
  "metadata": {
47
  "execution": {
48
+ "iopub.execute_input": "2025-01-20T20:17:06.787691Z",
49
+ "iopub.status.busy": "2025-01-20T20:17:06.787547Z",
50
+ "iopub.status.idle": "2025-01-20T20:17:06.789420Z",
51
+ "shell.execute_reply": "2025-01-20T20:17:06.789211Z",
52
+ "shell.execute_reply.started": "2025-01-20T20:17:06.787682Z"
53
  }
54
  },
55
  "outputs": [],
 
57
  "notebook_login(new_session=False)"
58
  ]
59
  },
60
+ {
61
+ "cell_type": "markdown",
62
+ "id": "a919d72c-8d10-4275-a2ca-4ead295f41a8",
63
+ "metadata": {},
64
+ "source": [
65
+ "# Functions"
66
+ ]
67
+ },
68
  {
69
  "cell_type": "code",
70
+ "execution_count": 3,
71
  "id": "d4b79fb9-5e70-4600-8885-94bc0a6e917c",
72
  "metadata": {
73
  "execution": {
74
+ "iopub.execute_input": "2025-01-20T20:17:06.789829Z",
75
+ "iopub.status.busy": "2025-01-20T20:17:06.789761Z",
76
+ "iopub.status.idle": "2025-01-20T20:17:06.794443Z",
77
+ "shell.execute_reply": "2025-01-20T20:17:06.794260Z",
78
+ "shell.execute_reply.started": "2025-01-20T20:17:06.789822Z"
79
  }
80
  },
81
  "outputs": [],
 
85
  " print(time_str, x)\n",
86
  "\n",
87
  "\n",
88
+ "def model_metrics(model, dataloader):\n",
89
+ " criterion = nn.CrossEntropyLoss()\n",
90
+ " model.eval()\n",
91
+ " with torch.no_grad():\n",
92
+ " total_loss = 0\n",
93
+ " total_correct = 0\n",
94
+ " total_length = 0\n",
95
+ " for batch in dataloader:\n",
96
+ " input_ids = batch[\"input_ids\"].to(device)\n",
97
+ " attention_mask = batch[\"attention_mask\"].to(device)\n",
98
+ " labels = batch[\"labels\"].to(device)\n",
99
+ "\n",
100
+ " outputs = model(input_ids, attention_mask)\n",
101
+ " loss = criterion(outputs, labels)\n",
102
+ " predictions_cpu = torch.argmax(outputs, dim=1).cpu().numpy()\n",
103
+ " labels_cpu = labels.cpu().numpy()\n",
104
+ " correct_count = (predictions_cpu == labels_cpu).sum()\n",
105
+ "\n",
106
+ " total_loss += loss.item()\n",
107
+ " total_correct += correct_count\n",
108
+ " total_length += len(labels_cpu)\n",
109
+ " avg_loss = total_loss / len(dataloader)\n",
110
+ " avg_acc = total_correct / total_length\n",
111
+ " model.train()\n",
112
+ " return avg_loss, avg_acc\n",
113
+ "\n",
114
+ "\n",
115
+ "def print_model_status(epoch, num_epochs, model, train_dataloader, test_dataloader):\n",
116
+ " train_loss, train_acc = model_metrics(model, train_dataloader)\n",
117
+ " test_loss, test_acc = model_metrics(model, test_dataloader)\n",
118
+ " loss_str = f\"Loss: Train {train_loss:0.3f}, Test {test_loss:0.3f}\"\n",
119
+ " acc_str = f\"Acc: Train {train_acc:0.3f}, Test {test_acc:0.3f}\"\n",
120
+ " my_print(f\"Epoch {epoch+1}/{num_epochs} done. {loss_str}; and {acc_str}\")\n",
121
+ "\n",
122
+ "\n",
123
  "class BertClassifier(nn.Module, PyTorchModelHubMixin):\n",
124
  " def __init__(self, num_labels=8, bert_variety=\"bert-base-uncased\"):\n",
125
  " super().__init__()\n",
 
155
  " return len(self.labels)\n",
156
  "\n",
157
  "\n",
158
+ "def train_model(model, train_dataloader, test_dataloader, device, num_epochs):\n",
159
  " optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)\n",
160
  " criterion = nn.CrossEntropyLoss()\n",
161
  " model.train()\n",
162
  "\n",
163
+ " print_model_status(-1, num_epochs, model, train_dataloader, test_dataloader)\n",
164
  " for epoch in range(num_epochs):\n",
165
  " total_loss = 0\n",
166
  " for batch in train_dataloader:\n",
 
178
  "\n",
179
  " total_loss += loss.item()\n",
180
  " avg_loss = total_loss / len(train_dataloader)\n",
181
+ " print_model_status(epoch, num_epochs, model, train_dataloader, test_dataloader)"
182
  ]
183
  },
184
  {
 
187
  "id": "07131bce-23ad-4787-8622-cce401f3e5ce",
188
  "metadata": {
189
  "execution": {
190
+ "iopub.execute_input": "2025-01-20T20:17:06.795335Z",
191
+ "iopub.status.busy": "2025-01-20T20:17:06.795239Z",
192
+ "iopub.status.idle": "2025-01-20T20:17:06.821293Z",
193
+ "shell.execute_reply": "2025-01-20T20:17:06.821061Z",
194
+ "shell.execute_reply.started": "2025-01-20T20:17:06.795328Z"
195
  }
196
  },
197
  "outputs": [],
 
211
  "id": "695bc080-bbd7-4937-af5b-50db1c936500",
212
  "metadata": {
213
  "execution": {
214
+ "iopub.execute_input": "2025-01-20T20:17:06.821637Z",
215
+ "iopub.status.busy": "2025-01-20T20:17:06.821569Z",
216
+ "iopub.status.idle": "2025-01-20T20:17:06.824265Z",
217
+ "shell.execute_reply": "2025-01-20T20:17:06.824082Z",
218
+ "shell.execute_reply.started": "2025-01-20T20:17:06.821630Z"
219
  }
220
  },
221
  "outputs": [],
 
228
  " batch_size=32,\n",
229
  "):\n",
230
  " hf_dataset = load_dataset(\"quotaclimat/frugalaichallenge-text-train\")\n",
231
+ " test_size = 0.2\n",
232
+ " test_seed = 42\n",
233
+ " train_test = hf_dataset[\"train\"].train_test_split(\n",
234
+ " test_size=test_size, seed=test_seed\n",
235
+ " )\n",
236
+ " train_dataset = train_test[\"train\"]\n",
237
+ " test_dataset = train_test[\"test\"]\n",
238
  " if not max_dataset_size == \"full\" and max_dataset_size < len(hf_dataset[\"train\"]):\n",
239
+ " train_dataset = train_dataset[:max_dataset_size]\n",
240
+ " test_dataset = test_dataset[:max_dataset_size]\n",
241
  " else:\n",
242
+ " train_dataset = train_dataset\n",
243
+ " test_dataset = test_dataset\n",
244
  "\n",
245
  " tokenizer = BertTokenizer.from_pretrained(bert_variety, max_length=max_length)\n",
246
  " model = BertClassifier(bert_variety=bert_variety)\n",
 
253
  " device = torch.device(\"cpu\")\n",
254
  " model.to(device)\n",
255
  "\n",
256
+ " text_dataset_train = TextDataset(\n",
257
  " train_dataset[\"quote\"],\n",
258
  " train_dataset[\"label\"],\n",
259
  " tokenizer=tokenizer,\n",
260
  " max_length=max_length,\n",
261
  " )\n",
262
+ " text_dataset_test = TextDataset(\n",
263
+ " test_dataset[\"quote\"],\n",
264
+ " test_dataset[\"label\"],\n",
265
+ " tokenizer=tokenizer,\n",
266
+ " max_length=max_length,\n",
267
+ " )\n",
268
+ " dataloader_train = DataLoader(\n",
269
+ " text_dataset_train, batch_size=batch_size, shuffle=True\n",
270
+ " )\n",
271
+ " dataloader_test = DataLoader(\n",
272
+ " text_dataset_test, batch_size=batch_size, shuffle=False\n",
273
+ " )\n",
274
  "\n",
275
+ " train_model(model, dataloader_train, dataloader_test, device, num_epochs=num_epochs)\n",
276
  " return model, tokenizer"
277
  ]
278
  },
279
+ {
280
+ "cell_type": "markdown",
281
+ "id": "5af751f3-1fc4-4540-ae25-638db9d33c67",
282
+ "metadata": {},
283
+ "source": [
284
+ "# Exploration"
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "markdown",
289
+ "id": "a847135f-ce86-46a1-9c61-3459a847cb29",
290
+ "metadata": {
291
+ "execution": {
292
+ "iopub.execute_input": "2025-01-20T19:13:05.482383Z",
293
+ "iopub.status.busy": "2025-01-20T19:13:05.481449Z",
294
+ "iopub.status.idle": "2025-01-20T19:13:05.487546Z",
295
+ "shell.execute_reply": "2025-01-20T19:13:05.486557Z",
296
+ "shell.execute_reply.started": "2025-01-20T19:13:05.482339Z"
297
+ }
298
+ },
299
+ "source": [
300
+ "## Check if runs"
301
+ ]
302
+ },
303
  {
304
  "cell_type": "code",
305
+ "execution_count": 6,
306
  "id": "792fd13f-e7cc-4d90-832d-c0da15e193cd",
307
  "metadata": {
308
  "execution": {
309
+ "iopub.execute_input": "2025-01-20T20:17:06.824513Z",
310
+ "iopub.status.busy": "2025-01-20T20:17:06.824457Z",
311
+ "iopub.status.idle": "2025-01-20T20:17:14.130284Z",
312
+ "shell.execute_reply": "2025-01-20T20:17:14.129964Z",
313
+ "shell.execute_reply.started": "2025-01-20T20:17:06.824506Z"
314
  }
315
  },
316
  "outputs": [
 
318
  "name": "stdout",
319
  "output_type": "stream",
320
  "text": [
321
+ "2025-01-20 12:17:10 Epoch 0/3 done. Loss: Train 2.111, Test 2.247; and Acc: Train 0.281, Test 0.156\n",
322
+ "2025-01-20 12:17:11 Epoch 1/3 done. Loss: Train 2.026, Test 2.222; and Acc: Train 0.344, Test 0.156\n",
323
+ "2025-01-20 12:17:12 Epoch 2/3 done. Loss: Train 1.943, Test 2.194; and Acc: Train 0.312, Test 0.156\n",
324
+ "2025-01-20 12:17:14 Epoch 3/3 done. Loss: Train 1.859, Test 2.159; and Acc: Train 0.344, Test 0.156\n"
325
  ]
326
  }
327
  ],
328
  "source": [
329
  "model, tokenizer = run_training(\n",
330
+ " max_dataset_size=16 * 2,\n",
331
  " bert_variety=\"bert-base-uncased\",\n",
332
  " max_length=128,\n",
333
  " num_epochs=3,\n",
 
337
  },
338
  {
339
  "cell_type": "code",
340
+ "execution_count": 7,
341
  "id": "0aedfcca-843e-4f4c-8062-3e4625161bcc",
342
  "metadata": {
343
  "execution": {
344
+ "iopub.execute_input": "2025-01-20T20:17:14.130879Z",
345
+ "iopub.status.busy": "2025-01-20T20:17:14.130792Z",
346
+ "iopub.status.idle": "2025-01-20T20:17:14.193695Z",
347
+ "shell.execute_reply": "2025-01-20T20:17:14.193466Z",
348
+ "shell.execute_reply.started": "2025-01-20T20:17:14.130869Z"
349
  }
350
  },
351
  "outputs": [
 
353
  "name": "stdout",
354
  "output_type": "stream",
355
  "text": [
356
+ "2025-01-20 12:17:14 Predictions: tensor([4, 1, 1, 1, 3, 1, 1], device='mps:0')\n"
357
  ]
358
  }
359
  ],
 
384
  ]
385
  },
386
  {
387
+ "cell_type": "markdown",
388
+ "id": "0c3ea938-dd87-4673-b1d6-f06c70b19455",
389
+ "metadata": {},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  "source": [
391
+ "## Hyperparameters"
 
 
 
 
 
 
392
  ]
393
  },
394
  {
 
397
  "id": "1d29336e-7f88-4127-afdf-2fe043e310e1",
398
  "metadata": {
399
  "execution": {
400
+ "iopub.execute_input": "2025-01-20T20:17:14.194160Z",
401
+ "iopub.status.busy": "2025-01-20T20:17:14.194076Z",
402
+ "iopub.status.idle": "2025-01-20T20:25:46.660251Z",
403
+ "shell.execute_reply": "2025-01-20T20:25:46.659652Z",
404
+ "shell.execute_reply.started": "2025-01-20T20:17:14.194152Z"
405
  }
406
  },
407
  "outputs": [
 
409
  "name": "stdout",
410
  "output_type": "stream",
411
  "text": [
412
+ "2025-01-20 12:18:02 Epoch 0/3 done. Loss: Train 2.106, Test 2.091; and Acc: Train 0.118, Test 0.135\n",
413
+ "2025-01-20 12:20:37 Epoch 1/3 done. Loss: Train 0.989, Test 1.114; and Acc: Train 0.647, Test 0.603\n",
414
+ "2025-01-20 12:23:12 Epoch 2/3 done. Loss: Train 0.584, Test 0.928; and Acc: Train 0.825, Test 0.669\n",
415
+ "2025-01-20 12:25:46 Epoch 3/3 done. Loss: Train 0.313, Test 0.950; and Acc: Train 0.913, Test 0.683\n"
416
  ]
417
  }
418
  ],
 
432
  "id": "461b8f57-0c52-403a-bb69-3bc192b323bf",
433
  "metadata": {
434
  "execution": {
435
+ "iopub.execute_input": "2025-01-20T20:25:46.661264Z",
436
+ "iopub.status.busy": "2025-01-20T20:25:46.661132Z",
437
+ "iopub.status.idle": "2025-01-20T20:34:54.221239Z",
438
+ "shell.execute_reply": "2025-01-20T20:34:54.220590Z",
439
+ "shell.execute_reply.started": "2025-01-20T20:25:46.661249Z"
440
  }
441
  },
442
  "outputs": [
 
444
  "name": "stdout",
445
  "output_type": "stream",
446
  "text": [
447
+ "2025-01-20 12:26:34 Epoch 0/3 done. Loss: Train 2.174, Test 2.168; and Acc: Train 0.096, Test 0.094\n",
448
+ "2025-01-20 12:29:21 Epoch 1/3 done. Loss: Train 0.878, Test 1.033; and Acc: Train 0.712, Test 0.653\n",
449
+ "2025-01-20 12:32:07 Epoch 2/3 done. Loss: Train 0.458, Test 0.906; and Acc: Train 0.869, Test 0.678\n",
450
+ "2025-01-20 12:34:54 Epoch 3/3 done. Loss: Train 0.218, Test 0.959; and Acc: Train 0.944, Test 0.695\n"
451
  ]
452
  }
453
  ],
 
463
  },
464
  {
465
  "cell_type": "code",
466
+ "execution_count": 10,
467
  "id": "28354e8c-886a-4523-8968-8c688c13f6a3",
468
  "metadata": {
469
  "execution": {
470
+ "iopub.execute_input": "2025-01-20T20:34:54.224989Z",
471
+ "iopub.status.busy": "2025-01-20T20:34:54.224772Z",
472
+ "iopub.status.idle": "2025-01-20T20:54:07.531338Z",
473
+ "shell.execute_reply": "2025-01-20T20:54:07.530559Z",
474
+ "shell.execute_reply.started": "2025-01-20T20:34:54.224968Z"
475
  }
476
  },
477
  "outputs": [
 
479
  "name": "stdout",
480
  "output_type": "stream",
481
  "text": [
482
+ "2025-01-20 12:36:37 Epoch 0/3 done. Loss: Train 2.122, Test 2.127; and Acc: Train 0.122, Test 0.118\n",
483
+ "2025-01-20 12:42:26 Epoch 1/3 done. Loss: Train 0.779, Test 0.978; and Acc: Train 0.748, Test 0.652\n",
484
+ "2025-01-20 12:48:16 Epoch 2/3 done. Loss: Train 0.391, Test 0.884; and Acc: Train 0.897, Test 0.696\n",
485
+ "2025-01-20 12:54:07 Epoch 3/3 done. Loss: Train 0.154, Test 0.978; and Acc: Train 0.959, Test 0.705\n"
486
  ]
487
  }
488
  ],
 
506
  },
507
  {
508
  "cell_type": "code",
509
+ "execution_count": 14,
510
+ "id": "ec2516f9-79f2-4ae1-ab9a-9a51a7a50587",
511
  "metadata": {
512
  "execution": {
513
+ "iopub.execute_input": "2025-01-20T22:10:34.055595Z",
514
+ "iopub.status.busy": "2025-01-20T22:10:34.054690Z",
515
+ "iopub.status.idle": "2025-01-20T22:10:34.083784Z",
516
+ "shell.execute_reply": "2025-01-20T22:10:34.083448Z",
517
+ "shell.execute_reply.started": "2025-01-20T22:10:34.055529Z"
518
+ },
519
+ "scrolled": true
520
  },
521
  "outputs": [
522
  {
523
  "name": "stdout",
524
  "output_type": "stream",
525
  "text": [
526
+ "---\n",
527
+ "base_model: google-bert/bert-base-uncased\n",
528
+ "datasets:\n",
529
+ "- QuotaClimat/frugalaichallenge-text-train\n",
530
+ "language:\n",
531
+ "- en\n",
532
+ "license: apache-2.0\n",
533
+ "model_name: frugal-ai-text-bert-base\n",
534
+ "pipeline_tag: text-classification\n",
535
+ "tags:\n",
536
+ "- model_hub_mixin\n",
537
+ "- pytorch_model_hub_mixin\n",
538
+ "- climate\n",
539
+ "---\n",
540
+ "\n",
541
+ "# Model Card for Model ID\n",
542
+ "\n",
543
+ "<!-- Provide a quick summary of what the model is/does. -->\n",
544
+ "\n",
545
+ "Classify text into 8 categories of climate misinformation.\n",
546
+ "\n",
547
+ "## Model Details\n",
548
+ "\n",
549
+ "### Model Description\n",
550
+ "\n",
551
+ "<!-- Provide a longer summary of what this model is. -->\n",
552
+ "\n",
553
+ "Fine trained BERT for classifying climate information as part of the Frugal AI Challenge, for submission to https://huggingface.co/frugal-ai-challenge and scoring on accuracy and efficiency. Trainied on only the non-evaluation 80% of the data, so it's (non-cheating) score will be lower.\n",
554
+ "\n",
555
+ "- **Developed by:** Andre Bach\n",
556
+ "- **Funded by [optional]:** N/A\n",
557
+ "- **Shared by [optional]:** Andre Bach\n",
558
+ "- **Model type:** Text classification\n",
559
+ "- **Language(s) (NLP):** ['en']\n",
560
+ "- **License:** apache-2.0\n",
561
+ "- **Finetuned from model [optional]:** google-bert/bert-base-uncased\n",
562
+ "\n",
563
+ "### Model Sources [optional]\n",
564
+ "\n",
565
+ "<!-- Provide the basic links for the model. -->\n",
566
+ "\n",
567
+ "- **Repository:** frugal-ai-text-bert-base\n",
568
+ "- **Paper [optional]:** [More Information Needed]\n",
569
+ "- **Demo [optional]:** [More Information Needed]\n",
570
+ "\n",
571
+ "## Uses\n",
572
+ "\n",
573
+ "<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->\n",
574
+ "\n",
575
+ "### Direct Use\n",
576
+ "\n",
577
+ "<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->\n",
578
+ "\n",
579
+ "[More Information Needed]\n",
580
+ "\n",
581
+ "### Downstream Use [optional]\n",
582
+ "\n",
583
+ "<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->\n",
584
+ "\n",
585
+ "[More Information Needed]\n",
586
+ "\n",
587
+ "### Out-of-Scope Use\n",
588
+ "\n",
589
+ "<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->\n",
590
+ "\n",
591
+ "[More Information Needed]\n",
592
+ "\n",
593
+ "## Bias, Risks, and Limitations\n",
594
+ "\n",
595
+ "<!-- This section is meant to convey both technical and sociotechnical limitations. -->\n",
596
+ "\n",
597
+ "[More Information Needed]\n",
598
+ "\n",
599
+ "### Recommendations\n",
600
+ "\n",
601
+ "<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->\n",
602
+ "\n",
603
+ "Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.\n",
604
+ "\n",
605
+ "## How to Get Started with the Model\n",
606
+ "\n",
607
+ "Use the code below to get started with the model.\n",
608
+ "\n",
609
+ "[More Information Needed]\n",
610
+ "\n",
611
+ "## Training Details\n",
612
+ "\n",
613
+ "### Training Data\n",
614
+ "\n",
615
+ "<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->\n",
616
+ "\n",
617
+ "[More Information Needed]\n",
618
+ "\n",
619
+ "### Training Procedure\n",
620
+ "\n",
621
+ "<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->\n",
622
+ "\n",
623
+ "#### Preprocessing [optional]\n",
624
+ "\n",
625
+ "[More Information Needed]\n",
626
+ "\n",
627
+ "\n",
628
+ "#### Training Hyperparameters\n",
629
+ "\n",
630
+ "- **Training regime:** {'max_dataset_size': 'full', 'bert_variety': 'bert-base-uncased', 'max_length': 256, 'num_epochs': 3, 'batch_size': 16} <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->\n",
631
+ "\n",
632
+ "#### Speeds, Sizes, Times [optional]\n",
633
+ "\n",
634
+ "<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->\n",
635
+ "\n",
636
+ "[More Information Needed]\n",
637
+ "\n",
638
+ "## Evaluation\n",
639
+ "\n",
640
+ "<!-- This section describes the evaluation protocols and provides the results. -->\n",
641
+ "\n",
642
+ "### Testing Data, Factors & Metrics\n",
643
+ "\n",
644
+ "#### Testing Data\n",
645
+ "\n",
646
+ "<!-- This should link to a Dataset Card if possible. -->\n",
647
+ "\n",
648
+ "[More Information Needed]\n",
649
+ "\n",
650
+ "#### Factors\n",
651
+ "\n",
652
+ "<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->\n",
653
+ "\n",
654
+ "[More Information Needed]\n",
655
+ "\n",
656
+ "#### Metrics\n",
657
+ "\n",
658
+ "<!-- These are the evaluation metrics being used, ideally with a description of why. -->\n",
659
+ "\n",
660
+ "{'loss_train': 0.154, 'loss_test': 0.978, 'acc_train': 0.959, 'acc_test': 0.705}\n",
661
+ "\n",
662
+ "### Results\n",
663
+ "\n",
664
+ "[More Information Needed]\n",
665
+ "\n",
666
+ "#### Summary\n",
667
+ "\n",
668
+ "\n",
669
+ "\n",
670
+ "## Model Examination [optional]\n",
671
+ "\n",
672
+ "<!-- Relevant interpretability work for the model goes here -->\n",
673
+ "\n",
674
+ "[More Information Needed]\n",
675
+ "\n",
676
+ "## Environmental Impact\n",
677
+ "\n",
678
+ "<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->\n",
679
+ "\n",
680
+ "Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).\n",
681
+ "\n",
682
+ "- **Hardware Type:** [More Information Needed]\n",
683
+ "- **Hours used:** [More Information Needed]\n",
684
+ "- **Cloud Provider:** [More Information Needed]\n",
685
+ "- **Compute Region:** [More Information Needed]\n",
686
+ "- **Carbon Emitted:** [More Information Needed]\n",
687
+ "\n",
688
+ "## Technical Specifications [optional]\n",
689
+ "\n",
690
+ "### Model Architecture and Objective\n",
691
+ "\n",
692
+ "[More Information Needed]\n",
693
+ "\n",
694
+ "### Compute Infrastructure\n",
695
+ "\n",
696
+ "[More Information Needed]\n",
697
+ "\n",
698
+ "#### Hardware\n",
699
+ "\n",
700
+ "[More Information Needed]\n",
701
+ "\n",
702
+ "#### Software\n",
703
+ "\n",
704
+ "[More Information Needed]\n",
705
+ "\n",
706
+ "## Citation [optional]\n",
707
+ "\n",
708
+ "<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->\n",
709
+ "\n",
710
+ "**BibTeX:**\n",
711
+ "\n",
712
+ "[More Information Needed]\n",
713
+ "\n",
714
+ "**APA:**\n",
715
+ "\n",
716
+ "[More Information Needed]\n",
717
+ "\n",
718
+ "## Glossary [optional]\n",
719
+ "\n",
720
+ "<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->\n",
721
+ "\n",
722
+ "[More Information Needed]\n",
723
+ "\n",
724
+ "## More Information [optional]\n",
725
+ "\n",
726
+ "[More Information Needed]\n",
727
+ "\n",
728
+ "## Model Card Authors [optional]\n",
729
+ "\n",
730
+ "[More Information Needed]\n",
731
+ "\n",
732
+ "## Model Card Contact\n",
733
+ "\n",
734
+ "[More Information Needed]\n"
735
  ]
736
  }
737
  ],
738
  "source": [
739
+ "model_and_repo_name = \"frugal-ai-text-bert-base\"\n",
740
+ "card_data = ModelCardData(\n",
741
+ " model_name=model_and_repo_name,\n",
742
+ " base_model=\"google-bert/bert-base-uncased\",\n",
743
+ " license=\"apache-2.0\",\n",
744
+ " language=[\"en\"],\n",
745
+ " datasets=[\"QuotaClimat/frugalaichallenge-text-train\"],\n",
746
+ " tags=[\"model_hub_mixin\", \"pytorch_model_hub_mixin\", \"climate\"],\n",
747
+ " pipeline_tag=\"text-classification\",\n",
748
+ ")\n",
749
+ "card = ModelCard.from_template(\n",
750
+ " card_data,\n",
751
+ " model_summary=\"Classify text into 8 categories of climate misinformation.\",\n",
752
+ " model_description=\"Fine trained BERT for classifying climate information as part of the Frugal AI Challenge, for submission to https://huggingface.co/frugal-ai-challenge and scoring on accuracy and efficiency. Trainied on only the non-evaluation 80% of the data, so it's (non-cheating) score will be lower.\",\n",
753
+ " developers=\"Andre Bach\",\n",
754
+ " funded_by=\"N/A\",\n",
755
+ " shared_by=\"Andre Bach\",\n",
756
+ " model_type=\"Text classification\",\n",
757
+ " repo=model_and_repo_name,\n",
758
+ " training_regime=dict(\n",
759
+ " max_dataset_size=\"full\",\n",
760
+ " bert_variety=\"bert-base-uncased\",\n",
761
+ " max_length=256,\n",
762
+ " num_epochs=3,\n",
763
+ " batch_size=16,\n",
764
+ " ),\n",
765
+ " testing_metrics=dict(\n",
766
+ " loss_train=0.154, loss_test=0.978, acc_train=0.959, acc_test=0.705\n",
767
+ " ),\n",
768
+ ")\n",
769
+ "# print(card_data.to_yaml())\n",
770
+ "print(card)"
771
  ]
772
  },
773
  {
774
  "cell_type": "code",
775
+ "execution_count": 17,
776
+ "id": "29d3bbf9-ab2a-48e2-a550-e16da5025720",
777
+ "metadata": {
778
+ "execution": {
779
+ "iopub.execute_input": "2025-01-20T22:11:59.827681Z",
780
+ "iopub.status.busy": "2025-01-20T22:11:59.827001Z",
781
+ "iopub.status.idle": "2025-01-20T22:11:59.831852Z",
782
+ "shell.execute_reply": "2025-01-20T22:11:59.831047Z",
783
+ "shell.execute_reply.started": "2025-01-20T22:11:59.827635Z"
784
+ }
785
+ },
786
+ "outputs": [],
787
+ "source": [
788
+ "model_final = model\n",
789
+ "tokenizer_final = tokenizer"
790
+ ]
791
+ },
792
+ {
793
+ "cell_type": "code",
794
+ "execution_count": 18,
795
  "id": "e3b099c6-6b98-473b-8797-5032213b9fcb",
796
  "metadata": {
797
  "execution": {
798
+ "iopub.execute_input": "2025-01-20T22:12:00.576369Z",
799
+ "iopub.status.busy": "2025-01-20T22:12:00.575421Z",
800
+ "iopub.status.idle": "2025-01-20T22:12:01.065512Z",
801
+ "shell.execute_reply": "2025-01-20T22:12:01.065237Z",
802
+ "shell.execute_reply.started": "2025-01-20T22:12:00.576294Z"
803
  }
804
  },
805
  "outputs": [
 
807
  "name": "stdout",
808
  "output_type": "stream",
809
  "text": [
810
+ "2025-01-20 14:12:01 Predictions: tensor([0, 0, 3, 6, 2, 4, 6], device='mps:0')\n"
811
  ]
812
  }
813
  ],
 
839
  },
840
  {
841
  "cell_type": "code",
842
+ "execution_count": 19,
843
  "id": "befb94b5-88bf-40fc-8b26-cf373d1256e0",
844
  "metadata": {
845
  "execution": {
846
+ "iopub.execute_input": "2025-01-20T22:12:15.099356Z",
847
+ "iopub.status.busy": "2025-01-20T22:12:15.098818Z",
848
+ "iopub.status.idle": "2025-01-20T22:12:33.175760Z",
849
+ "shell.execute_reply": "2025-01-20T22:12:33.174719Z",
850
+ "shell.execute_reply.started": "2025-01-20T22:12:15.099315Z"
851
  }
852
  },
853
  "outputs": [
854
  {
855
  "data": {
856
  "application/vnd.jupyter.widget-view+json": {
857
+ "model_id": "fbc09ae2c5614831a2fb02fa48a44fd1",
858
  "version_major": 2,
859
  "version_minor": 0
860
  },
 
868
  {
869
  "data": {
870
  "text/plain": [
871
+ "CommitInfo(commit_url='https://huggingface.co/Nonnormalizable/frugal-ai-text-bert-base/commit/bdc2daf80d9647566ef56297f2cdc32f898170df', commit_message='Push model using huggingface_hub.', commit_description='', oid='bdc2daf80d9647566ef56297f2cdc32f898170df', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Nonnormalizable/frugal-ai-text-bert-base', endpoint='https://huggingface.co', repo_type='model', repo_id='Nonnormalizable/frugal-ai-text-bert-base'), pr_revision=None, pr_num=None)"
872
  ]
873
  },
874
+ "execution_count": 19,
875
  "metadata": {},
876
  "output_type": "execute_result"
877
  }
 
882
  },
883
  {
884
  "cell_type": "code",
885
+ "execution_count": 20,
886
  "id": "251ef9ee-8ba3-495f-8fe6-a93aa63168ce",
887
  "metadata": {
888
  "execution": {
889
+ "iopub.execute_input": "2025-01-20T22:12:33.178424Z",
890
+ "iopub.status.busy": "2025-01-20T22:12:33.178028Z",
891
+ "iopub.status.idle": "2025-01-20T22:12:34.321979Z",
892
+ "shell.execute_reply": "2025-01-20T22:12:34.320974Z",
893
+ "shell.execute_reply.started": "2025-01-20T22:12:33.178397Z"
894
  }
895
  },
896
  "outputs": [
897
  {
898
  "data": {
 
 
 
 
 
899
  "text/plain": [
900
+ "CommitInfo(commit_url='https://huggingface.co/Nonnormalizable/frugal-ai-text-bert-base/commit/9081285a20fa0d62c5c1580aa17884de2b3bc236', commit_message='Upload tokenizer', commit_description='', oid='9081285a20fa0d62c5c1580aa17884de2b3bc236', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Nonnormalizable/frugal-ai-text-bert-base', endpoint='https://huggingface.co', repo_type='model', repo_id='Nonnormalizable/frugal-ai-text-bert-base'), pr_revision=None, pr_num=None)"
901
  ]
902
  },
903
+ "execution_count": 20,
904
  "metadata": {},
905
+ "output_type": "execute_result"
906
+ }
907
+ ],
908
+ "source": [
909
+ "tokenizer_final.push_to_hub(\"frugal-ai-text-bert-base\")"
910
+ ]
911
+ },
912
+ {
913
+ "cell_type": "code",
914
+ "execution_count": 21,
915
+ "id": "863d3553-89a6-4188-a8d0-eaa0b6bccb6c",
916
+ "metadata": {
917
+ "execution": {
918
+ "iopub.execute_input": "2025-01-20T22:12:34.324003Z",
919
+ "iopub.status.busy": "2025-01-20T22:12:34.323725Z",
920
+ "iopub.status.idle": "2025-01-20T22:12:35.350962Z",
921
+ "shell.execute_reply": "2025-01-20T22:12:35.350482Z",
922
+ "shell.execute_reply.started": "2025-01-20T22:12:34.323976Z"
923
+ }
924
+ },
925
+ "outputs": [
926
  {
927
  "data": {
928
  "text/plain": [
929
+ "CommitInfo(commit_url='https://huggingface.co/Nonnormalizable/frugal-ai-text-bert-base/commit/b3078a95ea36d71c1d1bf0d153e069b83f74bddf', commit_message='Upload README.md with huggingface_hub', commit_description='', oid='b3078a95ea36d71c1d1bf0d153e069b83f74bddf', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Nonnormalizable/frugal-ai-text-bert-base', endpoint='https://huggingface.co', repo_type='model', repo_id='Nonnormalizable/frugal-ai-text-bert-base'), pr_revision=None, pr_num=None)"
930
  ]
931
  },
932
+ "execution_count": 21,
933
  "metadata": {},
934
  "output_type": "execute_result"
935
  }
936
  ],
937
  "source": [
938
+ "card.push_to_hub(\"Nonnormalizable/frugal-ai-text-bert-base\")"
939
  ]
940
  },
941
  {
942
  "cell_type": "code",
943
  "execution_count": null,
944
+ "id": "2c22cc30-7578-4aad-b7db-1ffe4954c46c",
945
  "metadata": {},
946
  "outputs": [],
947
  "source": []
 
968
  "widgets": {
969
  "application/vnd.jupyter.widget-state+json": {
970
  "state": {
971
+ "47fba054bcbc4563934b6d25ea787e43": {
 
 
 
 
 
 
 
 
972
  "model_module": "@jupyter-widgets/base",
973
  "model_module_version": "2.0.0",
974
  "model_name": "LayoutModel",
975
  "state": {}
976
  },
977
+ "5cdf8fe39a634d048f2140b3af85165f": {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
978
  "model_module": "@jupyter-widgets/base",
979
  "model_module_version": "2.0.0",
980
  "model_name": "LayoutModel",
981
  "state": {}
982
  },
983
+ "6a6b93c568744ed48ba6c58f84c3d59a": {
984
  "model_module": "@jupyter-widgets/base",
985
  "model_module_version": "2.0.0",
986
  "model_name": "LayoutModel",
987
  "state": {}
988
  },
989
+ "802b81b278a34a1a9ed480ca2ae299a0": {
990
+ "model_module": "@jupyter-widgets/controls",
991
+ "model_module_version": "2.0.0",
992
+ "model_name": "HTMLModel",
993
+ "state": {
994
+ "layout": "IPY_MODEL_47fba054bcbc4563934b6d25ea787e43",
995
+ "style": "IPY_MODEL_cab10a06b0064a4f876d47bbd5dda288",
996
+ "value": "model.safetensors: 100%"
997
+ }
998
+ },
999
+ "80984aaf16ce41ce839cc4bd5c0ea202": {
1000
  "model_module": "@jupyter-widgets/base",
1001
  "model_module_version": "2.0.0",
1002
  "model_name": "LayoutModel",
1003
  "state": {}
1004
  },
1005
+ "87a62c5c11cc43649d6ce177ab39f244": {
1006
  "model_module": "@jupyter-widgets/controls",
1007
  "model_module_version": "2.0.0",
1008
  "model_name": "HTMLStyleModel",
 
1012
  "text_color": null
1013
  }
1014
  },
1015
+ "8b033d0c246145a082c43e73d1377035": {
 
 
 
 
 
 
 
 
 
 
 
 
1016
  "model_module": "@jupyter-widgets/controls",
1017
  "model_module_version": "2.0.0",
1018
+ "model_name": "HTMLModel",
1019
  "state": {
1020
+ "layout": "IPY_MODEL_5cdf8fe39a634d048f2140b3af85165f",
1021
+ "style": "IPY_MODEL_87a62c5c11cc43649d6ce177ab39f244",
1022
+ "value": " 438M/438M [00:15&lt;00:00, 22.9MB/s]"
 
 
 
1023
  }
1024
  },
1025
+ "c5eebb3e916e4c59864d29582ab336bf": {
1026
  "model_module": "@jupyter-widgets/controls",
1027
  "model_module_version": "2.0.0",
1028
  "model_name": "ProgressStyleModel",
 
1030
  "description_width": ""
1031
  }
1032
  },
1033
+ "cab10a06b0064a4f876d47bbd5dda288": {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1034
  "model_module": "@jupyter-widgets/controls",
1035
  "model_module_version": "2.0.0",
1036
  "model_name": "HTMLStyleModel",
 
1040
  "text_color": null
1041
  }
1042
  },
1043
+ "d83e79effc3542f49c38928463bb41ec": {
1044
  "model_module": "@jupyter-widgets/controls",
1045
  "model_module_version": "2.0.0",
1046
+ "model_name": "FloatProgressModel",
1047
  "state": {
1048
+ "bar_style": "success",
1049
+ "layout": "IPY_MODEL_6a6b93c568744ed48ba6c58f84c3d59a",
1050
+ "max": 437977072,
1051
+ "style": "IPY_MODEL_c5eebb3e916e4c59864d29582ab336bf",
1052
+ "value": 437977072
1053
  }
1054
  },
1055
+ "fbc09ae2c5614831a2fb02fa48a44fd1": {
1056
  "model_module": "@jupyter-widgets/controls",
1057
  "model_module_version": "2.0.0",
1058
+ "model_name": "HBoxModel",
1059
  "state": {
1060
+ "children": [
1061
+ "IPY_MODEL_802b81b278a34a1a9ed480ca2ae299a0",
1062
+ "IPY_MODEL_d83e79effc3542f49c38928463bb41ec",
1063
+ "IPY_MODEL_8b033d0c246145a082c43e73d1377035"
1064
+ ],
1065
+ "layout": "IPY_MODEL_80984aaf16ce41ce839cc4bd5c0ea202"
1066
  }
 
 
 
 
 
 
1067
  }
1068
  },
1069
  "version_major": 2,
tasks/text.py CHANGED
@@ -70,20 +70,13 @@ def bert_model(test_dataset: dict, model_type: str):
70
  return predictions
71
 
72
 
73
- @router.post("/text-bert-base", tags=["Text Task"])
74
- async def evauate_text_model_1(request: TextEvaluationRequest):
75
- return evaluate_text(request, model_type="bert-base")
76
-
77
-
78
- @router.post("/text-baseline", tags=["Text Task"])
79
- async def evauate_text_model_2(request: TextEvaluationRequest):
80
- return evaluate_text(request, model_type="baseline")
81
-
82
-
83
  @router.post(ROUTE, tags=["Text Task"])
84
  async def evaluate_text(
85
  request: TextEvaluationRequest,
86
  model_type: str = "bert-base",
 
 
 
87
  ):
88
  """
89
  Evaluate text classification for climate disinformation detection.
 
70
  return predictions
71
 
72
 
 
 
 
 
 
 
 
 
 
 
73
  @router.post(ROUTE, tags=["Text Task"])
74
  async def evaluate_text(
75
  request: TextEvaluationRequest,
76
  model_type: str = "bert-base",
77
+ # This should be an API query parameter, but it looks like the submission repo
78
+ # https://huggingface.co/spaces/frugal-ai-challenge/submission-portal
79
+ # is built in a way to not accept any other endpoints or parameters.
80
  ):
81
  """
82
  Evaluate text classification for climate disinformation detection.