ketanmore commited on
Commit
51ed342
·
verified ·
1 Parent(s): e831476

Delete temp_test.ipynb

Browse files
Files changed (1) hide show
  1. temp_test.ipynb +0 -1642
temp_test.ipynb DELETED
@@ -1,1642 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "import os\n",
10
- "os.environ['HF_HOME'] = '/data2/ketan/orc/HF_Cache'\n",
11
- "\n",
12
- "import torch\n",
13
- "import torch.nn as nn\n",
14
- "import torch.optim as optim\n",
15
- "from torch.utils.data import DataLoader\n",
16
- "from transformers import SegformerConfig\n",
17
- "from surya.model.detection.segformer import SegformerForRegressionMask\n",
18
- "from surya.input.processing import prepare_image_detection\n",
19
- "from surya.model.detection.segformer import load_processor , load_model\n",
20
- "from datasets import load_dataset\n",
21
- "from tqdm import tqdm\n",
22
- "from torch.utils.tensorboard import SummaryWriter\n",
23
- "\n",
24
- "import torch.nn.functional as F\n",
25
- "import numpy as np \n",
26
- "from PIL import ImageDraw, ImageFont\n",
27
- "from surya.layout import parallel_get_regions\n",
28
- "import cv2"
29
- ]
30
- },
31
- {
32
- "cell_type": "code",
33
- "execution_count": 2,
34
- "metadata": {},
35
- "outputs": [
36
- {
37
- "name": "stdout",
38
- "output_type": "stream",
39
- "text": [
40
- "Loaded detection model vikp/surya_layout2 on device cuda with dtype torch.float16\n"
41
- ]
42
- }
43
- ],
44
- "source": [
45
- "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
46
- "\n",
47
- "\n",
48
- "dataset = load_dataset(\"vikp/publaynet_bench\", split=\"train[:100]\")\n",
49
- "train_loader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=lambda x: x)\n",
50
- "\n",
51
- "\n",
52
- "model = load_model(\"vikp/surya_layout2\")"
53
- ]
54
- },
55
- {
56
- "cell_type": "code",
57
- "execution_count": 17,
58
- "metadata": {},
59
- "outputs": [],
60
- "source": [
61
- "\n",
62
- "optimizer = optim.Adam(model.parameters(), lr=0.00001)\n",
63
- "\n",
64
- "# Logging and Checkpoints\n",
65
- "log_dir = \"logs\"\n",
66
- "checkpoint_dir = \"checkpoints\"\n",
67
- "os.makedirs(log_dir, exist_ok=True)\n",
68
- "os.makedirs(checkpoint_dir, exist_ok=True)\n",
69
- "writer = SummaryWriter(log_dir=log_dir)\n",
70
- "\n",
71
- "def calculate_iou(box1, box2):\n",
72
- " box1 = torch.tensor(box1, dtype=torch.float32, requires_grad=True) if not isinstance(box1, torch.Tensor) else box1\n",
73
- " box2 = torch.tensor(box2, dtype=torch.float32, requires_grad=True) if not isinstance(box2, torch.Tensor) else box2\n",
74
- " \n",
75
- " x_min = torch.max(box1[0], box2[0])\n",
76
- " y_min = torch.max(box1[1], box2[1])\n",
77
- " x_max = torch.min(box1[2], box2[2])\n",
78
- " y_max = torch.min(box1[3], box2[3])\n",
79
- " \n",
80
- " intersection = torch.clamp(x_max - x_min, min=0) * torch.clamp(y_max - y_min, min=0)\n",
81
- " area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])\n",
82
- " area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])\n",
83
- " union = area1 + area2 - intersection\n",
84
- " \n",
85
- " iou = intersection / union if union > 0 else torch.tensor(0.0, requires_grad=True)\n",
86
- " \n",
87
- " return iou\n",
88
- "\n",
89
- "def pair_boxes(pred_boxes, target_boxes):\n",
90
- " pred_boxes = [torch.tensor(box, dtype=torch.float32, requires_grad=True) for box in pred_boxes]\n",
91
- " target_boxes = [torch.tensor(box, dtype=torch.float32, requires_grad=True) for box in target_boxes]\n",
92
- " \n",
93
- " matched_pred_boxes = []\n",
94
- " matched_target_boxes = []\n",
95
- " \n",
96
- " for target in target_boxes:\n",
97
- " best_iou = 0\n",
98
- " best_pred = None\n",
99
- " for pred in pred_boxes:\n",
100
- " iou = calculate_iou(pred, target)\n",
101
- " if iou > best_iou:\n",
102
- " best_iou = iou\n",
103
- " best_pred = pred\n",
104
- " \n",
105
- " if best_pred is not None:\n",
106
- " matched_pred_boxes.append(best_pred)\n",
107
- " matched_target_boxes.append(target)\n",
108
- " pred_boxes = [p for p in pred_boxes if not torch.equal(p, best_pred)]\n",
109
- "\n",
110
- " return matched_pred_boxes, matched_target_boxes\n",
111
- "\n",
112
- "def smooth_l1_loss(pred_boxes, target_boxes, beta=1.0):\n",
113
- " matched_pred_boxes, matched_target_boxes = pair_boxes(pred_boxes, target_boxes)\n",
114
- " \n",
115
- " if len(matched_pred_boxes) == 0:\n",
116
- " return torch.tensor(0.0, requires_grad=True)\n",
117
- " \n",
118
- " diff = torch.abs(torch.stack(matched_pred_boxes) - torch.stack(matched_target_boxes))\n",
119
- " loss = torch.where(diff < beta, 0.5 * (diff ** 2) / beta, diff - 0.5 * beta)\n",
120
- " return loss.mean()\n",
121
- "\n",
122
- "def iou_loss(pred_boxes, target_boxes):\n",
123
- " matched_pred_boxes, matched_target_boxes = pair_boxes(pred_boxes, target_boxes)\n",
124
- " \n",
125
- " if len(matched_pred_boxes) == 0:\n",
126
- " return torch.tensor(1.0, requires_grad=True)\n",
127
- " \n",
128
- " ious = [calculate_iou(pred, target) for pred, target in zip(matched_pred_boxes, matched_target_boxes)]\n",
129
- " return 1 - torch.mean(torch.tensor(ious, requires_grad=True))\n",
130
- "\n",
131
- "\n",
132
- "def logits_to_bboxes(logits,image) :\n",
133
- " correct_shape = (300, 300) \n",
134
- " logits_temp = F.interpolate(logits, size=correct_shape, mode='bilinear', align_corners=False)\n",
135
- " logits_temp = logits_temp.cpu().detach().numpy().astype(np.float32)\n",
136
- "\n",
137
- " heatmap_count = logits_temp.shape[1]\n",
138
- " heatmaps = [logits_temp[i][k] for i in range(logits_temp.shape[0]) for k in range(heatmap_count)]\n",
139
- " regions = parallel_get_regions(heatmaps=heatmaps, orig_size=image.size, id2label=model.config.id2label)\n",
140
- "\n",
141
- " final_bboxes = []\n",
142
- " for i in regions.bboxes :\n",
143
- " final_bboxes.append(i.bbox)\n",
144
- " return final_bboxes\n"
145
- ]
146
- },
147
- {
148
- "cell_type": "code",
149
- "execution_count": 18,
150
- "metadata": {},
151
- "outputs": [
152
- {
153
- "name": "stderr",
154
- "output_type": "stream",
155
- "text": [
156
- "Epoch 1/1: 1%| | 1/100 [00:00<00:41, 2.41it/s]"
157
- ]
158
- },
159
- {
160
- "name": "stdout",
161
- "output_type": "stream",
162
- "text": [
163
- "Epoch [1/1], Step [1/100], Moving Avg Loss: 16.4575\n"
164
- ]
165
- },
166
- {
167
- "name": "stderr",
168
- "output_type": "stream",
169
- "text": [
170
- "Epoch 1/1: 2%|▏ | 2/100 [00:00<00:41, 2.34it/s]"
171
- ]
172
- },
173
- {
174
- "name": "stdout",
175
- "output_type": "stream",
176
- "text": [
177
- "Epoch [1/1], Step [2/100], Moving Avg Loss: 15.4938\n"
178
- ]
179
- },
180
- {
181
- "name": "stderr",
182
- "output_type": "stream",
183
- "text": [
184
- "Epoch 1/1: 3%|▎ | 3/100 [00:01<00:40, 2.42it/s]"
185
- ]
186
- },
187
- {
188
- "name": "stdout",
189
- "output_type": "stream",
190
- "text": [
191
- "Epoch [1/1], Step [3/100], Moving Avg Loss: 18.9512\n"
192
- ]
193
- },
194
- {
195
- "name": "stderr",
196
- "output_type": "stream",
197
- "text": [
198
- "Epoch 1/1: 4%|▍ | 4/100 [00:01<00:39, 2.40it/s]"
199
- ]
200
- },
201
- {
202
- "name": "stdout",
203
- "output_type": "stream",
204
- "text": [
205
- "Epoch [1/1], Step [4/100], Moving Avg Loss: 18.3995\n"
206
- ]
207
- },
208
- {
209
- "name": "stderr",
210
- "output_type": "stream",
211
- "text": [
212
- "Epoch 1/1: 5%|▌ | 5/100 [00:02<00:39, 2.43it/s]"
213
- ]
214
- },
215
- {
216
- "name": "stdout",
217
- "output_type": "stream",
218
- "text": [
219
- "Epoch [1/1], Step [5/100], Moving Avg Loss: 20.1250\n"
220
- ]
221
- },
222
- {
223
- "name": "stderr",
224
- "output_type": "stream",
225
- "text": [
226
- "Epoch 1/1: 6%|▌ | 6/100 [00:02<00:39, 2.38it/s]"
227
- ]
228
- },
229
- {
230
- "name": "stdout",
231
- "output_type": "stream",
232
- "text": [
233
- "Epoch [1/1], Step [6/100], Moving Avg Loss: 18.9854\n"
234
- ]
235
- },
236
- {
237
- "name": "stderr",
238
- "output_type": "stream",
239
- "text": [
240
- "Epoch 1/1: 7%|▋ | 7/100 [00:02<00:38, 2.40it/s]"
241
- ]
242
- },
243
- {
244
- "name": "stdout",
245
- "output_type": "stream",
246
- "text": [
247
- "Epoch [1/1], Step [7/100], Moving Avg Loss: 18.4753\n"
248
- ]
249
- },
250
- {
251
- "name": "stderr",
252
- "output_type": "stream",
253
- "text": [
254
- "Epoch 1/1: 8%|▊ | 8/100 [00:03<00:40, 2.27it/s]"
255
- ]
256
- },
257
- {
258
- "name": "stdout",
259
- "output_type": "stream",
260
- "text": [
261
- "Epoch [1/1], Step [8/100], Moving Avg Loss: 17.0382\n"
262
- ]
263
- },
264
- {
265
- "name": "stderr",
266
- "output_type": "stream",
267
- "text": [
268
- "Epoch 1/1: 9%|▉ | 9/100 [00:03<00:41, 2.20it/s]"
269
- ]
270
- },
271
- {
272
- "name": "stdout",
273
- "output_type": "stream",
274
- "text": [
275
- "Epoch [1/1], Step [9/100], Moving Avg Loss: 17.7276\n"
276
- ]
277
- },
278
- {
279
- "name": "stderr",
280
- "output_type": "stream",
281
- "text": [
282
- "Epoch 1/1: 10%|█ | 10/100 [00:04<00:40, 2.23it/s]"
283
- ]
284
- },
285
- {
286
- "name": "stdout",
287
- "output_type": "stream",
288
- "text": [
289
- "Epoch [1/1], Step [10/100], Moving Avg Loss: 19.5423\n"
290
- ]
291
- },
292
- {
293
- "name": "stderr",
294
- "output_type": "stream",
295
- "text": [
296
- "Epoch 1/1: 11%|█ | 11/100 [00:04<00:39, 2.27it/s]"
297
- ]
298
- },
299
- {
300
- "name": "stdout",
301
- "output_type": "stream",
302
- "text": [
303
- "Epoch [1/1], Step [11/100], Moving Avg Loss: 18.4347\n"
304
- ]
305
- },
306
- {
307
- "name": "stderr",
308
- "output_type": "stream",
309
- "text": [
310
- "Epoch 1/1: 12%|█▏ | 12/100 [00:05<00:40, 2.20it/s]"
311
- ]
312
- },
313
- {
314
- "name": "stdout",
315
- "output_type": "stream",
316
- "text": [
317
- "Epoch [1/1], Step [12/100], Moving Avg Loss: 17.3114\n"
318
- ]
319
- },
320
- {
321
- "name": "stderr",
322
- "output_type": "stream",
323
- "text": [
324
- "Epoch 1/1: 13%|█▎ | 13/100 [00:05<00:39, 2.21it/s]"
325
- ]
326
- },
327
- {
328
- "name": "stdout",
329
- "output_type": "stream",
330
- "text": [
331
- "Epoch [1/1], Step [13/100], Moving Avg Loss: 16.2870\n"
332
- ]
333
- },
334
- {
335
- "name": "stderr",
336
- "output_type": "stream",
337
- "text": [
338
- "Epoch 1/1: 14%|█▍ | 14/100 [00:06<00:37, 2.29it/s]"
339
- ]
340
- },
341
- {
342
- "name": "stdout",
343
- "output_type": "stream",
344
- "text": [
345
- "Epoch [1/1], Step [14/100], Moving Avg Loss: 21.5170\n"
346
- ]
347
- },
348
- {
349
- "name": "stderr",
350
- "output_type": "stream",
351
- "text": [
352
- "Epoch 1/1: 15%|█▌ | 15/100 [00:06<00:36, 2.30it/s]"
353
- ]
354
- },
355
- {
356
- "name": "stdout",
357
- "output_type": "stream",
358
- "text": [
359
- "Epoch [1/1], Step [15/100], Moving Avg Loss: 21.3559\n"
360
- ]
361
- },
362
- {
363
- "name": "stderr",
364
- "output_type": "stream",
365
- "text": [
366
- "Epoch 1/1: 16%|█▌ | 16/100 [00:06<00:36, 2.29it/s]"
367
- ]
368
- },
369
- {
370
- "name": "stdout",
371
- "output_type": "stream",
372
- "text": [
373
- "Epoch [1/1], Step [16/100], Moving Avg Loss: 19.8276\n"
374
- ]
375
- },
376
- {
377
- "name": "stderr",
378
- "output_type": "stream",
379
- "text": [
380
- "Epoch 1/1: 17%|█▋ | 17/100 [00:07<00:35, 2.36it/s]"
381
- ]
382
- },
383
- {
384
- "name": "stdout",
385
- "output_type": "stream",
386
- "text": [
387
- "Epoch [1/1], Step [17/100], Moving Avg Loss: 18.9123\n"
388
- ]
389
- },
390
- {
391
- "name": "stderr",
392
- "output_type": "stream",
393
- "text": [
394
- "Epoch 1/1: 18%|█▊ | 18/100 [00:07<00:34, 2.38it/s]"
395
- ]
396
- },
397
- {
398
- "name": "stdout",
399
- "output_type": "stream",
400
- "text": [
401
- "Epoch [1/1], Step [18/100], Moving Avg Loss: 23.7418\n"
402
- ]
403
- },
404
- {
405
- "name": "stderr",
406
- "output_type": "stream",
407
- "text": [
408
- "Epoch 1/1: 19%|█▉ | 19/100 [00:08<00:34, 2.36it/s]"
409
- ]
410
- },
411
- {
412
- "name": "stdout",
413
- "output_type": "stream",
414
- "text": [
415
- "Epoch [1/1], Step [19/100], Moving Avg Loss: 22.2312\n"
416
- ]
417
- },
418
- {
419
- "name": "stderr",
420
- "output_type": "stream",
421
- "text": [
422
- "Epoch 1/1: 20%|██ | 20/100 [00:08<00:34, 2.35it/s]"
423
- ]
424
- },
425
- {
426
- "name": "stdout",
427
- "output_type": "stream",
428
- "text": [
429
- "Epoch [1/1], Step [20/100], Moving Avg Loss: 21.1758\n"
430
- ]
431
- },
432
- {
433
- "name": "stderr",
434
- "output_type": "stream",
435
- "text": [
436
- "Epoch 1/1: 21%|██ | 21/100 [00:09<00:32, 2.40it/s]"
437
- ]
438
- },
439
- {
440
- "name": "stdout",
441
- "output_type": "stream",
442
- "text": [
443
- "Epoch [1/1], Step [21/100], Moving Avg Loss: 24.8048\n"
444
- ]
445
- },
446
- {
447
- "name": "stderr",
448
- "output_type": "stream",
449
- "text": [
450
- "Epoch 1/1: 22%|██▏ | 22/100 [00:09<00:32, 2.40it/s]"
451
- ]
452
- },
453
- {
454
- "name": "stdout",
455
- "output_type": "stream",
456
- "text": [
457
- "Epoch [1/1], Step [22/100], Moving Avg Loss: 27.5316\n"
458
- ]
459
- },
460
- {
461
- "name": "stderr",
462
- "output_type": "stream",
463
- "text": [
464
- "Epoch 1/1: 23%|██▎ | 23/100 [00:09<00:32, 2.39it/s]"
465
- ]
466
- },
467
- {
468
- "name": "stdout",
469
- "output_type": "stream",
470
- "text": [
471
- "Epoch [1/1], Step [23/100], Moving Avg Loss: 27.4807\n"
472
- ]
473
- },
474
- {
475
- "name": "stderr",
476
- "output_type": "stream",
477
- "text": [
478
- "Epoch 1/1: 24%|██▍ | 24/100 [00:10<00:31, 2.41it/s]"
479
- ]
480
- },
481
- {
482
- "name": "stdout",
483
- "output_type": "stream",
484
- "text": [
485
- "Epoch [1/1], Step [24/100], Moving Avg Loss: 25.2076\n"
486
- ]
487
- },
488
- {
489
- "name": "stderr",
490
- "output_type": "stream",
491
- "text": [
492
- "Epoch 1/1: 25%|██▌ | 25/100 [00:10<00:30, 2.42it/s]"
493
- ]
494
- },
495
- {
496
- "name": "stdout",
497
- "output_type": "stream",
498
- "text": [
499
- "Epoch [1/1], Step [25/100], Moving Avg Loss: 23.2897\n"
500
- ]
501
- },
502
- {
503
- "name": "stderr",
504
- "output_type": "stream",
505
- "text": [
506
- "Epoch 1/1: 26%|██▌ | 26/100 [00:11<00:30, 2.40it/s]"
507
- ]
508
- },
509
- {
510
- "name": "stdout",
511
- "output_type": "stream",
512
- "text": [
513
- "Epoch [1/1], Step [26/100], Moving Avg Loss: 22.1549\n"
514
- ]
515
- },
516
- {
517
- "name": "stderr",
518
- "output_type": "stream",
519
- "text": [
520
- "Epoch 1/1: 27%|██▋ | 27/100 [00:11<00:29, 2.43it/s]"
521
- ]
522
- },
523
- {
524
- "name": "stdout",
525
- "output_type": "stream",
526
- "text": [
527
- "Epoch [1/1], Step [27/100], Moving Avg Loss: 21.9602\n"
528
- ]
529
- },
530
- {
531
- "name": "stderr",
532
- "output_type": "stream",
533
- "text": [
534
- "Epoch 1/1: 28%|██▊ | 28/100 [00:11<00:29, 2.43it/s]"
535
- ]
536
- },
537
- {
538
- "name": "stdout",
539
- "output_type": "stream",
540
- "text": [
541
- "Epoch [1/1], Step [28/100], Moving Avg Loss: 23.2106\n"
542
- ]
543
- },
544
- {
545
- "name": "stderr",
546
- "output_type": "stream",
547
- "text": [
548
- "Epoch 1/1: 29%|██▉ | 29/100 [00:12<00:30, 2.33it/s]"
549
- ]
550
- },
551
- {
552
- "name": "stdout",
553
- "output_type": "stream",
554
- "text": [
555
- "Epoch [1/1], Step [29/100], Moving Avg Loss: 21.3036\n"
556
- ]
557
- },
558
- {
559
- "name": "stderr",
560
- "output_type": "stream",
561
- "text": [
562
- "Epoch 1/1: 30%|███ | 30/100 [00:12<00:30, 2.31it/s]"
563
- ]
564
- },
565
- {
566
- "name": "stdout",
567
- "output_type": "stream",
568
- "text": [
569
- "Epoch [1/1], Step [30/100], Moving Avg Loss: 22.1421\n"
570
- ]
571
- },
572
- {
573
- "name": "stderr",
574
- "output_type": "stream",
575
- "text": [
576
- "Epoch 1/1: 31%|███ | 31/100 [00:13<00:29, 2.35it/s]"
577
- ]
578
- },
579
- {
580
- "name": "stdout",
581
- "output_type": "stream",
582
- "text": [
583
- "Epoch [1/1], Step [31/100], Moving Avg Loss: 27.1543\n"
584
- ]
585
- },
586
- {
587
- "name": "stderr",
588
- "output_type": "stream",
589
- "text": [
590
- "Epoch 1/1: 32%|███▏ | 32/100 [00:13<00:29, 2.34it/s]"
591
- ]
592
- },
593
- {
594
- "name": "stdout",
595
- "output_type": "stream",
596
- "text": [
597
- "Epoch [1/1], Step [32/100], Moving Avg Loss: 27.6630\n"
598
- ]
599
- },
600
- {
601
- "name": "stderr",
602
- "output_type": "stream",
603
- "text": [
604
- "Epoch 1/1: 33%|███▎ | 33/100 [00:14<00:28, 2.39it/s]"
605
- ]
606
- },
607
- {
608
- "name": "stdout",
609
- "output_type": "stream",
610
- "text": [
611
- "Epoch [1/1], Step [33/100], Moving Avg Loss: 25.8453\n"
612
- ]
613
- },
614
- {
615
- "name": "stderr",
616
- "output_type": "stream",
617
- "text": [
618
- "Epoch 1/1: 34%|███▍ | 34/100 [00:14<00:27, 2.41it/s]"
619
- ]
620
- },
621
- {
622
- "name": "stdout",
623
- "output_type": "stream",
624
- "text": [
625
- "Epoch [1/1], Step [34/100], Moving Avg Loss: 27.6460\n"
626
- ]
627
- },
628
- {
629
- "name": "stderr",
630
- "output_type": "stream",
631
- "text": [
632
- "Epoch 1/1: 35%|███▌ | 35/100 [00:14<00:26, 2.44it/s]"
633
- ]
634
- },
635
- {
636
- "name": "stdout",
637
- "output_type": "stream",
638
- "text": [
639
- "Epoch [1/1], Step [35/100], Moving Avg Loss: 25.1319\n"
640
- ]
641
- },
642
- {
643
- "name": "stderr",
644
- "output_type": "stream",
645
- "text": [
646
- "Epoch 1/1: 36%|███▌ | 36/100 [00:15<00:26, 2.45it/s]"
647
- ]
648
- },
649
- {
650
- "name": "stdout",
651
- "output_type": "stream",
652
- "text": [
653
- "Epoch [1/1], Step [36/100], Moving Avg Loss: 25.8555\n"
654
- ]
655
- },
656
- {
657
- "name": "stderr",
658
- "output_type": "stream",
659
- "text": [
660
- "Epoch 1/1: 37%|███▋ | 37/100 [00:15<00:25, 2.45it/s]"
661
- ]
662
- },
663
- {
664
- "name": "stdout",
665
- "output_type": "stream",
666
- "text": [
667
- "Epoch [1/1], Step [37/100], Moving Avg Loss: 28.9348\n"
668
- ]
669
- },
670
- {
671
- "name": "stderr",
672
- "output_type": "stream",
673
- "text": [
674
- "Epoch 1/1: 38%|███▊ | 38/100 [00:16<00:25, 2.42it/s]"
675
- ]
676
- },
677
- {
678
- "name": "stdout",
679
- "output_type": "stream",
680
- "text": [
681
- "Epoch [1/1], Step [38/100], Moving Avg Loss: 28.3364\n"
682
- ]
683
- },
684
- {
685
- "name": "stderr",
686
- "output_type": "stream",
687
- "text": [
688
- "Epoch 1/1: 39%|███▉ | 39/100 [00:16<00:24, 2.45it/s]"
689
- ]
690
- },
691
- {
692
- "name": "stdout",
693
- "output_type": "stream",
694
- "text": [
695
- "Epoch [1/1], Step [39/100], Moving Avg Loss: 26.0808\n"
696
- ]
697
- },
698
- {
699
- "name": "stderr",
700
- "output_type": "stream",
701
- "text": [
702
- "Epoch 1/1: 40%|████ | 40/100 [00:16<00:24, 2.43it/s]"
703
- ]
704
- },
705
- {
706
- "name": "stdout",
707
- "output_type": "stream",
708
- "text": [
709
- "Epoch [1/1], Step [40/100], Moving Avg Loss: 26.2237\n"
710
- ]
711
- },
712
- {
713
- "name": "stderr",
714
- "output_type": "stream",
715
- "text": [
716
- "Epoch 1/1: 41%|████ | 41/100 [00:17<00:24, 2.43it/s]"
717
- ]
718
- },
719
- {
720
- "name": "stdout",
721
- "output_type": "stream",
722
- "text": [
723
- "Epoch [1/1], Step [41/100], Moving Avg Loss: 26.2728\n"
724
- ]
725
- },
726
- {
727
- "name": "stderr",
728
- "output_type": "stream",
729
- "text": [
730
- "Epoch 1/1: 42%|████▏ | 42/100 [00:17<00:23, 2.45it/s]"
731
- ]
732
- },
733
- {
734
- "name": "stdout",
735
- "output_type": "stream",
736
- "text": [
737
- "Epoch [1/1], Step [42/100], Moving Avg Loss: 26.7065\n"
738
- ]
739
- },
740
- {
741
- "name": "stderr",
742
- "output_type": "stream",
743
- "text": [
744
- "Epoch 1/1: 43%|████▎ | 43/100 [00:18<00:23, 2.39it/s]"
745
- ]
746
- },
747
- {
748
- "name": "stdout",
749
- "output_type": "stream",
750
- "text": [
751
- "Epoch [1/1], Step [43/100], Moving Avg Loss: 24.7438\n"
752
- ]
753
- },
754
- {
755
- "name": "stderr",
756
- "output_type": "stream",
757
- "text": [
758
- "Epoch 1/1: 44%|████▍ | 44/100 [00:18<00:23, 2.41it/s]"
759
- ]
760
- },
761
- {
762
- "name": "stdout",
763
- "output_type": "stream",
764
- "text": [
765
- "Epoch [1/1], Step [44/100], Moving Avg Loss: 26.6885\n"
766
- ]
767
- },
768
- {
769
- "name": "stderr",
770
- "output_type": "stream",
771
- "text": [
772
- "Epoch 1/1: 45%|████▌ | 45/100 [00:18<00:22, 2.41it/s]"
773
- ]
774
- },
775
- {
776
- "name": "stdout",
777
- "output_type": "stream",
778
- "text": [
779
- "Epoch [1/1], Step [45/100], Moving Avg Loss: 27.7764\n"
780
- ]
781
- },
782
- {
783
- "name": "stderr",
784
- "output_type": "stream",
785
- "text": [
786
- "Epoch 1/1: 46%|████▌ | 46/100 [00:19<00:22, 2.44it/s]"
787
- ]
788
- },
789
- {
790
- "name": "stdout",
791
- "output_type": "stream",
792
- "text": [
793
- "Epoch [1/1], Step [46/100], Moving Avg Loss: 25.7708\n"
794
- ]
795
- },
796
- {
797
- "name": "stderr",
798
- "output_type": "stream",
799
- "text": [
800
- "Epoch 1/1: 47%|████▋ | 47/100 [00:19<00:22, 2.36it/s]"
801
- ]
802
- },
803
- {
804
- "name": "stdout",
805
- "output_type": "stream",
806
- "text": [
807
- "Epoch [1/1], Step [47/100], Moving Avg Loss: 23.6295\n"
808
- ]
809
- },
810
- {
811
- "name": "stderr",
812
- "output_type": "stream",
813
- "text": [
814
- "Epoch 1/1: 48%|████▊ | 48/100 [00:20<00:21, 2.40it/s]"
815
- ]
816
- },
817
- {
818
- "name": "stdout",
819
- "output_type": "stream",
820
- "text": [
821
- "Epoch [1/1], Step [48/100], Moving Avg Loss: 23.5793\n"
822
- ]
823
- },
824
- {
825
- "name": "stderr",
826
- "output_type": "stream",
827
- "text": [
828
- "Epoch 1/1: 49%|████▉ | 49/100 [00:20<00:20, 2.44it/s]"
829
- ]
830
- },
831
- {
832
- "name": "stdout",
833
- "output_type": "stream",
834
- "text": [
835
- "Epoch [1/1], Step [49/100], Moving Avg Loss: 22.1319\n"
836
- ]
837
- },
838
- {
839
- "name": "stderr",
840
- "output_type": "stream",
841
- "text": [
842
- "Epoch 1/1: 50%|█████ | 50/100 [00:21<00:20, 2.46it/s]"
843
- ]
844
- },
845
- {
846
- "name": "stdout",
847
- "output_type": "stream",
848
- "text": [
849
- "Epoch [1/1], Step [50/100], Moving Avg Loss: 21.5178\n"
850
- ]
851
- },
852
- {
853
- "name": "stderr",
854
- "output_type": "stream",
855
- "text": [
856
- "Epoch 1/1: 51%|█████ | 51/100 [00:21<00:19, 2.48it/s]"
857
- ]
858
- },
859
- {
860
- "name": "stdout",
861
- "output_type": "stream",
862
- "text": [
863
- "Epoch [1/1], Step [51/100], Moving Avg Loss: 22.2565\n"
864
- ]
865
- },
866
- {
867
- "name": "stderr",
868
- "output_type": "stream",
869
- "text": [
870
- "Epoch 1/1: 52%|█████▏ | 52/100 [00:21<00:19, 2.45it/s]"
871
- ]
872
- },
873
- {
874
- "name": "stdout",
875
- "output_type": "stream",
876
- "text": [
877
- "Epoch [1/1], Step [52/100], Moving Avg Loss: 24.8366\n"
878
- ]
879
- },
880
- {
881
- "name": "stderr",
882
- "output_type": "stream",
883
- "text": [
884
- "Epoch 1/1: 53%|█████▎ | 53/100 [00:22<00:19, 2.35it/s]"
885
- ]
886
- },
887
- {
888
- "name": "stdout",
889
- "output_type": "stream",
890
- "text": [
891
- "Epoch [1/1], Step [53/100], Moving Avg Loss: 23.3091\n"
892
- ]
893
- },
894
- {
895
- "name": "stderr",
896
- "output_type": "stream",
897
- "text": [
898
- "Epoch 1/1: 54%|█████▍ | 54/100 [00:22<00:19, 2.35it/s]"
899
- ]
900
- },
901
- {
902
- "name": "stdout",
903
- "output_type": "stream",
904
- "text": [
905
- "Epoch [1/1], Step [54/100], Moving Avg Loss: 22.1764\n"
906
- ]
907
- },
908
- {
909
- "name": "stderr",
910
- "output_type": "stream",
911
- "text": [
912
- "Epoch 1/1: 55%|█████▌ | 55/100 [00:23<00:19, 2.32it/s]"
913
- ]
914
- },
915
- {
916
- "name": "stdout",
917
- "output_type": "stream",
918
- "text": [
919
- "Epoch [1/1], Step [55/100], Moving Avg Loss: 22.5117\n"
920
- ]
921
- },
922
- {
923
- "name": "stderr",
924
- "output_type": "stream",
925
- "text": [
926
- "Epoch 1/1: 56%|█████▌ | 56/100 [00:23<00:18, 2.34it/s]"
927
- ]
928
- },
929
- {
930
- "name": "stdout",
931
- "output_type": "stream",
932
- "text": [
933
- "Epoch [1/1], Step [56/100], Moving Avg Loss: 23.7047\n"
934
- ]
935
- },
936
- {
937
- "name": "stderr",
938
- "output_type": "stream",
939
- "text": [
940
- "Epoch 1/1: 57%|█████▋ | 57/100 [00:24<00:18, 2.38it/s]"
941
- ]
942
- },
943
- {
944
- "name": "stdout",
945
- "output_type": "stream",
946
- "text": [
947
- "Epoch [1/1], Step [57/100], Moving Avg Loss: 24.7985\n"
948
- ]
949
- },
950
- {
951
- "name": "stderr",
952
- "output_type": "stream",
953
- "text": [
954
- "Epoch 1/1: 58%|█████▊ | 58/100 [00:24<00:17, 2.37it/s]"
955
- ]
956
- },
957
- {
958
- "name": "stdout",
959
- "output_type": "stream",
960
- "text": [
961
- "Epoch [1/1], Step [58/100], Moving Avg Loss: 25.7531\n"
962
- ]
963
- },
964
- {
965
- "name": "stderr",
966
- "output_type": "stream",
967
- "text": [
968
- "Epoch 1/1: 59%|█████▉ | 59/100 [00:24<00:17, 2.36it/s]"
969
- ]
970
- },
971
- {
972
- "name": "stdout",
973
- "output_type": "stream",
974
- "text": [
975
- "Epoch [1/1], Step [59/100], Moving Avg Loss: 24.8322\n"
976
- ]
977
- },
978
- {
979
- "name": "stderr",
980
- "output_type": "stream",
981
- "text": [
982
- "Epoch 1/1: 60%|██████ | 60/100 [00:25<00:17, 2.32it/s]"
983
- ]
984
- },
985
- {
986
- "name": "stdout",
987
- "output_type": "stream",
988
- "text": [
989
- "Epoch [1/1], Step [60/100], Moving Avg Loss: 24.5820\n"
990
- ]
991
- },
992
- {
993
- "name": "stderr",
994
- "output_type": "stream",
995
- "text": [
996
- "Epoch 1/1: 61%|██████ | 61/100 [00:25<00:16, 2.37it/s]"
997
- ]
998
- },
999
- {
1000
- "name": "stdout",
1001
- "output_type": "stream",
1002
- "text": [
1003
- "Epoch [1/1], Step [61/100], Moving Avg Loss: 29.7474\n"
1004
- ]
1005
- },
1006
- {
1007
- "name": "stderr",
1008
- "output_type": "stream",
1009
- "text": [
1010
- "Epoch 1/1: 62%|██████▏ | 62/100 [00:26<00:15, 2.39it/s]"
1011
- ]
1012
- },
1013
- {
1014
- "name": "stdout",
1015
- "output_type": "stream",
1016
- "text": [
1017
- "Epoch [1/1], Step [62/100], Moving Avg Loss: 30.3602\n"
1018
- ]
1019
- },
1020
- {
1021
- "name": "stderr",
1022
- "output_type": "stream",
1023
- "text": [
1024
- "Epoch 1/1: 63%|██████▎ | 63/100 [00:26<00:15, 2.42it/s]"
1025
- ]
1026
- },
1027
- {
1028
- "name": "stdout",
1029
- "output_type": "stream",
1030
- "text": [
1031
- "Epoch [1/1], Step [63/100], Moving Avg Loss: 29.7396\n"
1032
- ]
1033
- },
1034
- {
1035
- "name": "stderr",
1036
- "output_type": "stream",
1037
- "text": [
1038
- "Epoch 1/1: 64%|██████▍ | 64/100 [00:26<00:14, 2.44it/s]"
1039
- ]
1040
- },
1041
- {
1042
- "name": "stdout",
1043
- "output_type": "stream",
1044
- "text": [
1045
- "Epoch [1/1], Step [64/100], Moving Avg Loss: 27.3900\n"
1046
- ]
1047
- },
1048
- {
1049
- "name": "stderr",
1050
- "output_type": "stream",
1051
- "text": [
1052
- "Epoch 1/1: 65%|██████▌ | 65/100 [00:27<00:14, 2.44it/s]"
1053
- ]
1054
- },
1055
- {
1056
- "name": "stdout",
1057
- "output_type": "stream",
1058
- "text": [
1059
- "Epoch [1/1], Step [65/100], Moving Avg Loss: 25.9465\n"
1060
- ]
1061
- },
1062
- {
1063
- "name": "stderr",
1064
- "output_type": "stream",
1065
- "text": [
1066
- "Epoch 1/1: 66%|██████▌ | 66/100 [00:27<00:13, 2.44it/s]"
1067
- ]
1068
- },
1069
- {
1070
- "name": "stdout",
1071
- "output_type": "stream",
1072
- "text": [
1073
- "Epoch [1/1], Step [66/100], Moving Avg Loss: 24.0045\n"
1074
- ]
1075
- },
1076
- {
1077
- "name": "stderr",
1078
- "output_type": "stream",
1079
- "text": [
1080
- "Epoch 1/1: 67%|██████▋ | 67/100 [00:28<00:13, 2.42it/s]"
1081
- ]
1082
- },
1083
- {
1084
- "name": "stdout",
1085
- "output_type": "stream",
1086
- "text": [
1087
- "Epoch [1/1], Step [67/100], Moving Avg Loss: 22.4123\n"
1088
- ]
1089
- },
1090
- {
1091
- "name": "stderr",
1092
- "output_type": "stream",
1093
- "text": [
1094
- "Epoch 1/1: 68%|██████▊ | 68/100 [00:28<00:13, 2.39it/s]"
1095
- ]
1096
- },
1097
- {
1098
- "name": "stdout",
1099
- "output_type": "stream",
1100
- "text": [
1101
- "Epoch [1/1], Step [68/100], Moving Avg Loss: 21.4466\n"
1102
- ]
1103
- },
1104
- {
1105
- "name": "stderr",
1106
- "output_type": "stream",
1107
- "text": [
1108
- "Epoch 1/1: 69%|██████▉ | 69/100 [00:28<00:12, 2.44it/s]"
1109
- ]
1110
- },
1111
- {
1112
- "name": "stdout",
1113
- "output_type": "stream",
1114
- "text": [
1115
- "Epoch [1/1], Step [69/100], Moving Avg Loss: 21.5559\n"
1116
- ]
1117
- },
1118
- {
1119
- "name": "stderr",
1120
- "output_type": "stream",
1121
- "text": [
1122
- "Epoch 1/1: 70%|███████ | 70/100 [00:29<00:12, 2.47it/s]"
1123
- ]
1124
- },
1125
- {
1126
- "name": "stdout",
1127
- "output_type": "stream",
1128
- "text": [
1129
- "Epoch [1/1], Step [70/100], Moving Avg Loss: 20.5609\n"
1130
- ]
1131
- },
1132
- {
1133
- "name": "stderr",
1134
- "output_type": "stream",
1135
- "text": [
1136
- "Epoch 1/1: 71%|███████ | 71/100 [00:29<00:11, 2.48it/s]"
1137
- ]
1138
- },
1139
- {
1140
- "name": "stdout",
1141
- "output_type": "stream",
1142
- "text": [
1143
- "Epoch [1/1], Step [71/100], Moving Avg Loss: 19.9970\n"
1144
- ]
1145
- },
1146
- {
1147
- "name": "stderr",
1148
- "output_type": "stream",
1149
- "text": [
1150
- "Epoch 1/1: 72%|███████▏ | 72/100 [00:30<00:11, 2.39it/s]"
1151
- ]
1152
- },
1153
- {
1154
- "name": "stdout",
1155
- "output_type": "stream",
1156
- "text": [
1157
- "Epoch [1/1], Step [72/100], Moving Avg Loss: 20.6782\n"
1158
- ]
1159
- },
1160
- {
1161
- "name": "stderr",
1162
- "output_type": "stream",
1163
- "text": [
1164
- "Epoch 1/1: 73%|███████▎ | 73/100 [00:30<00:11, 2.41it/s]"
1165
- ]
1166
- },
1167
- {
1168
- "name": "stdout",
1169
- "output_type": "stream",
1170
- "text": [
1171
- "Epoch [1/1], Step [73/100], Moving Avg Loss: 23.7722\n"
1172
- ]
1173
- },
1174
- {
1175
- "name": "stderr",
1176
- "output_type": "stream",
1177
- "text": [
1178
- "Epoch 1/1: 74%|███████▍ | 74/100 [00:31<00:10, 2.40it/s]"
1179
- ]
1180
- },
1181
- {
1182
- "name": "stdout",
1183
- "output_type": "stream",
1184
- "text": [
1185
- "Epoch [1/1], Step [74/100], Moving Avg Loss: 22.6156\n"
1186
- ]
1187
- },
1188
- {
1189
- "name": "stderr",
1190
- "output_type": "stream",
1191
- "text": [
1192
- "Epoch 1/1: 75%|███████▌ | 75/100 [00:31<00:10, 2.44it/s]"
1193
- ]
1194
- },
1195
- {
1196
- "name": "stdout",
1197
- "output_type": "stream",
1198
- "text": [
1199
- "Epoch [1/1], Step [75/100], Moving Avg Loss: 27.7204\n"
1200
- ]
1201
- },
1202
- {
1203
- "name": "stderr",
1204
- "output_type": "stream",
1205
- "text": [
1206
- "Epoch 1/1: 76%|███████▌ | 76/100 [00:31<00:09, 2.42it/s]"
1207
- ]
1208
- },
1209
- {
1210
- "name": "stdout",
1211
- "output_type": "stream",
1212
- "text": [
1213
- "Epoch [1/1], Step [76/100], Moving Avg Loss: 27.3355\n"
1214
- ]
1215
- },
1216
- {
1217
- "name": "stderr",
1218
- "output_type": "stream",
1219
- "text": [
1220
- "Epoch 1/1: 77%|███████▋ | 77/100 [00:32<00:09, 2.43it/s]"
1221
- ]
1222
- },
1223
- {
1224
- "name": "stdout",
1225
- "output_type": "stream",
1226
- "text": [
1227
- "Epoch [1/1], Step [77/100], Moving Avg Loss: 26.1804\n"
1228
- ]
1229
- },
1230
- {
1231
- "name": "stderr",
1232
- "output_type": "stream",
1233
- "text": [
1234
- "Epoch 1/1: 78%|███████▊ | 78/100 [00:32<00:09, 2.39it/s]"
1235
- ]
1236
- },
1237
- {
1238
- "name": "stdout",
1239
- "output_type": "stream",
1240
- "text": [
1241
- "Epoch [1/1], Step [78/100], Moving Avg Loss: 25.3216\n"
1242
- ]
1243
- },
1244
- {
1245
- "name": "stderr",
1246
- "output_type": "stream",
1247
- "text": [
1248
- "Epoch 1/1: 79%|███████▉ | 79/100 [00:33<00:08, 2.43it/s]"
1249
- ]
1250
- },
1251
- {
1252
- "name": "stdout",
1253
- "output_type": "stream",
1254
- "text": [
1255
- "Epoch [1/1], Step [79/100], Moving Avg Loss: 27.5742\n"
1256
- ]
1257
- },
1258
- {
1259
- "name": "stderr",
1260
- "output_type": "stream",
1261
- "text": [
1262
- "Epoch 1/1: 80%|████████ | 80/100 [00:33<00:08, 2.41it/s]"
1263
- ]
1264
- },
1265
- {
1266
- "name": "stdout",
1267
- "output_type": "stream",
1268
- "text": [
1269
- "Epoch [1/1], Step [80/100], Moving Avg Loss: 27.5931\n"
1270
- ]
1271
- },
1272
- {
1273
- "name": "stderr",
1274
- "output_type": "stream",
1275
- "text": [
1276
- "Epoch 1/1: 81%|████████ | 81/100 [00:33<00:08, 2.34it/s]"
1277
- ]
1278
- },
1279
- {
1280
- "name": "stdout",
1281
- "output_type": "stream",
1282
- "text": [
1283
- "Epoch [1/1], Step [81/100], Moving Avg Loss: 25.5491\n"
1284
- ]
1285
- },
1286
- {
1287
- "name": "stderr",
1288
- "output_type": "stream",
1289
- "text": [
1290
- "Epoch 1/1: 82%|████████▏ | 82/100 [00:34<00:07, 2.33it/s]"
1291
- ]
1292
- },
1293
- {
1294
- "name": "stdout",
1295
- "output_type": "stream",
1296
- "text": [
1297
- "Epoch [1/1], Step [82/100], Moving Avg Loss: 24.0114\n"
1298
- ]
1299
- },
1300
- {
1301
- "name": "stderr",
1302
- "output_type": "stream",
1303
- "text": [
1304
- "Epoch 1/1: 83%|████████▎ | 83/100 [00:34<00:07, 2.33it/s]"
1305
- ]
1306
- },
1307
- {
1308
- "name": "stdout",
1309
- "output_type": "stream",
1310
- "text": [
1311
- "Epoch [1/1], Step [83/100], Moving Avg Loss: 22.3863\n"
1312
- ]
1313
- },
1314
- {
1315
- "name": "stderr",
1316
- "output_type": "stream",
1317
- "text": [
1318
- "Epoch 1/1: 84%|████████▍ | 84/100 [00:35<00:06, 2.36it/s]"
1319
- ]
1320
- },
1321
- {
1322
- "name": "stdout",
1323
- "output_type": "stream",
1324
- "text": [
1325
- "Epoch [1/1], Step [84/100], Moving Avg Loss: 23.4298\n"
1326
- ]
1327
- },
1328
- {
1329
- "name": "stderr",
1330
- "output_type": "stream",
1331
- "text": [
1332
- "Epoch 1/1: 85%|████████▌ | 85/100 [00:35<00:06, 2.34it/s]"
1333
- ]
1334
- },
1335
- {
1336
- "name": "stdout",
1337
- "output_type": "stream",
1338
- "text": [
1339
- "Epoch [1/1], Step [85/100], Moving Avg Loss: 21.6505\n"
1340
- ]
1341
- },
1342
- {
1343
- "name": "stderr",
1344
- "output_type": "stream",
1345
- "text": [
1346
- "Epoch 1/1: 86%|████████▌ | 86/100 [00:36<00:06, 2.28it/s]"
1347
- ]
1348
- },
1349
- {
1350
- "name": "stdout",
1351
- "output_type": "stream",
1352
- "text": [
1353
- "Epoch [1/1], Step [86/100], Moving Avg Loss: 22.6546\n"
1354
- ]
1355
- },
1356
- {
1357
- "name": "stderr",
1358
- "output_type": "stream",
1359
- "text": [
1360
- "Epoch 1/1: 87%|████████▋ | 87/100 [00:36<00:05, 2.35it/s]"
1361
- ]
1362
- },
1363
- {
1364
- "name": "stdout",
1365
- "output_type": "stream",
1366
- "text": [
1367
- "Epoch [1/1], Step [87/100], Moving Avg Loss: 21.1156\n"
1368
- ]
1369
- },
1370
- {
1371
- "name": "stderr",
1372
- "output_type": "stream",
1373
- "text": [
1374
- "Epoch 1/1: 88%|████████▊ | 88/100 [00:36<00:05, 2.38it/s]"
1375
- ]
1376
- },
1377
- {
1378
- "name": "stdout",
1379
- "output_type": "stream",
1380
- "text": [
1381
- "Epoch [1/1], Step [88/100], Moving Avg Loss: 23.9993\n"
1382
- ]
1383
- },
1384
- {
1385
- "name": "stderr",
1386
- "output_type": "stream",
1387
- "text": [
1388
- "Epoch 1/1: 89%|████████▉ | 89/100 [00:37<00:04, 2.39it/s]"
1389
- ]
1390
- },
1391
- {
1392
- "name": "stdout",
1393
- "output_type": "stream",
1394
- "text": [
1395
- "Epoch [1/1], Step [89/100], Moving Avg Loss: 24.9765\n"
1396
- ]
1397
- },
1398
- {
1399
- "name": "stderr",
1400
- "output_type": "stream",
1401
- "text": [
1402
- "Epoch 1/1: 90%|█████████ | 90/100 [00:37<00:04, 2.42it/s]"
1403
- ]
1404
- },
1405
- {
1406
- "name": "stdout",
1407
- "output_type": "stream",
1408
- "text": [
1409
- "Epoch [1/1], Step [90/100], Moving Avg Loss: 28.3430\n"
1410
- ]
1411
- },
1412
- {
1413
- "name": "stderr",
1414
- "output_type": "stream",
1415
- "text": [
1416
- "Epoch 1/1: 91%|█████████ | 91/100 [00:38<00:03, 2.39it/s]"
1417
- ]
1418
- },
1419
- {
1420
- "name": "stdout",
1421
- "output_type": "stream",
1422
- "text": [
1423
- "Epoch [1/1], Step [91/100], Moving Avg Loss: 28.5874\n"
1424
- ]
1425
- },
1426
- {
1427
- "name": "stderr",
1428
- "output_type": "stream",
1429
- "text": [
1430
- "Epoch 1/1: 92%|█████████▏| 92/100 [00:38<00:03, 2.33it/s]"
1431
- ]
1432
- },
1433
- {
1434
- "name": "stdout",
1435
- "output_type": "stream",
1436
- "text": [
1437
- "Epoch [1/1], Step [92/100], Moving Avg Loss: 27.0662\n"
1438
- ]
1439
- },
1440
- {
1441
- "name": "stderr",
1442
- "output_type": "stream",
1443
- "text": [
1444
- "Epoch 1/1: 93%|█████████▎| 93/100 [00:39<00:03, 2.32it/s]"
1445
- ]
1446
- },
1447
- {
1448
- "name": "stdout",
1449
- "output_type": "stream",
1450
- "text": [
1451
- "Epoch [1/1], Step [93/100], Moving Avg Loss: 29.0707\n"
1452
- ]
1453
- },
1454
- {
1455
- "name": "stderr",
1456
- "output_type": "stream",
1457
- "text": [
1458
- "Epoch 1/1: 94%|█████████▍| 94/100 [00:39<00:02, 2.23it/s]"
1459
- ]
1460
- },
1461
- {
1462
- "name": "stdout",
1463
- "output_type": "stream",
1464
- "text": [
1465
- "Epoch [1/1], Step [94/100], Moving Avg Loss: 26.7228\n"
1466
- ]
1467
- },
1468
- {
1469
- "name": "stderr",
1470
- "output_type": "stream",
1471
- "text": [
1472
- "Epoch 1/1: 95%|█████████▌| 95/100 [00:40<00:02, 2.24it/s]"
1473
- ]
1474
- },
1475
- {
1476
- "name": "stdout",
1477
- "output_type": "stream",
1478
- "text": [
1479
- "Epoch [1/1], Step [95/100], Moving Avg Loss: 24.9785\n"
1480
- ]
1481
- },
1482
- {
1483
- "name": "stderr",
1484
- "output_type": "stream",
1485
- "text": [
1486
- "Epoch 1/1: 96%|█████████▌| 96/100 [00:40<00:01, 2.28it/s]"
1487
- ]
1488
- },
1489
- {
1490
- "name": "stdout",
1491
- "output_type": "stream",
1492
- "text": [
1493
- "Epoch [1/1], Step [96/100], Moving Avg Loss: 28.0284\n"
1494
- ]
1495
- },
1496
- {
1497
- "name": "stderr",
1498
- "output_type": "stream",
1499
- "text": [
1500
- "Epoch 1/1: 97%|█████████▋| 97/100 [00:40<00:01, 2.32it/s]"
1501
- ]
1502
- },
1503
- {
1504
- "name": "stdout",
1505
- "output_type": "stream",
1506
- "text": [
1507
- "Epoch [1/1], Step [97/100], Moving Avg Loss: 25.9050\n"
1508
- ]
1509
- },
1510
- {
1511
- "name": "stderr",
1512
- "output_type": "stream",
1513
- "text": [
1514
- "Epoch 1/1: 98%|█████████▊| 98/100 [00:41<00:00, 2.36it/s]"
1515
- ]
1516
- },
1517
- {
1518
- "name": "stdout",
1519
- "output_type": "stream",
1520
- "text": [
1521
- "Epoch [1/1], Step [98/100], Moving Avg Loss: 26.5735\n"
1522
- ]
1523
- },
1524
- {
1525
- "name": "stderr",
1526
- "output_type": "stream",
1527
- "text": [
1528
- "Epoch 1/1: 99%|█████████▉| 99/100 [00:41<00:00, 2.39it/s]"
1529
- ]
1530
- },
1531
- {
1532
- "name": "stdout",
1533
- "output_type": "stream",
1534
- "text": [
1535
- "Epoch [1/1], Step [99/100], Moving Avg Loss: 24.9826\n"
1536
- ]
1537
- },
1538
- {
1539
- "name": "stderr",
1540
- "output_type": "stream",
1541
- "text": [
1542
- "Epoch 1/1: 100%|██████████| 100/100 [00:42<00:00, 2.38it/s]"
1543
- ]
1544
- },
1545
- {
1546
- "name": "stdout",
1547
- "output_type": "stream",
1548
- "text": [
1549
- "Epoch [1/1], Step [100/100], Moving Avg Loss: 23.1838\n",
1550
- "Average Loss for Epoch 1: 24.4323\n"
1551
- ]
1552
- },
1553
- {
1554
- "name": "stderr",
1555
- "output_type": "stream",
1556
- "text": [
1557
- "\n"
1558
- ]
1559
- }
1560
- ],
1561
- "source": [
1562
- "num_epochs = 1\n",
1563
- "for epoch in range(num_epochs):\n",
1564
- " model.train()\n",
1565
- " running_loss = 0.0\n",
1566
- " avg_loss = 0.0\n",
1567
- "\n",
1568
- " for idx, item in enumerate(tqdm(dataset, desc=f\"Epoch {epoch + 1}/{num_epochs}\")):\n",
1569
- "\n",
1570
- " images = [prepare_image_detection(img=item['image'], processor=load_processor())]\n",
1571
- " images = torch.stack(images, dim=0).to(model.dtype).to(model.device)\n",
1572
- " optimizer.zero_grad()\n",
1573
- " outputs = model(pixel_values=images)\n",
1574
- " predicted_boxes = logits_to_bboxes(outputs.logits, item['image'])\n",
1575
- " target_boxes = item['bboxes']\n",
1576
- "\n",
1577
- " smooth_l1 = smooth_l1_loss(predicted_boxes, target_boxes)\n",
1578
- " iou = iou_loss(predicted_boxes, target_boxes)\n",
1579
- " loss = 0.5 * smooth_l1 + 0.5 * iou\n",
1580
- "\n",
1581
- " loss.backward()\n",
1582
- " optimizer.step()\n",
1583
- " running_loss += loss.item()\n",
1584
- "\n",
1585
- " # Update moving average of the loss\n",
1586
- " avg_loss = 0.9 * avg_loss + 0.1 * loss.item() if idx > 0 else loss.item()\n",
1587
- "\n",
1588
- " # Print moving average loss\n",
1589
- " print(f\"Epoch [{epoch + 1}/{num_epochs}], Step [{idx + 1}/{len(dataset)}], Moving Avg Loss: {avg_loss:.4f}\")\n",
1590
- "\n",
1591
- " avg_loss = running_loss / len(dataset)\n",
1592
- " writer.add_scalar('Training Loss', avg_loss, epoch + 1)\n",
1593
- " print(f\"Average Loss for Epoch {epoch + 1}: {avg_loss:.4f}\")\n",
1594
- "\n",
1595
- " torch.save(model.state_dict(), os.path.join(checkpoint_dir, f\"model_epoch_{epoch + 1}.pth\"))"
1596
- ]
1597
- },
1598
- {
1599
- "cell_type": "code",
1600
- "execution_count": null,
1601
- "metadata": {},
1602
- "outputs": [],
1603
- "source": [
1604
- "checkpoint_path = '/data2/ketan/orc/surya-layout-fine-tune/checkpoints/model_epoch_3.pth' \n",
1605
- "state_dict = torch.load(checkpoint_path,weights_only=True)\n",
1606
- "\n",
1607
- "model.load_state_dict(state_dict)"
1608
- ]
1609
- },
1610
- {
1611
- "cell_type": "code",
1612
- "execution_count": null,
1613
- "metadata": {},
1614
- "outputs": [],
1615
- "source": [
1616
- "model.to('cpu')\n",
1617
- "model.save_pretrained(\"fine-tuned-surya-model-layout\")"
1618
- ]
1619
- }
1620
- ],
1621
- "metadata": {
1622
- "kernelspec": {
1623
- "display_name": "Python 3",
1624
- "language": "python",
1625
- "name": "python3"
1626
- },
1627
- "language_info": {
1628
- "codemirror_mode": {
1629
- "name": "ipython",
1630
- "version": 3
1631
- },
1632
- "file_extension": ".py",
1633
- "mimetype": "text/x-python",
1634
- "name": "python",
1635
- "nbconvert_exporter": "python",
1636
- "pygments_lexer": "ipython3",
1637
- "version": "3.10.14"
1638
- }
1639
- },
1640
- "nbformat": 4,
1641
- "nbformat_minor": 2
1642
- }