ketanmore commited on
Commit
21fbd41
Β·
verified Β·
1 Parent(s): 9d2d85e

Upload surya-layout-fien-tuneCrossEntropyLoss.ipynb

Browse files
surya-layout-fien-tuneCrossEntropyLoss.ipynb ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Loading Packages"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": 1,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "import os\n",
17
+ "os.environ['HF_HOME'] = '/data2/ketan/orc/HF_Cache'\n",
18
+ "import torch\n",
19
+ "import torch.nn as nn\n",
20
+ "import torch.optim as optim\n",
21
+ "from torch.utils.data import DataLoader\n",
22
+ "from surya.input.processing import prepare_image_detection\n",
23
+ "from surya.model.detection.segformer import load_processor , load_model\n",
24
+ "from datasets import load_dataset\n",
25
+ "from tqdm import tqdm\n",
26
+ "from torch.utils.tensorboard import SummaryWriter\n",
27
+ "import torch.nn.functional as F\n",
28
+ "import numpy as np \n",
29
+ "from surya.layout import parallel_get_regions\n",
30
+ "import torch.nn.functional as F"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "markdown",
35
+ "metadata": {},
36
+ "source": [
37
+ "# Initializing The Dataset And Model"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": 2,
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "device = torch.device(\"cuda:3\" if torch.cuda.is_available() else \"cpu\")\n",
47
+ "dataset = load_dataset(\"vikp/publaynet_bench\", split=\"train[:100]\") # You can choose you own dataset"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": 3,
53
+ "metadata": {},
54
+ "outputs": [
55
+ {
56
+ "name": "stdout",
57
+ "output_type": "stream",
58
+ "text": [
59
+ "Loaded detection model vikp/surya_layout2 on device cuda with dtype torch.float16\n"
60
+ ]
61
+ },
62
+ {
63
+ "data": {
64
+ "text/plain": [
65
+ "'.'"
66
+ ]
67
+ },
68
+ "execution_count": 3,
69
+ "metadata": {},
70
+ "output_type": "execute_result"
71
+ }
72
+ ],
73
+ "source": [
74
+ "model = load_model(\"vikp/surya_layout2\").to(device)\n",
75
+ "model.to(torch.float32)\n",
76
+ "\".\""
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "markdown",
81
+ "metadata": {},
82
+ "source": [
83
+ "# Helper Functions, Loss Function And Optimizer"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": 4,
89
+ "metadata": {},
90
+ "outputs": [],
91
+ "source": [
92
+ "optimizer = optim.Adam(model.parameters(), lr=1e-4)\n",
93
+ "log_dir = \"logs\"\n",
94
+ "checkpoint_dir = \"checkpoints\"\n",
95
+ "os.makedirs(log_dir, exist_ok=True)\n",
96
+ "os.makedirs(checkpoint_dir, exist_ok=True)\n",
97
+ "writer = SummaryWriter(log_dir=log_dir)"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": 5,
103
+ "metadata": {},
104
+ "outputs": [],
105
+ "source": [
106
+ "def logits_to_mask(logits, labels, bboxes, original_size=(1200, 1200)):\n",
107
+ " batch_size, num_classes, height, width = logits.shape\n",
108
+ " mask = torch.zeros((batch_size, num_classes, height, width), dtype=torch.float32).to(logits.device)\n",
109
+ "\n",
110
+ " for bbox, class_id in zip(bboxes, labels):\n",
111
+ " x_min, y_min, x_max, y_max = bbox\n",
112
+ "\n",
113
+ " x_min = int(x_min * width / original_size[0])\n",
114
+ " y_min = int(y_min * height / original_size[1])\n",
115
+ " x_max = int(x_max * width / original_size[0])\n",
116
+ " y_max = int(y_max * height / original_size[1])\n",
117
+ "\n",
118
+ " x_min = max(0, min(x_min, width - 1))\n",
119
+ " y_min = max(0, min(y_min, height - 1))\n",
120
+ " x_max = max(0, min(x_max, width - 1))\n",
121
+ " y_max = max(0, min(y_max, height - 1))\n",
122
+ "\n",
123
+ " if x_min < x_max and y_min < y_max:\n",
124
+ " mask[:, class_id, y_min:y_max, x_min:x_max] = torch.maximum(\n",
125
+ " mask[:, class_id, y_min:y_max, x_min:x_max], torch.tensor(1.0).to(logits.device)\n",
126
+ " )\n",
127
+ " else:\n",
128
+ " print(f\"Invalid bounding box after adjustment: {bbox}, adjusted to: {(x_min, y_min, x_max, y_max)}\")\n",
129
+ "\n",
130
+ " return mask\n",
131
+ "\n",
132
+ "\n",
133
+ "def loss_function(logits, mask):\n",
134
+ " loss_fn = torch.nn.CrossEntropyLoss() \n",
135
+ " loss = loss_fn(logits, mask)\n",
136
+ " return loss"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "markdown",
141
+ "metadata": {},
142
+ "source": [
143
+ "# Fine-Tuning Process"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": 6,
149
+ "metadata": {},
150
+ "outputs": [
151
+ {
152
+ "name": "stderr",
153
+ "output_type": "stream",
154
+ "text": [
155
+ "Epoch 1/5: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 100/100 [01:46<00:00, 1.07s/it]\n"
156
+ ]
157
+ },
158
+ {
159
+ "name": "stdout",
160
+ "output_type": "stream",
161
+ "text": [
162
+ "Average Loss for Epoch 1: 0.3322\n"
163
+ ]
164
+ },
165
+ {
166
+ "name": "stderr",
167
+ "output_type": "stream",
168
+ "text": [
169
+ "Epoch 2/5: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 100/100 [01:51<00:00, 1.11s/it]\n"
170
+ ]
171
+ },
172
+ {
173
+ "name": "stdout",
174
+ "output_type": "stream",
175
+ "text": [
176
+ "Average Loss for Epoch 2: 0.3311\n"
177
+ ]
178
+ },
179
+ {
180
+ "name": "stderr",
181
+ "output_type": "stream",
182
+ "text": [
183
+ "Epoch 3/5: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 100/100 [01:51<00:00, 1.12s/it]\n"
184
+ ]
185
+ },
186
+ {
187
+ "name": "stdout",
188
+ "output_type": "stream",
189
+ "text": [
190
+ "Average Loss for Epoch 3: 0.3197\n"
191
+ ]
192
+ },
193
+ {
194
+ "name": "stderr",
195
+ "output_type": "stream",
196
+ "text": [
197
+ "Epoch 4/5: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 100/100 [01:42<00:00, 1.03s/it]\n"
198
+ ]
199
+ },
200
+ {
201
+ "name": "stdout",
202
+ "output_type": "stream",
203
+ "text": [
204
+ "Average Loss for Epoch 4: 0.3106\n"
205
+ ]
206
+ },
207
+ {
208
+ "name": "stderr",
209
+ "output_type": "stream",
210
+ "text": [
211
+ "Epoch 5/5: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 100/100 [01:46<00:00, 1.06s/it]\n"
212
+ ]
213
+ },
214
+ {
215
+ "name": "stdout",
216
+ "output_type": "stream",
217
+ "text": [
218
+ "Average Loss for Epoch 5: 0.3160\n"
219
+ ]
220
+ }
221
+ ],
222
+ "source": [
223
+ "num_epochs = 5\n",
224
+ "\n",
225
+ "for param in model.parameters():\n",
226
+ " param.requires_grad = True\n",
227
+ "\n",
228
+ "\n",
229
+ "model.train()\n",
230
+ "with torch.autograd.set_detect_anomaly(True):\n",
231
+ "\n",
232
+ " for epoch in range(num_epochs):\n",
233
+ " running_loss = 0.0\n",
234
+ " avg_loss = 0.0\n",
235
+ "\n",
236
+ " for idx, item in enumerate(tqdm(dataset, desc=f\"Epoch {epoch + 1}/{num_epochs}\")):\n",
237
+ " images = [prepare_image_detection(img=item['image'], processor=load_processor())]\n",
238
+ " images = torch.stack(images, dim=0).to(model.dtype).to(model.device)\n",
239
+ " \n",
240
+ " optimizer.zero_grad()\n",
241
+ " outputs = model(pixel_values=images)\n",
242
+ "\n",
243
+ "\n",
244
+ " logits = outputs.logits\n",
245
+ "\n",
246
+ " bboxes = item['bboxes']\n",
247
+ " labels = item['category_ids']\n",
248
+ " mask = logits_to_mask(logits, labels, bboxes)\n",
249
+ "\n",
250
+ " logits = logits.to(torch.float32)\n",
251
+ " mask = mask.to(torch.float32)\n",
252
+ " loss = loss_function(logits, mask)\n",
253
+ "\n",
254
+ " loss.backward()\n",
255
+ "\n",
256
+ " optimizer.step()\n",
257
+ "\n",
258
+ " avg_loss = 0.9 * avg_loss + 0.1 * loss.item() if idx > 0 else loss.item()\n",
259
+ "\n",
260
+ " writer.add_scalar('Training Loss', avg_loss, epoch + 1)\n",
261
+ " print(f\"Average Loss for Epoch {epoch + 1}: {avg_loss:.4f}\")\n",
262
+ "\n",
263
+ " torch.save(model.state_dict(), os.path.join(checkpoint_dir, f\"model_epoch_{epoch + 1}.pth\"))\n"
264
+ ]
265
+ },
266
+ {
267
+ "cell_type": "markdown",
268
+ "metadata": {},
269
+ "source": [
270
+ "# Loading The Checkpoint "
271
+ ]
272
+ },
273
+ {
274
+ "cell_type": "code",
275
+ "execution_count": 7,
276
+ "metadata": {},
277
+ "outputs": [
278
+ {
279
+ "data": {
280
+ "text/plain": [
281
+ "<All keys matched successfully>"
282
+ ]
283
+ },
284
+ "execution_count": 7,
285
+ "metadata": {},
286
+ "output_type": "execute_result"
287
+ }
288
+ ],
289
+ "source": [
290
+ "checkpoint_path = '/data2/ketan/orc/surya-layout-fine-tune/checkpoints/model_epoch_5.pth' \n",
291
+ "state_dict = torch.load(checkpoint_path,weights_only=True)\n",
292
+ "\n",
293
+ "model.load_state_dict(state_dict)"
294
+ ]
295
+ },
296
+ {
297
+ "cell_type": "code",
298
+ "execution_count": 8,
299
+ "metadata": {},
300
+ "outputs": [],
301
+ "source": [
302
+ "model.to('cpu')\n",
303
+ "model.save_pretrained(\"fine-tuned-surya-model-layout\")"
304
+ ]
305
+ }
306
+ ],
307
+ "metadata": {
308
+ "kernelspec": {
309
+ "display_name": "Python 3",
310
+ "language": "python",
311
+ "name": "python3"
312
+ },
313
+ "language_info": {
314
+ "codemirror_mode": {
315
+ "name": "ipython",
316
+ "version": 3
317
+ },
318
+ "file_extension": ".py",
319
+ "mimetype": "text/x-python",
320
+ "name": "python",
321
+ "nbconvert_exporter": "python",
322
+ "pygments_lexer": "ipython3",
323
+ "version": "3.10.14"
324
+ }
325
+ },
326
+ "nbformat": 4,
327
+ "nbformat_minor": 2
328
+ }