File size: 69,208 Bytes
0719c14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import pytorch_lightning as pl\n",
    "import torch\n",
    "import torchmetrics\n",
    "\n",
    "from datasets import load_dataset, load_metric\n",
    "\n",
    "from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation\n",
    "\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader, Dataset, random_split\n",
    "\n",
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SemanticSegmentationDataset(Dataset):\n",
    "    \"\"\"Image segmentation datasets.\"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self, \n",
    "        dataset: torch.utils.data.dataset.Subset, \n",
    "        feature_extractor = SegformerFeatureExtractor(reduce_labels=True),\n",
    "    ):\n",
    "        \"\"\"\n",
    "        Initialize the dataset with the given feature extractor and split.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        hub_dir : Dataset\n",
    "            The dataset to use.\n",
    "        feature_extractor : FeatureExtractor, optional\n",
    "            The feature extractor to use. The default is SegformerFeatureExtractor.\n",
    "        \"\"\"\n",
    "        self.dataset = dataset\n",
    "        self.feature_extractor = feature_extractor\n",
    "        self.length = len(self.dataset)\n",
    "        print(f\"Loaded {self.length} samples.\")\n",
    "\n",
    "\n",
    "    def __len__(self):\n",
    "        \"\"\"Return the number of samples in the dataset.\"\"\"\n",
    "        return self.length\n",
    "\n",
    "\n",
    "    def __getitem__(self, index: int):\n",
    "        \"\"\"Get the sample at the given index.\"\"\"\n",
    "        image = self.dataset[index][\"pixel_values\"]\n",
    "        label = self.dataset[index][\"label\"]\n",
    "\n",
    "        encoded_inputs = self.feature_extractor(image, label, return_tensors=\"pt\")\n",
    "\n",
    "        for k, v in encoded_inputs.items():\n",
    "            encoded_inputs[k].squeeze_() # remove batch dimension\n",
    "\n",
    "        return encoded_inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 32\n",
    "HUB_DIR = \"segments/sidewalk-semantic\"\n",
    "EPOCHS = 200"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration segments--sidewalk-semantic-2-f89d0845be9cadc9\n",
      "Reusing dataset parquet (/home/chainyo/.cache/huggingface/datasets/segments___parquet/segments--sidewalk-semantic-2-f89d0845be9cadc9/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded 800 samples.\n",
      "Loaded 200 samples.\n"
     ]
    }
   ],
   "source": [
    "dataset = load_dataset(HUB_DIR, split=\"train\")\n",
    "\n",
    "train_dataset, val_dataset = random_split(dataset, [int(0.8 * len(dataset)), len(dataset) - int(0.8 * len(dataset))])\n",
    "train_dataset = SemanticSegmentationDataset(train_dataset)\n",
    "val_dataset = SemanticSegmentationDataset(val_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
    "val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pixel_values torch.Size([32, 3, 512, 512])\n",
      "labels torch.Size([32, 512, 512])\n"
     ]
    }
   ],
   "source": [
    "batch = next(iter(train_dataloader))\n",
    "\n",
    "for k, v in batch.items():\n",
    "    print(k, v.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# class SidewalkSegmentationModel(pl.LightningModule):\n",
    "#     def __init__(self, num_classes: int, learning_rate: float = 6e-5):\n",
    "#         super().__init__()\n",
    "#         self.model = SegformerForSemanticSegmentation.from_pretrained(\n",
    "#             \"nvidia/mit-b0\", num_labels=num_classes, id2label=id2label, label2id=label2id,\n",
    "#         )\n",
    "#         self.learning_rate = learning_rate\n",
    "#         self.metric = load_metric(\"mean_iou\")\n",
    "\n",
    "    \n",
    "#     def forward(self, *args, **kwargs):\n",
    "#         return self.model(*args, **kwargs)\n",
    "\n",
    "    \n",
    "#     def training_step(self, batch, batch_idx):\n",
    "#         pixel_values = batch[\"pixel_values\"]\n",
    "#         labels = batch[\"labels\"]\n",
    "\n",
    "#         outputs = self(pixel_values=pixel_values, labels=labels)\n",
    "#         loss, logits = outputs.loss, outputs.logits\n",
    "\n",
    "    \n",
    "#     def configure_optimizers(self) -> torch.optim.AdamW:\n",
    "#         \"\"\"\n",
    "#         Configure the optimizer.\n",
    "#         Returns\n",
    "#         -------\n",
    "#         torch.optim.AdamW\n",
    "#             Optimizer for the model\n",
    "#         \"\"\"\n",
    "#         return torch.optim.AdamW(model.parameters(), lr=self.learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{0: 'unlabeled', 1: 'flat-road', 2: 'flat-sidewalk', 3: 'flat-crosswalk', 4: 'flat-cyclinglane', 5: 'flat-parkingdriveway', 6: 'flat-railtrack', 7: 'flat-curb', 8: 'human-person', 9: 'human-rider', 10: 'vehicle-car', 11: 'vehicle-truck', 12: 'vehicle-bus', 13: 'vehicle-tramtrain', 14: 'vehicle-motorcycle', 15: 'vehicle-bicycle', 16: 'vehicle-caravan', 17: 'vehicle-cartrailer', 18: 'construction-building', 19: 'construction-door', 20: 'construction-wall', 21: 'construction-fenceguardrail', 22: 'construction-bridge', 23: 'construction-tunnel', 24: 'construction-stairs', 25: 'object-pole', 26: 'object-trafficsign', 27: 'object-trafficlight', 28: 'nature-vegetation', 29: 'nature-terrain', 30: 'sky', 31: 'void-ground', 32: 'void-dynamic', 33: 'void-static', 34: 'void-unclear'}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at nvidia/mit-b0 were not used when initializing SegformerForSemanticSegmentation: ['classifier.weight', 'classifier.bias']\n",
      "- This IS expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.classifier.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.linear_c.1.proj.weight', 'decode_head.classifier.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.batch_norm.bias', 'decode_head.linear_c.2.proj.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "id2label_file = json.load(open(\"id2label.json\", \"r\"))\n",
    "id2label = {int(k): v for k, v in id2label_file.items()}\n",
    "print(id2label)\n",
    "label2id = {v: k for k, v in id2label_file.items()}\n",
    "num_labels = len(id2label)\n",
    "\n",
    "model = SegformerForSemanticSegmentation.from_pretrained(\n",
    "    \"nvidia/mit-b0\", num_labels=num_labels, id2label=id2label, label2id=label2id,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "metric = load_metric(\"mean_iou\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "091f8e1587f64625a4bbbf04f13a840e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/chainyo/.cache/huggingface/modules/datasets_modules/metrics/mean_iou/d4add40cf977cdd73590b5873fa830f3f13adb678f6777a29fb07b7c81d14342/mean_iou.py:259: RuntimeWarning: invalid value encountered in true_divide\n",
      "  acc = total_area_intersect / total_area_label\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0/200 Batch 0/25 Loss 3.5735 Metrics {'mean_iou': 0.005477220659286406, 'mean_accuracy': 0.03572801337234697, 'overall_accuracy': 0.0286087999162653, 'per_category_iou': array([2.25093889e-02, 1.39043249e-03, 7.61652063e-03, 1.08658706e-02,\n",
      "       1.00475237e-02, 0.00000000e+00, 2.03193907e-04, 1.38439962e-03,\n",
      "       3.06388194e-05, 2.21495741e-03, 2.81951515e-05, 0.00000000e+00,\n",
      "       0.00000000e+00, 1.13529929e-04, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 7.05787860e-03, 8.48350742e-03, 2.38476991e-03,\n",
      "       1.76276224e-03, 2.17775404e-03, 0.00000000e+00, 1.04358386e-03,\n",
      "       4.58714414e-04, 6.41413308e-05, 1.11431787e-04, 7.99247870e-02,\n",
      "       8.28409255e-03, 2.07198485e-03, 6.12199013e-04, 6.53992687e-03,\n",
      "       1.42611461e-02, 5.93919660e-05, 0.00000000e+00]), 'per_category_accuracy': array([2.49789663e-02, 1.40417891e-03, 5.70403118e-02, 1.19771803e-02,\n",
      "       1.35858556e-02,            nan, 2.43287079e-04, 1.42133234e-02,\n",
      "       9.06344411e-03, 2.86141687e-03, 1.51057402e-03,            nan,\n",
      "                  nan, 1.54786781e-04, 0.00000000e+00, 0.00000000e+00,\n",
      "                  nan, 7.59912501e-03, 1.08030661e-01, 7.46983779e-03,\n",
      "       3.63612647e-03, 2.03376823e-01,            nan, 1.34638923e-02,\n",
      "       4.23971037e-03, 2.44969379e-03, 2.61780105e-02, 1.44916888e-01,\n",
      "       1.09119869e-02, 2.12033947e-03, 5.77414410e-02, 8.69919527e-02,\n",
      "       1.99881736e-01, 2.00708383e-02,            nan])}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8b14f2d3fe2044bf8940697dfce1f2d9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/200 Batch 0/25 Loss 2.3349 Metrics {'mean_iou': 0.08249196196840773, 'mean_accuracy': 0.12242777101928329, 'overall_accuracy': 0.5340892205419953, 'per_category_iou': array([3.12479410e-01, 6.01364321e-01, 1.11562075e-02, 2.09116315e-02,\n",
      "       1.95069293e-02, 0.00000000e+00, 1.09903909e-04, 1.13759900e-03,\n",
      "       1.31593453e-03, 4.16974851e-01, 9.43479560e-04, 0.00000000e+00,\n",
      "       0.00000000e+00, 2.71726824e-05, 7.02479634e-05, 1.71339744e-06,\n",
      "       1.01240638e-04, 3.64420274e-01, 5.22553547e-03, 1.01179313e-03,\n",
      "       5.40034112e-03, 6.14980455e-04, 0.00000000e+00, 1.37384839e-03,\n",
      "       1.58339239e-03, 2.28333709e-05, 4.11671538e-05, 5.37431493e-01,\n",
      "       6.31855647e-02, 4.88071958e-01, 5.19480641e-03, 4.67615947e-03,\n",
      "       2.28171525e-02, 4.67269054e-05, 0.00000000e+00]), 'per_category_accuracy': array([4.06453404e-01, 7.78244778e-01, 2.33772733e-02, 2.33628638e-02,\n",
      "       2.53283555e-02,            nan, 1.11846810e-04, 3.42729695e-03,\n",
      "       9.61392116e-03, 6.59758916e-01, 1.07169352e-02, 0.00000000e+00,\n",
      "       0.00000000e+00, 1.71831147e-04, 1.20609014e-04, 1.32753642e-04,\n",
      "       7.55138223e-03, 5.08304753e-01, 2.45514876e-02, 1.15429656e-03,\n",
      "       6.00626887e-03, 1.89040602e-02, 0.00000000e+00, 3.03657284e-03,\n",
      "       3.07788162e-03, 1.93985207e-04, 1.35743374e-03, 8.37114885e-01,\n",
      "       7.86780250e-02, 5.18267366e-01, 1.14241257e-02, 1.51054789e-02,\n",
      "       6.31875993e-02, 1.38005816e-03,            nan])}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1bcd35da7d4a43a0958776974bfa7918",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/chainyo/.cache/huggingface/modules/datasets_modules/metrics/mean_iou/d4add40cf977cdd73590b5873fa830f3f13adb678f6777a29fb07b7c81d14342/mean_iou.py:258: RuntimeWarning: invalid value encountered in true_divide\n",
      "  iou = total_area_intersect / total_area_union\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2/200 Batch 0/25 Loss 1.9441 Metrics {'mean_iou': 0.1115145961827685, 'mean_accuracy': 0.1571059701593332, 'overall_accuracy': 0.6731429879739427, 'per_category_iou': array([4.45592658e-01, 7.27905738e-01, 3.53301085e-05, 1.04372074e-02,\n",
      "       7.95093029e-03,            nan, 1.38316668e-06, 0.00000000e+00,\n",
      "       0.00000000e+00, 4.80050947e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 4.95525330e-01, 0.00000000e+00, 1.11859236e-05,\n",
      "       8.15231791e-07, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
      "       2.30751037e-06, 0.00000000e+00, 0.00000000e+00, 6.24028521e-01,\n",
      "       1.05303497e-01, 7.82230424e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       9.05398708e-04, 0.00000000e+00,            nan]), 'per_category_accuracy': array([7.33079709e-01, 9.03717795e-01, 3.57307878e-05, 1.05455168e-02,\n",
      "       8.21302511e-03,            nan, 1.38319984e-06, 0.00000000e+00,\n",
      "       0.00000000e+00, 8.08957376e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 8.13886465e-01, 0.00000000e+00, 1.11898068e-05,\n",
      "       8.19780315e-07, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
      "       2.30767678e-06, 0.00000000e+00, 0.00000000e+00, 9.49630788e-01,\n",
      "       1.14320143e-01, 8.41172201e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       9.22565695e-04, 0.00000000e+00,            nan])}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f3ef333a6137481180b4ec59efd1673e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3/200 Batch 0/25 Loss 1.7086 Metrics {'mean_iou': 0.13337084618347653, 'mean_accuracy': 0.17885172041247846, 'overall_accuracy': 0.7166871222915292, 'per_category_iou': array([0.49122674, 0.77341676, 0.        , 0.1407711 , 0.00553198,\n",
      "              nan, 0.        , 0.        , 0.        , 0.55152752,\n",
      "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
      "       0.        , 0.        , 0.51584332, 0.        , 0.        ,\n",
      "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
      "       0.        , 0.        , 0.68907817, 0.42107346, 0.81276888,\n",
      "       0.        , 0.        , 0.        , 0.        ,        nan]), 'per_category_accuracy': array([0.82412545, 0.91420412, 0.        , 0.14242155, 0.00561311,\n",
      "              nan, 0.        , 0.        , 0.        , 0.84759725,\n",
      "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
      "       0.        , 0.        , 0.85361568, 0.        , 0.        ,\n",
      "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
      "       0.        , 0.        , 0.93567573, 0.47723124, 0.90162263,\n",
      "       0.        , 0.        , 0.        , 0.        ,        nan])}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "384340072f794e70bcbbc74dacc50c21",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4/200 Batch 0/25 Loss 1.4155 Metrics {'mean_iou': 0.15564168770290954, 'mean_accuracy': 0.20019284726893685, 'overall_accuracy': 0.7574203051314757, 'per_category_iou': array([5.71505637e-01, 8.01399129e-01, 0.00000000e+00, 4.92403441e-01,\n",
      "       5.41554099e-03,            nan, 2.31577543e-07, 0.00000000e+00,\n",
      "       0.00000000e+00, 5.86567051e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 5.51951880e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 7.24133440e-01,\n",
      "       5.67064833e-01, 8.35733013e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       1.49737193e-06, 0.00000000e+00,            nan]), 'per_category_accuracy': array([8.47391615e-01, 9.31621860e-01, 0.00000000e+00, 5.38104498e-01,\n",
      "       5.46801709e-03,            nan, 2.31577972e-07, 0.00000000e+00,\n",
      "       0.00000000e+00, 8.70377721e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 8.79001612e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 9.26588036e-01,\n",
      "       6.90579673e-01, 9.17229200e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       1.49737361e-06, 0.00000000e+00,            nan])}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "428849c8d6ce414887a0661d3d4625da",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5/200 Batch 0/25 Loss 1.2700 Metrics {'mean_iou': 0.16264865782481297, 'mean_accuracy': 0.20757696789528982, 'overall_accuracy': 0.7689239961674764, 'per_category_iou': array([0.58414589, 0.81498541, 0.        , 0.57467831, 0.00689438,\n",
      "              nan, 0.        , 0.        , 0.        , 0.61111453,\n",
      "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
      "       0.        , 0.        , 0.55374014, 0.        , 0.        ,\n",
      "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
      "       0.        , 0.        , 0.74374917, 0.6398338 , 0.83826408,\n",
      "       0.        , 0.        , 0.        , 0.        ,        nan]), 'per_category_accuracy': array([0.86273737, 0.93244115, 0.        , 0.66435561, 0.00696605,\n",
      "              nan, 0.        , 0.        , 0.        , 0.87784737,\n",
      "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
      "       0.        , 0.        , 0.87298997, 0.        , 0.        ,\n",
      "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
      "       0.        , 0.        , 0.92712481, 0.78796366, 0.91761394,\n",
      "       0.        , 0.        , 0.        , 0.        ,        nan])}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1577eed162ec4162a69493365c950329",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 6/200 Batch 0/25 Loss 1.1401 Metrics {'mean_iou': 0.16748948766093116, 'mean_accuracy': 0.21181445128118748, 'overall_accuracy': 0.7795023598828317, 'per_category_iou': array([6.09680179e-01, 8.31918538e-01, 0.00000000e+00, 6.34236889e-01,\n",
      "       1.24257235e-02,            nan, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 6.35402026e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 5.64656796e-01, 0.00000000e+00, 3.65611521e-06,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 7.53253595e-01,\n",
      "       6.35493134e-01, 8.50082556e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 0.00000000e+00,            nan]), 'per_category_accuracy': array([8.81516256e-01, 9.41665624e-01, 0.00000000e+00, 7.39806944e-01,\n",
      "       1.26589939e-02,            nan, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 8.90136686e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 8.94297220e-01, 0.00000000e+00, 3.65611521e-06,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 9.26219754e-01,\n",
      "       7.75692895e-01, 9.27878865e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 0.00000000e+00,            nan])}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "169ae321c5ac4a4aab54a816c05035f2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 7/200 Batch 0/25 Loss 1.0940 Metrics {'mean_iou': 0.17137856945919164, 'mean_accuracy': 0.21495584433690398, 'overall_accuracy': 0.7881396456781556, 'per_category_iou': array([6.32120417e-01, 8.39814384e-01, 0.00000000e+00, 6.53276797e-01,\n",
      "       2.36145329e-02,            nan, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 6.44513271e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 5.70917228e-01, 0.00000000e+00, 5.49670921e-05,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 7.63117039e-01,\n",
      "       6.70860701e-01, 8.57197972e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       5.48363063e-06, 0.00000000e+00,            nan]), 'per_category_accuracy': array([9.08405231e-01, 9.45289137e-01, 0.00000000e+00, 7.43501061e-01,\n",
      "       2.43273947e-02,            nan, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 8.96271845e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 8.93162955e-01, 0.00000000e+00, 5.49677426e-05,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 9.25997998e-01,\n",
      "       8.26946192e-01, 9.29580599e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       5.48370681e-06, 0.00000000e+00,            nan])}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "544f85de17a04a3ca74001e2fbbc6bc9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 8/200 Batch 0/25 Loss 0.9463 Metrics {'mean_iou': 0.17478409568512296, 'mean_accuracy': 0.21773480038212398, 'overall_accuracy': 0.7927706422623841, 'per_category_iou': array([6.29223165e-01, 8.42392089e-01, 3.52355169e-04, 6.80396607e-01,\n",
      "       5.21568283e-02,            nan, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 6.55000862e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 5.75092138e-01, 0.00000000e+00, 7.12203670e-04,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
      "       5.25684195e-05, 0.00000000e+00, 0.00000000e+00, 7.79574322e-01,\n",
      "       6.87965551e-01, 8.64953847e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       2.62043422e-06, 0.00000000e+00,            nan]), 'per_category_accuracy': array([9.02830276e-01, 9.46588697e-01, 3.52355621e-04, 7.80701400e-01,\n",
      "       5.56946042e-02,            nan, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 8.98895404e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 8.98825852e-01, 0.00000000e+00, 7.12358991e-04,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
      "       5.25687464e-05, 0.00000000e+00, 0.00000000e+00, 9.33269035e-01,\n",
      "       8.30703600e-01, 9.36619641e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       2.62043913e-06, 0.00000000e+00,            nan])}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a88b7f1699074168b624caaad5466071",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 9/200 Batch 0/25 Loss 1.1287 Metrics {'mean_iou': 0.17933414157617142, 'mean_accuracy': 0.22211677037195932, 'overall_accuracy': 0.7957460527936768, 'per_category_iou': array([6.42468748e-01, 8.51322031e-01, 4.38690416e-02, 6.83716408e-01,\n",
      "       1.05816211e-01,            nan, 3.89710724e-04, 0.00000000e+00,\n",
      "       0.00000000e+00, 6.68283305e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 5.86580529e-01, 0.00000000e+00, 1.00559462e-03,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
      "       1.15812990e-04, 0.00000000e+00, 0.00000000e+00, 7.77224737e-01,\n",
      "       6.87868711e-01, 8.69336423e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       2.94092022e-05, 0.00000000e+00,            nan]), 'per_category_accuracy': array([9.05123309e-01, 9.50701912e-01, 4.38913906e-02, 7.96283141e-01,\n",
      "       1.19598233e-01,            nan, 3.89758329e-04, 0.00000000e+00,\n",
      "       0.00000000e+00, 8.97971821e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
      "       0.00000000e+00, 9.09473869e-01, 0.00000000e+00, 1.00721614e-03,\n",
      "       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,\n",
      "       1.15831636e-04, 0.00000000e+00, 0.00000000e+00, 9.30939519e-01,\n",
      "       8.38613121e-01, 9.35714875e-01, 0.00000000e+00, 0.00000000e+00,\n",
      "       2.94260342e-05, 0.00000000e+00,            nan])}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "569b4a3ebdd24be7811709fc8d588e30",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m/home/chainyo/code/segformer-sidewalk/finetuning.ipynb Cell 10'\u001b[0m in \u001b[0;36m<cell line: 7>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000022?line=24'>25</a>\u001b[0m     metric\u001b[39m.\u001b[39madd_batch(predictions\u001b[39m=\u001b[39mpredicted\u001b[39m.\u001b[39mdetach()\u001b[39m.\u001b[39mcpu()\u001b[39m.\u001b[39mnumpy(), references\u001b[39m=\u001b[39mlabels\u001b[39m.\u001b[39mdetach()\u001b[39m.\u001b[39mcpu()\u001b[39m.\u001b[39mnumpy())\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000022?line=26'>27</a>\u001b[0m \u001b[39mif\u001b[39;00m index \u001b[39m%\u001b[39m \u001b[39m100\u001b[39m \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[0;32m---> <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000022?line=27'>28</a>\u001b[0m     metrics \u001b[39m=\u001b[39m metric\u001b[39m.\u001b[39;49mcompute(num_labels\u001b[39m=\u001b[39;49mnum_labels, ignore_index\u001b[39m=\u001b[39;49m\u001b[39m255\u001b[39;49m, reduce_labels\u001b[39m=\u001b[39;49m\u001b[39mFalse\u001b[39;49;00m)\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000022?line=28'>29</a>\u001b[0m     \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mEpoch \u001b[39m\u001b[39m{\u001b[39;00mepoch\u001b[39m}\u001b[39;00m\u001b[39m/\u001b[39m\u001b[39m{\u001b[39;00mEPOCHS\u001b[39m}\u001b[39;00m\u001b[39m Batch \u001b[39m\u001b[39m{\u001b[39;00mindex\u001b[39m}\u001b[39;00m\u001b[39m/\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mlen\u001b[39m(train_dataloader)\u001b[39m}\u001b[39;00m\u001b[39m Loss \u001b[39m\u001b[39m{\u001b[39;00mloss\u001b[39m.\u001b[39mitem()\u001b[39m:\u001b[39;00m\u001b[39m.4f\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m Metrics \u001b[39m\u001b[39m{\u001b[39;00mmetrics\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n",
      "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py:428\u001b[0m, in \u001b[0;36mMetric.compute\u001b[0;34m(self, predictions, references, **kwargs)\u001b[0m\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=424'>425</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mprocess_id \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=425'>426</a>\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdata\u001b[39m.\u001b[39mset_format(\u001b[39mtype\u001b[39m\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39minfo\u001b[39m.\u001b[39mformat)\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=427'>428</a>\u001b[0m     inputs \u001b[39m=\u001b[39m {input_name: \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdata[input_name] \u001b[39mfor\u001b[39;00m input_name \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfeatures}\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=428'>429</a>\u001b[0m     \u001b[39mwith\u001b[39;00m temp_seed(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mseed):\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=429'>430</a>\u001b[0m         output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compute(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39minputs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mcompute_kwargs)\n",
      "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py:428\u001b[0m, in \u001b[0;36m<dictcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=424'>425</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mprocess_id \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=425'>426</a>\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdata\u001b[39m.\u001b[39mset_format(\u001b[39mtype\u001b[39m\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39minfo\u001b[39m.\u001b[39mformat)\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=427'>428</a>\u001b[0m     inputs \u001b[39m=\u001b[39m {input_name: \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdata[input_name] \u001b[39mfor\u001b[39;00m input_name \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfeatures}\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=428'>429</a>\u001b[0m     \u001b[39mwith\u001b[39;00m temp_seed(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mseed):\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/metric.py?line=429'>430</a>\u001b[0m         output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_compute(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39minputs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mcompute_kwargs)\n",
      "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py:2124\u001b[0m, in \u001b[0;36mDataset.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m   <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2121'>2122</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__getitem__\u001b[39m(\u001b[39mself\u001b[39m, key):  \u001b[39m# noqa: F811\u001b[39;00m\n\u001b[1;32m   <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2122'>2123</a>\u001b[0m     \u001b[39m\"\"\"Can be used to index columns (by string names) or rows (by integer index or iterable of indices or bools).\"\"\"\u001b[39;00m\n\u001b[0;32m-> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2123'>2124</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_getitem(\n\u001b[1;32m   <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2124'>2125</a>\u001b[0m         key,\n\u001b[1;32m   <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2125'>2126</a>\u001b[0m     )\n",
      "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py:2109\u001b[0m, in \u001b[0;36mDataset._getitem\u001b[0;34m(self, key, decoded, **kwargs)\u001b[0m\n\u001b[1;32m   <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2106'>2107</a>\u001b[0m formatter \u001b[39m=\u001b[39m get_formatter(format_type, features\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfeatures, decoded\u001b[39m=\u001b[39mdecoded, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mformat_kwargs)\n\u001b[1;32m   <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2107'>2108</a>\u001b[0m pa_subtable \u001b[39m=\u001b[39m query_table(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_data, key, indices\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_indices \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_indices \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m)\n\u001b[0;32m-> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2108'>2109</a>\u001b[0m formatted_output \u001b[39m=\u001b[39m format_table(\n\u001b[1;32m   <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2109'>2110</a>\u001b[0m     pa_subtable, key, formatter\u001b[39m=\u001b[39;49mformatter, format_columns\u001b[39m=\u001b[39;49mformat_columns, output_all_columns\u001b[39m=\u001b[39;49moutput_all_columns\n\u001b[1;32m   <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2110'>2111</a>\u001b[0m )\n\u001b[1;32m   <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/arrow_dataset.py?line=2111'>2112</a>\u001b[0m \u001b[39mreturn\u001b[39;00m formatted_output\n",
      "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py:532\u001b[0m, in \u001b[0;36mformat_table\u001b[0;34m(table, key, formatter, format_columns, output_all_columns)\u001b[0m\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=529'>530</a>\u001b[0m python_formatter \u001b[39m=\u001b[39m PythonFormatter(features\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m)\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=530'>531</a>\u001b[0m \u001b[39mif\u001b[39;00m format_columns \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=531'>532</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m formatter(pa_table, query_type\u001b[39m=\u001b[39;49mquery_type)\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=532'>533</a>\u001b[0m \u001b[39melif\u001b[39;00m query_type \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mcolumn\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=533'>534</a>\u001b[0m     \u001b[39mif\u001b[39;00m key \u001b[39min\u001b[39;00m format_columns:\n",
      "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py:283\u001b[0m, in \u001b[0;36mFormatter.__call__\u001b[0;34m(self, pa_table, query_type)\u001b[0m\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=280'>281</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mformat_row(pa_table)\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=281'>282</a>\u001b[0m \u001b[39melif\u001b[39;00m query_type \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mcolumn\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=282'>283</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mformat_column(pa_table)\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=283'>284</a>\u001b[0m \u001b[39melif\u001b[39;00m query_type \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mbatch\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=284'>285</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mformat_batch(pa_table)\n",
      "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py:316\u001b[0m, in \u001b[0;36mPythonFormatter.format_column\u001b[0;34m(self, pa_table)\u001b[0m\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=314'>315</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mformat_column\u001b[39m(\u001b[39mself\u001b[39m, pa_table: pa\u001b[39m.\u001b[39mTable) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mlist\u001b[39m:\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=315'>316</a>\u001b[0m     column \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mpython_arrow_extractor()\u001b[39m.\u001b[39;49mextract_column(pa_table)\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=316'>317</a>\u001b[0m     \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdecoded:\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=317'>318</a>\u001b[0m         column \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpython_features_decoder\u001b[39m.\u001b[39mdecode_column(column, pa_table\u001b[39m.\u001b[39mcolumn_names[\u001b[39m0\u001b[39m])\n",
      "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py:143\u001b[0m, in \u001b[0;36mPythonArrowExtractor.extract_column\u001b[0;34m(self, pa_table)\u001b[0m\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=141'>142</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mextract_column\u001b[39m(\u001b[39mself\u001b[39m, pa_table: pa\u001b[39m.\u001b[39mTable) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mlist\u001b[39m:\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/datasets/formatting/formatting.py?line=142'>143</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m pa_table\u001b[39m.\u001b[39;49mcolumn(\u001b[39m0\u001b[39;49m)\u001b[39m.\u001b[39;49mto_pylist()\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device)\n",
    "\n",
    "model.train()\n",
    "for epoch in range(EPOCHS):\n",
    "    for index, batch in enumerate(tqdm(train_dataloader)):\n",
    "        pixel_values = batch[\"pixel_values\"].to(device)\n",
    "        labels = batch[\"labels\"].to(device)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        outputs = model(pixel_values=pixel_values, labels=labels)\n",
    "        loss, logits = outputs.loss, outputs.logits\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        with torch.no_grad():\n",
    "            upsampled_logits = nn.functional.interpolate(\n",
    "                logits, size=labels.shape[-2:], mode=\"bilinear\", align_corners=False\n",
    "            )\n",
    "            predicted = upsampled_logits.argmax(dim=1)\n",
    "            metric.add_batch(predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy())\n",
    "\n",
    "        if index % 100 == 0:\n",
    "            metrics = metric.compute(num_labels=num_labels, ignore_index=255, reduce_labels=False)\n",
    "            print(f\"Epoch {epoch}/{EPOCHS} Batch {index}/{len(train_dataloader)} Loss {loss.item():.4f} Metrics {metrics}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration segments--sidewalk-semantic-2-f89d0845be9cadc9\n",
      "Reusing dataset parquet (/home/chainyo/.cache/huggingface/datasets/segments___parquet/segments--sidewalk-semantic-2-f89d0845be9cadc9/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1000\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "tokenizer = SegformerFeatureExtractor(reduce_labels=True)\n",
    "\n",
    "dataset = load_dataset(HUB_DIR, split=\"train\")\n",
    "length = len(dataset)\n",
    "print(length)\n",
    "\n",
    "encoded_dataset = tokenizer(\n",
    "    images=dataset[\"pixel_values\"], segmentation_maps=dataset[\"label\"], return_tensors=\"pt\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "pixel_values = encoded_dataset[\"pixel_values\"]\n",
    "labels = encoded_dataset[\"labels\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Tensor"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(pixel_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader, Dataset, random_split, Subset\n",
    "\n",
    "class SegmentationDataset(Dataset):\n",
    "    def __init__(self, pixel_values: torch.Tensor, labels: torch.Tensor):\n",
    "        self.pixel_values = pixel_values\n",
    "        self.labels = labels\n",
    "        assert pixel_values.shape[0] == labels.shape[0]\n",
    "        self.length = pixel_values.shape[0]\n",
    "        print(f\"Created dataset with {self.length} samples\")\n",
    "    \n",
    "\n",
    "    def __len__(self):\n",
    "        return self.length\n",
    "\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        image = self.pixel_values[index]\n",
    "        label = self.labels[index]\n",
    "\n",
    "        encoded_inputs = BatchFeature({\"pixel_values\": image, \"labels\": label})\n",
    "\n",
    "        return encoded_inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Created dataset with 1000 samples\n"
     ]
    }
   ],
   "source": [
    "segmentation_dataset = SegmentationDataset(pixel_values, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = segmentation_dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'pixel_values': tensor([[[ 0.0912, -0.1828, -0.1143,  ..., -0.5253, -0.5424, -0.6623],\n",
       "         [-0.0116, -0.2342, -0.1486,  ..., -0.5424, -0.6109, -0.7137],\n",
       "         [ 0.0398, -0.1828, -0.1314,  ..., -0.5082, -0.6281, -0.7650],\n",
       "         ...,\n",
       "         [ 1.1529,  1.3927,  1.0331,  ...,  0.4166,  0.3481,  0.3309],\n",
       "         [ 0.9474,  1.1358,  1.3070,  ...,  0.5022,  0.3652,  0.3994],\n",
       "         [ 0.6049,  1.3413,  1.1358,  ...,  1.2728,  0.6563,  0.8104]],\n",
       "\n",
       "        [[ 0.3102, -0.2150, -0.3200,  ..., -0.3901, -0.4426, -0.5651],\n",
       "         [ 0.2227, -0.2850, -0.3725,  ..., -0.4076, -0.5126, -0.6176],\n",
       "         [ 0.2577, -0.2325, -0.3550,  ..., -0.3725, -0.5301, -0.6702],\n",
       "         ...,\n",
       "         [ 1.1506,  1.3957,  1.0280,  ...,  0.4678,  0.3978,  0.3803],\n",
       "         [ 0.9405,  1.1331,  1.3081,  ...,  0.5553,  0.4153,  0.4503],\n",
       "         [ 0.5903,  1.3431,  1.1331,  ...,  1.3431,  0.7129,  0.8704]],\n",
       "\n",
       "        [[ 0.4788, -0.0267, -0.1312,  ...,  0.0431, -0.0441, -0.2010],\n",
       "         [ 0.3916, -0.0790, -0.1835,  ...,  0.0256, -0.1138, -0.2532],\n",
       "         [ 0.4265, -0.0267, -0.1661,  ...,  0.0605, -0.1312, -0.3055],\n",
       "         ...,\n",
       "         [ 1.2805,  1.5245,  1.1585,  ...,  0.6356,  0.5659,  0.5485],\n",
       "         [ 1.0714,  1.2631,  1.4374,  ...,  0.7228,  0.5834,  0.6182],\n",
       "         [ 0.7228,  1.4722,  1.2631,  ...,  1.5071,  0.8797,  1.0365]]]), 'labels': tensor([[17, 17, 17,  ..., 17, 17, 17],\n",
       "        [17, 17, 17,  ..., 17, 17, 17],\n",
       "        [17, 17, 17,  ..., 17, 17, 17],\n",
       "        ...,\n",
       "        [ 1,  1,  1,  ...,  1,  1,  1],\n",
       "        [ 1,  1,  1,  ...,  1,  1,  1],\n",
       "        [ 1,  1,  1,  ...,  1,  1,  1]])}"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 0,  1,  2,  4,  6,  9, 17, 18, 19, 23, 24, 25, 27, 31, 32])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "segmentation_dataset[0][1].squeeze().unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "indices = np.arange(length)\n",
    "train_indices, val_indices = random_split(indices, [int(length * 0.8), int(length * 0.2)])\n",
    "\n",
    "train_dataset = SegmentationDataset(encoded_dataset, train_indices)\n",
    "val_dataset = SegmentationDataset(encoded_dataset, val_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "list indices must be integers or slices, not str",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m/home/chainyo/code/segformer-sidewalk/finetuning.ipynb Cell 14'\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000018?line=0'>1</a>\u001b[0m train_dataset[\u001b[39m0\u001b[39;49m]\n",
      "\u001b[1;32m/home/chainyo/code/segformer-sidewalk/finetuning.ipynb Cell 12'\u001b[0m in \u001b[0;36mSegmentationDataset.__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000013?line=11'>12</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__getitem__\u001b[39m(\u001b[39mself\u001b[39m, index):\n\u001b[0;32m---> <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000013?line=12'>13</a>\u001b[0m     image \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset[\u001b[39m\"\u001b[39;49m\u001b[39mpixel_values\u001b[39;49m\u001b[39m\"\u001b[39;49m][index]\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000013?line=13'>14</a>\u001b[0m     label \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[\u001b[39m\"\u001b[39m\u001b[39mlabel\u001b[39m\u001b[39m\"\u001b[39m][index]\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000013?line=15'>16</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m image, label\n",
      "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py:471\u001b[0m, in \u001b[0;36mSubset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=468'>469</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(idx, \u001b[39mlist\u001b[39m):\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=469'>470</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[[\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mindices[i] \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m idx]]\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=470'>471</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mindices[idx]]\n",
      "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py:471\u001b[0m, in \u001b[0;36mSubset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=468'>469</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(idx, \u001b[39mlist\u001b[39m):\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=469'>470</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[[\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mindices[i] \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m idx]]\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=470'>471</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mindices[idx]]\n",
      "\u001b[0;31mTypeError\u001b[0m: list indices must be integers or slices, not str"
     ]
    }
   ],
   "source": [
    "train_dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)\n",
    "valid_dataloader = DataLoader(val_dataset, batch_size=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "list indices must be integers or slices, not str",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m/home/chainyo/code/segformer-sidewalk/finetuning.ipynb Cell 15'\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000016?line=0'>1</a>\u001b[0m batch \u001b[39m=\u001b[39m \u001b[39mnext\u001b[39;49m(\u001b[39miter\u001b[39;49m(train_dataloader))\n",
      "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py:530\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=527'>528</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_sampler_iter \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=528'>529</a>\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_reset()\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=529'>530</a>\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_next_data()\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=530'>531</a>\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=531'>532</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataset_kind \u001b[39m==\u001b[39m _DatasetKind\u001b[39m.\u001b[39mIterable \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=532'>533</a>\u001b[0m         \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=533'>534</a>\u001b[0m         \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m>\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called:\n",
      "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py:570\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=567'>568</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_next_data\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=568'>569</a>\u001b[0m     index \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_next_index()  \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=569'>570</a>\u001b[0m     data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_dataset_fetcher\u001b[39m.\u001b[39;49mfetch(index)  \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=570'>571</a>\u001b[0m     \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory:\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py?line=571'>572</a>\u001b[0m         data \u001b[39m=\u001b[39m _utils\u001b[39m.\u001b[39mpin_memory\u001b[39m.\u001b[39mpin_memory(data)\n",
      "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py:49\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m     <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=46'>47</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mfetch\u001b[39m(\u001b[39mself\u001b[39m, possibly_batched_index):\n\u001b[1;32m     <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=47'>48</a>\u001b[0m     \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mauto_collation:\n\u001b[0;32m---> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=48'>49</a>\u001b[0m         data \u001b[39m=\u001b[39m [\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[idx] \u001b[39mfor\u001b[39;00m idx \u001b[39min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m     <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=49'>50</a>\u001b[0m     \u001b[39melse\u001b[39;00m:\n\u001b[1;32m     <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=50'>51</a>\u001b[0m         data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[possibly_batched_index]\n",
      "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py:49\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m     <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=46'>47</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mfetch\u001b[39m(\u001b[39mself\u001b[39m, possibly_batched_index):\n\u001b[1;32m     <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=47'>48</a>\u001b[0m     \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mauto_collation:\n\u001b[0;32m---> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=48'>49</a>\u001b[0m         data \u001b[39m=\u001b[39m [\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset[idx] \u001b[39mfor\u001b[39;00m idx \u001b[39min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m     <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=49'>50</a>\u001b[0m     \u001b[39melse\u001b[39;00m:\n\u001b[1;32m     <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py?line=50'>51</a>\u001b[0m         data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[possibly_batched_index]\n",
      "\u001b[1;32m/home/chainyo/code/segformer-sidewalk/finetuning.ipynb Cell 12'\u001b[0m in \u001b[0;36mSegmentationDataset.__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000013?line=11'>12</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__getitem__\u001b[39m(\u001b[39mself\u001b[39m, index):\n\u001b[0;32m---> <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000013?line=12'>13</a>\u001b[0m     image \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset[\u001b[39m\"\u001b[39;49m\u001b[39mpixel_values\u001b[39;49m\u001b[39m\"\u001b[39;49m][index]\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000013?line=13'>14</a>\u001b[0m     label \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[\u001b[39m\"\u001b[39m\u001b[39mlabel\u001b[39m\u001b[39m\"\u001b[39m][index]\n\u001b[1;32m     <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000013?line=15'>16</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m image, label\n",
      "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py:471\u001b[0m, in \u001b[0;36mSubset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=468'>469</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(idx, \u001b[39mlist\u001b[39m):\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=469'>470</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[[\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mindices[i] \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m idx]]\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=470'>471</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mindices[idx]]\n",
      "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py:471\u001b[0m, in \u001b[0;36mSubset.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=468'>469</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(idx, \u001b[39mlist\u001b[39m):\n\u001b[1;32m    <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=469'>470</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[[\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mindices[i] \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m idx]]\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/torch/utils/data/dataset.py?line=470'>471</a>\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mindices[idx]]\n",
      "\u001b[0;31mTypeError\u001b[0m: list indices must be integers or slices, not str"
     ]
    }
   ],
   "source": [
    "batch = next(iter(train_dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "19a4294f02aad727716e8ed0f765e04171ea0ecb4c129d0b0eebf53be4c3a095"
  },
  "kernelspec": {
   "display_name": "Python 3.8.13 ('segformer')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}