Nonnormalizable commited on
Commit
7bc734f
·
1 Parent(s): 250d2de

Training bert-tiny. More integratoin with model card data.

Browse files
Files changed (1) hide show
  1. Finetune BERT.ipynb +163 -462
Finetune BERT.ipynb CHANGED
@@ -14,11 +14,11 @@
14
  "id": "73e72549-69f2-46b5-b0f5-655777139972",
15
  "metadata": {
16
  "execution": {
17
- "iopub.execute_input": "2025-01-20T20:17:03.803583Z",
18
- "iopub.status.busy": "2025-01-20T20:17:03.803051Z",
19
- "iopub.status.idle": "2025-01-20T20:17:06.786959Z",
20
- "shell.execute_reply": "2025-01-20T20:17:06.786718Z",
21
- "shell.execute_reply.started": "2025-01-20T20:17:03.803542Z"
22
  }
23
  },
24
  "outputs": [],
@@ -45,11 +45,11 @@
45
  "id": "07e0787e-c72b-41f3-baba-43cef3f8d6f8",
46
  "metadata": {
47
  "execution": {
48
- "iopub.execute_input": "2025-01-20T20:17:06.787691Z",
49
- "iopub.status.busy": "2025-01-20T20:17:06.787547Z",
50
- "iopub.status.idle": "2025-01-20T20:17:06.789420Z",
51
- "shell.execute_reply": "2025-01-20T20:17:06.789211Z",
52
- "shell.execute_reply.started": "2025-01-20T20:17:06.787682Z"
53
  }
54
  },
55
  "outputs": [],
@@ -71,11 +71,11 @@
71
  "id": "d4b79fb9-5e70-4600-8885-94bc0a6e917c",
72
  "metadata": {
73
  "execution": {
74
- "iopub.execute_input": "2025-01-20T20:17:06.789829Z",
75
- "iopub.status.busy": "2025-01-20T20:17:06.789761Z",
76
- "iopub.status.idle": "2025-01-20T20:17:06.794443Z",
77
- "shell.execute_reply": "2025-01-20T20:17:06.794260Z",
78
- "shell.execute_reply.started": "2025-01-20T20:17:06.789822Z"
79
  }
80
  },
81
  "outputs": [],
@@ -109,7 +109,7 @@
109
  " avg_loss = total_loss / len(dataloader)\n",
110
  " avg_acc = total_correct / total_length\n",
111
  " model.train()\n",
112
- " return avg_loss, avg_acc\n",
113
  "\n",
114
  "\n",
115
  "def print_model_status(epoch, num_epochs, model, train_dataloader, test_dataloader):\n",
@@ -117,7 +117,14 @@
117
  " test_loss, test_acc = model_metrics(model, test_dataloader)\n",
118
  " loss_str = f\"Loss: Train {train_loss:0.3f}, Test {test_loss:0.3f}\"\n",
119
  " acc_str = f\"Acc: Train {train_acc:0.3f}, Test {test_acc:0.3f}\"\n",
120
- " my_print(f\"Epoch {epoch+1}/{num_epochs} done. {loss_str}; and {acc_str}\")\n",
 
 
 
 
 
 
 
121
  "\n",
122
  "\n",
123
  "class BertClassifier(nn.Module, PyTorchModelHubMixin):\n",
@@ -136,7 +143,7 @@
136
  "\n",
137
  "\n",
138
  "class TextDataset(Dataset):\n",
139
- " def __init__(self, texts, labels, tokenizer, max_length=512):\n",
140
  " self.encodings = tokenizer(\n",
141
  " texts,\n",
142
  " truncation=True,\n",
@@ -160,7 +167,7 @@
160
  " criterion = nn.CrossEntropyLoss()\n",
161
  " model.train()\n",
162
  "\n",
163
- " print_model_status(-1, num_epochs, model, train_dataloader, test_dataloader)\n",
164
  " for epoch in range(num_epochs):\n",
165
  " total_loss = 0\n",
166
  " for batch in train_dataloader:\n",
@@ -178,7 +185,10 @@
178
  "\n",
179
  " total_loss += loss.item()\n",
180
  " avg_loss = total_loss / len(train_dataloader)\n",
181
- " print_model_status(epoch, num_epochs, model, train_dataloader, test_dataloader)"
 
 
 
182
  ]
183
  },
184
  {
@@ -187,11 +197,11 @@
187
  "id": "07131bce-23ad-4787-8622-cce401f3e5ce",
188
  "metadata": {
189
  "execution": {
190
- "iopub.execute_input": "2025-01-20T20:17:06.795335Z",
191
- "iopub.status.busy": "2025-01-20T20:17:06.795239Z",
192
- "iopub.status.idle": "2025-01-20T20:17:06.821293Z",
193
- "shell.execute_reply": "2025-01-20T20:17:06.821061Z",
194
- "shell.execute_reply.started": "2025-01-20T20:17:06.795328Z"
195
  }
196
  },
197
  "outputs": [],
@@ -211,11 +221,11 @@
211
  "id": "695bc080-bbd7-4937-af5b-50db1c936500",
212
  "metadata": {
213
  "execution": {
214
- "iopub.execute_input": "2025-01-20T20:17:06.821637Z",
215
- "iopub.status.busy": "2025-01-20T20:17:06.821569Z",
216
- "iopub.status.idle": "2025-01-20T20:17:06.824265Z",
217
- "shell.execute_reply": "2025-01-20T20:17:06.824082Z",
218
- "shell.execute_reply.started": "2025-01-20T20:17:06.821630Z"
219
  }
220
  },
221
  "outputs": [],
@@ -223,10 +233,17 @@
223
  "def run_training(\n",
224
  " max_dataset_size=16 * 200,\n",
225
  " bert_variety=\"bert-base-uncased\",\n",
226
- " max_length=200,\n",
227
  " num_epochs=3,\n",
228
  " batch_size=32,\n",
229
  "):\n",
 
 
 
 
 
 
 
230
  " hf_dataset = load_dataset(\"quotaclimat/frugalaichallenge-text-train\")\n",
231
  " test_size = 0.2\n",
232
  " test_seed = 42\n",
@@ -272,8 +289,10 @@
272
  " text_dataset_test, batch_size=batch_size, shuffle=False\n",
273
  " )\n",
274
  "\n",
275
- " train_model(model, dataloader_train, dataloader_test, device, num_epochs=num_epochs)\n",
276
- " return model, tokenizer"
 
 
277
  ]
