Commit
·
f718d63
1
Parent(s):
6c2e610
TextDataset bug
Browse files- Finetune BERT.ipynb +73 -60
- tasks/text.py +1 -0
Finetune BERT.ipynb
CHANGED
@@ -10,15 +10,15 @@
|
|
10 |
},
|
11 |
{
|
12 |
"cell_type": "code",
|
13 |
-
"execution_count":
|
14 |
"id": "73e72549-69f2-46b5-b0f5-655777139972",
|
15 |
"metadata": {
|
16 |
"execution": {
|
17 |
-
"iopub.execute_input": "2025-01-24T18:
|
18 |
-
"iopub.status.busy": "2025-01-24T18:
|
19 |
-
"iopub.status.idle": "2025-01-24T18:
|
20 |
-
"shell.execute_reply": "2025-01-24T18:
|
21 |
-
"shell.execute_reply.started": "2025-01-24T18:
|
22 |
}
|
23 |
},
|
24 |
"outputs": [],
|
@@ -45,11 +45,11 @@
|
|
45 |
"id": "07e0787e-c72b-41f3-baba-43cef3f8d6f8",
|
46 |
"metadata": {
|
47 |
"execution": {
|
48 |
-
"iopub.execute_input": "2025-01-
|
49 |
-
"iopub.status.busy": "2025-01-
|
50 |
-
"iopub.status.idle": "2025-01-
|
51 |
-
"shell.execute_reply": "2025-01-
|
52 |
-
"shell.execute_reply.started": "2025-01-
|
53 |
}
|
54 |
},
|
55 |
"outputs": [],
|
@@ -67,15 +67,15 @@
|
|
67 |
},
|
68 |
{
|
69 |
"cell_type": "code",
|
70 |
-
"execution_count":
|
71 |
"id": "d4b79fb9-5e70-4600-8885-94bc0a6e917c",
|
72 |
"metadata": {
|
73 |
"execution": {
|
74 |
-
"iopub.execute_input": "2025-01-
|
75 |
-
"iopub.status.busy": "2025-01-
|
76 |
-
"iopub.status.idle": "2025-01-
|
77 |
-
"shell.execute_reply": "2025-01-
|
78 |
-
"shell.execute_reply.started": "2025-01-
|
79 |
}
|
80 |
},
|
81 |
"outputs": [],
|
@@ -146,6 +146,7 @@
|
|
146 |
"\n",
|
147 |
"class TextDataset(Dataset):\n",
|
148 |
" def __init__(self, texts, labels, tokenizer, max_length=256):\n",
|
|
|
149 |
" self.encodings = tokenizer(\n",
|
150 |
" texts,\n",
|
151 |
" truncation=True,\n",
|
@@ -195,15 +196,15 @@
|
|
195 |
},
|
196 |
{
|
197 |
"cell_type": "code",
|
198 |
-
"execution_count":
|
199 |
"id": "07131bce-23ad-4787-8622-cce401f3e5ce",
|
200 |
"metadata": {
|
201 |
"execution": {
|
202 |
-
"iopub.execute_input": "2025-01-
|
203 |
-
"iopub.status.busy": "2025-01-
|
204 |
-
"iopub.status.idle": "2025-01-
|
205 |
-
"shell.execute_reply": "2025-01-
|
206 |
-
"shell.execute_reply.started": "2025-01-
|
207 |
}
|
208 |
},
|
209 |
"outputs": [],
|
@@ -219,15 +220,15 @@
|
|
219 |
},
|
220 |
{
|
221 |
"cell_type": "code",
|
222 |
-
"execution_count":
|
223 |
"id": "695bc080-bbd7-4937-af5b-50db1c936500",
|
224 |
"metadata": {
|
225 |
"execution": {
|
226 |
-
"iopub.execute_input": "2025-01-
|
227 |
-
"iopub.status.busy": "2025-01-
|
228 |
-
"iopub.status.idle": "2025-01-
|
229 |
-
"shell.execute_reply": "2025-01-
|
230 |
-
"shell.execute_reply.started": "2025-01-
|
231 |
}
|
232 |
},
|
233 |
"outputs": [],
|
@@ -307,15 +308,15 @@
|
|
307 |
},
|
308 |
{
|
309 |
"cell_type": "code",
|
310 |
-
"execution_count":
|
311 |
"id": "11890d3b-8bcb-4a9b-b421-5431081cca39",
|
312 |
"metadata": {
|
313 |
"execution": {
|
314 |
-
"iopub.execute_input": "2025-01-
|
315 |
-
"iopub.status.busy": "2025-01-
|
316 |
-
"iopub.status.idle": "2025-01-
|
317 |
-
"shell.execute_reply": "2025-01-
|
318 |
-
"shell.execute_reply.started": "2025-01-
|
319 |
}
|
320 |
},
|
321 |
"outputs": [],
|
@@ -342,15 +343,15 @@
|
|
342 |
},
|
343 |
{
|
344 |
"cell_type": "code",
|
345 |
-
"execution_count":
|
346 |
"id": "34a7c310-c486-4db1-b94d-4363c3d3df5b",
|
347 |
"metadata": {
|
348 |
"execution": {
|
349 |
-
"iopub.execute_input": "2025-01-
|
350 |
-
"iopub.status.busy": "2025-01-
|
351 |
-
"iopub.status.idle": "2025-01-
|
352 |
-
"shell.execute_reply": "2025-01-
|
353 |
-
"shell.execute_reply.started": "2025-01-
|
354 |
}
|
355 |
},
|
356 |
"outputs": [
|
@@ -358,10 +359,23 @@
|
|
358 |
"name": "stdout",
|
359 |
"output_type": "stream",
|
360 |
"text": [
|
361 |
-
"
|
362 |
-
"
|
363 |
-
|
364 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
]
|
366 |
}
|
367 |
],
|
@@ -377,31 +391,22 @@
|
|
377 |
},
|
378 |
{
|
379 |
"cell_type": "code",
|
380 |
-
"execution_count":
|
381 |
"id": "0aedfcca-843e-4f4c-8062-3e4625161bcc",
|
382 |
"metadata": {
|
383 |
"editable": true,
|
384 |
"execution": {
|
385 |
-
"iopub.
|
386 |
-
"iopub.status.
|
387 |
-
"
|
388 |
-
"shell.execute_reply": "2025-01-
|
389 |
-
"shell.execute_reply.started": "2025-01-22T18:19:33.994628Z"
|
390 |
},
|
391 |
"slideshow": {
|
392 |
"slide_type": ""
|
393 |
},
|
394 |
"tags": []
|
395 |
},
|
396 |
-
"outputs": [
|
397 |
-
{
|
398 |
-
"name": "stdout",
|
399 |
-
"output_type": "stream",
|
400 |
-
"text": [
|
401 |
-
"2025-01-22 13:19:34 Predictions: tensor([0, 0, 3, 6, 2, 4, 6], device='mps:0')\n"
|
402 |
-
]
|
403 |
-
}
|
404 |
-
],
|
405 |
"source": [
|
406 |
"model.eval()\n",
|
407 |
"test_text = [\n",
|
@@ -429,6 +434,14 @@
|
|
429 |
" my_print(f\"Predictions: {predictions}\")"
|
430 |
]
|
431 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
432 |
{
|
433 |
"cell_type": "markdown",
|
434 |
"id": "0c3ea938-dd87-4673-b1d6-f06c70b19455",
|
|
|
10 |
},
|
11 |
{
|
12 |
"cell_type": "code",
|
13 |
+
"execution_count": 1,
|
14 |
"id": "73e72549-69f2-46b5-b0f5-655777139972",
|
15 |
"metadata": {
|
16 |
"execution": {
|
17 |
+
"iopub.execute_input": "2025-01-24T18:21:58.280871Z",
|
18 |
+
"iopub.status.busy": "2025-01-24T18:21:58.280785Z",
|
19 |
+
"iopub.status.idle": "2025-01-24T18:22:01.627392Z",
|
20 |
+
"shell.execute_reply": "2025-01-24T18:22:01.627134Z",
|
21 |
+
"shell.execute_reply.started": "2025-01-24T18:21:58.280861Z"
|
22 |
}
|
23 |
},
|
24 |
"outputs": [],
|
|
|
45 |
"id": "07e0787e-c72b-41f3-baba-43cef3f8d6f8",
|
46 |
"metadata": {
|
47 |
"execution": {
|
48 |
+
"iopub.execute_input": "2025-01-24T18:22:01.628023Z",
|
49 |
+
"iopub.status.busy": "2025-01-24T18:22:01.627838Z",
|
50 |
+
"iopub.status.idle": "2025-01-24T18:22:01.629825Z",
|
51 |
+
"shell.execute_reply": "2025-01-24T18:22:01.629635Z",
|
52 |
+
"shell.execute_reply.started": "2025-01-24T18:22:01.628013Z"
|
53 |
}
|
54 |
},
|
55 |
"outputs": [],
|
|
|
67 |
},
|
68 |
{
|
69 |
"cell_type": "code",
|
70 |
+
"execution_count": 12,
|
71 |
"id": "d4b79fb9-5e70-4600-8885-94bc0a6e917c",
|
72 |
"metadata": {
|
73 |
"execution": {
|
74 |
+
"iopub.execute_input": "2025-01-24T18:23:58.768682Z",
|
75 |
+
"iopub.status.busy": "2025-01-24T18:23:58.768083Z",
|
76 |
+
"iopub.status.idle": "2025-01-24T18:23:58.787548Z",
|
77 |
+
"shell.execute_reply": "2025-01-24T18:23:58.786993Z",
|
78 |
+
"shell.execute_reply.started": "2025-01-24T18:23:58.768631Z"
|
79 |
}
|
80 |
},
|
81 |
"outputs": [],
|
|
|
146 |
"\n",
|
147 |
"class TextDataset(Dataset):\n",
|
148 |
" def __init__(self, texts, labels, tokenizer, max_length=256):\n",
|
149 |
+
" self.texts = texts\n",
|
150 |
" self.encodings = tokenizer(\n",
|
151 |
" texts,\n",
|
152 |
" truncation=True,\n",
|
|
|
196 |
},
|
197 |
{
|
198 |
"cell_type": "code",
|
199 |
+
"execution_count": 13,
|
200 |
"id": "07131bce-23ad-4787-8622-cce401f3e5ce",
|
201 |
"metadata": {
|
202 |
"execution": {
|
203 |
+
"iopub.execute_input": "2025-01-24T18:23:59.127835Z",
|
204 |
+
"iopub.status.busy": "2025-01-24T18:23:59.126787Z",
|
205 |
+
"iopub.status.idle": "2025-01-24T18:23:59.136440Z",
|
206 |
+
"shell.execute_reply": "2025-01-24T18:23:59.135267Z",
|
207 |
+
"shell.execute_reply.started": "2025-01-24T18:23:59.127791Z"
|
208 |
}
|
209 |
},
|
210 |
"outputs": [],
|
|
|
220 |
},
|
221 |
{
|
222 |
"cell_type": "code",
|
223 |
+
"execution_count": 14,
|
224 |
"id": "695bc080-bbd7-4937-af5b-50db1c936500",
|
225 |
"metadata": {
|
226 |
"execution": {
|
227 |
+
"iopub.execute_input": "2025-01-24T18:23:59.442432Z",
|
228 |
+
"iopub.status.busy": "2025-01-24T18:23:59.441786Z",
|
229 |
+
"iopub.status.idle": "2025-01-24T18:23:59.453218Z",
|
230 |
+
"shell.execute_reply": "2025-01-24T18:23:59.452473Z",
|
231 |
+
"shell.execute_reply.started": "2025-01-24T18:23:59.442367Z"
|
232 |
}
|
233 |
},
|
234 |
"outputs": [],
|
|
|
308 |
},
|
309 |
{
|
310 |
"cell_type": "code",
|
311 |
+
"execution_count": 15,
|
312 |
"id": "11890d3b-8bcb-4a9b-b421-5431081cca39",
|
313 |
"metadata": {
|
314 |
"execution": {
|
315 |
+
"iopub.execute_input": "2025-01-24T18:24:00.153856Z",
|
316 |
+
"iopub.status.busy": "2025-01-24T18:24:00.153044Z",
|
317 |
+
"iopub.status.idle": "2025-01-24T18:24:00.158876Z",
|
318 |
+
"shell.execute_reply": "2025-01-24T18:24:00.157762Z",
|
319 |
+
"shell.execute_reply.started": "2025-01-24T18:24:00.153804Z"
|
320 |
}
|
321 |
},
|
322 |
"outputs": [],
|
|
|
343 |
},
|
344 |
{
|
345 |
"cell_type": "code",
|
346 |
+
"execution_count": 16,
|
347 |
"id": "34a7c310-c486-4db1-b94d-4363c3d3df5b",
|
348 |
"metadata": {
|
349 |
"execution": {
|
350 |
+
"iopub.execute_input": "2025-01-24T18:24:00.721937Z",
|
351 |
+
"iopub.status.busy": "2025-01-24T18:24:00.721190Z",
|
352 |
+
"iopub.status.idle": "2025-01-24T18:24:06.157768Z",
|
353 |
+
"shell.execute_reply": "2025-01-24T18:24:06.157299Z",
|
354 |
+
"shell.execute_reply.started": "2025-01-24T18:24:00.721894Z"
|
355 |
}
|
356 |
},
|
357 |
"outputs": [
|
|
|
359 |
"name": "stdout",
|
360 |
"output_type": "stream",
|
361 |
"text": [
|
362 |
+
"4872 1219\n",
|
363 |
+
"8 8\n"
|
364 |
+
]
|
365 |
+
},
|
366 |
+
{
|
367 |
+
"ename": "KeyboardInterrupt",
|
368 |
+
"evalue": "",
|
369 |
+
"output_type": "error",
|
370 |
+
"traceback": [
|
371 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
372 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
373 |
+
"Cell \u001b[0;32mIn[16], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m model, tokenizer, regime, metrics \u001b[38;5;241m=\u001b[39m \u001b[43mrun_training\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_dataset_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m16\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m100\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mbert_variety\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbase_model_repo\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_length\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m128\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m16\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m)\u001b[49m\n",
|
374 |
+
"Cell \u001b[0;32mIn[14], line 62\u001b[0m, in \u001b[0;36mrun_training\u001b[0;34m(max_dataset_size, bert_variety, max_length, num_epochs, batch_size)\u001b[0m\n\u001b[1;32m 55\u001b[0m dataloader_train \u001b[38;5;241m=\u001b[39m DataLoader(\n\u001b[1;32m 56\u001b[0m text_dataset_train, batch_size\u001b[38;5;241m=\u001b[39mbatch_size, shuffle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 57\u001b[0m )\n\u001b[1;32m 58\u001b[0m dataloader_test \u001b[38;5;241m=\u001b[39m DataLoader(\n\u001b[1;32m 59\u001b[0m text_dataset_test, batch_size\u001b[38;5;241m=\u001b[39mbatch_size, shuffle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m 60\u001b[0m )\n\u001b[0;32m---> 62\u001b[0m metrics \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_model\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 63\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataloader_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataloader_test\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_epochs\u001b[49m\n\u001b[1;32m 64\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 65\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m model, tokenizer, training_regime, metrics\n",
|
375 |
+
"Cell \u001b[0;32mIn[12], line 91\u001b[0m, in \u001b[0;36mtrain_model\u001b[0;34m(model, train_dataloader, test_dataloader, device, num_epochs)\u001b[0m\n\u001b[1;32m 88\u001b[0m criterion \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mCrossEntropyLoss()\n\u001b[1;32m 89\u001b[0m model\u001b[38;5;241m.\u001b[39mtrain()\n\u001b[0;32m---> 91\u001b[0m _ \u001b[38;5;241m=\u001b[39m \u001b[43mprint_model_status\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_dataloader\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 92\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(num_epochs):\n\u001b[1;32m 93\u001b[0m total_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n",
|
376 |
+
"Cell \u001b[0;32mIn[12], line 34\u001b[0m, in \u001b[0;36mprint_model_status\u001b[0;34m(epoch, num_epochs, model, train_dataloader, test_dataloader)\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mprint_model_status\u001b[39m(epoch, num_epochs, model, train_dataloader, test_dataloader):\n\u001b[0;32m---> 34\u001b[0m train_loss, train_acc \u001b[38;5;241m=\u001b[39m \u001b[43mmodel_metrics\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 35\u001b[0m test_loss, test_acc \u001b[38;5;241m=\u001b[39m model_metrics(model, test_dataloader)\n\u001b[1;32m 36\u001b[0m loss_str \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLoss: Train \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrain_loss\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m0.3f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Test \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtest_loss\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m0.3f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n",
|
377 |
+
"Cell \u001b[0;32mIn[12], line 20\u001b[0m, in \u001b[0;36mmodel_metrics\u001b[0;34m(model, dataloader)\u001b[0m\n\u001b[1;32m 18\u001b[0m outputs \u001b[38;5;241m=\u001b[39m model(input_ids, attention_mask)\n\u001b[1;32m 19\u001b[0m loss \u001b[38;5;241m=\u001b[39m criterion(outputs, labels)\n\u001b[0;32m---> 20\u001b[0m predictions_cpu \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43margmax\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcpu\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mnumpy()\n\u001b[1;32m 21\u001b[0m labels_cpu \u001b[38;5;241m=\u001b[39m labels\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mnumpy()\n\u001b[1;32m 22\u001b[0m correct_count \u001b[38;5;241m=\u001b[39m (predictions_cpu \u001b[38;5;241m==\u001b[39m labels_cpu)\u001b[38;5;241m.\u001b[39msum()\n",
|
378 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
379 |
]
|
380 |
}
|
381 |
],
|
|
|
391 |
},
|
392 |
{
|
393 |
"cell_type": "code",
|
394 |
+
"execution_count": null,
|
395 |
"id": "0aedfcca-843e-4f4c-8062-3e4625161bcc",
|
396 |
"metadata": {
|
397 |
"editable": true,
|
398 |
"execution": {
|
399 |
+
"iopub.status.busy": "2025-01-24T18:24:06.157956Z",
|
400 |
+
"iopub.status.idle": "2025-01-24T18:24:06.158060Z",
|
401 |
+
"shell.execute_reply": "2025-01-24T18:24:06.158008Z",
|
402 |
+
"shell.execute_reply.started": "2025-01-24T18:24:06.158002Z"
|
|
|
403 |
},
|
404 |
"slideshow": {
|
405 |
"slide_type": ""
|
406 |
},
|
407 |
"tags": []
|
408 |
},
|
409 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
"source": [
|
411 |
"model.eval()\n",
|
412 |
"test_text = [\n",
|
|
|
434 |
" my_print(f\"Predictions: {predictions}\")"
|
435 |
]
|
436 |
},
|
437 |
+
{
|
438 |
+
"cell_type": "code",
|
439 |
+
"execution_count": null,
|
440 |
+
"id": "1201bf29-5040-4317-be30-77bec0bfe5b4",
|
441 |
+
"metadata": {},
|
442 |
+
"outputs": [],
|
443 |
+
"source": []
|
444 |
+
},
|
445 |
{
|
446 |
"cell_type": "markdown",
|
447 |
"id": "0c3ea938-dd87-4673-b1d6-f06c70b19455",
|
tasks/text.py
CHANGED
@@ -27,6 +27,7 @@ ROUTE = "/text"
|
|
27 |
|
28 |
class TextDataset(Dataset):
|
29 |
def __init__(self, texts, tokenizer, max_length=256):
|
|
|
30 |
self.encodings = tokenizer(
|
31 |
texts,
|
32 |
truncation=True,
|
|
|
27 |
|
28 |
class TextDataset(Dataset):
|
29 |
def __init__(self, texts, tokenizer, max_length=256):
|
30 |
+
self.texts = texts
|
31 |
self.encodings = tokenizer(
|
32 |
texts,
|
33 |
truncation=True,
|