ketanmore commited on
Commit
e831476
Β·
verified Β·
1 Parent(s): d80f76d

Upload temp_test.ipynb

Browse files
Files changed (1) hide show
  1. temp_test.ipynb +1642 -0
temp_test.ipynb ADDED
@@ -0,0 +1,1642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "os.environ['HF_HOME'] = '/data2/ketan/orc/HF_Cache'\n",
11
+ "\n",
12
+ "import torch\n",
13
+ "import torch.nn as nn\n",
14
+ "import torch.optim as optim\n",
15
+ "from torch.utils.data import DataLoader\n",
16
+ "from transformers import SegformerConfig\n",
17
+ "from surya.model.detection.segformer import SegformerForRegressionMask\n",
18
+ "from surya.input.processing import prepare_image_detection\n",
19
+ "from surya.model.detection.segformer import load_processor , load_model\n",
20
+ "from datasets import load_dataset\n",
21
+ "from tqdm import tqdm\n",
22
+ "from torch.utils.tensorboard import SummaryWriter\n",
23
+ "\n",
24
+ "import torch.nn.functional as F\n",
25
+ "import numpy as np \n",
26
+ "from PIL import ImageDraw, ImageFont\n",
27
+ "from surya.layout import parallel_get_regions\n",
28
+ "import cv2"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": 2,
34
+ "metadata": {},
35
+ "outputs": [
36
+ {
37
+ "name": "stdout",
38
+ "output_type": "stream",
39
+ "text": [
40
+ "Loaded detection model vikp/surya_layout2 on device cuda with dtype torch.float16\n"
41
+ ]
42
+ }
43
+ ],
44
+ "source": [
45
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
46
+ "\n",
47
+ "\n",
48
+ "dataset = load_dataset(\"vikp/publaynet_bench\", split=\"train[:100]\")\n",
49
+ "train_loader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=lambda x: x)\n",
50
+ "\n",
51
+ "\n",
52
+ "model = load_model(\"vikp/surya_layout2\")"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": 17,
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": [
61
+ "\n",
62
+ "optimizer = optim.Adam(model.parameters(), lr=0.00001)\n",
63
+ "\n",
64
+ "# Logging and Checkpoints\n",
65
+ "log_dir = \"logs\"\n",
66
+ "checkpoint_dir = \"checkpoints\"\n",
67
+ "os.makedirs(log_dir, exist_ok=True)\n",
68
+ "os.makedirs(checkpoint_dir, exist_ok=True)\n",
69
+ "writer = SummaryWriter(log_dir=log_dir)\n",
70
+ "\n",
71
+ "def calculate_iou(box1, box2):\n",
72
+ " box1 = torch.tensor(box1, dtype=torch.float32, requires_grad=True) if not isinstance(box1, torch.Tensor) else box1\n",
73
+ " box2 = torch.tensor(box2, dtype=torch.float32, requires_grad=True) if not isinstance(box2, torch.Tensor) else box2\n",
74
+ " \n",
75
+ " x_min = torch.max(box1[0], box2[0])\n",
76
+ " y_min = torch.max(box1[1], box2[1])\n",
77
+ " x_max = torch.min(box1[2], box2[2])\n",
78
+ " y_max = torch.min(box1[3], box2[3])\n",
79
+ " \n",
80
+ " intersection = torch.clamp(x_max - x_min, min=0) * torch.clamp(y_max - y_min, min=0)\n",
81
+ " area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])\n",
82
+ " area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])\n",
83
+ " union = area1 + area2 - intersection\n",
84
+ " \n",
85
+ " iou = intersection / union if union > 0 else torch.tensor(0.0, requires_grad=True)\n",
86
+ " \n",
87
+ " return iou\n",
88
+ "\n",
89
+ "def pair_boxes(pred_boxes, target_boxes):\n",
90
+ " pred_boxes = [torch.tensor(box, dtype=torch.float32, requires_grad=True) for box in pred_boxes]\n",
91
+ " target_boxes = [torch.tensor(box, dtype=torch.float32, requires_grad=True) for box in target_boxes]\n",
92
+ " \n",
93
+ " matched_pred_boxes = []\n",
94
+ " matched_target_boxes = []\n",
95
+ " \n",
96
+ " for target in target_boxes:\n",
97
+ " best_iou = 0\n",
98
+ " best_pred = None\n",
99
+ " for pred in pred_boxes:\n",
100
+ " iou = calculate_iou(pred, target)\n",
101
+ " if iou > best_iou:\n",
102
+ " best_iou = iou\n",
103
+ " best_pred = pred\n",
104
+ " \n",
105
+ " if best_pred is not None:\n",
106
+ " matched_pred_boxes.append(best_pred)\n",
107
+ " matched_target_boxes.append(target)\n",
108
+ " pred_boxes = [p for p in pred_boxes if not torch.equal(p, best_pred)]\n",
109
+ "\n",
110
+ " return matched_pred_boxes, matched_target_boxes\n",
111
+ "\n",
112
+ "def smooth_l1_loss(pred_boxes, target_boxes, beta=1.0):\n",
113
+ " matched_pred_boxes, matched_target_boxes = pair_boxes(pred_boxes, target_boxes)\n",
114
+ " \n",
115
+ " if len(matched_pred_boxes) == 0:\n",
116
+ " return torch.tensor(0.0, requires_grad=True)\n",
117
+ " \n",
118
+ " diff = torch.abs(torch.stack(matched_pred_boxes) - torch.stack(matched_target_boxes))\n",
119
+ " loss = torch.where(diff < beta, 0.5 * (diff ** 2) / beta, diff - 0.5 * beta)\n",
120
+ " return loss.mean()\n",
121
+ "\n",
122
+ "def iou_loss(pred_boxes, target_boxes):\n",
123
+ " matched_pred_boxes, matched_target_boxes = pair_boxes(pred_boxes, target_boxes)\n",
124
+ " \n",
125
+ " if len(matched_pred_boxes) == 0:\n",
126
+ " return torch.tensor(1.0, requires_grad=True)\n",
127
+ " \n",
128
+ " ious = [calculate_iou(pred, target) for pred, target in zip(matched_pred_boxes, matched_target_boxes)]\n",
129
+ " return 1 - torch.mean(torch.tensor(ious, requires_grad=True))\n",
130
+ "\n",
131
+ "\n",
132
+ "def logits_to_bboxes(logits,image) :\n",
133
+ " correct_shape = (300, 300) \n",
134
+ " logits_temp = F.interpolate(logits, size=correct_shape, mode='bilinear', align_corners=False)\n",
135
+ " logits_temp = logits_temp.cpu().detach().numpy().astype(np.float32)\n",
136
+ "\n",
137
+ " heatmap_count = logits_temp.shape[1]\n",
138
+ " heatmaps = [logits_temp[i][k] for i in range(logits_temp.shape[0]) for k in range(heatmap_count)]\n",
139
+ " regions = parallel_get_regions(heatmaps=heatmaps, orig_size=image.size, id2label=model.config.id2label)\n",
140
+ "\n",
141
+ " final_bboxes = []\n",
142
+ " for i in regions.bboxes :\n",
143
+ " final_bboxes.append(i.bbox)\n",
144
+ " return final_bboxes\n"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "code",
149
+ "execution_count": 18,
150
+ "metadata": {},
151
+ "outputs": [
152
+ {
153
+ "name": "stderr",
154
+ "output_type": "stream",
155
+ "text": [
156
+ "Epoch 1/1: 1%| | 1/100 [00:00<00:41, 2.41it/s]"
157
+ ]
158
+ },
159
+ {
160
+ "name": "stdout",
161
+ "output_type": "stream",
162
+ "text": [
163
+ "Epoch [1/1], Step [1/100], Moving Avg Loss: 16.4575\n"
164
+ ]
165
+ },
166
+ {
167
+ "name": "stderr",
168
+ "output_type": "stream",
169
+ "text": [
170
+ "Epoch 1/1: 2%|▏ | 2/100 [00:00<00:41, 2.34it/s]"
171
+ ]
172
+ },
173
+ {
174
+ "name": "stdout",
175
+ "output_type": "stream",
176
+ "text": [
177
+ "Epoch [1/1], Step [2/100], Moving Avg Loss: 15.4938\n"
178
+ ]
179
+ },
180
+ {
181
+ "name": "stderr",
182
+ "output_type": "stream",
183
+ "text": [
184
+ "Epoch 1/1: 3%|β–Ž | 3/100 [00:01<00:40, 2.42it/s]"
185
+ ]
186
+ },
187
+ {
188
+ "name": "stdout",
189
+ "output_type": "stream",
190
+ "text": [
191
+ "Epoch [1/1], Step [3/100], Moving Avg Loss: 18.9512\n"
192
+ ]
193
+ },
194
+ {
195
+ "name": "stderr",
196
+ "output_type": "stream",
197
+ "text": [
198
+ "Epoch 1/1: 4%|▍ | 4/100 [00:01<00:39, 2.40it/s]"
199
+ ]
200
+ },
201
+ {
202
+ "name": "stdout",
203
+ "output_type": "stream",
204
+ "text": [
205
+ "Epoch [1/1], Step [4/100], Moving Avg Loss: 18.3995\n"
206
+ ]
207
+ },
208
+ {
209
+ "name": "stderr",
210
+ "output_type": "stream",
211
+ "text": [
212
+ "Epoch 1/1: 5%|β–Œ | 5/100 [00:02<00:39, 2.43it/s]"
213
+ ]
214
+ },
215
+ {
216
+ "name": "stdout",
217
+ "output_type": "stream",
218
+ "text": [
219
+ "Epoch [1/1], Step [5/100], Moving Avg Loss: 20.1250\n"
220
+ ]
221
+ },
222
+ {
223
+ "name": "stderr",
224
+ "output_type": "stream",
225
+ "text": [
226
+ "Epoch 1/1: 6%|β–Œ | 6/100 [00:02<00:39, 2.38it/s]"
227
+ ]
228
+ },
229
+ {
230
+ "name": "stdout",
231
+ "output_type": "stream",
232
+ "text": [
233
+ "Epoch [1/1], Step [6/100], Moving Avg Loss: 18.9854\n"
234
+ ]
235
+ },
236
+ {
237
+ "name": "stderr",
238
+ "output_type": "stream",
239
+ "text": [
240
+ "Epoch 1/1: 7%|β–‹ | 7/100 [00:02<00:38, 2.40it/s]"
241
+ ]
242
+ },
243
+ {
244
+ "name": "stdout",
245
+ "output_type": "stream",
246
+ "text": [
247
+ "Epoch [1/1], Step [7/100], Moving Avg Loss: 18.4753\n"
248
+ ]
249
+ },
250
+ {
251
+ "name": "stderr",
252
+ "output_type": "stream",
253
+ "text": [
254
+ "Epoch 1/1: 8%|β–Š | 8/100 [00:03<00:40, 2.27it/s]"
255
+ ]
256
+ },
257
+ {
258
+ "name": "stdout",
259
+ "output_type": "stream",
260
+ "text": [
261
+ "Epoch [1/1], Step [8/100], Moving Avg Loss: 17.0382\n"
262
+ ]
263
+ },
264
+ {
265
+ "name": "stderr",
266
+ "output_type": "stream",
267
+ "text": [
268
+ "Epoch 1/1: 9%|β–‰ | 9/100 [00:03<00:41, 2.20it/s]"
269
+ ]
270
+ },
271
+ {
272
+ "name": "stdout",
273
+ "output_type": "stream",
274
+ "text": [
275
+ "Epoch [1/1], Step [9/100], Moving Avg Loss: 17.7276\n"
276
+ ]
277
+ },
278
+ {
279
+ "name": "stderr",
280
+ "output_type": "stream",
281
+ "text": [
282
+ "Epoch 1/1: 10%|β–ˆ | 10/100 [00:04<00:40, 2.23it/s]"
283
+ ]
284
+ },
285
+ {
286
+ "name": "stdout",
287
+ "output_type": "stream",
288
+ "text": [
289
+ "Epoch [1/1], Step [10/100], Moving Avg Loss: 19.5423\n"
290
+ ]
291
+ },
292
+ {
293
+ "name": "stderr",
294
+ "output_type": "stream",
295
+ "text": [
296
+ "Epoch 1/1: 11%|β–ˆ | 11/100 [00:04<00:39, 2.27it/s]"
297
+ ]
298
+ },
299
+ {
300
+ "name": "stdout",
301
+ "output_type": "stream",
302
+ "text": [
303
+ "Epoch [1/1], Step [11/100], Moving Avg Loss: 18.4347\n"
304
+ ]
305
+ },
306
+ {
307
+ "name": "stderr",
308
+ "output_type": "stream",
309
+ "text": [
310
+ "Epoch 1/1: 12%|β–ˆβ– | 12/100 [00:05<00:40, 2.20it/s]"
311
+ ]
312
+ },
313
+ {
314
+ "name": "stdout",
315
+ "output_type": "stream",
316
+ "text": [
317
+ "Epoch [1/1], Step [12/100], Moving Avg Loss: 17.3114\n"
318
+ ]
319
+ },
320
+ {
321
+ "name": "stderr",
322
+ "output_type": "stream",
323
+ "text": [
324
+ "Epoch 1/1: 13%|β–ˆβ–Ž | 13/100 [00:05<00:39, 2.21it/s]"
325
+ ]
326
+ },
327
+ {
328
+ "name": "stdout",
329
+ "output_type": "stream",
330
+ "text": [
331
+ "Epoch [1/1], Step [13/100], Moving Avg Loss: 16.2870\n"
332
+ ]
333
+ },
334
+ {
335
+ "name": "stderr",
336
+ "output_type": "stream",
337
+ "text": [
338
+ "Epoch 1/1: 14%|β–ˆβ– | 14/100 [00:06<00:37, 2.29it/s]"
339
+ ]
340
+ },
341
+ {
342
+ "name": "stdout",
343
+ "output_type": "stream",
344
+ "text": [
345
+ "Epoch [1/1], Step [14/100], Moving Avg Loss: 21.5170\n"
346
+ ]
347
+ },
348
+ {
349
+ "name": "stderr",
350
+ "output_type": "stream",
351
+ "text": [
352
+ "Epoch 1/1: 15%|β–ˆβ–Œ | 15/100 [00:06<00:36, 2.30it/s]"
353
+ ]
354
+ },
355
+ {
356
+ "name": "stdout",
357
+ "output_type": "stream",
358
+ "text": [
359
+ "Epoch [1/1], Step [15/100], Moving Avg Loss: 21.3559\n"
360
+ ]
361
+ },
362
+ {
363
+ "name": "stderr",
364
+ "output_type": "stream",
365
+ "text": [
366
+ "Epoch 1/1: 16%|β–ˆβ–Œ | 16/100 [00:06<00:36, 2.29it/s]"
367
+ ]
368
+ },
369
+ {
370
+ "name": "stdout",
371
+ "output_type": "stream",
372
+ "text": [
373
+ "Epoch [1/1], Step [16/100], Moving Avg Loss: 19.8276\n"
374
+ ]
375
+ },
376
+ {
377
+ "name": "stderr",
378
+ "output_type": "stream",
379
+ "text": [
380
+ "Epoch 1/1: 17%|β–ˆβ–‹ | 17/100 [00:07<00:35, 2.36it/s]"
381
+ ]
382
+ },
383
+ {
384
+ "name": "stdout",
385
+ "output_type": "stream",
386
+ "text": [
387
+ "Epoch [1/1], Step [17/100], Moving Avg Loss: 18.9123\n"
388
+ ]
389
+ },
390
+ {
391
+ "name": "stderr",
392
+ "output_type": "stream",
393
+ "text": [
394
+ "Epoch 1/1: 18%|β–ˆβ–Š | 18/100 [00:07<00:34, 2.38it/s]"
395
+ ]
396
+ },
397
+ {
398
+ "name": "stdout",
399
+ "output_type": "stream",
400
+ "text": [
401
+ "Epoch [1/1], Step [18/100], Moving Avg Loss: 23.7418\n"
402
+ ]
403
+ },
404
+ {
405
+ "name": "stderr",
406
+ "output_type": "stream",
407
+ "text": [
408
+ "Epoch 1/1: 19%|β–ˆβ–‰ | 19/100 [00:08<00:34, 2.36it/s]"
409
+ ]
410
+ },
411
+ {
412
+ "name": "stdout",
413
+ "output_type": "stream",
414
+ "text": [
415
+ "Epoch [1/1], Step [19/100], Moving Avg Loss: 22.2312\n"
416
+ ]
417
+ },
418
+ {
419
+ "name": "stderr",
420
+ "output_type": "stream",
421
+ "text": [
422
+ "Epoch 1/1: 20%|β–ˆβ–ˆ | 20/100 [00:08<00:34, 2.35it/s]"
423
+ ]
424
+ },
425
+ {
426
+ "name": "stdout",
427
+ "output_type": "stream",
428
+ "text": [
429
+ "Epoch [1/1], Step [20/100], Moving Avg Loss: 21.1758\n"
430
+ ]
431
+ },
432
+ {
433
+ "name": "stderr",
434
+ "output_type": "stream",
435
+ "text": [
436
+ "Epoch 1/1: 21%|β–ˆβ–ˆ | 21/100 [00:09<00:32, 2.40it/s]"
437
+ ]
438
+ },
439
+ {
440
+ "name": "stdout",
441
+ "output_type": "stream",
442
+ "text": [
443
+ "Epoch [1/1], Step [21/100], Moving Avg Loss: 24.8048\n"
444
+ ]
445
+ },
446
+ {
447
+ "name": "stderr",
448
+ "output_type": "stream",
449
+ "text": [
450
+ "Epoch 1/1: 22%|β–ˆβ–ˆβ– | 22/100 [00:09<00:32, 2.40it/s]"
451
+ ]
452
+ },
453
+ {
454
+ "name": "stdout",
455
+ "output_type": "stream",
456
+ "text": [
457
+ "Epoch [1/1], Step [22/100], Moving Avg Loss: 27.5316\n"
458
+ ]
459
+ },
460
+ {
461
+ "name": "stderr",
462
+ "output_type": "stream",
463
+ "text": [
464
+ "Epoch 1/1: 23%|β–ˆβ–ˆβ–Ž | 23/100 [00:09<00:32, 2.39it/s]"
465
+ ]
466
+ },
467
+ {
468
+ "name": "stdout",
469
+ "output_type": "stream",
470
+ "text": [
471
+ "Epoch [1/1], Step [23/100], Moving Avg Loss: 27.4807\n"
472
+ ]
473
+ },
474
+ {
475
+ "name": "stderr",
476
+ "output_type": "stream",
477
+ "text": [
478
+ "Epoch 1/1: 24%|β–ˆβ–ˆβ– | 24/100 [00:10<00:31, 2.41it/s]"
479
+ ]
480
+ },
481
+ {
482
+ "name": "stdout",
483
+ "output_type": "stream",
484
+ "text": [
485
+ "Epoch [1/1], Step [24/100], Moving Avg Loss: 25.2076\n"
486
+ ]
487
+ },
488
+ {
489
+ "name": "stderr",
490
+ "output_type": "stream",
491
+ "text": [
492
+ "Epoch 1/1: 25%|β–ˆβ–ˆβ–Œ | 25/100 [00:10<00:30, 2.42it/s]"
493
+ ]
494
+ },
495
+ {
496
+ "name": "stdout",
497
+ "output_type": "stream",
498
+ "text": [
499
+ "Epoch [1/1], Step [25/100], Moving Avg Loss: 23.2897\n"
500
+ ]
501
+ },
502
+ {
503
+ "name": "stderr",
504
+ "output_type": "stream",
505
+ "text": [
506
+ "Epoch 1/1: 26%|β–ˆβ–ˆβ–Œ | 26/100 [00:11<00:30, 2.40it/s]"
507
+ ]
508
+ },
509
+ {
510
+ "name": "stdout",
511
+ "output_type": "stream",
512
+ "text": [
513
+ "Epoch [1/1], Step [26/100], Moving Avg Loss: 22.1549\n"
514
+ ]
515
+ },
516
+ {
517
+ "name": "stderr",
518
+ "output_type": "stream",
519
+ "text": [
520
+ "Epoch 1/1: 27%|β–ˆβ–ˆβ–‹ | 27/100 [00:11<00:29, 2.43it/s]"
521
+ ]
522
+ },
523
+ {
524
+ "name": "stdout",
525
+ "output_type": "stream",
526
+ "text": [
527
+ "Epoch [1/1], Step [27/100], Moving Avg Loss: 21.9602\n"
528
+ ]
529
+ },
530
+ {
531
+ "name": "stderr",
532
+ "output_type": "stream",
533
+ "text": [
534
+ "Epoch 1/1: 28%|β–ˆβ–ˆβ–Š | 28/100 [00:11<00:29, 2.43it/s]"
535
+ ]
536
+ },
537
+ {
538
+ "name": "stdout",
539
+ "output_type": "stream",
540
+ "text": [
541
+ "Epoch [1/1], Step [28/100], Moving Avg Loss: 23.2106\n"
542
+ ]
543
+ },
544
+ {
545
+ "name": "stderr",
546
+ "output_type": "stream",
547
+ "text": [
548
+ "Epoch 1/1: 29%|β–ˆβ–ˆβ–‰ | 29/100 [00:12<00:30, 2.33it/s]"
549
+ ]
550
+ },
551
+ {
552
+ "name": "stdout",
553
+ "output_type": "stream",
554
+ "text": [
555
+ "Epoch [1/1], Step [29/100], Moving Avg Loss: 21.3036\n"
556
+ ]
557
+ },
558
+ {
559
+ "name": "stderr",
560
+ "output_type": "stream",
561
+ "text": [
562
+ "Epoch 1/1: 30%|β–ˆβ–ˆβ–ˆ | 30/100 [00:12<00:30, 2.31it/s]"
563
+ ]
564
+ },
565
+ {
566
+ "name": "stdout",
567
+ "output_type": "stream",
568
+ "text": [
569
+ "Epoch [1/1], Step [30/100], Moving Avg Loss: 22.1421\n"
570
+ ]
571
+ },
572
+ {
573
+ "name": "stderr",
574
+ "output_type": "stream",
575
+ "text": [
576
+ "Epoch 1/1: 31%|β–ˆβ–ˆβ–ˆ | 31/100 [00:13<00:29, 2.35it/s]"
577
+ ]
578
+ },
579
+ {
580
+ "name": "stdout",
581
+ "output_type": "stream",
582
+ "text": [
583
+ "Epoch [1/1], Step [31/100], Moving Avg Loss: 27.1543\n"
584
+ ]
585
+ },
586
+ {
587
+ "name": "stderr",
588
+ "output_type": "stream",
589
+ "text": [
590
+ "Epoch 1/1: 32%|β–ˆβ–ˆβ–ˆβ– | 32/100 [00:13<00:29, 2.34it/s]"
591
+ ]
592
+ },
593
+ {
594
+ "name": "stdout",
595
+ "output_type": "stream",
596
+ "text": [
597
+ "Epoch [1/1], Step [32/100], Moving Avg Loss: 27.6630\n"
598
+ ]
599
+ },
600
+ {
601
+ "name": "stderr",
602
+ "output_type": "stream",
603
+ "text": [
604
+ "Epoch 1/1: 33%|β–ˆβ–ˆβ–ˆβ–Ž | 33/100 [00:14<00:28, 2.39it/s]"
605
+ ]
606
+ },
607
+ {
608
+ "name": "stdout",
609
+ "output_type": "stream",
610
+ "text": [
611
+ "Epoch [1/1], Step [33/100], Moving Avg Loss: 25.8453\n"
612
+ ]
613
+ },
614
+ {
615
+ "name": "stderr",
616
+ "output_type": "stream",
617
+ "text": [
618
+ "Epoch 1/1: 34%|β–ˆβ–ˆβ–ˆβ– | 34/100 [00:14<00:27, 2.41it/s]"
619
+ ]
620
+ },
621
+ {
622
+ "name": "stdout",
623
+ "output_type": "stream",
624
+ "text": [
625
+ "Epoch [1/1], Step [34/100], Moving Avg Loss: 27.6460\n"
626
+ ]
627
+ },
628
+ {
629
+ "name": "stderr",
630
+ "output_type": "stream",
631
+ "text": [
632
+ "Epoch 1/1: 35%|β–ˆβ–ˆβ–ˆβ–Œ | 35/100 [00:14<00:26, 2.44it/s]"
633
+ ]
634
+ },
635
+ {
636
+ "name": "stdout",
637
+ "output_type": "stream",
638
+ "text": [
639
+ "Epoch [1/1], Step [35/100], Moving Avg Loss: 25.1319\n"
640
+ ]
641
+ },
642
+ {
643
+ "name": "stderr",
644
+ "output_type": "stream",
645
+ "text": [
646
+ "Epoch 1/1: 36%|β–ˆβ–ˆβ–ˆβ–Œ | 36/100 [00:15<00:26, 2.45it/s]"
647
+ ]
648
+ },
649
+ {
650
+ "name": "stdout",
651
+ "output_type": "stream",
652
+ "text": [
653
+ "Epoch [1/1], Step [36/100], Moving Avg Loss: 25.8555\n"
654
+ ]
655
+ },
656
+ {
657
+ "name": "stderr",
658
+ "output_type": "stream",
659
+ "text": [
660
+ "Epoch 1/1: 37%|β–ˆβ–ˆβ–ˆβ–‹ | 37/100 [00:15<00:25, 2.45it/s]"
661
+ ]
662
+ },
663
+ {
664
+ "name": "stdout",
665
+ "output_type": "stream",
666
+ "text": [
667
+ "Epoch [1/1], Step [37/100], Moving Avg Loss: 28.9348\n"
668
+ ]
669
+ },
670
+ {
671
+ "name": "stderr",
672
+ "output_type": "stream",
673
+ "text": [
674
+ "Epoch 1/1: 38%|β–ˆβ–ˆβ–ˆβ–Š | 38/100 [00:16<00:25, 2.42it/s]"
675
+ ]
676
+ },
677
+ {
678
+ "name": "stdout",
679
+ "output_type": "stream",
680
+ "text": [
681
+ "Epoch [1/1], Step [38/100], Moving Avg Loss: 28.3364\n"
682
+ ]
683
+ },
684
+ {
685
+ "name": "stderr",
686
+ "output_type": "stream",
687
+ "text": [
688
+ "Epoch 1/1: 39%|β–ˆβ–ˆβ–ˆβ–‰ | 39/100 [00:16<00:24, 2.45it/s]"
689
+ ]
690
+ },
691
+ {
692
+ "name": "stdout",
693
+ "output_type": "stream",
694
+ "text": [
695
+ "Epoch [1/1], Step [39/100], Moving Avg Loss: 26.0808\n"
696
+ ]
697
+ },
698
+ {
699
+ "name": "stderr",
700
+ "output_type": "stream",
701
+ "text": [
702
+ "Epoch 1/1: 40%|β–ˆβ–ˆβ–ˆβ–ˆ | 40/100 [00:16<00:24, 2.43it/s]"
703
+ ]
704
+ },
705
+ {
706
+ "name": "stdout",
707
+ "output_type": "stream",
708
+ "text": [
709
+ "Epoch [1/1], Step [40/100], Moving Avg Loss: 26.2237\n"
710
+ ]
711
+ },
712
+ {
713
+ "name": "stderr",
714
+ "output_type": "stream",
715
+ "text": [
716
+ "Epoch 1/1: 41%|β–ˆβ–ˆβ–ˆβ–ˆ | 41/100 [00:17<00:24, 2.43it/s]"
717
+ ]
718
+ },
719
+ {
720
+ "name": "stdout",
721
+ "output_type": "stream",
722
+ "text": [
723
+ "Epoch [1/1], Step [41/100], Moving Avg Loss: 26.2728\n"
724
+ ]
725
+ },
726
+ {
727
+ "name": "stderr",
728
+ "output_type": "stream",
729
+ "text": [
730
+ "Epoch 1/1: 42%|β–ˆβ–ˆβ–ˆβ–ˆβ– | 42/100 [00:17<00:23, 2.45it/s]"
731
+ ]
732
+ },
733
+ {
734
+ "name": "stdout",
735
+ "output_type": "stream",
736
+ "text": [
737
+ "Epoch [1/1], Step [42/100], Moving Avg Loss: 26.7065\n"
738
+ ]
739
+ },
740
+ {
741
+ "name": "stderr",
742
+ "output_type": "stream",
743
+ "text": [
744
+ "Epoch 1/1: 43%|β–ˆβ–ˆβ–ˆβ–ˆβ–Ž | 43/100 [00:18<00:23, 2.39it/s]"
745
+ ]
746
+ },
747
+ {
748
+ "name": "stdout",
749
+ "output_type": "stream",
750
+ "text": [
751
+ "Epoch [1/1], Step [43/100], Moving Avg Loss: 24.7438\n"
752
+ ]
753
+ },
754
+ {
755
+ "name": "stderr",
756
+ "output_type": "stream",
757
+ "text": [
758
+ "Epoch 1/1: 44%|β–ˆβ–ˆβ–ˆβ–ˆβ– | 44/100 [00:18<00:23, 2.41it/s]"
759
+ ]
760
+ },
761
+ {
762
+ "name": "stdout",
763
+ "output_type": "stream",
764
+ "text": [
765
+ "Epoch [1/1], Step [44/100], Moving Avg Loss: 26.6885\n"
766
+ ]
767
+ },
768
+ {
769
+ "name": "stderr",
770
+ "output_type": "stream",
771
+ "text": [
772
+ "Epoch 1/1: 45%|β–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 45/100 [00:18<00:22, 2.41it/s]"
773
+ ]
774
+ },
775
+ {
776
+ "name": "stdout",
777
+ "output_type": "stream",
778
+ "text": [
779
+ "Epoch [1/1], Step [45/100], Moving Avg Loss: 27.7764\n"
780
+ ]
781
+ },
782
+ {
783
+ "name": "stderr",
784
+ "output_type": "stream",
785
+ "text": [
786
+ "Epoch 1/1: 46%|β–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 46/100 [00:19<00:22, 2.44it/s]"
787
+ ]
788
+ },
789
+ {
790
+ "name": "stdout",
791
+ "output_type": "stream",
792
+ "text": [
793
+ "Epoch [1/1], Step [46/100], Moving Avg Loss: 25.7708\n"
794
+ ]
795
+ },
796
+ {
797
+ "name": "stderr",
798
+ "output_type": "stream",
799
+ "text": [
800
+ "Epoch 1/1: 47%|β–ˆβ–ˆβ–ˆβ–ˆβ–‹ | 47/100 [00:19<00:22, 2.36it/s]"
801
+ ]
802
+ },
803
+ {
804
+ "name": "stdout",
805
+ "output_type": "stream",
806
+ "text": [
807
+ "Epoch [1/1], Step [47/100], Moving Avg Loss: 23.6295\n"
808
+ ]
809
+ },
810
+ {
811
+ "name": "stderr",
812
+ "output_type": "stream",
813
+ "text": [
814
+ "Epoch 1/1: 48%|β–ˆβ–ˆβ–ˆβ–ˆβ–Š | 48/100 [00:20<00:21, 2.40it/s]"
815
+ ]
816
+ },
817
+ {
818
+ "name": "stdout",
819
+ "output_type": "stream",
820
+ "text": [
821
+ "Epoch [1/1], Step [48/100], Moving Avg Loss: 23.5793\n"
822
+ ]
823
+ },
824
+ {
825
+ "name": "stderr",
826
+ "output_type": "stream",
827
+ "text": [
828
+ "Epoch 1/1: 49%|β–ˆβ–ˆβ–ˆβ–ˆβ–‰ | 49/100 [00:20<00:20, 2.44it/s]"
829
+ ]
830
+ },
831
+ {
832
+ "name": "stdout",
833
+ "output_type": "stream",
834
+ "text": [
835
+ "Epoch [1/1], Step [49/100], Moving Avg Loss: 22.1319\n"
836
+ ]
837
+ },
838
+ {
839
+ "name": "stderr",
840
+ "output_type": "stream",
841
+ "text": [
842
+ "Epoch 1/1: 50%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 50/100 [00:21<00:20, 2.46it/s]"
843
+ ]
844
+ },
845
+ {
846
+ "name": "stdout",
847
+ "output_type": "stream",
848
+ "text": [
849
+ "Epoch [1/1], Step [50/100], Moving Avg Loss: 21.5178\n"
850
+ ]
851
+ },
852
+ {
853
+ "name": "stderr",
854
+ "output_type": "stream",
855
+ "text": [
856
+ "Epoch 1/1: 51%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 51/100 [00:21<00:19, 2.48it/s]"
857
+ ]
858
+ },
859
+ {
860
+ "name": "stdout",
861
+ "output_type": "stream",
862
+ "text": [
863
+ "Epoch [1/1], Step [51/100], Moving Avg Loss: 22.2565\n"
864
+ ]
865
+ },
866
+ {
867
+ "name": "stderr",
868
+ "output_type": "stream",
869
+ "text": [
870
+ "Epoch 1/1: 52%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 52/100 [00:21<00:19, 2.45it/s]"
871
+ ]
872
+ },
873
+ {
874
+ "name": "stdout",
875
+ "output_type": "stream",
876
+ "text": [
877
+ "Epoch [1/1], Step [52/100], Moving Avg Loss: 24.8366\n"
878
+ ]
879
+ },
880
+ {
881
+ "name": "stderr",
882
+ "output_type": "stream",
883
+ "text": [
884
+ "Epoch 1/1: 53%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž | 53/100 [00:22<00:19, 2.35it/s]"
885
+ ]
886
+ },
887
+ {
888
+ "name": "stdout",
889
+ "output_type": "stream",
890
+ "text": [
891
+ "Epoch [1/1], Step [53/100], Moving Avg Loss: 23.3091\n"
892
+ ]
893
+ },
894
+ {
895
+ "name": "stderr",
896
+ "output_type": "stream",
897
+ "text": [
898
+ "Epoch 1/1: 54%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 54/100 [00:22<00:19, 2.35it/s]"
899
+ ]
900
+ },
901
+ {
902
+ "name": "stdout",
903
+ "output_type": "stream",
904
+ "text": [
905
+ "Epoch [1/1], Step [54/100], Moving Avg Loss: 22.1764\n"
906
+ ]
907
+ },
908
+ {
909
+ "name": "stderr",
910
+ "output_type": "stream",
911
+ "text": [
912
+ "Epoch 1/1: 55%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 55/100 [00:23<00:19, 2.32it/s]"
913
+ ]
914
+ },
915
+ {
916
+ "name": "stdout",
917
+ "output_type": "stream",
918
+ "text": [
919
+ "Epoch [1/1], Step [55/100], Moving Avg Loss: 22.5117\n"
920
+ ]
921
+ },
922
+ {
923
+ "name": "stderr",
924
+ "output_type": "stream",
925
+ "text": [
926
+ "Epoch 1/1: 56%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 56/100 [00:23<00:18, 2.34it/s]"
927
+ ]
928
+ },
929
+ {
930
+ "name": "stdout",
931
+ "output_type": "stream",
932
+ "text": [
933
+ "Epoch [1/1], Step [56/100], Moving Avg Loss: 23.7047\n"
934
+ ]
935
+ },
936
+ {
937
+ "name": "stderr",
938
+ "output_type": "stream",
939
+ "text": [
940
+ "Epoch 1/1: 57%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹ | 57/100 [00:24<00:18, 2.38it/s]"
941
+ ]
942
+ },
943
+ {
944
+ "name": "stdout",
945
+ "output_type": "stream",
946
+ "text": [
947
+ "Epoch [1/1], Step [57/100], Moving Avg Loss: 24.7985\n"
948
+ ]
949
+ },
950
+ {
951
+ "name": "stderr",
952
+ "output_type": "stream",
953
+ "text": [
954
+ "Epoch 1/1: 58%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š | 58/100 [00:24<00:17, 2.37it/s]"
955
+ ]
956
+ },
957
+ {
958
+ "name": "stdout",
959
+ "output_type": "stream",
960
+ "text": [
961
+ "Epoch [1/1], Step [58/100], Moving Avg Loss: 25.7531\n"
962
+ ]
963
+ },
964
+ {
965
+ "name": "stderr",
966
+ "output_type": "stream",
967
+ "text": [
968
+ "Epoch 1/1: 59%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰ | 59/100 [00:24<00:17, 2.36it/s]"
969
+ ]
970
+ },
971
+ {
972
+ "name": "stdout",
973
+ "output_type": "stream",
974
+ "text": [
975
+ "Epoch [1/1], Step [59/100], Moving Avg Loss: 24.8322\n"
976
+ ]
977
+ },
978
+ {
979
+ "name": "stderr",
980
+ "output_type": "stream",
981
+ "text": [
982
+ "Epoch 1/1: 60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 60/100 [00:25<00:17, 2.32it/s]"
983
+ ]
984
+ },
985
+ {
986
+ "name": "stdout",
987
+ "output_type": "stream",
988
+ "text": [
989
+ "Epoch [1/1], Step [60/100], Moving Avg Loss: 24.5820\n"
990
+ ]
991
+ },
992
+ {
993
+ "name": "stderr",
994
+ "output_type": "stream",
995
+ "text": [
996
+ "Epoch 1/1: 61%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 61/100 [00:25<00:16, 2.37it/s]"
997
+ ]
998
+ },
999
+ {
1000
+ "name": "stdout",
1001
+ "output_type": "stream",
1002
+ "text": [
1003
+ "Epoch [1/1], Step [61/100], Moving Avg Loss: 29.7474\n"
1004
+ ]
1005
+ },
1006
+ {
1007
+ "name": "stderr",
1008
+ "output_type": "stream",
1009
+ "text": [
1010
+ "Epoch 1/1: 62%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 62/100 [00:26<00:15, 2.39it/s]"
1011
+ ]
1012
+ },
1013
+ {
1014
+ "name": "stdout",
1015
+ "output_type": "stream",
1016
+ "text": [
1017
+ "Epoch [1/1], Step [62/100], Moving Avg Loss: 30.3602\n"
1018
+ ]
1019
+ },
1020
+ {
1021
+ "name": "stderr",
1022
+ "output_type": "stream",
1023
+ "text": [
1024
+ "Epoch 1/1: 63%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž | 63/100 [00:26<00:15, 2.42it/s]"
1025
+ ]
1026
+ },
1027
+ {
1028
+ "name": "stdout",
1029
+ "output_type": "stream",
1030
+ "text": [
1031
+ "Epoch [1/1], Step [63/100], Moving Avg Loss: 29.7396\n"
1032
+ ]
1033
+ },
1034
+ {
1035
+ "name": "stderr",
1036
+ "output_type": "stream",
1037
+ "text": [
1038
+ "Epoch 1/1: 64%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 64/100 [00:26<00:14, 2.44it/s]"
1039
+ ]
1040
+ },
1041
+ {
1042
+ "name": "stdout",
1043
+ "output_type": "stream",
1044
+ "text": [
1045
+ "Epoch [1/1], Step [64/100], Moving Avg Loss: 27.3900\n"
1046
+ ]
1047
+ },
1048
+ {
1049
+ "name": "stderr",
1050
+ "output_type": "stream",
1051
+ "text": [
1052
+ "Epoch 1/1: 65%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 65/100 [00:27<00:14, 2.44it/s]"
1053
+ ]
1054
+ },
1055
+ {
1056
+ "name": "stdout",
1057
+ "output_type": "stream",
1058
+ "text": [
1059
+ "Epoch [1/1], Step [65/100], Moving Avg Loss: 25.9465\n"
1060
+ ]
1061
+ },
1062
+ {
1063
+ "name": "stderr",
1064
+ "output_type": "stream",
1065
+ "text": [
1066
+ "Epoch 1/1: 66%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 66/100 [00:27<00:13, 2.44it/s]"
1067
+ ]
1068
+ },
1069
+ {
1070
+ "name": "stdout",
1071
+ "output_type": "stream",
1072
+ "text": [
1073
+ "Epoch [1/1], Step [66/100], Moving Avg Loss: 24.0045\n"
1074
+ ]
1075
+ },
1076
+ {
1077
+ "name": "stderr",
1078
+ "output_type": "stream",
1079
+ "text": [
1080
+ "Epoch 1/1: 67%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹ | 67/100 [00:28<00:13, 2.42it/s]"
1081
+ ]
1082
+ },
1083
+ {
1084
+ "name": "stdout",
1085
+ "output_type": "stream",
1086
+ "text": [
1087
+ "Epoch [1/1], Step [67/100], Moving Avg Loss: 22.4123\n"
1088
+ ]
1089
+ },
1090
+ {
1091
+ "name": "stderr",
1092
+ "output_type": "stream",
1093
+ "text": [
1094
+ "Epoch 1/1: 68%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š | 68/100 [00:28<00:13, 2.39it/s]"
1095
+ ]
1096
+ },
1097
+ {
1098
+ "name": "stdout",
1099
+ "output_type": "stream",
1100
+ "text": [
1101
+ "Epoch [1/1], Step [68/100], Moving Avg Loss: 21.4466\n"
1102
+ ]
1103
+ },
1104
+ {
1105
+ "name": "stderr",
1106
+ "output_type": "stream",
1107
+ "text": [
1108
+ "Epoch 1/1: 69%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰ | 69/100 [00:28<00:12, 2.44it/s]"
1109
+ ]
1110
+ },
1111
+ {
1112
+ "name": "stdout",
1113
+ "output_type": "stream",
1114
+ "text": [
1115
+ "Epoch [1/1], Step [69/100], Moving Avg Loss: 21.5559\n"
1116
+ ]
1117
+ },
1118
+ {
1119
+ "name": "stderr",
1120
+ "output_type": "stream",
1121
+ "text": [
1122
+ "Epoch 1/1: 70%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 70/100 [00:29<00:12, 2.47it/s]"
1123
+ ]
1124
+ },
1125
+ {
1126
+ "name": "stdout",
1127
+ "output_type": "stream",
1128
+ "text": [
1129
+ "Epoch [1/1], Step [70/100], Moving Avg Loss: 20.5609\n"
1130
+ ]
1131
+ },
1132
+ {
1133
+ "name": "stderr",
1134
+ "output_type": "stream",
1135
+ "text": [
1136
+ "Epoch 1/1: 71%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 71/100 [00:29<00:11, 2.48it/s]"
1137
+ ]
1138
+ },
1139
+ {
1140
+ "name": "stdout",
1141
+ "output_type": "stream",
1142
+ "text": [
1143
+ "Epoch [1/1], Step [71/100], Moving Avg Loss: 19.9970\n"
1144
+ ]
1145
+ },
1146
+ {
1147
+ "name": "stderr",
1148
+ "output_type": "stream",
1149
+ "text": [
1150
+ "Epoch 1/1: 72%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 72/100 [00:30<00:11, 2.39it/s]"
1151
+ ]
1152
+ },
1153
+ {
1154
+ "name": "stdout",
1155
+ "output_type": "stream",
1156
+ "text": [
1157
+ "Epoch [1/1], Step [72/100], Moving Avg Loss: 20.6782\n"
1158
+ ]
1159
+ },
1160
+ {
1161
+ "name": "stderr",
1162
+ "output_type": "stream",
1163
+ "text": [
1164
+ "Epoch 1/1: 73%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž | 73/100 [00:30<00:11, 2.41it/s]"
1165
+ ]
1166
+ },
1167
+ {
1168
+ "name": "stdout",
1169
+ "output_type": "stream",
1170
+ "text": [
1171
+ "Epoch [1/1], Step [73/100], Moving Avg Loss: 23.7722\n"
1172
+ ]
1173
+ },
1174
+ {
1175
+ "name": "stderr",
1176
+ "output_type": "stream",
1177
+ "text": [
1178
+ "Epoch 1/1: 74%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 74/100 [00:31<00:10, 2.40it/s]"
1179
+ ]
1180
+ },
1181
+ {
1182
+ "name": "stdout",
1183
+ "output_type": "stream",
1184
+ "text": [
1185
+ "Epoch [1/1], Step [74/100], Moving Avg Loss: 22.6156\n"
1186
+ ]
1187
+ },
1188
+ {
1189
+ "name": "stderr",
1190
+ "output_type": "stream",
1191
+ "text": [
1192
+ "Epoch 1/1: 75%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 75/100 [00:31<00:10, 2.44it/s]"
1193
+ ]
1194
+ },
1195
+ {
1196
+ "name": "stdout",
1197
+ "output_type": "stream",
1198
+ "text": [
1199
+ "Epoch [1/1], Step [75/100], Moving Avg Loss: 27.7204\n"
1200
+ ]
1201
+ },
1202
+ {
1203
+ "name": "stderr",
1204
+ "output_type": "stream",
1205
+ "text": [
1206
+ "Epoch 1/1: 76%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 76/100 [00:31<00:09, 2.42it/s]"
1207
+ ]
1208
+ },
1209
+ {
1210
+ "name": "stdout",
1211
+ "output_type": "stream",
1212
+ "text": [
1213
+ "Epoch [1/1], Step [76/100], Moving Avg Loss: 27.3355\n"
1214
+ ]
1215
+ },
1216
+ {
1217
+ "name": "stderr",
1218
+ "output_type": "stream",
1219
+ "text": [
1220
+ "Epoch 1/1: 77%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹ | 77/100 [00:32<00:09, 2.43it/s]"
1221
+ ]
1222
+ },
1223
+ {
1224
+ "name": "stdout",
1225
+ "output_type": "stream",
1226
+ "text": [
1227
+ "Epoch [1/1], Step [77/100], Moving Avg Loss: 26.1804\n"
1228
+ ]
1229
+ },
1230
+ {
1231
+ "name": "stderr",
1232
+ "output_type": "stream",
1233
+ "text": [
1234
+ "Epoch 1/1: 78%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š | 78/100 [00:32<00:09, 2.39it/s]"
1235
+ ]
1236
+ },
1237
+ {
1238
+ "name": "stdout",
1239
+ "output_type": "stream",
1240
+ "text": [
1241
+ "Epoch [1/1], Step [78/100], Moving Avg Loss: 25.3216\n"
1242
+ ]
1243
+ },
1244
+ {
1245
+ "name": "stderr",
1246
+ "output_type": "stream",
1247
+ "text": [
1248
+ "Epoch 1/1: 79%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰ | 79/100 [00:33<00:08, 2.43it/s]"
1249
+ ]
1250
+ },
1251
+ {
1252
+ "name": "stdout",
1253
+ "output_type": "stream",
1254
+ "text": [
1255
+ "Epoch [1/1], Step [79/100], Moving Avg Loss: 27.5742\n"
1256
+ ]
1257
+ },
1258
+ {
1259
+ "name": "stderr",
1260
+ "output_type": "stream",
1261
+ "text": [
1262
+ "Epoch 1/1: 80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 80/100 [00:33<00:08, 2.41it/s]"
1263
+ ]
1264
+ },
1265
+ {
1266
+ "name": "stdout",
1267
+ "output_type": "stream",
1268
+ "text": [
1269
+ "Epoch [1/1], Step [80/100], Moving Avg Loss: 27.5931\n"
1270
+ ]
1271
+ },
1272
+ {
1273
+ "name": "stderr",
1274
+ "output_type": "stream",
1275
+ "text": [
1276
+ "Epoch 1/1: 81%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 81/100 [00:33<00:08, 2.34it/s]"
1277
+ ]
1278
+ },
1279
+ {
1280
+ "name": "stdout",
1281
+ "output_type": "stream",
1282
+ "text": [
1283
+ "Epoch [1/1], Step [81/100], Moving Avg Loss: 25.5491\n"
1284
+ ]
1285
+ },
1286
+ {
1287
+ "name": "stderr",
1288
+ "output_type": "stream",
1289
+ "text": [
1290
+ "Epoch 1/1: 82%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 82/100 [00:34<00:07, 2.33it/s]"
1291
+ ]
1292
+ },
1293
+ {
1294
+ "name": "stdout",
1295
+ "output_type": "stream",
1296
+ "text": [
1297
+ "Epoch [1/1], Step [82/100], Moving Avg Loss: 24.0114\n"
1298
+ ]
1299
+ },
1300
+ {
1301
+ "name": "stderr",
1302
+ "output_type": "stream",
1303
+ "text": [
1304
+ "Epoch 1/1: 83%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž | 83/100 [00:34<00:07, 2.33it/s]"
1305
+ ]
1306
+ },
1307
+ {
1308
+ "name": "stdout",
1309
+ "output_type": "stream",
1310
+ "text": [
1311
+ "Epoch [1/1], Step [83/100], Moving Avg Loss: 22.3863\n"
1312
+ ]
1313
+ },
1314
+ {
1315
+ "name": "stderr",
1316
+ "output_type": "stream",
1317
+ "text": [
1318
+ "Epoch 1/1: 84%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 84/100 [00:35<00:06, 2.36it/s]"
1319
+ ]
1320
+ },
1321
+ {
1322
+ "name": "stdout",
1323
+ "output_type": "stream",
1324
+ "text": [
1325
+ "Epoch [1/1], Step [84/100], Moving Avg Loss: 23.4298\n"
1326
+ ]
1327
+ },
1328
+ {
1329
+ "name": "stderr",
1330
+ "output_type": "stream",
1331
+ "text": [
1332
+ "Epoch 1/1: 85%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 85/100 [00:35<00:06, 2.34it/s]"
1333
+ ]
1334
+ },
1335
+ {
1336
+ "name": "stdout",
1337
+ "output_type": "stream",
1338
+ "text": [
1339
+ "Epoch [1/1], Step [85/100], Moving Avg Loss: 21.6505\n"
1340
+ ]
1341
+ },
1342
+ {
1343
+ "name": "stderr",
1344
+ "output_type": "stream",
1345
+ "text": [
1346
+ "Epoch 1/1: 86%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 86/100 [00:36<00:06, 2.28it/s]"
1347
+ ]
1348
+ },
1349
+ {
1350
+ "name": "stdout",
1351
+ "output_type": "stream",
1352
+ "text": [
1353
+ "Epoch [1/1], Step [86/100], Moving Avg Loss: 22.6546\n"
1354
+ ]
1355
+ },
1356
+ {
1357
+ "name": "stderr",
1358
+ "output_type": "stream",
1359
+ "text": [
1360
+ "Epoch 1/1: 87%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹ | 87/100 [00:36<00:05, 2.35it/s]"
1361
+ ]
1362
+ },
1363
+ {
1364
+ "name": "stdout",
1365
+ "output_type": "stream",
1366
+ "text": [
1367
+ "Epoch [1/1], Step [87/100], Moving Avg Loss: 21.1156\n"
1368
+ ]
1369
+ },
1370
+ {
1371
+ "name": "stderr",
1372
+ "output_type": "stream",
1373
+ "text": [
1374
+ "Epoch 1/1: 88%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š | 88/100 [00:36<00:05, 2.38it/s]"
1375
+ ]
1376
+ },
1377
+ {
1378
+ "name": "stdout",
1379
+ "output_type": "stream",
1380
+ "text": [
1381
+ "Epoch [1/1], Step [88/100], Moving Avg Loss: 23.9993\n"
1382
+ ]
1383
+ },
1384
+ {
1385
+ "name": "stderr",
1386
+ "output_type": "stream",
1387
+ "text": [
1388
+ "Epoch 1/1: 89%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰ | 89/100 [00:37<00:04, 2.39it/s]"
1389
+ ]
1390
+ },
1391
+ {
1392
+ "name": "stdout",
1393
+ "output_type": "stream",
1394
+ "text": [
1395
+ "Epoch [1/1], Step [89/100], Moving Avg Loss: 24.9765\n"
1396
+ ]
1397
+ },
1398
+ {
1399
+ "name": "stderr",
1400
+ "output_type": "stream",
1401
+ "text": [
1402
+ "Epoch 1/1: 90%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 90/100 [00:37<00:04, 2.42it/s]"
1403
+ ]
1404
+ },
1405
+ {
1406
+ "name": "stdout",
1407
+ "output_type": "stream",
1408
+ "text": [
1409
+ "Epoch [1/1], Step [90/100], Moving Avg Loss: 28.3430\n"
1410
+ ]
1411
+ },
1412
+ {
1413
+ "name": "stderr",
1414
+ "output_type": "stream",
1415
+ "text": [
1416
+ "Epoch 1/1: 91%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 91/100 [00:38<00:03, 2.39it/s]"
1417
+ ]
1418
+ },
1419
+ {
1420
+ "name": "stdout",
1421
+ "output_type": "stream",
1422
+ "text": [
1423
+ "Epoch [1/1], Step [91/100], Moving Avg Loss: 28.5874\n"
1424
+ ]
1425
+ },
1426
+ {
1427
+ "name": "stderr",
1428
+ "output_type": "stream",
1429
+ "text": [
1430
+ "Epoch 1/1: 92%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 92/100 [00:38<00:03, 2.33it/s]"
1431
+ ]
1432
+ },
1433
+ {
1434
+ "name": "stdout",
1435
+ "output_type": "stream",
1436
+ "text": [
1437
+ "Epoch [1/1], Step [92/100], Moving Avg Loss: 27.0662\n"
1438
+ ]
1439
+ },
1440
+ {
1441
+ "name": "stderr",
1442
+ "output_type": "stream",
1443
+ "text": [
1444
+ "Epoch 1/1: 93%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž| 93/100 [00:39<00:03, 2.32it/s]"
1445
+ ]
1446
+ },
1447
+ {
1448
+ "name": "stdout",
1449
+ "output_type": "stream",
1450
+ "text": [
1451
+ "Epoch [1/1], Step [93/100], Moving Avg Loss: 29.0707\n"
1452
+ ]
1453
+ },
1454
+ {
1455
+ "name": "stderr",
1456
+ "output_type": "stream",
1457
+ "text": [
1458
+ "Epoch 1/1: 94%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 94/100 [00:39<00:02, 2.23it/s]"
1459
+ ]
1460
+ },
1461
+ {
1462
+ "name": "stdout",
1463
+ "output_type": "stream",
1464
+ "text": [
1465
+ "Epoch [1/1], Step [94/100], Moving Avg Loss: 26.7228\n"
1466
+ ]
1467
+ },
1468
+ {
1469
+ "name": "stderr",
1470
+ "output_type": "stream",
1471
+ "text": [
1472
+ "Epoch 1/1: 95%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ| 95/100 [00:40<00:02, 2.24it/s]"
1473
+ ]
1474
+ },
1475
+ {
1476
+ "name": "stdout",
1477
+ "output_type": "stream",
1478
+ "text": [
1479
+ "Epoch [1/1], Step [95/100], Moving Avg Loss: 24.9785\n"
1480
+ ]
1481
+ },
1482
+ {
1483
+ "name": "stderr",
1484
+ "output_type": "stream",
1485
+ "text": [
1486
+ "Epoch 1/1: 96%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ| 96/100 [00:40<00:01, 2.28it/s]"
1487
+ ]
1488
+ },
1489
+ {
1490
+ "name": "stdout",
1491
+ "output_type": "stream",
1492
+ "text": [
1493
+ "Epoch [1/1], Step [96/100], Moving Avg Loss: 28.0284\n"
1494
+ ]
1495
+ },
1496
+ {
1497
+ "name": "stderr",
1498
+ "output_type": "stream",
1499
+ "text": [
1500
+ "Epoch 1/1: 97%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹| 97/100 [00:40<00:01, 2.32it/s]"
1501
+ ]
1502
+ },
1503
+ {
1504
+ "name": "stdout",
1505
+ "output_type": "stream",
1506
+ "text": [
1507
+ "Epoch [1/1], Step [97/100], Moving Avg Loss: 25.9050\n"
1508
+ ]
1509
+ },
1510
+ {
1511
+ "name": "stderr",
1512
+ "output_type": "stream",
1513
+ "text": [
1514
+ "Epoch 1/1: 98%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š| 98/100 [00:41<00:00, 2.36it/s]"
1515
+ ]
1516
+ },
1517
+ {
1518
+ "name": "stdout",
1519
+ "output_type": "stream",
1520
+ "text": [
1521
+ "Epoch [1/1], Step [98/100], Moving Avg Loss: 26.5735\n"
1522
+ ]
1523
+ },
1524
+ {
1525
+ "name": "stderr",
1526
+ "output_type": "stream",
1527
+ "text": [
1528
+ "Epoch 1/1: 99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰| 99/100 [00:41<00:00, 2.39it/s]"
1529
+ ]
1530
+ },
1531
+ {
1532
+ "name": "stdout",
1533
+ "output_type": "stream",
1534
+ "text": [
1535
+ "Epoch [1/1], Step [99/100], Moving Avg Loss: 24.9826\n"
1536
+ ]
1537
+ },
1538
+ {
1539
+ "name": "stderr",
1540
+ "output_type": "stream",
1541
+ "text": [
1542
+ "Epoch 1/1: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 100/100 [00:42<00:00, 2.38it/s]"
1543
+ ]
1544
+ },
1545
+ {
1546
+ "name": "stdout",
1547
+ "output_type": "stream",
1548
+ "text": [
1549
+ "Epoch [1/1], Step [100/100], Moving Avg Loss: 23.1838\n",
1550
+ "Average Loss for Epoch 1: 24.4323\n"
1551
+ ]
1552
+ },
1553
+ {
1554
+ "name": "stderr",
1555
+ "output_type": "stream",
1556
+ "text": [
1557
+ "\n"
1558
+ ]
1559
+ }
1560
+ ],
1561
+ "source": [
1562
+ "num_epochs = 1\n",
1563
+ "for epoch in range(num_epochs):\n",
1564
+ " model.train()\n",
1565
+ " running_loss = 0.0\n",
1566
+ " avg_loss = 0.0\n",
1567
+ "\n",
1568
+ " for idx, item in enumerate(tqdm(dataset, desc=f\"Epoch {epoch + 1}/{num_epochs}\")):\n",
1569
+ "\n",
1570
+ " images = [prepare_image_detection(img=item['image'], processor=load_processor())]\n",
1571
+ " images = torch.stack(images, dim=0).to(model.dtype).to(model.device)\n",
1572
+ " optimizer.zero_grad()\n",
1573
+ " outputs = model(pixel_values=images)\n",
1574
+ " predicted_boxes = logits_to_bboxes(outputs.logits, item['image'])\n",
1575
+ " target_boxes = item['bboxes']\n",
1576
+ "\n",
1577
+ " smooth_l1 = smooth_l1_loss(predicted_boxes, target_boxes)\n",
1578
+ " iou = iou_loss(predicted_boxes, target_boxes)\n",
1579
+ " loss = 0.5 * smooth_l1 + 0.5 * iou\n",
1580
+ "\n",
1581
+ " loss.backward()\n",
1582
+ " optimizer.step()\n",
1583
+ " running_loss += loss.item()\n",
1584
+ "\n",
1585
+ " # Update moving average of the loss\n",
1586
+ " avg_loss = 0.9 * avg_loss + 0.1 * loss.item() if idx > 0 else loss.item()\n",
1587
+ "\n",
1588
+ " # Print moving average loss\n",
1589
+ " print(f\"Epoch [{epoch + 1}/{num_epochs}], Step [{idx + 1}/{len(dataset)}], Moving Avg Loss: {avg_loss:.4f}\")\n",
1590
+ "\n",
1591
+ " avg_loss = running_loss / len(dataset)\n",
1592
+ " writer.add_scalar('Training Loss', avg_loss, epoch + 1)\n",
1593
+ " print(f\"Average Loss for Epoch {epoch + 1}: {avg_loss:.4f}\")\n",
1594
+ "\n",
1595
+ " torch.save(model.state_dict(), os.path.join(checkpoint_dir, f\"model_epoch_{epoch + 1}.pth\"))"
1596
+ ]
1597
+ },
1598
+ {
1599
+ "cell_type": "code",
1600
+ "execution_count": null,
1601
+ "metadata": {},
1602
+ "outputs": [],
1603
+ "source": [
1604
+ "checkpoint_path = '/data2/ketan/orc/surya-layout-fine-tune/checkpoints/model_epoch_3.pth' \n",
1605
+ "state_dict = torch.load(checkpoint_path,weights_only=True)\n",
1606
+ "\n",
1607
+ "model.load_state_dict(state_dict)"
1608
+ ]
1609
+ },
1610
+ {
1611
+ "cell_type": "code",
1612
+ "execution_count": null,
1613
+ "metadata": {},
1614
+ "outputs": [],
1615
+ "source": [
1616
+ "model.to('cpu')\n",
1617
+ "model.save_pretrained(\"fine-tuned-surya-model-layout\")"
1618
+ ]
1619
+ }
1620
+ ],
1621
+ "metadata": {
1622
+ "kernelspec": {
1623
+ "display_name": "Python 3",
1624
+ "language": "python",
1625
+ "name": "python3"
1626
+ },
1627
+ "language_info": {
1628
+ "codemirror_mode": {
1629
+ "name": "ipython",
1630
+ "version": 3
1631
+ },
1632
+ "file_extension": ".py",
1633
+ "mimetype": "text/x-python",
1634
+ "name": "python",
1635
+ "nbconvert_exporter": "python",
1636
+ "pygments_lexer": "ipython3",
1637
+ "version": "3.10.14"
1638
+ }
1639
+ },
1640
+ "nbformat": 4,
1641
+ "nbformat_minor": 2
1642
+ }