278
  },
279
  {
@@ -302,61 +321,57 @@
302
  },
303
  {
304
  "cell_type": "code",
305
- "execution_count": 6,
306
- "id": "792fd13f-e7cc-4d90-832d-c0da15e193cd",
307
  "metadata": {
308
  "execution": {
309
- "iopub.execute_input": "2025-01-20T20:17:06.824513Z",
310
- "iopub.status.busy": "2025-01-20T20:17:06.824457Z",
311
- "iopub.status.idle": "2025-01-20T20:17:14.130284Z",
312
- "shell.execute_reply": "2025-01-20T20:17:14.129964Z",
313
- "shell.execute_reply.started": "2025-01-20T20:17:06.824506Z"
314
  }
315
  },
316
- "outputs": [
317
- {
318
- "name": "stdout",
319
- "output_type": "stream",
320
- "text": [
321
- "2025-01-20 12:17:10 Epoch 0/3 done. Loss: Train 2.111, Test 2.247; and Acc: Train 0.281, Test 0.156\n",
322
- "2025-01-20 12:17:11 Epoch 1/3 done. Loss: Train 2.026, Test 2.222; and Acc: Train 0.344, Test 0.156\n",
323
- "2025-01-20 12:17:12 Epoch 2/3 done. Loss: Train 1.943, Test 2.194; and Acc: Train 0.312, Test 0.156\n",
324
- "2025-01-20 12:17:14 Epoch 3/3 done. Loss: Train 1.859, Test 2.159; and Acc: Train 0.344, Test 0.156\n"
325
- ]
326
- }
327
- ],
328
  "source": [
329
- "model, tokenizer = run_training(\n",
330
- " max_dataset_size=16 * 2,\n",
331
- " bert_variety=\"bert-base-uncased\",\n",
332
  " max_length=128,\n",
333
- " num_epochs=3,\n",
334
  " batch_size=32,\n",
335
  ")"
336
  ]
337
  },
338
  {
339
  "cell_type": "code",
340
- "execution_count": 7,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  "id": "0aedfcca-843e-4f4c-8062-3e4625161bcc",
342
  "metadata": {
343
- "execution": {
344
- "iopub.execute_input": "2025-01-20T20:17:14.130879Z",
345
- "iopub.status.busy": "2025-01-20T20:17:14.130792Z",
346
- "iopub.status.idle": "2025-01-20T20:17:14.193695Z",
347
- "shell.execute_reply": "2025-01-20T20:17:14.193466Z",
348
- "shell.execute_reply.started": "2025-01-20T20:17:14.130869Z"
349
- }
350
  },
351
- "outputs": [
352
- {
353
- "name": "stdout",
354
- "output_type": "stream",
355
- "text": [
356
- "2025-01-20 12:17:14 Predictions: tensor([4, 1, 1, 1, 3, 1, 1], device='mps:0')\n"
357
- ]
358
- }
359
- ],
360
  "source": [
361
  "model.eval()\n",
362
  "test_text = [\n",
@@ -373,6 +388,7 @@
373
  " truncation=True,\n",
374
  " padding=True,\n",
375
  " return_tensors=\"pt\",\n",
 
376
  ")\n",
377
  "\n",
378
  "with torch.no_grad():\n",
@@ -392,86 +408,66 @@
392
  ]
393
  },
394
  {
395
- "cell_type": "code",
396
- "execution_count": 8,
397
- "id": "1d29336e-7f88-4127-afdf-2fe043e310e1",
398
- "metadata": {
399
- "execution": {
400
- "iopub.execute_input": "2025-01-20T20:17:14.194160Z",
401
- "iopub.status.busy": "2025-01-20T20:17:14.194076Z",
402
- "iopub.status.idle": "2025-01-20T20:25:46.660251Z",
403
- "shell.execute_reply": "2025-01-20T20:25:46.659652Z",
404
- "shell.execute_reply.started": "2025-01-20T20:17:14.194152Z"
405
- }
406
- },
407
- "outputs": [
408
- {
409
- "name": "stdout",
410
- "output_type": "stream",
411
- "text": [
412
- "2025-01-20 12:18:02 Epoch 0/3 done. Loss: Train 2.106, Test 2.091; and Acc: Train 0.118, Test 0.135\n",
413
- "2025-01-20 12:20:37 Epoch 1/3 done. Loss: Train 0.989, Test 1.114; and Acc: Train 0.647, Test 0.603\n",
414
- "2025-01-20 12:23:12 Epoch 2/3 done. Loss: Train 0.584, Test 0.928; and Acc: Train 0.825, Test 0.669\n",
415
- "2025-01-20 12:25:46 Epoch 3/3 done. Loss: Train 0.313, Test 0.950; and Acc: Train 0.913, Test 0.683\n"
416
- ]
417
- }
418
- ],
419
  "source": [
420
- "model, tokenizer = run_training(\n",
421
- " max_dataset_size=\"full\",\n",
422
- " bert_variety=\"bert-base-uncased\",\n",
423
- " max_length=128,\n",
424
- " num_epochs=3,\n",
425
- " batch_size=32,\n",
426
- ")"
 
 
 
 
 
 
 
 
 
 
 
 
 
427
  ]
428
  },
