File size: 46,185 Bytes
1904ee8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "070a4097-7a17-409f-af5d-3d0cf43926ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM\n",
    "from huggingface_hub import list_repo_refs\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "100ec138-f7c1-4d8f-b7e0-eb715f320fdc",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"mnoukhov/pythia410m-tldr-sft\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "dbc9a2db-2c16-4e8f-bd2a-213ddc5d139d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.add_special_tokens({\"pad_token\": \"<|padding|>\"}) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "03788af8-6733-492f-84e3-fd358bb88ffd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.pad_token_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "576d3fda-7902-43d7-b4b1-3054f6192b11",
   "metadata": {},
   "outputs": [],
   "source": [
    "example_text = \"hello my name is mr hello\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "c73ddb0c-1551-4b12-82d8-26d3742d6f57",
   "metadata": {},
   "outputs": [],
   "source": [
    "toks = tokenizer(example_text + tokenizer.eos_token, padding=\"max_length\", max_length=7, truncation=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "8904af15-4d27-4718-b53a-060ae65173a9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': [25521, 619, 1416, 310, 278, 83, 23120], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "toks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "8fcf7c83-e8df-457b-9eab-1b1ed2145a76",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum(toks['attention_mask'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ef1dddf6-1d26-4950-910a-c40b2cc394c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_model_name = \"vwxyzjn/EleutherAI_pythia-1b-deduped__sft__tldr\"\n",
    "base_model_revision = \"sft__55513__1706646024\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "bb0df32c-9d90-4ab0-a87d-0ff6ecab03b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = \"/home/toolkit/trl_results/mnoukhov/EleutherAI_pythia-1b-deduped__sft__tldr_dpo_costa_1b_fp16.yml_3d94f50_b9ff2_merged/main\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "3ae77b2a-3132-4dd1-903b-35f28b7e7e5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_model = AutoModelForCausalLM.from_pretrained(model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "08c1d05d-44a4-4859-9d54-48e7a3cd1da7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "12749e76749a40469d7732dc23e0f1dc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/4.05G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "CommitInfo(commit_url='https://huggingface.co/mnoukhov/EleutherAI_pythia-1b-deduped__sft__tldr_dpo_costa_1b_fp16.yml_3d94f50_b9ff2_merged/commit/cd8f4bf53ab02881549cb73b6271005b2e8c3be6', commit_message='Upload GPTNeoXForCausalLM', commit_description='', oid='cd8f4bf53ab02881549cb73b6271005b2e8c3be6', pr_url=None, pr_revision=None, pr_num=None)"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "base_model.push_to_hub(\"mnoukhov/EleutherAI_pythia-1b-deduped__sft__tldr_dpo_costa_1b_fp16.yml_3d94f50_b9ff2_merged\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9ef8927b-f908-460f-adba-54508b133ae0",
   "metadata": {},
   "outputs": [],
   "source": [
    "adapter_repo = \"mnoukhov/EleutherAI_pythia-1b-deduped__sft__tldr_dpo_1b_fp16.yml_24e9f83\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "cb7336d2-a4ac-4607-83ae-e7e1e0b1665d",
   "metadata": {},
   "outputs": [],
   "source": [
    "refs = list_repo_refs(adapter_repo)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2ab002af-7f3b-41b1-a8ad-f7c2296bd68f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f4c8e90f4fba4589a00ec3ee75dc7505",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "adapter_config.json:   0%|          | 0.00/706 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1eb21476d51f44858c32c59e72d70105",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "adapter_model.safetensors:   0%|          | 0.00/18.5M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "014e10f937374412aa10524d1a4d7a8f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/4.05G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step2324\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7c55865b8a7f4c6795b22c0a68b702a6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "adapter_model.safetensors:   0%|          | 0.00/18.5M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "58cb851b46974ae5a3cb066717520f8d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/4.05G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step1743\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "742410c4a20d4a09b10c0c96c8977df5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "adapter_model.safetensors:   0%|          | 0.00/18.5M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "48117e9ab89248038fa8a76ca9a191db",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/4.05G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step1162\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9b9cef19d78c465f900e345ac44acae6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "adapter_model.safetensors:   0%|          | 0.00/18.5M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "08878fc4eb1a407bb6238d4bec9e2817",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/4.05G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step581\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f39fc2322c504c4b9d6b601bbcbbb923",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "adapter_model.safetensors:   0%|          | 0.00/18.5M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step1\n"
     ]
    }
   ],
   "source": [
    "for branch in refs.branches:\n",
    "    if branch.name == \"main\":\n",
    "        continue\n",
    "\n",
    "    model = PeftModelForCausalLM.from_pretrained(base_model, adapter_repo, revision=branch.name)\n",
    "    merged = model.merge_and_unload()\n",
    "    merged.push_to_hub(f\"{adapter_repo}_merged\", revision=branch.name)\n",
    "    print(branch.name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "24627996-2bc2-4944-a36c-0d86108a82c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset, builder, load_from_disk\n",
    "builder.has_sufficient_disk_space = lambda needed_bytes, directory=\".\": True  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ab8916ed-d39b-4d64-b287-ea4569567005",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = load_from_disk(\"/home/toolkit/trl_results/vwxyzjn_summarize_from_feedback_tldr_3_filtered_oai_preprocessing_1706381144/vwxyzjn_EleutherAI_pythia-1b-deduped__dpo__tldr\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6ee65d83-872d-4d96-9c81-be53f2fc54c1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'?'"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ds['generations_dpo__55513__1707379566'][0][-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a11a3760-515b-4a02-9053-853aa3b06fd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "ppo_ds = load_from_disk(\"vwxyzjn_summarize_from_feedback_tldr_3_filtered_oai_preprocessing_1706381144/vwxyzjn_EleutherAI_pythia-1b-deduped__ppo_left_padding_new_nowhiten_reward__tldr\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "5d0c3c4f-71b1-46b0-abdb-036e1bd49a26",
   "metadata": {},
   "outputs": [],
   "source": [
    "text = ppo_ds[\"generations_ppo_left_padding_new_nowhiten_reward__55513__1709671967\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8d2ec316-db2b-481b-9e25-82b2dd363772",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"EleutherAI/pythia-6.9b-deduped\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1fedd4e0-a0a5-4499-9561-605e5adc8d88",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[1]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.encode('<|padding|>')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "42b8260f-19a7-42e1-b809-a24deff3699c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "032ad7febe1b4eb9899d22e5d44d23a0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading readme:   0%|          | 0.00/456 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "954fb6b000ac4b29b0c9033f242aac73",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/122M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3e544b20f15d48f59e901fbaf896a24d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/6.54M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "62d3170267d742ceaf6bdad2a2cef5ae",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split:   0%|          | 0/160800 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ce7cff29f9c042949acad2dcec3ddd6e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating test split:   0%|          | 0/8552 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "ds = load_dataset(\"sophiex/hh-rlhf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "df1ccb5e-7206-45e7-a449-76b64fda72ed",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a9abf38ffb184ba4a4995450a4413bf2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map (num_proc=16):   0%|          | 0/160800 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8e0af258d31742998176207df5cac540",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map (num_proc=16):   0%|          | 0/8552 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tokds = ds.map(lambda x: tokenizer(x['prompt'] + x['chosen']), num_proc=16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "2e72f7f3-b047-4eab-99a7-cc08d19efeba",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "99c7615c05da46d6be5c68ecfba3e748",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/160800 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c9b61731ac524d8c8ad1a44e47bb12b2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/8552 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tokds = tokds.map(lambda x: {\"length\": len(x['input_ids'])})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "413e3eb3-ad2f-4f71-9f27-894c4942be4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "a4c42a89-88dd-4f3d-82cb-1fd7ecb60815",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<seaborn.axisgrid.FacetGrid at 0x7f8abec580d0>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAekAAAHpCAYAAACmzsSXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAz8UlEQVR4nO3deXhV1b3/8U+mE3IgI5gEymAAi6CgBRXTwSKkBEqrXrnPNZYKrYDFG6hAq5Rbq5b2Xvg5gAgRbCvEai3qfRRbsCCEQSwBMSUVUHkcsHgrCZWckzBlXr8/6NnNycAQTrJXct6v5zmPOXuv7Hz3Jp5P9t5rrxVhjDECAADWiXS7AAAA0DxCGgAASxHSAABYipAGAMBShDQAAJYipAEAsBQhDQCApQjp82CMUUVFhXikHADQngjp83D8+HElJibq+PHjbpcCAAgjhDQAAJYipAEAsBQhDQCApQhpAAAsRUgDAGApQhoAAEsR0gAAWIqQBgDAUoQ0AACWIqQBALAUIQ0AgKUIaQAALEVIAwBgKUIaAABLEdIAAFiKkAYAwFKENAAAliKkAQCwFCENAIClCOkOxhijsrIyGWPcLgUA0MYI6Q7G5/Mp59G18vl8bpcCAGhjhHQHFOONd7sEAEA7IKRdxuVrAEBLCGmXcfkaANASQtoCZ7t8zZk2AIQvQtpynGkDQPgipDsAOooBQHgipAEAsBQhDQCApQhpAAAsRUgDAGApQhoAAEsR0gAAWIqQ7uAY7AQAOi9CuoNjsBMA6LwI6U6AwU4AoHMipAEAsBQhDQCApQhpAAAsRUgDAGApQhoAAEsR0gAAWIqQBgDAUoQ0AACWIqQBALAUIQ0AgKUIaQAALEVIAwBgKUIaAABLEdKWYX5oAECANSG9aNEiRUREaPbs2c6yyspK5ebmqnv37urWrZsmTpyo0tLSoO87fPiwJkyYIK/Xq9TUVN17772qra0NarNt2zYNHz5csbGxGjhwoPLz89thj1qH+aEBAAFWhPSePXv01FNPadiwYUHL58yZoz/+8Y966aWXtH37dn322We69dZbnfV1dXWaMGGCqqurtXPnTj3zzDPKz8/XAw884LQ5dOiQJkyYoBtvvFHFxcWaPXu2pk2bpo0bN7bb/l0o5ocGAEgWhPSJEyc0adIk/frXv1ZycrKzvLy8XE8//bQWL16s0aNHa8SIEVq9erV27typXbt2SZJef/11vfvuu3ruued09dVXa/z48frFL36hvLw8VVdXS5JWrlypjIwMPfbYYxo8eLBmzpypf//3f9eSJUtarKmqqkoVFRVBLwAA2pvrIZ2bm6sJEyYoKysraHlRUZFqamqCll9++eXq27evCgsLJUmFhYUaOnSo0tLSnDbZ2dmqqKjQgQMHnDaNt52dne1sozkLFy5UYmKi8+rTp89F7ycAABfK1ZBes2aN/vKXv2jhwoVN1pWUlMjj8SgpKSloeVpamkpKSpw2DQM6sD6w7mxtKioqdPr06Wbrmj9/vsrLy53Xp59+2qr9AwDgYkS79YM//fRT3XPPPdq0aZO6dOniVhnNio2NVWxsrNtlAADCnGtn0kVFRTp69KiGDx+u6OhoRUdHa/v27XriiScUHR2ttLQ0VVdXy+/3B31faWmp0tPTJUnp6elNensH3p+rTUJCguLi4tpo7wAAuHiuhfSYMWO0b98+FRcXO69rrrlGkyZNcr6OiYlRQUGB8z0HDx7U4cOHlZmZKUnKzMzUvn37dPToUafNpk2blJCQoCFDhjhtGm4j0CawDQAAbOXa5e74+HhdeeWVQcu6du2q7t27O8unTp2quXPnKiUlRQkJCZo1a5YyMzN1/fXXS5LGjh2rIUOG6I477tDDDz+skpIS3X///crNzXUuV8+YMUPLly/XfffdpzvvvFNbtmzRiy++qPXr17fvDgMAcIFcC+nzsWTJEkVGRmrixImqqqpSdna2nnzySWd9VFSU1q1bp7vvvluZmZnq2rWrpkyZogULFjhtMjIytH79es2ZM0dLly5V79699Zvf/EbZ2dlu7BIAAOfNqpDetm1b0PsuXbooLy9PeXl5LX5Pv3799Nprr511u6NGjdLevXtDUSIAAO3G9eekAQBA8whpAAAsRUh3IMYYJt4AgDBCSHcgPp9P05avV12jWb4AAJ0TId3BRMcxQxYAhAurenfj/DS87G2McbkaAEBbIaQ7oJrKk8p9do+iY6K0NGe42+UAANoIl7s7KI83XjHehCbLjTEqKyvjDBsAOgFCupPx+XzKeXQtvcABoBMgpDuhGC+dywCgMyCkAQCwFCENAIClCGkAACxFSAMAYClCGgAASxHSAABYihHHOgBmvwKA8ERIdwA1p88MA1pfc1p1dcyABQDhgsvdHYTHG68YZsACgLBCSAMAYClCGgAASxHSHZgxRn6/3+0yAABthJDuwGpOndCc/B2qq6UzGQB0RoR0Bxft7ep2CQCANkJIAwBgKUIaAABLMZiJhRqOMGaMcbkaAIBbCGkLBUYYi4qO1ILxA9wuBwDgEi53W8rjjZcUeab3NkOBAkBYIqQtR+9tAAhfhDQAAJYipAEAsBQdxzoBeoMDQOdESHcCNZVneoNHx0Rpac7woHWBAE9OTlZERIRLFQIAWoPL3RYIBOmxY8dUVlbWqm14vPGK8SY0We7z+ZTz6FrnTBsA0HFwJm2BwHPR9TWnVVlRHvIe3THe+JBuDwDQPghpS3i88aqrjlYtM1oBAP6Jy90AAFiKkAYAwFKENAAAliKkAQCwFCHdiRhj5Pf73S4DABAihHQnUnPqxJlZs+ghDgCdAiHdyTBrFgB0HoQ0AACWIqQBALAUIQ0AgKUIaQAALEVIAwBgKUIaAABLEdIAAFiKkAYAwFKENAAAliKkAQCwFCENAIClCGkAACxFSAMAYClCGgAASxHSAABYKtrtAhB6xhj5fD7n6+bWJScnKyIiwo3yAADniTPpTqim8qRyn92jySsK5Pf7g9b5fD7lPLrWCXEAgL04k3ZRwzPeUPN44xUV0/w/b4w3vk1+JgAgtDiTdpHP59O05etVV1frdikAAAsR0i6LjuOsFgDQPEIaAABLEdIAAFiKkAYAwFKENAAAliKkAQCwFCENAIClCOlOzBjTZMQxAEDHQUh3YjWnTmhO/g7V1TJYCgB0RIR0Jxft7ep2CQCAViKkAQCwFCENAIClCGkAACxFSAMAYClCGgAASxHSAABYipAGAMBShDQAAJYipAEAsBQhDQCApQhpAAAsRUgDAGApQjrMGWNUVlYmY4zbpQAAGiGkw5zP51POo2vl8/ncLgUA0IirIb1ixQoNGzZMCQkJSkhIUGZmpv70pz856ysrK5Wbm6vu3burW7dumjhxokpLS4O2cfjwYU2YMEFer1epqam69957Vdto/uRt27Zp+PDhio2N1cCBA5Wfn98eu9dhxHjj3S4BANAMV0O6d+/eWrRokYqKivT2229r9OjRuvnmm3XgwAFJ0pw5c/THP/5RL730krZv367PPvtMt956q/P9dXV1mjBhgqqrq7Vz504988wzys/P1wMPPOC0OXTokCZMmKAbb7xRxcXFmj17tqZNm6aNGze2+/4CAHAhot384d/+9reD3v/3f/+3VqxYoV27dql37956+umn9fzzz2v06NGSpNWrV2vw4MHatWuXrr/+er3++ut69913tXnzZqWlpenqq6/WL37xC82bN08PPfSQPB6PVq5cqYyMDD322GOSpMGDB+vNN9/UkiVLlJ2d3WxdVVVVqqqqct5XVFS00REAAKBl1tyTrqur05o1a3Ty5EllZmaqqKhINTU1ysrKctpcfvnl6tu3rwoLCyVJhYWFGjp0qNLS0pw22dnZqqiocM7GCwsLg7YRaBPYRnMWLlyoxMRE59WnT59Q7ioAAOfF9ZDet2+funXrptjYWM2YMUOvvPKKhgwZopKSEnk8HiUlJQW1T0tLU0lJiSSppKQkKKAD6wPrztamoqJCp0+fbram+fPnq7y83Hl9+umnodhVAAAuiKuXuyVp0KBBKi4uVnl5uf73f/9XU6ZM0fbt212tKTY2VrGxsa7WAACA6yHt8Xg0cOBASdKIESO0Z88eLV26VLfddpuqq6vl9/uDzqZLS0uVnp4uSUpPT9dbb70VtL1A7++GbRr3CC8tLVVCQoLi4uLaarcAALhorl/ubqy+vl5VVVUaMWKEYmJiVFBQ4Kw7ePCgDh8+rMzMTElSZmam9u3bp6NHjzptNm3apISEBA0ZMsRp03AbgTaBbQAAYCtXz6Tnz5+v8ePHq2/fvjp+/Lief/55bdu2TRs3blRiYqKmTp2quXPnKiUlRQkJCZo1a5YyMzN1/fXXS5LGjh2rIUOG6I477tDDDz+skpIS3X///crNzXUuV8+YMUPLly/XfffdpzvvvFNbtmzRiy++qPXr17u56wAAnJOrIX306FFNnjxZR44cUWJiooYNG6aNGzfqG9/4hiRpyZIlioyM1MSJE1VVVaXs7Gw9+eSTzvdHRUVp3bp1uvvuu5WZmamuXbtqypQpWrBggdMmIyND69ev15w5c7R06VL17t1bv/nNb1p8/AoAAFu4GtJPP/30Wdd36dJFeXl5ysvLa7FNv3799Nprr511O6NGjdLevXtbVSMAAG6x7p40AAA4w/Xe3Wh7xhj5fD7nvwCAjoGQDgM1lSeV++we1decVmVFubzd0879TQAA1xHSYcLjjVdddXSTGcIAAPbinjQAAJYipAEAsBQhDQCApQhpAAAsRUgDAGApQhoXxBijsrIyGWPcLgUAOj1CGhfE5/Mp59G1DIoCAO2AkMYFi/HGu10CAIQFQhoAAEsR0gAAWIphQcNQw4k26AAGAPYipMNQYMKN6JgoLc0Z7nY5AIAWENJhyuONV1QM//wAYDPuSQMAYClCGgAASxHSAABYipAGAMBShDSCMDY3ANiDkEYQxuYGAHsQ0miCsbkBwA6ENAAAliKkAQCwFCENAIClCGkAACxFSAMAYClCGgAASxHSYcwYI7/f73YZAIAWMFdhGKs5dUJz8t9XXEqqM3gJI40BgD0I6TAX7e2qmsqTyn12j6JjorQ0Z7jbJQEA/omQhiTJ441XVAy/DgBgE+5JAwBgKUIaAABLEdIAAFiKkAYAwFKENAAAliKkAQCwFCENAIClCGkAACxFSAMAYClCGgAASxHSAABYipAGAMBSrQrp/v3769ixY02W+/1+9e/f/6KLgp2MMc6UlgCAtteqkP7kk09UV1fXZHlVVZX+/ve/X3RRsJPP59O05etVV1vrdikAEBYuaG7CP/zhD87XGzduVGJiovO+rq5OBQUFuvTSS0NWHOwTHRfvdgkAEDYuKKRvueUWSVJERISmTJkStC4mJkaXXnqpHnvssZAVh/ZljJHf73e7DADAP11QSNfX10uSMjIytGfPHvXo0aNNioI7ak6d0Jz89+XtnuZ2KQAAXWBIBxw6dCjUdcAS0d6ubpcAAPinVoW0JBUUFKigoEBHjx51zrADVq1addGFwQ6BHt3JyclulwIAYadVvbt//vOfa+zYsSooKNDnn38un88X9ELn4fP5lPPoWv5dAcAFrTqTXrlypfLz83XHHXeEuh5YKMZLj24AcEOrQrq6ulpf/vKXQ10LLNFw0BJjjMvVAED4atXl7mnTpun5558PdS2wRE3lSeU+u0eTVxTwSBYAuKhVZ9KVlZX61a9+pc2bN2vYsGGKiYkJWr948eKQFAf3eLzxioppdb9CAEAItOpT+J133tHVV18tSdq/f3/QuoiIiIsuCgAAtDKkt27dGuo6AABAI0xVCQCApVp1Jn3jjTee9bL2li1bWl0QAAA4o1UhHbgfHVBTU6Pi4mLt37+/ycQbAACgdVoV0kuWLGl2+UMPPaQTJ05cVEEAAOCMkN6T/u53v8u43QAAhEhIQ7qwsFBdunQJ5SYBAAhbrbrcfeuttwa9N8boyJEjevvtt/Wzn/0sJIUBABDuWhXSiYmJQe8jIyM1aNAgLViwQGPHjg1JYQAAhLtWhfTq1atDXQcAAGjkogZnLioq0nvvvSdJuuKKK/SlL30pJEUBAIBWhvTRo0eVk5Ojbdu2KSkpSZLk9/t14403as2aNbrkkktCWSMAAGGpVb27Z82apePHj+vAgQMqKytTWVmZ9u/fr4qKCv3whz8MdY0AAISlVp1Jb9iwQZs3b9bgwYOdZUOGDFFeXh4dxwAACJFWnUnX19c3mUNakmJiYlRfX3/RRQEAgFaG9OjRo3XPPffos88+c5b9/e9/15w5czRmzJiQFQcAQDhrVUgvX75cFRUVuvTSSzVgwAANGDBAGRkZqqio0LJly0JdIwAAYalV96T79Omjv/zlL9q8ebPef/99SdLgwYOVlZUV0uIAAAhnF3QmvWXLFg0ZMkQVFRWKiIjQN77xDc2aNUuzZs3StddeqyuuuEI7duxoq1oBAAgrFxTSjz/+uKZPn66EhIQm6xITE/WDH/xAixcvDllxAACEswsK6b/+9a8aN25ci+vHjh2roqKiiy4KdjDGyO/3u10GAIStCwrp0tLSZh+9CoiOjtY//vGPiy4Kdqg5dUJz8neorrbW7VIAICxdUEh/4Qtf0P79+1tc/84776hnz54XXRTsEe3t6nYJABC2Liikv/nNb+pnP/uZKisrm6w7ffq0HnzwQX3rW98KWXEAAISzC3oE6/7779fLL7+sL37xi5o5c6YGDRokSXr//feVl5enuro6/fSnP22TQgEACDcXFNJpaWnauXOn7r77bs2fP1/GGElSRESEsrOzlZeXp7S0tDYpFO4xxsjn87ldBgCEnQsezKRfv3567bXX5PP59OGHH8oYo8suu0zJycltUR8sUFN5UrnP7lF9zWnV1dUqyu2CACBMtGrEMUlKTk7WtddeG8paYDGPN1511dGqPc4ZNQC0l1aN3R0qCxcu1LXXXqv4+Hilpqbqlltu0cGDB4PaVFZWKjc3V927d1e3bt00ceJElZaWBrU5fPiwJkyYIK/Xq9TUVN17772qbfTY0LZt2zR8+HDFxsZq4MCBys/Pb+vdAwDgorga0tu3b1dubq527dqlTZs2qaamRmPHjtXJkyedNnPmzNEf//hHvfTSS9q+fbs+++wz3Xrrrc76uro6TZgwQdXV1dq5c6eeeeYZ5efn64EHHnDaHDp0SBMmTNCNN96o4uJizZ49W9OmTdPGjRvbdX8BALgQrb7cHQobNmwIep+fn6/U1FQVFRXphhtuUHl5uZ5++mk9//zzGj16tCRp9erVGjx4sHbt2qXrr79er7/+ut59911t3rxZaWlpuvrqq/WLX/xC8+bN00MPPSSPx6OVK1cqIyNDjz32mKQzk4G8+eabWrJkibKzs5vUVVVVpaqqKud9RUVFGx4FAACa5+qZdGPl5eWSpJSUFElSUVGRampqgmbXuvzyy9W3b18VFhZKkgoLCzV06NCgXuXZ2dmqqKjQgQMHnDaNZ+jKzs52ttHYwoULlZiY6Lz69OkTup0EAOA8WRPS9fX1mj17tr7yla/oyiuvlCSVlJTI4/EoKSkpqG1aWppKSkqcNo0f+wq8P1ebiooKnT59ukkt8+fPV3l5ufP69NNPQ7KPAABcCFcvdzeUm5ur/fv3680333S7FMXGxio2NtbtMqzV8Lnp5ORkRUREuFwRAHROVpxJz5w5U+vWrdPWrVvVu3dvZ3l6erqqq6ubzMRUWlqq9PR0p03j3t6B9+dqk5CQoLi4uFDvTqcXeG568ooCBjkBgDbkakgbYzRz5ky98sor2rJlizIyMoLWjxgxQjExMSooKHCWHTx4UIcPH1ZmZqYkKTMzU/v27dPRo0edNps2bVJCQoKGDBnitGm4jUCbwDZw4TzeeMV4m84rDgAIHVcvd+fm5ur555/Xq6++qvj4eOcecmJiouLi4pSYmKipU6dq7ty5SklJUUJCgmbNmqXMzExdf/31ks7MYT1kyBDdcccdevjhh1VSUqL7779fubm5ziXrGTNmaPny5brvvvt05513asuWLXrxxRe1fv161/YdAIBzcfVMesWKFSovL9eoUaPUs2dP5/XCCy84bZYsWaJvfetbmjhxom644Qalp6fr5ZdfdtZHRUVp3bp1ioqKUmZmpr773e9q8uTJWrBggdMmIyND69ev16ZNm3TVVVfpscce029+85tmH78CAMAWrp5JByboOJsuXbooLy9PeXl5LbYJjCd+NqNGjdLevXsvuEYAANxiRccxdFzGGJWVlZ3XH1wAgAtDSOOi+P1+5Ty6ll7eANAGCGlctBhvvNslAECnREgDAGApQhoAAEsR0gAAWIqQBgDAUoQ0AACWIqQBALAUIQ0AgKUIaQAALEVIAwBgKUIaAABLEdIAAFiKkAYAwFKENAAAliKk0WrGGPn9frfLAIBOi5BGq9WcOqE5+TtUV1vrdikA0CkR0rgo0d6ubpcAAJ0WIQ0AgKUIaQAALEVII6SMMSorK5Mxxu1SAKDDI6QRUj6fTzmPrpXP53O7FADo8AhphFyMN97tEgCgUyCkAQCwFCENAIClCGkAACxFSAMAYClCGgAAS0W7XQA6PmOM88gVz0cDQOgQ0rhoNZUnlfvsHkXHRGlpznC3ywGAToOQRkh4vPGKiuHXCQBCiXvSAABYipAGAMBShDQAAJYipAEAsBQhDQCApQhphIwxRn6/3+0yAKDTIKQRMjWnTmhO/g7V1da6XQoAdAqENEIq2tvV7RIAoNMgpAEAsBQhjTZljFFZWRljegNAKxDSaFM+n085j651JuAAAJw/QhptLsYb73YJANAhEdIAAFiKkAYAwFKENAAAliKkAQCwFCENAIClCGkAACwV7XYB6HyMMc5z0QxiAgCtR0gj5GoqTyr32T2KjonS0pzhbpcDAB0WIY024fHGKyqGXy8AuBjck0a7YixvADh/hDTaFWN5A8D543qkCwIdq8I1qBjLGwDODyHtAp/Pp8krClR96oTq6mrdLgcAYCkud7skxpugmLjOfUZpjJHf73e7DADosAhptJmaUyc0J3+H6mq5WgAArUFIo01Fe7u6XQIAdFiENAAAliKkAQCwFCENAIClCGkAACxFSAMAYClCGgAASxHScBUTbgBAywhpuIoJNwCgZYQ0XMeEGwDQPEIaAABLMQsW2lxgas7A1wCA80NIo83VVJ5U7rN7FB0TpaU5w90uBwA6DEIa7cLjjVdkdBRTVwLABeCeNNoNU1cCwIUhpNGumLoSAM4fIQ0AgKUIaQAALEVIAwBgKUIaAABLEdIAAFiKkAYAwFIMZoJ2xzChAHB+CGm0O4YJBYDzw+VuuMLjjVeMN6HJcmOMysrKOMMGABHSsIzP51POo2udy+EAEM4IabjGGNPshBsx3vj2LwYALERIwzVMuAEAZ+dqSL/xxhv69re/rV69eikiIkJr164NWm+M0QMPPKCePXsqLi5OWVlZ+uCDD4LalJWVadKkSUpISFBSUpKmTp2qEydOBLV555139LWvfU1dunRRnz599PDDD7f1ruE8MeEGALTM1ZA+efKkrrrqKuXl5TW7/uGHH9YTTzyhlStXavfu3eratauys7NVWVnptJk0aZIOHDigTZs2ad26dXrjjTd01113OesrKio0duxY9evXT0VFRXrkkUf00EMP6Ve/+lWb7x8AABfD1Uewxo8fr/Hjxze7zhijxx9/XPfff79uvvlmSdJvf/tbpaWlae3atcrJydF7772nDRs2aM+ePbrmmmskScuWLdM3v/lNPfroo+rVq5d+97vfqbq6WqtWrZLH49EVV1yh4uJiLV68OCjMG6qqqlJVVZXzvqKiIsR7DgDAuVl7T/rQoUMqKSlRVlaWsywxMVEjR45UYWGhJKmwsFBJSUlOQEtSVlaWIiMjtXv3bqfNDTfcII/H47TJzs7WwYMHW+xBvHDhQiUmJjqvPn36tMUuAgBwVtaGdElJiSQpLS0taHlaWpqzrqSkRKmpqUHro6OjlZKSEtSmuW00/BmNzZ8/X+Xl5c7r008/vfgdAgDgAjHiWDNiY2MVGxvrdhkAgDBn7Zl0enq6JKm0tDRoeWlpqbMuPT1dR48eDVpfW1ursrKyoDbNbaPhz4C7AmN5M9IYAASzNqQzMjKUnp6ugoICZ1lFRYV2796tzMxMSVJmZqb8fr+KioqcNlu2bFF9fb1GjhzptHnjjTdUU1PjtNm0aZMGDRqk5OTkdtobnE1gLO/JKwqaHdwEAMKVqyF94sQJFRcXq7i4WNKZzmLFxcU6fPiwIiIiNHv2bP3yl7/UH/7wB+3bt0+TJ09Wr169dMstt0iSBg8erHHjxmn69Ol666239Oc//1kzZ85UTk6OevXqJUn6zne+I4/Ho6lTp+rAgQN64YUXtHTpUs2dO9elvUZzWhrLGwDCmav3pN9++23deOONzvtAcE6ZMkX5+fm67777dPLkSd11113y+/366le/qg0bNqhLly7O9/zud7/TzJkzNWbMGEVGRmrixIl64oknnPWJiYl6/fXXlZubqxEjRqhHjx564IEHWnz8CgAAW7ga0qNGjTrrPciIiAgtWLBACxYsaLFNSkqKnn/++bP+nGHDhmnHjh2trhMAADdYe08aAIBwR0jDGi3NigUA4YqQhjWYFQsAghHSsAqzYgHAvxDSsJoxhkFOAIQtQhrWaTgCWVlZmXIeXdviZCgA0JkxdjesExiBLDomSktzhivGG+92SQDgCkIaVvJ44xUVw68ngPDG5W4AACxFSAMAYClCGtZicBMA4Y6QhrUY3ARAuCOkYTUGNwEQzghpAAAsRUgDAGApQhoAAEsR0uhQGMsbQDghpNGh+Hw+xvIGEDYYdxHWC0y4EfgvY3kDCBeENKwXmHCjvua0KivK5e2e5nZJANAuCGl0CB5vvOqqo1XLwCYAwgj3pAEAsBQhDQCApQhpAAAsRUgDAGApQhoAAEsR0gAAWIpHsNDhNBzcxBijiIgIpaSkKCIiwu3SACCkCGl0OI0HN4mJ66qXfjJRKSkpbpcGACFFSKNDaji4SUxcN7fLAYA2wT1pAAAsRUijU2EqSwCdCSGNToWpLAF0JtyTRocX6O0d+JqpLAF0FoQ0OrxAb+/omCgtzRnudjkAEDKENDoFjzdeUTH8OgPoXLgnjU7DGCO/3+92GQAQMoQ0Oo2aUyc0J3+H6mpr3S4FAEKCkEanEu3t6nYJABAyhDQAAJYipAEAsBQhDQCApQhpdGqNhwll2FAAHQkhjU6t8TChDBsKoCNh9Ad0OucaJjQ6rpuzPjk5WREREe1eIwCcD0IanU5Lw4QGwrvm9Jn1UdGRWpozXMnJyUpJSSGsAViHy93olDzeeMV4E4KW+Xw+TVu+XnV1tfJ44yVF6s6l6/Qf/+9lLn8DsBIhjU6ruWFCo+MaXfr2dlWMt1s7VgUA54/L3ei0zgwT+r7iUlI5UwbQIRHS6NSivV2de9T1NadVV8e43gA6Di53Iyx4vPGKaXSpGwBsR0gDAGApQhoAAEsR0gAAWIqQBgDAUoQ0AACW4hEshL3AcKGBmbEYIhSALQhphL2Gz1HXVtfo1zPGqH///gQ1ANdxuRtQg+eoIyN011NbGKEMgBUIaaCRhmN5G2NUVlbmXAoHgPZESANn4fP5lPPoWs6sAbiCkAbOIcbLcKIA3EHHMaCRQG/vwNcA4BZCGmgk0Ns7OiZKS3OGu10OgDBGSAPN8HjjFRXD/x4A3MU9aQAALEVIAy0wxsjv9zdZxiNZANoLIQ20oObUCc3J36Hamhr5fD6VlZWprKyMR7IAtBtuugFnEe3t6nQki4qO1ILxA3gkC0C74UwaOA8eb7ykSM3J36G62lpnOZe/AbQlQhq4ANHerkHvGZEMQFsipNtZw4Ey0LEF/i25/A2grRDS7czn82na8vWqq6k9d2NYKRDOH3/8saYtXx/UsYzL3gBCiY5jLoiO48yrI2s4/3SEJy5ohLJnZoxWRESEkpOTmY8awEXjTBpoBWf+6Qbvo+Pi9cknnwTdo6ZjGYCLQUgDIRJ4rjrSE+cso2MZgIvB5W4ghBr3/pak6LhuTkhzGRzAhSCkgRALdCwL/LfmdMv3rANtCG8AzSGkgRBr2LGssqJc0d6uivPGKzI6Sp988ol+8r/FWjltlJKTk2WM0e2Pvao1P75FKSkpbpcOwDKENNAGPN541VVHq7bB6GRn7lm/r2hv1/MaZpSzbAB0HAPaUeCedcNhRgPPWR87dkyff/65Pv/8cx07dkwff/yxch5d60zsUVZWpvr6+qDe4vQeBzo3zqQBFzWcwCNweVySvN3TzzyHHdNFn3zyiR7c8LGMMVowfoBzuTwpKUk+n0//ueoNLpcDnRQhDVig8eXxwPvTx32ak79DSX0uU131ac3J3+FcLg+EelxKqtN7PCkpSX6/n0vkQCfB5W7Acg0f62p4uTwmLj7oTHzyigIdOnRItz3yij7++GMugwOdACENdAKBEc/8fr8UEancZ/fojic36+OPPw4Kau5hAx0LIQ10EoERz+rqap2OadNXFuijjz7SsWPHnM5nDc+0G3dEA2AX7kkDnUiTEc8iI3Tn0nXydk93HvkKnGkH3jfsiGaMce5lB4K74b3t5OTkM2frYvQ0oD2EVUjn5eXpkUceUUlJia666iotW7ZM1113ndtlAW0q2tv1nx3R/tXxLK7R+4Yd0QI9yxv3NK+trtEjOSOcnuZP3D5CGRkZTmgHepsHBAI/JSWl1WHecP51/ihAOAqbkH7hhRc0d+5crVy5UiNHjtTjjz+u7OxsHTx4UKmpqW6XB7SLxmfaDTuiBXqXt9TTvLY2uKf59JUFQaG9YPwA/fi5Pys2MfVfo6118erXM8Y4Z+mSWjxTb+693+/X7Bf2yhijpTnDlZSUFLQ+8EeAJOeyfXNn/saYoO1HRESc9Y8K6czkKIHe8oHvDcUfHMnJyU69DbfZ3OA1Df9Iadxz/1zt+aOmcwibkF68eLGmT5+u73//+5KklStXav369Vq1apV+8pOfBLWtqqpSVVWV8768/MwZRUVFxUXXUVFRoUr/P2SMUX1NpSqPn9lmZGQU73lvxftztT3lK3Xez1i2Vok9M1RfU6kZy/YpJrarorqcUn1NpWqrz7zufGL9mfanjkuSvImXXND7wPZv/+UzTdbHdOmqJ2eMkyTNWLZWdXV1zvq6+notvD1TP//DftWcPhm0/ajoKD1405Wa99utiu2W0uz2Zv3qdf3ytpHO99fV1+vJGeOUlJSk1vD7/Zr1q9e17K6xTr2RMV2cbTZcH/gZfr9fc5/9syTpwZuu1P0v7HbWn6v94ju+0upacXahHJMgPj7+7H9MmTBQVVVloqKizCuvvBK0fPLkyeamm25q0v7BBx80knjx4sWLF682fZWXl581v8LiTPrzzz9XXV2d0tLSgpanpaXp/fffb9J+/vz5mjt3rvM+0AO2e/furb58VFFRoT59+ujTTz9VQkJCq7bR3jpizVLHrLsj1ixRd3vqiDVLHbPu9qw5Pr75sfsDwiKkL1RsbKxiY2ODloXqslFCQkKH+UUN6Ig1Sx2z7o5Ys0Td7akj1ix1zLptqDksnpPu0aOHoqKiVFpaGrS8tLRU6enpLlUFAMDZhUVIezwejRgxQgUFBc6y+vp6FRQUKDMz08XKAABoWdhc7p47d66mTJmia665Rtddd50ef/xxnTx50unt3dZiY2P14IMPNrmMbrOOWLPUMevuiDVL1N2eOmLNUses26aaI4wJn/EAly9f7gxmcvXVV+uJJ57QyJEj3S4LAIBmhVVIAwDQkYTFPWkAADoiQhoAAEsR0gAAWIqQBgDAUoR0O8jLy9Oll16qLl26aOTIkXrrrbdcq2XhwoW69tprFR8fr9TUVN1yyy06ePBgUJtRo0Y5MwUFXjNmzAhqc/jwYU2YMEFer1epqam69957nVmT2sJDDz3UpKbLL7/cWV9ZWanc3Fx1795d3bp108SJE5sMXtPeNV966aVNao6IiFBubq4ke47zG2+8oW9/+9vq1auXIiIitHbt2qD1xhg98MAD6tmzp+Li4pSVlaUPPvggqE1ZWZkmTZqkhIQEJSUlaerUqTpx4kRQm3feeUdf+9rX1KVLF/Xp00cPP/xwm9VdU1OjefPmaejQoeratat69eqlyZMn67PPPgvaRnP/RosWLWqzus91rL/3ve81qWfcuHFBbWw71pKa/T2PiIjQI4884rRp72N9Pp91ofrc2LZtm4YPH67Y2FgNHDhQ+fn5ra67iRDNYYEWrFmzxng8HrNq1Spz4MABM336dJOUlGRKS0tdqSc7O9usXr3a7N+/3xQXF5tvfvObpm/fvubEiRNOm69//etm+vTp5siRI86r4SDwtbW15sorrzRZWVlm79695rXXXjM9evQw8+fPb7O6H3zwQXPFFVcE1fSPf/zDWT9jxgzTp08fU1BQYN5++21z/fXXmy9/+cuu1nz06NGgejdt2mQkma1btxpj7DnOr732mvnpT39qXn75ZSOpyUQ0ixYtMomJiWbt2rXmr3/9q7nppptMRkaGOX36tNNm3Lhx5qqrrjK7du0yO3bsMAMHDjS33367s768vNykpaWZSZMmmf3795vf//73Ji4uzjz11FNtUrff7zdZWVnmhRdeMO+//74pLCw01113nRkxYkTQNvr162cWLFgQ9G/Q8P+FUNd9rmM9ZcoUM27cuKB6ysrKgtrYdqyNMUH1HjlyxKxatcpERESYjz76yGnT3sf6fD7rQvG58fHHHxuv12vmzp1r3n33XbNs2TITFRVlNmzY0Kq6GyOk29h1111ncnNznfd1dXWmV69eZuHChS5W9S9Hjx41ksz27dudZV//+tfNPffc0+L3vPbaayYyMtKUlJQ4y1asWGESEhJMVVVVm9T54IMPmquuuqrZdX6/38TExJiXXnrJWfbee+8ZSaawsNC1mhu75557zIABA0x9fb0xxs7j3PgDuL6+3qSnp5tHHnnEWeb3+01sbKz5/e9/b4wx5t133zWSzJ49e5w2f/rTn0xERIT5+9//bowx5sknnzTJyclBdc+bN88MGjSoTepuzltvvWUkmb/97W/Osn79+pklS5a0+D1tWXdLIX3zzTe3+D0d5VjffPPNZvTo0UHL3DzWxjT9rAvV58Z9991nrrjiiqCfddttt5ns7OyQ1M3l7jZUXV2toqIiZWVlOcsiIyOVlZWlwsJCFyv7l8Bc2Y3nR/3d736nHj166Morr9T8+fN16tQpZ11hYaGGDh0aNKtYdna2KioqdODAgTar9YMPPlCvXr3Uv39/TZo0SYcPH5YkFRUVqaamJug4X3755erbt69znN2qOaC6ulrPPfec7rzzzqCZ1Gw8zg0dOnRIJSUlQcc2MTFRI0eODDq2SUlJuuaaa5w2WVlZioyM1O7du502N9xwgzweT9C+HDx4UD6fr132pby8XBEREU0my1m0aJG6d++uL33pS3rkkUeCLmW6Ufe2bduUmpqqQYMG6e6779axY8eC6rH9WJeWlmr9+vWaOnVqk3VuHuvGn3Wh+twoLCwM2kagTag+48NmWFA3XOgUme2tvr5es2fP1le+8hVdeeWVzvLvfOc76tevn3r16qV33nlH8+bN08GDB/Xyyy9LkkpKSprdp8C6tjBy5Ejl5+dr0KBBOnLkiH7+85/ra1/7mvbv36+SkhJ5PJ4mH75paWlOPW7U3NDatWvl9/v1ve99z1lm43FuLPBzmquj4bFNTU0NWh8dHa2UlJSgNhkZGU22EViXnJzcJvUHVFZWat68ebr99tuDZjX64Q9/qOHDhyslJUU7d+7U/PnzdeTIES1evNiVuseNG6dbb71VGRkZ+uijj/Rf//VfGj9+vAoLCxUVFdUhjvUzzzyj+Ph43XrrrUHL3TzWzX3Whepzo6U2FRUVOn36tOLi4lpdt0RIh7Xc3Fzt379fb775ZtDyu+66y/l66NCh6tmzp8aMGaOPPvpIAwYMaO8yJUnjx493vh42bJhGjhypfv366cUXX7zo/wnaw9NPP63x48erV69ezjIbj3NnVFNTo//4j/+QMUYrVqwIWtdw3vhhw4bJ4/HoBz/4gRYuXOjKuM05OTnO10OHDtWwYcM0YMAAbdu2TWPGjGn3elpj1apVmjRpkrp06RK03M1j3dJnXUfA5e42ZPMUmTNnztS6deu0detW9e7d+6xtA+Obf/jhh5Kk9PT0ZvcpsK49JCUl6Ytf/KI+/PBDpaenq7q6Wn6/v0lNgXrcrPlvf/ubNm/erGnTpp21nY3HOfBzzvY7nJ6erqNHjwatr62tVVlZmevHPxDQf/vb37Rp06Zzzg08cuRI1dbW6pNPPnFqc/PfoH///urRo0fQ74Stx1qSduzYoYMHD57zd11qv2Pd0mddqD43WmqTkJAQkhMIQroN2ThFpjFGM2fO1CuvvKItW7Y0ubzUnOLiYklSz549JUmZmZnat29f0IdF4ANwyJAhbVJ3YydOnNBHH32knj17asSIEYqJiQk6zgcPHtThw4ed4+xmzatXr1ZqaqomTJhw1nY2HueMjAylp6cHHduKigrt3r076Nj6/X4VFRU5bbZs2aL6+nrnD4/MzEy98cYbqqmpCdqXQYMGtdnl10BAf/DBB9q8ebO6d+9+zu8pLi5WZGSkc0nZjbob+r//+z8dO3Ys6HfCxmMd8PTTT2vEiBG66qqrztm2rY/1uT7rQvW5kZmZGbSNQJuQfcaHpPsZWrRmzRoTGxtr8vPzzbvvvmvuuusuk5SUFNRbsD3dfffdJjEx0Wzbti3oUYhTp04ZY4z58MMPzYIFC8zbb79tDh06ZF599VXTv39/c8MNNzjbCDyWMHbsWFNcXGw2bNhgLrnkkjZ9nOlHP/qR2bZtmzl06JD585//bLKyskyPHj3M0aNHjTFnHqXo27ev2bJli3n77bdNZmamyczMdLVmY8705u/bt6+ZN29e0HKbjvPx48fN3r17zd69e40ks3jxYrN3716nF/SiRYtMUlKSefXVV80777xjbr755mYfwfrSl75kdu/ebd58801z2WWXBT0W5Pf7TVpamrnjjjvM/v37zZo1a4zX672ox4LOVnd1dbW56aabTO/evU1xcXHQ73qgV+7OnTvNkiVLTHFxsfnoo4/Mc889Zy655BIzefLkNqv7bDUfP37c/PjHPzaFhYXm0KFDZvPmzWb48OHmsssuM5WVlc42bDvWAeXl5cbr9ZoVK1Y0+X43jvW5PuuMCc3nRuARrHvvvde89957Ji8vj0ewOpply5aZvn37Go/HY6677jqza9cu12qR1Oxr9erVxhhjDh8+bG644QaTkpJiYmNjzcCBA829994b9PyuMcZ88sknZvz48SYuLs706NHD/OhHPzI1NTVtVvdtt91mevbsaTwej/nCF75gbrvtNvPhhx8660+fPm3+8z//0yQnJxuv12v+7d/+zRw5csTVmo0xZuPGjUaSOXjwYNBym47z1q1bm/2dmDJlijHmzGNYP/vZz0xaWpqJjY01Y8aMabI/x44dM7fffrvp1q2bSUhIMN///vfN8ePHg9r89a9/NV/96ldNbGys+cIXvmAWLVrUZnUfOnSoxd/1wHPqRUVFZuTIkSYxMdF06dLFDB482PzP//xPUCCGuu6z1Xzq1CkzduxYc8kll5iYmBjTr18/M3369CZ/0Nt2rAOeeuopExcXZ/x+f5Pvd+NYn+uzzpjQfW5s3brVXH311cbj8Zj+/fsH/YyLxVSVAABYinvSAABYipAGAMBShDQAAJYipAEAsBQhDQCApQhpAAAsRUgDAGApQhoAAEsR0gAAWIqQBgDAUoQ0AACW+v/XfUaUz6/OiAAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 500x500 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "sns.displot(tokds[\"train\"][\"length\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "d11597f9-0441-440c-8214-b9d8b2df6f79",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "46d3909d41c649acb800d4bf00197951",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map (num_proc=16):   0%|          | 0/160800 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e886faa17c774740a2058a5dd8e0673d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map (num_proc=16):   0%|          | 0/8552 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tokds = ds.map(lambda x: tokenizer(x['prompt']), num_proc=16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "84290aac-1c4e-4d29-89bd-318cf2c9daf3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "eb0406bdb9884fcc826630224f2d1a8a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/160800 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "50580c27e575445bb239783adee19f90",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/8552 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tokds = tokds.map(lambda x: {\"prompt_length\": len(x['input_ids'])})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "44d2f307-118b-493d-b626-97490e2bc4aa",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "588d062fd6c2489da6f57b287c66d6e6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Filter (num_proc=16):   0%|          | 0/160800 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bc9327dc6ed2467597a56b4655aca9a9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Filter (num_proc=16):   0%|          | 0/8552 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "filttokds = tokds.filter(lambda x: x[\"prompt_length\"] > 1024, num_proc=16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "2b6d57f7-40b7-4417-88bc-83c63b22f153",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "31"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(filttokds[\"test\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78d391fa-9a57-446b-9007-fe64ef8fc735",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokds = ds.map(lambda x: tokenizer(x['prompt']), num_proc=16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "176dbd05-67c5-45a6-b891-1237deb7d6c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = load_dataset(\"mnoukhov/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_1706381144\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "9eab3eaa-55ed-4279-96d6-3c189266ba86",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "911df020ac294d9ca2b360e1a0be3f93",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Filter:   0%|          | 0/116722 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "ds[\"train\"] = ds[\"train\"].filter(lambda x: x[\"has_comparison\"] == True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "9cc44838-af6d-4e3d-be4e-49436900f469",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4041ceb7527a4e638b138fc12897c35e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "51b19ac739f14b31a5820a9f077cf177",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating parquet from Arrow format:   0%|          | 0/10 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d7e0c946845940bb8de920287d982cd6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "acd8997214f448ca8cc665b4fd3b1af6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating parquet from Arrow format:   0%|          | 0/7 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cae753400174470f8c97d61fbb557202",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2e6161edbdc24eca8fc89781a3b511e8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating parquet from Arrow format:   0%|          | 0/7 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/toolkit/.local/lib/python3.11/site-packages/huggingface_hub/file_download.py:983: UserWarning: Not enough free disk space to download the file. The expected file size is: 0.00 MB. The target location /home/toolkit/huggingface/hub only has 0.00 MB free disk space.\n",
      "  warnings.warn(\n",
      "/home/toolkit/.local/lib/python3.11/site-packages/huggingface_hub/file_download.py:983: UserWarning: Not enough free disk space to download the file. The expected file size is: 0.00 MB. The target location /home/toolkit/huggingface/hub/datasets--mnoukhov--summarize_from_feedback_tldr_3_filtered_oai_preprocessing_1706381144_labelled/blobs only has 0.00 MB free disk space.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d6504b2578bf42a3810c24e04da166e1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "README.md:   0%|          | 0.00/1.17k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "CommitInfo(commit_url='https://huggingface.co/datasets/mnoukhov/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_1706381144_labelled/commit/a873a0b902f97283fb440254b724da8257439c33', commit_message='Upload dataset', commit_description='', oid='a873a0b902f97283fb440254b724da8257439c33', pr_url=None, pr_revision=None, pr_num=None)"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ds.push_to_hub(\"mnoukhov/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_1706381144_labelled\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "557d07cf-781c-4c9a-8499-2bd4d076e98d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['id', 'subreddit', 'title', 'post', 'summary', 'query_token', 'query', 'reference_response', 'reference_response_token', 'reference_response_token_len', 'query_reference_response', 'query_reference_response_token', 'query_reference_response_token_response_label', 'query_reference_response_token_len', 'has_comparison'],\n",
       "        num_rows: 9504\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['id', 'subreddit', 'title', 'post', 'summary', 'query_token', 'query', 'reference_response', 'reference_response_token', 'reference_response_token_len', 'query_reference_response', 'query_reference_response_token', 'query_reference_response_token_response_label', 'query_reference_response_token_len', 'has_comparison'],\n",
       "        num_rows: 6447\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['id', 'subreddit', 'title', 'post', 'summary', 'query_token', 'query', 'reference_response', 'reference_response_token', 'reference_response_token_len', 'query_reference_response', 'query_reference_response_token', 'query_reference_response_token_response_label', 'query_reference_response_token_len', 'has_comparison'],\n",
       "        num_rows: 6553\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7ad1cea-5ac6-4d57-9aff-6783ea61fb13",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}