ketanmore commited on
Commit
5e62e22
·
verified ·
1 Parent(s): 51ed342

Upload layout-fine-tune.ipynb

Browse files
Files changed (1) hide show
  1. layout-fine-tune.ipynb +187 -0
layout-fine-tune.ipynb ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Loading Packages"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "import os\n",
17
+ "import torch\n",
18
+ "import torch.nn as nn\n",
19
+ "import torch.optim as optim\n",
20
+ "from torch.utils.data import DataLoader\n",
21
+ "# from transformers import SegformerConfig\n",
22
+ "# from surya.model.detection.segformer import SegformerForRegressionMask\n",
23
+ "from surya.input.processing import prepare_image_detection\n",
24
+ "from surya.model.detection.segformer import load_processor , load_model\n",
25
+ "from datasets import load_dataset\n",
26
+ "from tqdm import tqdm\n",
27
+ "from torch.utils.tensorboard import SummaryWriter\n",
28
+ "import torch.nn.functional as F\n",
29
+ "import numpy as np \n",
30
+ "from surya.layout import parallel_get_regions"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "markdown",
35
+ "metadata": {},
36
+ "source": [
37
+ "# Initializing The Dataset And Model"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
47
+ "dataset = load_dataset(\"vikp/publaynet_bench\", split=\"train[:100]\") # You can choose you own dataset\n",
48
+ "model = load_model(\"vikp/surya_layout2\") "
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "markdown",
53
+ "metadata": {},
54
+ "source": [
55
+ "# Helper Functions, Loss Function And Optimizer"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": null,
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "\n",
65
+ "optimizer = optim.Adam(model.parameters(), lr=0.00001)\n",
66
+ "log_dir = \"logs\"\n",
67
+ "checkpoint_dir = \"checkpoints\"\n",
68
+ "os.makedirs(log_dir, exist_ok=True)\n",
69
+ "os.makedirs(checkpoint_dir, exist_ok=True)\n",
70
+ "writer = SummaryWriter(log_dir=log_dir)\n",
71
+ "\n",
72
+ "def logits_to_bboxes(logits,image) : # This function is useful for converting the mask into bounding boxes.(The model does not provide bounding boxes.)\n",
73
+ " correct_shape = (300, 300) \n",
74
+ " logits_temp = F.interpolate(logits, size=correct_shape, mode='bilinear', align_corners=False)\n",
75
+ " logits_temp = logits_temp.cpu().detach().numpy().astype(np.float32)\n",
76
+ "\n",
77
+ " heatmap_count = logits_temp.shape[1]\n",
78
+ " heatmaps = [logits_temp[i][k] for i in range(logits_temp.shape[0]) for k in range(heatmap_count)]\n",
79
+ " regions = parallel_get_regions(heatmaps=heatmaps, orig_size=image.size, id2label=model.config.id2label)\n",
80
+ "\n",
81
+ " final_bboxes = []\n",
82
+ " for i in regions.bboxes :\n",
83
+ " final_bboxes.append(i.bbox)\n",
84
+ " return final_bboxes\n",
85
+ "\n",
86
+ "\n",
87
+ "def loss_function(): # This model does not have inbuild loss function, So we have to define it according to our dataset and the Requirements.\n",
88
+ " pass"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "markdown",
93
+ "metadata": {},
94
+ "source": [
95
+ "# Fine-Tuning Process"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": null,
101
+ "metadata": {},
102
+ "outputs": [],
103
+ "source": [
104
+ "num_epochs = 5\n",
105
+ "for epoch in range(num_epochs):\n",
106
+ " model.train()\n",
107
+ " running_loss = 0.0\n",
108
+ " avg_loss = 0.0\n",
109
+ "\n",
110
+ " for idx, item in enumerate(tqdm(dataset, desc=f\"Epoch {epoch + 1}/{num_epochs}\")):\n",
111
+ "\n",
112
+ " images = [prepare_image_detection(img=item['image'], processor=load_processor())]\n",
113
+ " images = torch.stack(images, dim=0).to(model.dtype).to(model.device)\n",
114
+ " \n",
115
+ " optimizer.zero_grad()\n",
116
+ " outputs = model(pixel_values=images)\n",
117
+ "\n",
118
+ " predicted_boxes = logits_to_bboxes(outputs.logits, item['image'])\n",
119
+ " target_boxes = item['bboxes']\n",
120
+ "\n",
121
+ " loss = loss_function(predicted_boxes,target_boxes)\n",
122
+ "\n",
123
+ " loss.backward()\n",
124
+ " optimizer.step()\n",
125
+ " running_loss += loss.item()\n",
126
+ "\n",
127
+ " avg_loss = 0.9 * avg_loss + 0.1 * loss.item() if idx > 0 else loss.item()\n",
128
+ "\n",
129
+ " avg_loss = running_loss / len(dataset)\n",
130
+ " writer.add_scalar('Training Loss', avg_loss, epoch + 1)\n",
131
+ " print(f\"Average Loss for Epoch {epoch + 1}: {avg_loss:.4f}\")\n",
132
+ "\n",
133
+ " torch.save(model.state_dict(), os.path.join(checkpoint_dir, f\"model_epoch_{epoch + 1}.pth\"))"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "markdown",
138
+ "metadata": {},
139
+ "source": [
140
+ "# Loading The Checkpoint "
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": null,
146
+ "metadata": {},
147
+ "outputs": [],
148
+ "source": [
149
+ "checkpoint_path = 'checkpoints/model_epoch_350.pth' \n",
150
+ "state_dict = torch.load(checkpoint_path,weights_only=True)\n",
151
+ "\n",
152
+ "model.load_state_dict(state_dict)"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": null,
158
+ "metadata": {},
159
+ "outputs": [],
160
+ "source": [
161
+ "model.to('cpu')\n",
162
+ "model.save_pretrained(\"fine-tuned-surya-model-layout\")"
163
+ ]
164
+ }
165
+ ],
166
+ "metadata": {
167
+ "kernelspec": {
168
+ "display_name": "Python 3",
169
+ "language": "python",
170
+ "name": "python3"
171
+ },
172
+ "language_info": {
173
+ "codemirror_mode": {
174
+ "name": "ipython",
175
+ "version": 3
176
+ },
177
+ "file_extension": ".py",
178
+ "mimetype": "text/x-python",
179
+ "name": "python",
180
+ "nbconvert_exporter": "python",
181
+ "pygments_lexer": "ipython3",
182
+ "version": "3.10.14"
183
+ }
184
+ },
185
+ "nbformat": 4,
186
+ "nbformat_minor": 2
187
+ }