429
  {
430
  "cell_type": "code",
431
- "execution_count": 9,
432
- "id": "461b8f57-0c52-403a-bb69-3bc192b323bf",
433
  "metadata": {
434
  "execution": {
435
- "iopub.execute_input": "2025-01-20T20:25:46.661264Z",
436
- "iopub.status.busy": "2025-01-20T20:25:46.661132Z",
437
- "iopub.status.idle": "2025-01-20T20:34:54.221239Z",
438
- "shell.execute_reply": "2025-01-20T20:34:54.220590Z",
439
- "shell.execute_reply.started": "2025-01-20T20:25:46.661249Z"
440
  }
441
  },
442
- "outputs": [
443
- {
444
- "name": "stdout",
445
- "output_type": "stream",
446
- "text": [
447
- "2025-01-20 12:26:34 Epoch 0/3 done. Loss: Train 2.174, Test 2.168; and Acc: Train 0.096, Test 0.094\n",
448
- "2025-01-20 12:29:21 Epoch 1/3 done. Loss: Train 0.878, Test 1.033; and Acc: Train 0.712, Test 0.653\n",
449
- "2025-01-20 12:32:07 Epoch 2/3 done. Loss: Train 0.458, Test 0.906; and Acc: Train 0.869, Test 0.678\n",
450
- "2025-01-20 12:34:54 Epoch 3/3 done. Loss: Train 0.218, Test 0.959; and Acc: Train 0.944, Test 0.695\n"
451
- ]
452
- }
453
- ],
454
  "source": [
455
- "model, tokenizer = run_training(\n",
456
  " max_dataset_size=\"full\",\n",
457
- " bert_variety=\"bert-base-uncased\",\n",
458
- " max_length=128,\n",
459
- " num_epochs=3,\n",
460
  " batch_size=16,\n",
461
  ")"
462
  ]
463
  },
464
  {
465
  "cell_type": "code",
466
- "execution_count": 10,
467
  "id": "28354e8c-886a-4523-8968-8c688c13f6a3",
468
  "metadata": {
469
  "execution": {
470
- "iopub.execute_input": "2025-01-20T20:34:54.224989Z",
471
- "iopub.status.busy": "2025-01-20T20:34:54.224772Z",
472
- "iopub.status.idle": "2025-01-20T20:54:07.531338Z",
473
- "shell.execute_reply": "2025-01-20T20:54:07.530559Z",
474
- "shell.execute_reply.started": "2025-01-20T20:34:54.224968Z"
475
  }
476
  },
477
  "outputs": [
@@ -479,20 +475,29 @@
479
  "name": "stdout",
480
  "output_type": "stream",
481
  "text": [
482
- "2025-01-20 12:36:37 Epoch 0/3 done. Loss: Train 2.122, Test 2.127; and Acc: Train 0.122, Test 0.118\n",
483
- "2025-01-20 12:42:26 Epoch 1/3 done. Loss: Train 0.779, Test 0.978; and Acc: Train 0.748, Test 0.652\n",
484
- "2025-01-20 12:48:16 Epoch 2/3 done. Loss: Train 0.391, Test 0.884; and Acc: Train 0.897, Test 0.696\n",
485
- "2025-01-20 12:54:07 Epoch 3/3 done. Loss: Train 0.154, Test 0.978; and Acc: Train 0.959, Test 0.705\n"
 
 
 
 
 
 
 
 
 
 
 
 
486
  ]
487
  }
488
  ],
489
  "source": [
490
- "model, tokenizer = run_training(\n",
491
- " max_dataset_size=\"full\",\n",
492
- " bert_variety=\"bert-base-uncased\",\n",
493
- " max_length=256,\n",
494
- " num_epochs=3,\n",
495
- " batch_size=16,\n",
496
  ")"
497
  ]
498
  },
@@ -506,240 +511,33 @@
506
  },
507
  {
508
  "cell_type": "code",
509
- "execution_count": 14,
510
  "id": "ec2516f9-79f2-4ae1-ab9a-9a51a7a50587",
511
  "metadata": {
512
  "execution": {
513
- "iopub.execute_input": "2025-01-20T22:10:34.055595Z",
514
- "iopub.status.busy": "2025-01-20T22:10:34.054690Z",
515
- "iopub.status.idle": "2025-01-20T22:10:34.083784Z",
516
- "shell.execute_reply": "2025-01-20T22:10:34.083448Z",
517
- "shell.execute_reply.started": "2025-01-20T22:10:34.055529Z"
518
  },
519
  "scrolled": true
520
  },
521
  "outputs": [
522
  {
523
- "name": "stdout",
524
- "output_type": "stream",
525
- "text": [
526
- "---\n",
527
- "base_model: google-bert/bert-base-uncased\n",
528
- "datasets:\n",
529
- "- QuotaClimat/frugalaichallenge-text-train\n",
530
- "language:\n",
531
- "- en\n",
532
- "license: apache-2.0\n",
533
- "model_name: frugal-ai-text-bert-base\n",
534
- "pipeline_tag: text-classification\n",
535
- "tags:\n",
536
- "- model_hub_mixin\n",
537
- "- pytorch_model_hub_mixin\n",
538
- "- climate\n",
539
- "---\n",
540
- "\n",
541
- "# Model Card for Model ID\n",
542
- "\n",
543
- "<!-- Provide a quick summary of what the model is/does. -->\n",
544
- "\n",
545
- "Classify text into 8 categories of climate misinformation.\n",
546
- "\n",
547
- "## Model Details\n",
548
- "\n",
549
- "### Model Description\n",
550
- "\n",
551
- "<!-- Provide a longer summary of what this model is. -->\n",
552
- "\n",
553
- "Fine trained BERT for classifying climate information as part of the Frugal AI Challenge, for submission to https://huggingface.co/frugal-ai-challenge and scoring on accuracy and efficiency. Trainied on only the non-evaluation 80% of the data, so it's (non-cheating) score will be lower.\n",
554
- "\n",
555
- "- **Developed by:** Andre Bach\n",
556
- "- **Funded by [optional]:** N/A\n",
557
- "- **Shared by [optional]:** Andre Bach\n",
558
- "- **Model type:** Text classification\n",
559
- "- **Language(s) (NLP):** ['en']\n",
560
- "- **License:** apache-2.0\n",
561
- "- **Finetuned from model [optional]:** google-bert/bert-base-uncased\n",
562
- "\n",
563
- "### Model Sources [optional]\n",
564
- "\n",
565
- "<!-- Provide the basic links for the model. -->\n",
566
- "\n",
567
- "- **Repository:** frugal-ai-text-bert-base\n",
568
- "- **Paper [optional]:** [More Information Needed]\n",
569
- "- **Demo [optional]:** [More Information Needed]\n",
570
- "\n",
571
- "## Uses\n",
572
- "\n",
573
- "<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->\n",
574
- "\n",
575
- "### Direct Use\n",
576
- "\n",
577
- "<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->\n",
578
- "\n",
579
- "[More Information Needed]\n",
580
- "\n",
581
- "### Downstream Use [optional]\n",
582
- "\n",
583
- "<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->\n",
584
- "\n",
585
- "[More Information Needed]\n",
586
- "\n",
587
- "### Out-of-Scope Use\n",
588
- "\n",
589
- "<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->\n",
590
- "\n",
591
- "[More Information Needed]\n",
592
- "\n",
593
- "## Bias, Risks, and Limitations\n",
594
- "\n",
595
- "<!-- This section is meant to convey both technical and sociotechnical limitations. -->\n",
596
- "\n",
597
- "[More Information Needed]\n",
598
- "\n",
599
- "### Recommendations\n",
600
- "\n",
601
- "<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->\n",
602
- "\n",
603
- "Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.\n",
604
- "\n",
605
- "## How to Get Started with the Model\n",
606
- "\n",
607
- "Use the code below to get started with the model.\n",
608
- "\n",
609
- "[More Information Needed]\n",
610
- "\n",
611
- "## Training Details\n",
612
- "\n",
613
- "### Training Data\n",
614
- "\n",
615
- "<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->\n",
616
- "\n",
617
- "[More Information Needed]\n",
618
- "\n",
619
- "### Training Procedure\n",
620
- "\n",
621
- "<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->\n",
622
- "\n",
623
- "#### Preprocessing [optional]\n",
624
- "\n",
625
- "[More Information Needed]\n",
626
- "\n",
627
- "\n",
628
- "#### Training Hyperparameters\n",
629
- "\n",
630
- "- **Training regime:** {'max_dataset_size': 'full', 'bert_variety': 'bert-base-uncased', 'max_length': 256, 'num_epochs': 3, 'batch_size': 16} <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->\n",
631
- "\n",
632
- "#### Speeds, Sizes, Times [optional]\n",
633
- "\n",
634
- "<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->\n",
635
- "\n",
636
- "[More Information Needed]\n",
637
- "\n",
638
- "## Evaluation\n",
639
- "\n",
640
- "<!-- This section describes the evaluation protocols and provides the results. -->\n",
641
- "\n",
642
- "### Testing Data, Factors & Metrics\n",
643
- "\n",
644
- "#### Testing Data\n",
645
- "\n",
646
- "<!-- This should link to a Dataset Card if possible. -->\n",
647
- "\n",
648
- "[More Information Needed]\n",
649
- "\n",
650
- "#### Factors\n",
651
- "\n",
652
- "<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->\n",
653
- "\n",
654
- "[More Information Needed]\n",
655
- "\n",
656
- "#### Metrics\n",
657
- "\n",
658
- "<!-- These are the evaluation metrics being used, ideally with a description of why. -->\n",
659
- "\n",
660
- "{'loss_train': 0.154, 'loss_test': 0.978, 'acc_train': 0.959, 'acc_test': 0.705}\n",
661
- "\n",
662
- "### Results\n",
663
- "\n",
664
- "[More Information Needed]\n",
665
- "\n",
666
- "#### Summary\n",
667
- "\n",
668
- "\n",
669
- "\n",
670
- "## Model Examination [optional]\n",
671
- "\n",
672
- "<!-- Relevant interpretability work for the model goes here -->\n",
673
- "\n",
674
- "[More Information Needed]\n",
675
- "\n",
676
- "## Environmental Impact\n",
677
- "\n",
678
- "<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->\n",
679
- "\n",
680
- "Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).\n",
681
- "\n",
682
- "- **Hardware Type:** [More Information Needed]\n",
683
- "- **Hours used:** [More Information Needed]\n",
684
- "- **Cloud Provider:** [More Information Needed]\n",
685
- "- **Compute Region:** [More Information Needed]\n",
686
- "- **Carbon Emitted:** [More Information Needed]\n",
687
- "\n",
688
- "## Technical Specifications [optional]\n",
689
- "\n",
690
- "### Model Architecture and Objective\n",
691
- "\n",
692
- "[More Information Needed]\n",
693
- "\n",
694
- "### Compute Infrastructure\n",
695
- "\n",
696
- "[More Information Needed]\n",
697
- "\n",
698
- "#### Hardware\n",
699
- "\n",
700
- "[More Information Needed]\n",
701
- "\n",
702
- "#### Software\n",
703
- "\n",
704
- "[More Information Needed]\n",
705
- "\n",
706
- "## Citation [optional]\n",
707
- "\n",
708
- "<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->\n",
709
- "\n",
710
- "**BibTeX:**\n",
711
- "\n",
712
- "[More Information Needed]\n",
713
- "\n",
714
- "**APA:**\n",
715
- "\n",
716
- "[More Information Needed]\n",
717
- "\n",
718
- "## Glossary [optional]\n",
719
- "\n",
720
- "<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->\n",
721
- "\n",
722
- "[More Information Needed]\n",
723
- "\n",
724
- "## More Information [optional]\n",
725
- "\n",
726
- "[More Information Needed]\n",
727
- "\n",
728
- "## Model Card Authors [optional]\n",
729
- "\n",
730
- "[More Information Needed]\n",
731
- "\n",
732
- "## Model Card Contact\n",
733
- "\n",
734
- "[More Information Needed]\n"
735
  ]
736
  }
737
  ],
738
  "source": [
739
- "model_and_repo_name = \"frugal-ai-text-bert-base\"\n",
740
  "card_data = ModelCardData(\n",
741
  " model_name=model_and_repo_name,\n",
742
- " base_model=\"google-bert/bert-base-uncased\",\n",
743
  " license=\"apache-2.0\",\n",
744
  " language=[\"en\"],\n",
745
  " datasets=[\"QuotaClimat/frugalaichallenge-text-train\"],\n",
@@ -827,6 +625,7 @@
827
  " truncation=True,\n",
828
  " padding=True,\n",
829
  " return_tensors=\"pt\",\n",
 
830
  ")\n",
831
  "\n",
832
  "with torch.no_grad():\n",
@@ -967,105 +766,7 @@
967
  },
968
  "widgets": {
969
  "application/vnd.jupyter.widget-state+json": {
970
- "state": {
971
- "47fba054bcbc4563934b6d25ea787e43": {
972
- "model_module": "@jupyter-widgets/base",
973
- "model_module_version": "2.0.0",
974
- "model_name": "LayoutModel",
975
- "state": {}
976
- },
977
- "5cdf8fe39a634d048f2140b3af85165f": {
978
- "model_module": "@jupyter-widgets/base",
979
- "model_module_version": "2.0.0",
980
- "model_name": "LayoutModel",
981
- "state": {}
982
- },
983
- "6a6b93c568744ed48ba6c58f84c3d59a": {
984
- "model_module": "@jupyter-widgets/base",
985
- "model_module_version": "2.0.0",
986
- "model_name": "LayoutModel",
987
- "state": {}
988
- },
989
- "802b81b278a34a1a9ed480ca2ae299a0": {
990
- "model_module": "@jupyter-widgets/controls",
991
- "model_module_version": "2.0.0",
992
- "model_name": "HTMLModel",
993
- "state": {
994
- "layout": "IPY_MODEL_47fba054bcbc4563934b6d25ea787e43",
995
- "style": "IPY_MODEL_cab10a06b0064a4f876d47bbd5dda288",
996
- "value": "model.safetensors: 100%"
997
- }
998
- },
999
- "80984aaf16ce41ce839cc4bd5c0ea202": {
1000
- "model_module": "@jupyter-widgets/base",
1001
- "model_module_version": "2.0.0",
1002
- "model_name": "LayoutModel",
1003
- "state": {}
1004
- },
1005
- "87a62c5c11cc43649d6ce177ab39f244": {
1006
- "model_module": "@jupyter-widgets/controls",
1007
- "model_module_version": "2.0.0",
1008
- "model_name": "HTMLStyleModel",
1009
- "state": {
1010
- "description_width": "",
1011
- "font_size": null,
1012
- "text_color": null
1013
- }
1014
- },
1015
- "8b033d0c246145a082c43e73d1377035": {
1016
- "model_module": "@jupyter-widgets/controls",
1017
- "model_module_version": "2.0.0",
1018
- "model_name": "HTMLModel",
1019
- "state": {
1020
- "layout": "IPY_MODEL_5cdf8fe39a634d048f2140b3af85165f",
1021
- "style": "IPY_MODEL_87a62c5c11cc43649d6ce177ab39f244",
1022
- "value": " 438M/438M [00:15&lt;00:00, 22.9MB/s]"
1023
- }
1024
- },
1025
- "c5eebb3e916e4c59864d29582ab336bf": {
1026
- "model_module": "@jupyter-widgets/controls",
1027
- "model_module_version": "2.0.0",
1028
- "model_name": "ProgressStyleModel",
1029
- "state": {
1030
- "description_width": ""
1031
- }
1032
- },
1033
- "cab10a06b0064a4f876d47bbd5dda288": {
1034
- "model_module": "@jupyter-widgets/controls",
1035
- "model_module_version": "2.0.0",
1036
- "model_name": "HTMLStyleModel",
1037
- "state": {
1038
- "description_width": "",
1039
- "font_size": null,
1040
- "text_color": null
1041
- }
1042
- },
1043
- "d83e79effc3542f49c38928463bb41ec": {
1044
- "model_module": "@jupyter-widgets/controls",
1045
- "model_module_version": "2.0.0",
1046
- "model_name": "FloatProgressModel",
1047
- "state": {
1048
- "bar_style": "success",
1049
- "layout": "IPY_MODEL_6a6b93c568744ed48ba6c58f84c3d59a",
1050
- "max": 437977072,
1051
- "style": "IPY_MODEL_c5eebb3e916e4c59864d29582ab336bf",
1052
- "value": 437977072
1053
- }
1054
- },
1055
- "fbc09ae2c5614831a2fb02fa48a44fd1": {
1056
- "model_module": "@jupyter-widgets/controls",
1057
- "model_module_version": "2.0.0",
1058
- "model_name": "HBoxModel",
1059
- "state": {
1060
- "children": [
1061
- "IPY_MODEL_802b81b278a34a1a9ed480ca2ae299a0",
1062
- "IPY_MODEL_d83e79effc3542f49c38928463bb41ec",
1063
- "IPY_MODEL_8b033d0c246145a082c43e73d1377035"
1064
- ],
1065
- "layout": "IPY_MODEL_80984aaf16ce41ce839cc4bd5c0ea202"
1066
- }
1067
- }
1068
- },
1069
  "version_major": 2,
1070
  "version_minor": 0
1071
  }
 
14
  "id": "73e72549-69f2-46b5-b0f5-655777139972",
15
  "metadata": {
16
  "execution": {
17
+ "iopub.execute_input": "2025-01-21T19:25:48.302003Z",
18
+ "iopub.status.busy": "2025-01-21T19:25:48.301808Z",
19
+ "iopub.status.idle": "2025-01-21T19:25:50.698806Z",
20
+ "shell.execute_reply": "2025-01-21T19:25:50.698535Z",
21
+ "shell.execute_reply.started": "2025-01-21T19:25:48.301982Z"
22
  }
23
  },
24
  "outputs": [],
 
45
  "id": "07e0787e-c72b-41f3-baba-43cef3f8d6f8",
46
  "metadata": {
47
  "execution": {
48
+ "iopub.execute_input": "2025-01-21T19:25:50.699344Z",
49
+ "iopub.status.busy": "2025-01-21T19:25:50.699200Z",
50
+ "iopub.status.idle": "2025-01-21T19:25:50.701241Z",
51
+ "shell.execute_reply": "2025-01-21T19:25:50.700993Z",
52
+ "shell.execute_reply.started": "2025-01-21T19:25:50.699335Z"
53
  }
54
  },
55
  "outputs": [],
 
71
  "id": "d4b79fb9-5e70-4600-8885-94bc0a6e917c",
72
  "metadata": {
73
  "execution": {
74
+ "iopub.execute_input": "2025-01-21T19:25:50.701789Z",
75
+ "iopub.status.busy": "2025-01-21T19:25:50.701708Z",
76
+ "iopub.status.idle": "2025-01-21T19:25:50.707095Z",
77
+ "shell.execute_reply": "2025-01-21T19:25:50.706788Z",
78
+ "shell.execute_reply.started": "2025-01-21T19:25:50.701781Z"
79
  }
80
  },
81
  "outputs": [],
 
109
  " avg_loss = total_loss / len(dataloader)\n",
110
  " avg_acc = total_correct / total_length\n",
111
  " model.train()\n",
112
+ " return float(avg_loss), float(avg_acc)\n",
113
  "\n",
114
  "\n",
115
  "def print_model_status(epoch, num_epochs, model, train_dataloader, test_dataloader):\n",
 
117
  " test_loss, test_acc = model_metrics(model, test_dataloader)\n",
118
  " loss_str = f\"Loss: Train {train_loss:0.3f}, Test {test_loss:0.3f}\"\n",
119
  " acc_str = f\"Acc: Train {train_acc:0.3f}, Test {test_acc:0.3f}\"\n",
120
+ " my_print(f\"Epoch {epoch+1:2}/{num_epochs} done. {loss_str}; and {acc_str}\")\n",
121
+ " metrics = dict(\n",
122
+ " train_loss=train_loss,\n",
123
+ " train_acc=train_acc,\n",
124
+ " test_loss=test_loss,\n",
125
+ " test_acc=test_acc,\n",
126
+ " )\n",
127
+ " return metrics\n",
128
  "\n",
129
  "\n",
130
  "class BertClassifier(nn.Module, PyTorchModelHubMixin):\n",
 
143
  "\n",
144
  "\n",
145
  "class TextDataset(Dataset):\n",
146
+ " def __init__(self, texts, labels, tokenizer, max_length=256):\n",
147
  " self.encodings = tokenizer(\n",
148
  " texts,\n",
149
  " truncation=True,\n",
 
167
  " criterion = nn.CrossEntropyLoss()\n",
168
  " model.train()\n",
169
  "\n",
170
+ " _ = print_model_status(-1, num_epochs, model, train_dataloader, test_dataloader)\n",
171
  " for epoch in range(num_epochs):\n",
172
  " total_loss = 0\n",
173
  " for batch in train_dataloader:\n",
 
185
  "\n",
186
  " total_loss += loss.item()\n",
187
  " avg_loss = total_loss / len(train_dataloader)\n",
188
+ " metrics = print_model_status(\n",
189
+ " epoch, num_epochs, model, train_dataloader, test_dataloader\n",
190
+ " )\n",
191
+ " return metrics"
192
  ]
193
  },
194
  {
 
197
  "id": "07131bce-23ad-4787-8622-cce401f3e5ce",
198
  "metadata": {
199
  "execution": {
200
+ "iopub.execute_input": "2025-01-21T19:25:50.707655Z",
201
+ "iopub.status.busy": "2025-01-21T19:25:50.707519Z",
202
+ "iopub.status.idle": "2025-01-21T19:25:50.718311Z",
203
+ "shell.execute_reply": "2025-01-21T19:25:50.718037Z",
204
+ "shell.execute_reply.started": "2025-01-21T19:25:50.707646Z"
205
  }
206
  },
207
  "outputs": [],
 
221
  "id": "695bc080-bbd7-4937-af5b-50db1c936500",
222
  "metadata": {
223
  "execution": {
224
+ "iopub.execute_input": "2025-01-21T19:25:50.718754Z",
225
+ "iopub.status.busy": "2025-01-21T19:25:50.718677Z",
226
+ "iopub.status.idle": "2025-01-21T19:25:50.721834Z",
227
+ "shell.execute_reply": "2025-01-21T19:25:50.721583Z",
228
+ "shell.execute_reply.started": "2025-01-21T19:25:50.718746Z"
229
  }
230
  },
231
  "outputs": [],
 
233
  "def run_training(\n",
234
  " max_dataset_size=16 * 200,\n",
235
  " bert_variety=\"bert-base-uncased\",\n",
236
+ " max_length=256,\n",
237
  " num_epochs=3,\n",
238
  " batch_size=32,\n",
239
  "):\n",
240
+ " training_regime = dict(\n",
241
+ " max_dataset_size=max_dataset_size,\n",
242
+ " bert_variety=bert_variety,\n",
243
+ " max_length=max_length,\n",
244
+ " num_epochs=num_epochs,\n",
245
+ " batch_size=batch_size,\n",
246
+ " )\n",
247
  " hf_dataset = load_dataset(\"quotaclimat/frugalaichallenge-text-train\")\n",
248
  " test_size = 0.2\n",
249
  " test_seed = 42\n",
 
289
  " text_dataset_test, batch_size=batch_size, shuffle=False\n",
290
  " )\n",
291
  "\n",
292
+ " metrics = train_model(\n",
293
+ " model, dataloader_train, dataloader_test, device, num_epochs=num_epochs\n",
294
+ " )\n",
295
+ " return model, tokenizer, training_regime, metrics"
296
  ]
297
  },
298
  {
 
321
  },
322
  {
323
  "cell_type": "code",
324
+ "execution_count": null,
325
+ "id": "34a7c310-c486-4db1-b94d-4363c3d3df5b",
326
  "metadata": {
327
  "execution": {
328
+ "iopub.execute_input": "2025-01-21T19:25:50.724036Z",
329
+ "iopub.status.busy": "2025-01-21T19:25:50.723968Z"
 
 
 
330
  }
331
  },
332
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
333
  "source": [
334
+ "model, tokenizer, regime, metrics = run_training(\n",
335
+ " max_dataset_size=16 * 10,\n",
336
+ " bert_variety=\"google/bert_uncased_L-2_H-128_A-2\",\n",
337
  " max_length=128,\n",
338
+ " num_epochs=4,\n",
339
  " batch_size=32,\n",
340
  ")"
341
  ]
342
  },
343
  {
344
  "cell_type": "code",
345
+ "execution_count": null,
346
+ "id": "32abaa1b-11f4-4793-97b8-36bb2dc29d56",
347
+ "metadata": {},
348
+ "outputs": [],
349
+ "source": [
350
+ "regime"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "execution_count": null,
356
+ "id": "fe108690-bcc1-4667-9f8e-907a1a8ac2ec",
357
+ "metadata": {},
358
+ "outputs": [],
359
+ "source": [
360
+ "metrics"
361
+ ]
362
+ },
363
+ {
364
+ "cell_type": "code",
365
+ "execution_count": null,
366
  "id": "0aedfcca-843e-4f4c-8062-3e4625161bcc",
367
  "metadata": {
368
+ "editable": true,
369
+ "slideshow": {
370
+ "slide_type": ""
371
+ },
372
+ "tags": []
 
 
373
  },
374
+ "outputs": [],
 
 
 
 
 
 
 
 
375
  "source": [
376
  "model.eval()\n",
377
  "test_text = [\n",
 
388
  " truncation=True,\n",
389
  " padding=True,\n",
390
  " return_tensors=\"pt\",\n",
391
+ " max_length=256,\n",
392
  ")\n",
393
  "\n",
394
  "with torch.no_grad():\n",
 
408
  ]
409
  },
410
  {
411
+ "cell_type": "markdown",
412
+ "id": "6264418d-10ef-4eca-b188-2b6b7f487797",
413
+ "metadata": {},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  "source": [
415
+ "Overall top performance per model. Machine: bert-base is using an Nvidia 1xL40S, no inference time cleaverness attempted.\n",
416
+ "\n",
417
+ "[accidentally cheating bert-base by trainging on full dataset](https://huggingface.co/datasets/frugal-ai-challenge/public-leaderboard-text/blob/main/submissions/Nonnormalizable_20250117_220350.json):\\\n",
418
+ "acc 0.954, energy 0.736 Wh, emissions 0.272 gco2eq\n",
419
+ "\n",
420
+ "[bert-base some hp tuning](https://huggingface.co/datasets/frugal-ai-challenge/public-leaderboard-text/blob/main/submissions/Nonnormalizable_20250120_231350.json):\\\n",
421
+ "acc 0.707, energy 0.803 Wh, emissions 0.296 gco2eq\n"
422
+ ]
423
+ },
424
+ {
425
+ "cell_type": "markdown",
426
+ "id": "df067c27-9d58-49fc-860d-ba79e5512013",
427
+ "metadata": {},
428
+ "source": [
429
+ "Looking at bert-tiny.\n",
430
+ "Scanning max_length and batch_size with num_epochs set to 3, looks like we want 256 and 16. That gets us\\\n",
431
+ "`2025-01-21 10:18:56 Epoch 3/3 done. Loss: Train 1.368, Test 1.432; and Acc: Train 0.499, Test 0.477`.\n",
432
+ "\n",
433
+ "Then looking at num_epochs, we saturate test set performance at 15 (~3 min), giving e.g.\\\n",
434
+ "`2025-01-21 10:38:30 Epoch 15/20 done. Loss: Train 0.553, Test 1.157; and Acc: Train 0.833, Test 0.595`"
435
  ]
436
  },
437
  {
438
  "cell_type": "code",
439
+ "execution_count": 32,
440
+ "id": "37794952-703c-466c-9d26-ee6cb2834246",
441
  "metadata": {
442
  "execution": {
443
+ "iopub.execute_input": "2025-01-21T18:35:29.897653Z",
444
+ "iopub.status.busy": "2025-01-21T18:35:29.897020Z",
445
+ "iopub.status.idle": "2025-01-21T18:35:29.901748Z",
446
+ "shell.execute_reply": "2025-01-21T18:35:29.901032Z",
447
+ "shell.execute_reply.started": "2025-01-21T18:35:29.897609Z"
448
  }
449
  },
450
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
451
  "source": [
452
+ "static_hyperparams = dict(\n",
453
  " max_dataset_size=\"full\",\n",
454
+ " bert_variety=\"google/bert_uncased_L-2_H-128_A-2\",\n",
455
+ " max_length=256,\n",
 
456
  " batch_size=16,\n",
457
  ")"
458
  ]
459
  },
460
  {
461
  "cell_type": "code",
462
+ "execution_count": 34,
463
  "id": "28354e8c-886a-4523-8968-8c688c13f6a3",
464
  "metadata": {
465
  "execution": {
466
+ "iopub.execute_input": "2025-01-21T18:42:35.614137Z",
467
+ "iopub.status.busy": "2025-01-21T18:42:35.613694Z",
468
+ "iopub.status.idle": "2025-01-21T18:45:35.341816Z",
469
+ "shell.execute_reply": "2025-01-21T18:45:35.341535Z",
470
+ "shell.execute_reply.started": "2025-01-21T18:42:35.614111Z"
471
  }
472
  },
473
  "outputs": [
 
475
  "name": "stdout",
476
  "output_type": "stream",
477
  "text": [
478
+ "2025-01-21 10:43:44 Epoch 0/15 done. Loss: Train 2.177, Test 2.172; and Acc: Train 0.063, Test 0.071\n",
479
+ "2025-01-21 10:43:52 Epoch 1/15 done. Loss: Train 1.786, Test 1.823; and Acc: Train 0.383, Test 0.354\n",
480
+ "2025-01-21 10:44:00 Epoch 2/15 done. Loss: Train 1.579, Test 1.628; and Acc: Train 0.465, Test 0.436\n",
481
+ "2025-01-21 10:44:07 Epoch 3/15 done. Loss: Train 1.431, Test 1.498; and Acc: Train 0.510, Test 0.484\n",
482
+ "2025-01-21 10:44:14 Epoch 4/15 done. Loss: Train 1.304, Test 1.402; and Acc: Train 0.555, Test 0.515\n",
483
+ "2025-01-21 10:44:22 Epoch 5/15 done. Loss: Train 1.212, Test 1.339; and Acc: Train 0.585, Test 0.535\n",
484
+ "2025-01-21 10:44:29 Epoch 6/15 done. Loss: Train 1.128, Test 1.288; and Acc: Train 0.611, Test 0.546\n",
485
+ "2025-01-21 10:44:36 Epoch 7/15 done. Loss: Train 1.039, Test 1.241; and Acc: Train 0.643, Test 0.559\n",
486
+ "2025-01-21 10:44:44 Epoch 8/15 done. Loss: Train 1.003, Test 1.236; and Acc: Train 0.665, Test 0.555\n",
487
+ "2025-01-21 10:44:51 Epoch 9/15 done. Loss: Train 0.897, Test 1.183; and Acc: Train 0.708, Test 0.568\n",
488
+ "2025-01-21 10:44:58 Epoch 10/15 done. Loss: Train 0.852, Test 1.187; and Acc: Train 0.724, Test 0.572\n",
489
+ "2025-01-21 10:45:06 Epoch 11/15 done. Loss: Train 0.769, Test 1.154; and Acc: Train 0.755, Test 0.581\n",
490
+ "2025-01-21 10:45:13 Epoch 12/15 done. Loss: Train 0.764, Test 1.197; and Acc: Train 0.752, Test 0.573\n",
491
+ "2025-01-21 10:45:20 Epoch 13/15 done. Loss: Train 0.660, Test 1.153; and Acc: Train 0.797, Test 0.590\n",
492
+ "2025-01-21 10:45:28 Epoch 14/15 done. Loss: Train 0.588, Test 1.143; and Acc: Train 0.820, Test 0.594\n",
493
+ "2025-01-21 10:45:35 Epoch 15/15 done. Loss: Train 0.579, Test 1.200; and Acc: Train 0.822, Test 0.575\n"
494
  ]
495
  }
496
  ],
497
  "source": [
498
+ "model, tokenizer, training_regime, testing_metrics = run_training(\n",
499
+ " **static_hyperparams,\n",
500
+ " num_epochs=15,\n",
 
 
 
501
  ")"
502
  ]
503
  },
 
511
  },
512
  {
513
  "cell_type": "code",
514
+ "execution_count": 35,
515
  "id": "ec2516f9-79f2-4ae1-ab9a-9a51a7a50587",
516
  "metadata": {
517
  "execution": {
518
+ "iopub.execute_input": "2025-01-21T18:57:29.278360Z",
519
+ "iopub.status.busy": "2025-01-21T18:57:29.276985Z",
520
+ "iopub.status.idle": "2025-01-21T18:57:29.289810Z",
521
+ "shell.execute_reply": "2025-01-21T18:57:29.288574Z",
522
+ "shell.execute_reply.started": "2025-01-21T18:57:29.278315Z"
523
  },
524
  "scrolled": true
525
  },
526
  "outputs": [
527
  {
528
+ "ename": "SyntaxError",
529
+ "evalue": "invalid syntax. Perhaps you forgot a comma? (3495586751.py, line 4)",
530
+ "output_type": "error",
531
+ "traceback": [
532
+ "\u001b[0;36m Cell \u001b[0;32mIn[35], line 4\u001b[0;36m\u001b[0m\n\u001b[0;31m base_model=static_hyperparams[],\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax. Perhaps you forgot a comma?\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
  ]
534
  }
535
  ],
536
  "source": [
537
+ "model_and_repo_name = \"frugal-ai-text-bert-tiny\"\n",
538
  "card_data = ModelCardData(\n",
539
  " model_name=model_and_repo_name,\n",
540
+ " base_model=static_hyperparams[\"bert_variety\"],\n",
541
  " license=\"apache-2.0\",\n",
542
  " language=[\"en\"],\n",
543
  " datasets=[\"QuotaClimat/frugalaichallenge-text-train\"],\n",
 
625
  " truncation=True,\n",
626
  " padding=True,\n",
627
  " return_tensors=\"pt\",\n",
628
+ " max_length=256,\n",
629
  ")\n",
630
  "\n",
631
  "with torch.no_grad():\n",
 
766
  },
767
  "widgets": {
768
  "application/vnd.jupyter.widget-state+json": {
769
+ "state": {},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
770
  "version_major": 2,
771
  "version_minor": 0
772
  }