File size: 42,012 Bytes
d90b3a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
# Copyright (c) 2024, EleutherAI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import subprocess
from dataclasses import dataclass

try:
    from .template import NeoXArgsTemplate
except ImportError:
    from template import NeoXArgsTemplate

try:
    from typing import List, Literal, Union, Optional, Any
except ImportError:
    from typing_extensions import List, Literal, Union, Optional


ATTENTION_TYPE_CHOICES = [
    "global",
    "local",
    "sparse_fixed",
    "sparse_variable",
    "bigbird",
    "bslongformer",
    "gmlp",
    "amlp",
    "flash",
    "rwkv",
    "mamba",
]


def get_git_commit_hash():
    """Gets the git commit hash of your current repo (if it exists)"""
    try:
        git_hash = subprocess.check_output(["git", "describe", "--always"]).strip()
        git_hash = git_hash.decode()
    except (subprocess.CalledProcessError, FileNotFoundError):
        git_hash = None
    return git_hash


@dataclass
class NeoXArgsParallelism(NeoXArgsTemplate):
    """
    Parallelism Arguments
    """

    pipe_parallel_size: int = 0
    """
    Number of pipeline parallel stages. Disable with 0.
    """

    model_parallel_size: int = 1
    """
    Size of the model parallelism.
    """

    pipe_partition_method: str = "type:transformer|mlp"
    """
    method used to distribute model layers across pipeline stages. Choose from "parameters", which balances the number
    of parameters on each pipeline stage, "uniform", which naively balances the number of layers per stage, or
    "type:[regex]", which balances layers whose class names match [regex]
    """

    world_size: int = None
    """
    Total world size (i.e number of gpus in cluster). Configured post-launch using distributed launcher
    """

    is_pipe_parallel: bool = False
    """
    flag to determine whether pipeline parallelism is on - shouldn't be set by user, is automatically determined
    according to pipeline parallel size.
    """

    sequence_parallel: bool = False
    """
    flag to determine whether Megatron-style Sequence Parallelism (https://arxiv.org/abs/2205.05198)
    (Layernorm inputs and activations are sharded across model parallel group) will be used. Has no effect when model_parallel_size is 1.
    **Set by user, in contrast to neox_args.is_pipe_parallel.**
    """

    expert_interval: int = 2
    """
    Have one MoE layer every expert_interval layers
    """


@dataclass
class NeoXArgsModel(NeoXArgsTemplate):
    """
    Model Arguments
    """

    precision: Literal["fp16", "fp32", "bfloat16"] = None
    """
    description of the used precision, either one of fp16 or fp32 (and in the future bf16).
    """

    num_layers: int = None
    """
    Number of transformer layers.
    """

    hidden_size: int = None
    """
    Transformer hidden size.
    """

    intermediate_size: int = None
    """
    Transformer intermediate size. Default = 4h
    """

    mlp_multiple_of: int = 1
    """
    force mlp size to be a multiple of this value
    """

    expansion_factor: float = None
    """
    Transformer intermediate size. Default = 4
    """

    num_attention_heads: int = None
    """
    Number of transformer attention heads.

    If num_kv_heads is set, will control only number of query heads.
    """

    num_kv_heads: int = None
    """
    Number of transformer key/value attention heads.

    If set to None or the same value as num_attention_heads, will perform multi-head attention (MHA).
    If set to < num_attention_heads but > 1, will perform grouped-query attention (GQA) (https://arxiv.org/pdf/2305.13245.pdf)
    If set to 1, will perform multi-query attention.

    Must be < num_attention_heads and divide num_attention_heads evenly.
    """

    seq_length: int = None
    """
    Maximum sequence length to process.
    """

    sliding_window_width: int = None
    """
    Width of the attention sliding window. Only supported with Flash Attention 2.
    """

    max_position_embeddings: int = None
    """
    Maximum number of position embeddings to use. This is the size of position embedding.
    """

    norm: Literal[
        "layernorm", "rmsnorm", "scalenorm", "te_rmsnorm", "te_layernorm"
    ] = "layernorm"
    """
    Normalization layer to use. Choose from "layernorm", "rmsnorm", "scalenorm", "te_rmsnorm", "te_layernorm".
    """

    layernorm_fusion: bool = False
    """
    Use fused layer norm kernel (if `norm` is `layernorm`).
    """

    rmsnorm_fusion: bool = False
    """
    Use fused RMS norm kernel (if `norm` is `rmsnorm`).
    """

    use_qk_layernorm: bool = False
    """
    Use QK Normalization
    """

    layernorm_epsilon: float = 1.0e-5
    """
    Layer norm epsilon.
    """

    rms_norm_epsilon: float = 1.0e-8
    """
    Root mean squared norm epsilon
    """

    scalenorm_epsilon: float = 1.0e-8
    """
    Scalenorm epsilon
    """

    pos_emb: Literal[
        "learned", "rotary", "sinusoidal", "rpe", "alibi", "none"
    ] = "learned"
    """
    Type of positional embedding to use - choose from 'learned', 'rotary', 'sinusoidal', 'rpe', 'none'
    """

    rpe_num_buckets: int = 32
    """
    T5 relative positional encoding number of buckets, default 32.
    """

    rpe_max_distance: int = 128
    """
    T5 relative positional encoding max distance, default 128.
    """

    opt_pos_emb_offset: int = 0
    """
    Learned position embedding offset (only used by OPT, where it should be set to 2).
    """

    no_weight_tying: bool = False
    """
    Disables weight tying between embedding weights and final Linear layer
    """

    attention_config: list = None

    """
    Attention configuration for gpt-neox

    The first item in the list specifies the attention type(s), and should be a list of strings. The second item
    specifies the number of times to repeat those attention types in the full list.

    attention type choices:  [global, local, sparse_fixed, sparse_variable, bslongformer, bigbird, "gmlp", "amlp", "flash", "mamba", "rwkv"]

    So a 12 layer network with only global attention could be specified like:
        [[[`global`], 12]]

    or a 12 layer network with alternating global / local like:
        [[[`global`, `local`], 6]]

    If none is specified, this defaults to
        [[[`global`], n_layers]]
    """

    sparsity_config: dict = None

    """
    Sparsity configuration dict as defined in https://www.deepspeed.ai/docs/config-json/#sparse-attention

    Note that since neox is autoregressive, attention is always "unidirectional" and `horizontal_global_attention` is
    always false.

    The main difference between our sparsity config and deepspeed's is that `mode` is ignored - since it is instead
    specified in attention_config defining each layer.

    An example config is given below:
          "sparse_attention": {
            "block": 16,
            "different_layout_per_head": true,
            "num_local_blocks": 4,
            "num_global_blocks": 1,
            "num_different_global_patterns": 4,
            "num_random_blocks": 0,
            "local_window_blocks": [4],
            "global_block_indices": [0],
            "global_block_end_indices": None,
            "num_sliding_window_blocks": 3
          }
    """

    num_unique_layers: int = None
    """
    Number of unique transformer layers. num-layers should be divisible by this value. Currently only has an effect when pipe_parallel_size=0.
    """

    param_sharing_style: str = "grouped"
    """
    Ordering of the shared parameters. For example, for a num-layers=4 and --num-unique-layers=2, we will have the following ordering for two unique layers 1 and 2-: grouped: [1, 2, 1, 2] and spaced: [1, 1, 2, 2].
    """

    make_vocab_size_divisible_by: int = 128
    """
    Pad the vocab size to be divisible by this value. This is added for computational efficiency reasons.
    """

    activation: Literal[
        "gelu",
        "geglu",
        "relu",
        "softsign",
        "swish",
        "mish",
        "silu",
        "reglu",
        "swiglu",
        "bilinear",
        "glu",
    ] = "gelu"
    """
    Activation function to use - choose from ["gelu", "geglu", "relu", "softsign", "swish", "mish", "silu", "reglu", "swiglu", "bilinear", "glu"]
    """

    use_flashattn_swiglu: bool = False
    """
    Use flash attention's version of swiglu
    """

    scaled_upper_triang_masked_softmax_fusion: bool = False
    """
    Enable fusion of query_key_value_scaling time (upper diagonal) masking and softmax.
    """

    scaled_masked_softmax_fusion: bool = False
    """
    Enable fusion of query_key_value_scaling general masking and softmax.
    """

    bias_gelu_fusion: bool = False
    """
    Enable bias and gelu fusion.
    """

    bias_dropout_fusion: bool = False
    """
    Enable bias and dropout fusion.
    """

    rope_fusion: bool = False
    """
    Enable rotary embedding fusion.
    """

    fp16_lm_cross_entropy: bool = False
    """
    Move the cross entropy unreduced loss calculation for lm head to fp16.
    """

    init_method_std: float = 0.02
    """
    Standard deviation of the zero mean normal distribution used for weight initialization.
    """

    apply_query_key_layer_scaling: bool = False
    """
    Scale Q * K^T by 1 / layer-number. If this flag is set, then it will automatically set attention-softmax-in-fp32 to true
    """

    use_cpu_initialization: bool = False
    """
    If set, affine parallel weights initialization uses CPU
    """

    attention_softmax_in_fp32: bool = False
    """
    Run attention masking and softmax in fp32.
    """

    rotary_pct: float = 1.0
    """
    pct of hidden dims to apply rotary positional embedding to
    """

    rotary_emb_base: int = 10000
    """
    Base for rotary positional embedding
    """

    rotary_save_freqs_buffer: bool = False
    """
    Used to control whether the `inv_freqs` buffer in rotary embeddings
    will be stored in checkpoints (persistent=True) or not.

    Defaults to false, but is left configurable to maintain backward-compatibility
    with GPT-NeoX checkpoints that were trained with this flag.
    """

    init_method: Literal[
        "normal",
        "scaled_normal",
        "orthogonal",
        "scaled_orthogonal",
        "xavier_uniform",
        "xavier_normal",
        "wang_init",
        "small_init",
        "single_residual_scaled_normal",
    ] = "normal"
    """
    Init function used on all layers except ff residual outputs - choose from
    ["normal", "scaled_normal", "orthogonal", "scaled_orthogonal", "xavier_uniform", "xavier_normal", "wang_init", "small_init"]
    """

    output_layer_init_method: Literal[
        "normal",
        "scaled_normal",
        "orthogonal",
        "scaled_orthogonal",
        "xavier_uniform",
        "xavier_normal",
        "wang_init",
        "small_init",
        "single_residual_scaled_normal",
    ] = "scaled_normal"
    """
    Init function used for ff residual outputs - choose from
    ["normal", "scaled_normal", "orthogonal", "scaled_orthogonal", "xavier_uniform", "xavier_normal", "wang_init", "small_init"]
    """

    gmlp_attn_dim: int = 64
    """
    the dimension of the single head self attention in gmlp model (not used in gpt models).
    If None - gmlp model doesn't use attention.
    """

    gpt_j_residual: bool = False
    """
    If false, we use the conventional residual path:
      x = x + attn(ln1(x))
      x = x + mlp(ln2(x))
    Otherwise, we use the residual path from GPT-J, which offers a slight speedup:
      x = ln(x)
      x = x + attn(x) + mlp(x)
    """

    gpt_j_tied: bool = False
    """
    If false, we use
      x = x + attn(ln1(x)) + mlp(ln2(x))
    Otherwise, we tie the layer norms
      y = ln(x)
      x = x + attn(y) + mlp(y)
    """

    use_bias_in_norms: bool = True
    """
    If false, norms (e.g. LayerNorm) will not have bias terms
    """
    use_bias_in_attn_linear: bool = True
    """
    If false, attn_linear (e.g. QKVO) will not have bias terms
    """
    use_bias_in_mlp: bool = True
    """
    If false, mlps will not have bias terms
    """

    soft_prompt_tuning: dict = None
    """
    Dictionary configuring the soft prompt tuning parameters.
    If enabled, will train *only* the soft prompt, and freezes the rest of the model.
    parameters in the dict are:
        'enabled': bool = True # enables soft prompting
        'num_tokens': int = 10 # length of the soft prompt in tokens
        'init_string': str = '' # if provided, initialize the soft prompt with the word embeddings of this string
        'init_range': float = 0.5 # if no init string is provided, initialize the soft prompt with a uniform distribution between -init_range and init_rang
    """

    mamba_selective_scan_fusion: bool = False
    """
    Enable fused kernels for Mamba selective scan.
    """

    mamba_causal_conv_fusion: bool = False
    """
    Enable fused kernels for Mamba causal Conv1d.
    """

    mamba_inner_func_fusion: bool = False
    """
    Enable fused inner operator for Mamba. (Supersedes conv. and selective scan fusion flags, requires each of those kernels to be installed.)
    """

    mamba_selective_fp32_params: bool = True
    """
    Keep selected parameters in fp32 for Mamba (A and D).
    Requires https://github.com/EleutherAI/DeeperSpeed/pull/61 .
    """

    mamba_use_bias_in_conv: bool = True
    """
    If false, conv1d in mamba block will not have bias term
    """

    mamba_use_bias_in_linears: bool = False
    """
    Enable bias terms in mamba block up- and down- projections (in_proj and out_proj).
    """

    # Output layer parallelism over the hidden dim is currently broken (https://github.com/EleutherAI/gpt-neox/issues/905)
    output_layer_parallelism: Literal["column"] = "column"

    """
    Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column)
    """

    dim_att: int = None
    """
    Total dimension of the attention mechanism for RWKV. If not set, defaults to hidden_size.
    """

    head_size: int = None
    """
    Size of each attention head for RWKV. Calculated as dim_att // num_attention_heads.
    """

    ffn_dim: int = None
    """
    Dimension of the feed-forward network for RWKV. If not set, calculated based on hidden_size and expansion_factor.
    """


@dataclass
class NeoXArgsOptimizer(NeoXArgsTemplate):
    """
    Optimizer Arguments
    """

    optimizer_type: Literal[
        "adam",
        "onebitadam",
        "cpu_adam",
        "cpu_torch_adam",
        "sm3",
        "madgrad_wd",
        "sgd",
        "lion",
    ] = "adam"
    """
    Type of optimizer to use. Choose from ['adam', 'onebitadam', 'cpu_adam', 'cpu_torch_adam', 'sm3', 'madgrad_wd', 'sgd', 'lion']
    NOTE: sgd will use MuSGD from Mup. Mup must be enabled for this optimizer.
    """

    use_bnb_optimizer: bool = False
    """
    Whether to enable the bitsandbytes optimizers
    """

    zero_stage: Union[int, List[int], Literal["all"]] = None
    """
    Zero Optimizer stage
    """

    zero_reduce_scatter: bool = None
    """
    Zero: Uses reduce or reduce scatter instead of allreduce to average gradients
    """

    zero_contiguous_gradients: bool = None
    """
    Zero: Copies the gradients to a contiguous buffer as they are produced. Avoids memory fragmentation during backward pass. Only useful when running very large models.
    """

    zero_reduce_bucket_size: int = None
    """
    Zero: Number of elements reduced/allreduced at a time. Limits the memory required for the allgather for large model sizes
    """

    zero_allgather_bucket_size: int = None
    """
    Zero: Number of elements allgathered at a time. Limits the memory required for the allgather for large model sizes
    """

    lr: float = None
    """
    Max Learning rate during training
    """


@dataclass
class NeoXArgsLRScheduler(NeoXArgsTemplate):
    """
    LR Scheduler Arguments
    """

    lr_decay_style: Literal["constant", "linear", "cosine", "exponential"] = "linear"
    """
    Learning rate decay function. Choose from 'constant', 'linear', 'cosine', 'exponential'.
    """

    lr_decay_iters: int = None
    """
    Number of iterations to decay learning rate over, If None defaults to
    --train-iters or the equivalent inferred valued from train_epochs.
    """

    lr_decay_fraction: float = None
    """
    Effective fraction of training over which to decay lr, overrides lr_decay_iters, useful when specifying train_epochs
    """

    min_lr: float = 0.0
    """
    Minimum value for learning rate. The scheduler clips values below this threshold.
    """

    warmup: float = 0.01
    """
    Percentage of total iterations to warmup on (.01 = 1 percent of all training iters).
    """

    override_lr_scheduler: bool = False
    """
    Reset the values of the scheduler (learning rate,warmup iterations, minimum learning rate, maximum number of iterations, and decay style from input arguments and ignore values from checkpoints. Note that all the above values will be reset.
    """

    use_checkpoint_lr_scheduler: bool = False
    """
    Use checkpoint to set the values of the scheduler (learning rate, warmup iterations, minimum learning rate, maximum number of iterations, and decay style from checkpoint and ignore input arguments.
    """


@dataclass
class NeoXArgsLogging(NeoXArgsTemplate):
    """
    Logging Arguments
    """

    use_wandb: bool = None
    """Flag indicating if wandb is to be used."""

    wandb_group: str = None
    """Weights and Biases group name - used to group together "runs"."""

    wandb_team: str = None
    """Team name for Weights and Biases."""

    wandb_project: str = "neox"
    """wandb project name"""

    wandb_host: str = "https://api.wandb.ai"
    """url of the wandb host"""

    wandb_init_all_ranks: bool = False
    """Initialize wandb on all ranks."""

    git_hash: str = get_git_commit_hash()
    """current git hash of repository"""

    log_dir: str = None
    """
    Directory to save logs to.
    """

    tensorboard_writer = None
    """
    initialized tensorboard writer
    """

    tensorboard_dir: str = None
    """
    Write TensorBoard logs to this directory.
    """

    use_comet: bool = None
    """Flag indicating if comet is to be used."""

    comet_workspace: Optional[str] = None
    """
    Comet workspace name, if not configured Comet Experiments will be created in the user configured default workspace.
    """

    comet_project: Optional[str] = None
    """
    Comet project name, if not configured Comet Experiments will be created in the Uncategorized Experiments project.
    """

    comet_experiment_name: Optional[str] = None
    """
    Custom name for the Comet experiment. If not provided, a random name is used.
    """

    comet_tags: Optional[list] = None
    """
    List of tags to attach to the created Comet Experiment.
    """

    comet_others: Optional[dict] = None
    """
    Custom metadata to attach to the created Comet Experiment.
    """

    comet_experiment: Any = None
    """
    Initialized comet experiment object used to log data
    """

    log_interval: int = 100
    """
    Interval between logging.
    """

    log_grad_pct_zeros: bool = False
    """
    Log the percentage of zeros for the gradient of each parameter to wandb / tensorboard (useful for debugging). Needs wandb_init_all_ranks set to True if using pipeline parallelism to log all ranks.
    """

    log_param_norm: bool = False
    """
    Log the frob norm of the parameters to wandb / tensorboard (useful for debugging). Needs wandb_init_all_ranks set to True if using pipeline parallelism to log all ranks.
    """

    log_grad_norm: bool = False
    """
    Log the frob norm of the gradients to wandb / tensorboard (useful for debugging).
    (N.B - this will only work with pp = 0 for now, as we don't have access to the gradients of the model because
    deepspeed.)
    """

    log_optimizer_states: bool = False
    """
    Log the frob norm of the optimizer states to wandb / tensorboard (useful for debugging).
    """

    log_gradient_noise_scale: bool = False
    """
    Whether to log the gradient noise scale when training (cf. https://arxiv.org/abs/1812.06162 for explanation)
    """

    gradient_noise_scale_n_batches: int = 5
    """
    Number of batches to accumulate gradients for in the gradient noise scale logger.
    """

    gradient_noise_scale_cpu_offload: bool = False
    """
    Whether to offload the buffered gradients to cpu when measuring gradient noise scale.
    """

    memory_profiling: bool = False
    """
    Whether to take a memory snapshot of the model. Useful for debugging memory issues.
    """

    memory_profiling_path: str = None
    """
    Path to save memory snapshot to.
    """

    profile: bool = False
    """
    Enable nsys and pytorch profiling. When using this option with nsys,
    nsys options should be directly specified in commandline.
    An example nsys commandline is
    ```
    nsys profile -s none -t nvtx,cuda -o <path/to/output_file>
    --force-overwrite true
    --capture-range=cudaProfilerApi
    --capture-range-end=stop
    ```
    """

    profile_step_start: int = 10
    """
    Step to start profiling at.
    """

    profile_step_stop: int = 12
    """
    Step to stop profiling at.
    """


@dataclass
class NeoXArgsOther(NeoXArgsTemplate):
    """
    Misc. Arguments
    """

    distributed_backend: str = "nccl"
    """
    Which backend to use for distributed training.
    """

    local_rank: int = None
    """
    local rank passed from distributed launcher.
    """

    rank: int = None
    """
    global rank of process being run (passed in via distributed launcher)
    """

    lazy_mpu_init: bool = False
    """
    If set to True, initialize_megatron() skips DDP initialization and returns function to complete it instead. Also turns on use-cpu-initialization flag. This is for external DDP manager.
    """

    short_seq_prob: float = 0.1
    """
    Probability of producing a short sequence.
    """

    eod_mask_loss: bool = False
    """
    Mask loss for the end of document tokens.
    """

    adlr_autoresume: bool = False
    """
    Enable auto-resume on adlr cluster.
    """

    adlr_autoresume_object = None
    """
    imported autoresume
    """

    adlr_autoresume_interval: int = 1000
    """
    Intervals over which check for auto-resume termination signal
    """

    seed: int = 1234
    """
    Random seed used for python, numpy, pytorch, and cuda.
    """

    onnx_safe: bool = False
    """
    Use workarounds for known problems with Torch ONNX exporter
    """

    deepscale: bool = False
    """
    (Deprecated) enable DeepSpeed (helper flag for user code, no impact on DeepSpeed backend)'
    """

    deepscale_config: str = None
    """(Deprecated) deepscale json configuration file."""

    deepspeed_mpi: bool = False
    """
    Run via MPI, this will attempt to discover the necessary variables to initialize torch distributed from the MPI environment
    """

    deepspeed_slurm: bool = False
    """
    Run via SLURM, this will attempt to discover the necessary variables to initialize torch distributed from the SLURM environment
    """

    user_script: str = None
    """
    user script to be run
    """

    iteration: int = None
    """
    Set during training
    """

    do_train: bool = None
    """
    Set during training
    """

    do_valid: bool = None
    """
    Set during training
    """

    do_test: bool = None
    """
    Set during training
    """

    global_num_gpus: int = None
    """
    Set during launching
    """


@dataclass
class NeoXArgsTokenizer(NeoXArgsTemplate):
    """
    Tokenizer Arguments
    """

    tokenizer_type: Literal[
        "GPT2BPETokenizer",
        "HFTokenizer",
        "HFGPT2Tokenizer",
        "SPMTokenizer",
        "CharLevelTokenizer",
        "TiktokenTokenizer",
    ] = "GPT2BPETokenizer"
    """
    Type of tokenizer to use - should be one of ["GPT2BPETokenizer", "HFTokenizer", "HFGPT2Tokenizer", "SPMTokenizer", "CharLevelTokenizer", "TiktokenTokenizer"]
    """

    padded_vocab_size: int = None
    """
    Total (padded) vocabulary size of tokenizer. Configured after launching of training,
    as it's dependent on the parallelism size.
    """

    tokenizer = None
    """
    tokenizer object loaded into memory and accessible by other functions
    """


@dataclass
class NeoXArgsTraining(NeoXArgsTemplate):
    """
    Training Arguments
    """

    data_path: str = None
    """
    Path to combined dataset to split.
    """

    use_shared_fs: bool = True
    """
    Whether to use a shared filesystem for data loading. If False, local rank 0 on all nodes will preprocess the data,
    otherwise only global rank 0 will preprocess the data. This is implemented in megatron/data/gpt2_dataset.py::_build_index_mappings.
    """

    train_data_paths: list = None
    """
    List of paths to train datasets.
    """

    train_label_data_paths: list = None
    """
    List of paths to train label datasets (not shifted by 1 yet!).
    """

    train_reward_data_paths: list = None
    """
    List of paths to train reward datasets
    """

    test_data_paths: list = None
    """
    List of paths to test datasets.
    """

    test_label_data_paths: list = None
    """
    List of paths to test label datasets (not shifted by 1 yet!).
    """

    test_reward_data_paths: list = None
    """
    List of paths to test reward datasets
    """

    valid_data_paths: list = None
    """
    List of paths to validation datasets.
    """

    valid_label_data_paths: list = None
    """
    List of paths to validation label datasets (not shifted by 1 yet!).
    """

    valid_reward_data_paths: list = None
    """
    List of paths to validation reward datasets
    """

    pos_train_data_paths: list = None
    neg_train_data_paths: list = None
    """
    List of paths to positive and negative training datasets.
    """

    pos_train_label_data_paths: list = None
    neg_train_label_data_paths: list = None
    """
    List of paths to positive and negative training label datasets (not shifted by 1 yet!).
    """

    pos_valid_data_paths: list = None
    neg_valid_data_paths: list = None
    """
    List of paths to positive and negative validation datasets.
    """

    pos_valid_label_data_paths: list = None
    neg_valid_label_data_paths: list = None
    """
    List of paths to positive and negative validation label datasets (not shifted by 1 yet!).
    """

    pos_test_data_paths: list = None
    neg_test_data_paths: list = None
    """
    List of paths to positive and negative test datasets.
    """

    pos_test_label_data_paths: list = None
    neg_test_label_data_paths: list = None
    """
    List of paths to positive and negative test label datasets (not shifted by 1 yet!).
    """

    train_data_weights: list = None
    """
    List of 'weights' that decide how often to sample from each training dataset when blending datasets. If None, defaults to equal weighting.
    Should be a list the same length as `train_data_paths`
    """

    valid_data_weights: list = None
    """
    List of 'weights' that decide how often to sample from each validation dataset when blending datasets. If None, defaults to equal weighting.
    Should be a list the same length as `valid_data_paths`
    """

    test_data_weights: list = None
    """
    List of 'weights' that decide how often to sample from each test dataset when blending datasets. If None, defaults to equal weighting.
    Should be a list the same length as `test_data_paths`
    """

    weight_by_num_documents: bool = False
    """
    If True, Builds dataset weights from a multinomial distribution over groups of data according to the number of
    documents in each group.

    WARNING: setting this to True will override any user provided weights

    We sample from a group according to the probability p(L) ∝ |L| ** α,
    where p(L) is the probability of sampling from a given group,
          |L| is the number of examples in that datapoint,
          and α is a coefficient that acts to upsample data from underrepresented groups

    Hence α (`alpha`) allows us to control how much to 'boost' the probability of training on low-resource groups.

    See https://arxiv.org/abs/1911.02116 for more details
    """

    weighted_sampler_alpha: float = 1.0
    """
    Alpha value for `weight_by_num_documents`. Only has an effect if `weight_by_num_documents` = True.

    when alpha = 1, the probability of sampling from a given group = n_samples / total_samples
    as alpha -> 0, the probability of sampling from all groups becomes equal, and number of documents has no effect
    as alpha -> inf, the probability of sampling from the groups with *the most samples* -> 1
    """

    data_impl: Literal["infer", "mmap", "cached"] = "infer"
    """
    Implementation of indexed datasets, can be one of "infer", "cached", or "mmap"
    """

    pack_impl: Literal["packed", "pack_until_overflow", "unpacked"] = "packed"
    """
    Packing implementation, can be one of "packed", "pack_until_overflow", or "unpacked".

    warning: pack_until_overflow is very naive and will likely have issues with pretraining scale datasets
    """

    dataset_impl: Literal["gpt2", "pairwise"] = "gpt2"
    """
    Dataset implementation, can be one of "gpt2" or "pairwise"
    """

    train_impl: Literal["normal", "dpo", "rm", "kto"] = "normal"
    """
    Training implementation, can be one of "normal", "dpo", "kto", or "rm"
    """

    dpo_fp32: bool = True
    """
    Whether to cast logits to fp32 for DPO loss calculation.
    """

    dpo_reference_free: bool = False
    """
    Whether to use reference-free DPO.
    """

    dpo_beta: float = 0.1
    """
    Beta value for DPO
    """

    kto_fp32: bool = True
    """
    Whether to cast logits to fp32 for KTO loss calculation.
    """

    kto_desirable_weight: float = 1.0
    """
    Weight for desirable loss in KTO. Might help if you have unbalanced desirable and undesirable classes.
    """

    kto_undesirable_weight: float = 1.0
    """
    Weight for undesirable loss in KTO. Might help if you have unbalanced desirable and undesirable classes.
    """

    z_loss: float = 0.0
    """
    Z-loss parameter, only implemented for RM training currently.
    https://arxiv.org/pdf/2204.02311
    https://arxiv.org/pdf/2309.10305
    """

    kto_beta: float = 0.1
    """
    Beta value for KTO
    """

    allow_chopped: bool = True
    """
    WARNING: if your packing impl is packed, this is ignored.

    Allow chopped samples in the dataset.
    (e.g if your sequence length is 1024 and you have a sample of length 1026, it will be chopped to 1024)
    """

    mmap_warmup: bool = False
    """
    Warm up mmap files.
    """

    save: str = None
    """
    Output directory to save checkpoints to.
    """

    s3_path: str = None
    """
    Path to s3 bucket for saving checkpoints.
    """

    s3_chunk_size: int = 104_857_600
    """
    The number of bytes in each file chunk when uploading to s3. Defaults to 100MiB.
    """

    config_files: dict = None
    """
    Store of original config files mapping config filename to file contents
    """

    load: str = None
    """
    Directory containing a model checkpoint.
    """

    checkpoint_validation_with_forward_pass: bool = False
    """
    save input and output of a forward pass with the checkpoint and validate after load
    """

    checkpoint_scale: Literal["linear", "log"] = "linear"
    """
    How step at which checkpoints are saved should scale. "linear" implies 1 checkpoint will be saved at every multiple of `checkpoint-factor`,
    while "log" implies that the number of steps between each checkpoint will be multiplied by `checkpoint-factor` at each step, starting from step 1.
    """

    checkpoint_factor: Union[int, float] = None
    """
    Acts as a multiplier on either the "log" or "linear" checkpoint spacing.

    With `checkpoint-scale="linear"`, `checkpoint-factor=20`, and `train-iters=100`, checkpoints will be saved at
    steps [20, 40, 60, 80, 100].

    With `checkpoint-scale="log"`, `checkpoint-factor=2`, and `train-iters=100`, checkpoints will be saved at
    steps [1, 2, 4, 8, 16, 32, 64, 100].

    Note that the last checkpoint step is always saved.
    """

    extra_save_iters: list = None
    """
    Additional iterations when a checkpoint should be saved.
    Must be a list of ints or `None`.
    """

    no_save_optim: bool = False
    """
    Do not save current optimizer.
    """

    no_save_rng: bool = False
    """
    Do not save current rng state.
    """

    no_load_optim: bool = False
    """
    Do not load optimizer when loading checkpoint.
    """

    no_load_rng: bool = False
    """
    Do not load rng state when loading checkpoint.
    """

    finetune: bool = False
    """
    Load model for finetuning. Do not load optimizer or rng state from checkpoint and set iteration to 0. Assumed when loading a release checkpoint.
    """

    batch_size: int = None
    """
    training microbatch size per gpu
    """

    train_iters: int = None
    """
    Number of iterations to run for training.
    """

    train_epochs: int = None
    """
    Number of epochs to run for training. Do not specify both train_epochs and train_iters.
    Not currently compatible with data reweighing, pairwise datasets, and packing other than 'packed'
    """

    eval_iters: int = 100
    """
    Number of iterations to run for evaluation validation/test for.
    """

    keep_last_n_checkpoints: int = None
    """
    Number of last checkpoints to keep
    """

    eval_interval: int = 1000
    """
    Interval between running evaluation on validation set.
    """

    split: str = "969, 30, 1"
    """
    Comma_separated list of proportions for training, validation, and test split. For example the split 90,5,5 will use 90% of data for training, 5% for validation and 5% for test.
    """

    vocab_file: str = None
    """
    Path to the vocab file.
    """

    merge_file: str = None
    """
    Path to the BPE merge file.
    """

    num_workers: int = 2
    """
    Dataloader number of workers.
    """

    exit_interval: int = None
    """
    Exit the program after the iteration is divisible by this value.
    """

    attention_dropout: float = 0.0
    """
    Post attention dropout probability.
    """

    hidden_dropout: float = 0.0
    """
    Dropout probability for hidden state transformer.
    """

    weight_decay: float = 0.1
    """
    Weight decay coefficient for L2 regularization.
    """

    checkpoint_activations: bool = False
    """
    Checkpoint activation to allow for training with larger models, sequences, and batch sizes.
    """

    checkpoint_num_layers: int = 1
    """
    Chunk size (number of layers) for checkpointing.
    """

    deepspeed_activation_checkpointing: bool = True
    """
    DEPRECATED - TODO: remove
    Uses activation checkpointing from deepspeed
    """

    contiguous_checkpointing: bool = False
    """
    Contiguous memory checkpointing for activations.
    """

    checkpoint_in_cpu: bool = False
    """
    Move the activation checkpoints to CPU.
    """

    synchronize_each_layer: bool = False
    """
    does a synchronize at the beginning and end of each checkpointed layer.
    """

    profile_backward: bool = False
    """
    Enables backward pass profiling for checkpointed layers.
    """

    partition_activations: bool = False
    """
    Partition Activations across GPUs before checkpointing.
    """

    clip_grad: float = 1.0
    """
    Gradient clipping based on global L2 norm.
    """

    hysteresis: int = 2
    """
    hysteresis for dynamic loss scaling
    """

    dynamic_loss_scale: bool = None
    """
    flag indicating whether dynamic loss scale is used
    """

    loss_scale: float = None
    """
    Static loss scaling, positive power of 2
    values can improve fp16 convergence. If None, dynamic loss scaling is used.
    """

    loss_scale_window: float = 1000.0
    """
    Window over which to raise/lower dynamic scale.
    """

    min_scale: float = 1.0
    """
    Minimum loss scale for dynamic loss scale.
    """

    char_level_ppl: bool = False
    """
    Whether to calculate character level perplexity as well as token level perplexity. (may incur a time cost)
    """

    use_mup: bool = False
    """
    Whether to use Microsoft's Mup https://github.com/microsoft/mup
    """

    coord_check: bool = False
    """
    Whether to generate a "coord check" plot to verify mup's implementation in neox
    """

    save_base_shapes: bool = False
    """
    Whether to save base shapes for mup. This will save the shapes to the path specified in base-shapes-file.
    """

    base_shapes_file: str = None
    """
    Path to the base shapes to save to/load from
    """

    mup_init_scale: float = 1.0
    """
    Initialization scale: All the parameters are multiplied by this value
    """

    mup_attn_temp: float = 1.0
    """
    Attention temperature: Reciprocal of the multiplier applied to the input to attention softmax
    """

    mup_output_temp: float = 1.0
    """
    Output temperature: Reciprocal of the multiplier applied to the input to softmax that
    produces the distribution over output tokens.
    """

    mup_embedding_mult: float = 1.0
    """
    Scalar by which we multiply the output of the embedding layer
    """

    mup_rp_embedding_mult: float = 1.0
    """
    Scalar by which we multiply vectors representing relative position
    """

    mup_width_scale: int = 2
    """
    What to scale width by when creating the delta model for mup
    """


@dataclass
class NeoXArgsTextgen(NeoXArgsTemplate):
    """
    Text Generation arguments
    """

    text_gen_type: str = None
    """
    How to generate text/sample the model.
    Options: `unconditional`, `input-file`, `interactive`, `precompute`
    """

    precompute_model_name: str = None
    """
    Model name to use for saving precomputed logprobs
    """

    temperature: float = 0.0
    """
    exponential scaling output distribution ("higher == more risk")
    """

    top_p: float = 0.0
    """
    Top-p (nucleus) sampling chooses from the smallest possible set of tokens whose cumulative probability exceeds the probability top_p.
    """

    top_k: int = 0
    """
    integer between 0 and the models vocab size. Filters out any logits with a probability less than that of the top_kth token.
    """

    return_logits: bool = False
    """
    Boolean for whether to return the logits for generated tokens
    """

    maximum_tokens: int = 64
    """
    maximum number of tokens to be generated
    """

    prompt_end: str = "\n"
    """
    a single prompt's end. Defaults to newline
    """

    sample_input_file: str = None
    """
    Get input from file instead of interactive mode, each line is an input.
    """

    sample_output_file: str = "samples.txt"
    """
    Output file
    """

    num_samples: int = 1
    """
    Number of samples to generate unconditionally, defaults to 1 and interactive conditional sampling
    """

    recompute: bool = False
    """
    During generation recompute all attention instead of using previously computed keys/values.
    Should be set to true for sparse attention models
    """

    eval_results_prefix: str = ""
    """
    prefix to which to save evaluation results - final fp will be {eval_results_prefix}_eval_results_yy-mm-dd-HH-MM.json
    """

    eval_tasks: list = None
    """
    Tasks to evaluate on using lm_eval_harness

    NOTE: Requires internet connection
    """

    moe_top_k: int = 1
    """
    Activate top K experts in MoE
    """

    use_tutel: bool = False
    """
    Use Tutel optimizations in MoE
    """

    moe_num_experts: int = 1
    """
    Number of MoE experts
    """

    moe_loss_coeff: float = 0.1
    """
    Coefficient for MoE loss
    """

    moe_train_capacity_factor: float = 1.0
    """
    The capacity of the expert at train time
    """

    moe_eval_capacity_factor: float = 1.0
    """
    The capacity of the expert at eval time
    """

    moe_min_capacity: int = 4
    """
    The minimum capacity per expert regardless of the capacity_factor
    """

    moe_token_dropping: bool = False
    """
    Whether to drop tokens when exceeding capacity
    """

    create_moe_param_group: bool = True
    """
    Whether to create a separate parameter group for MoE parameters
    """

    moe_use_residual: bool = True
    """
    Whether to use residual in MoE
    """

    moe_expert_parallel_size: int = 1
    """
    Number of parallel experts in MoE
    """

    moe_type: str = "megablocks"
    """
    Either `deepspeed` or `megablocks`
    """

    moe_glu: bool = False
    """
    Use gated linear units in MoE
    """

    moe_lbl_in_fp32: bool = False
    """
    Whether to compute the load balancing loss in fp32.
    """

    moe_jitter_eps: float = None
    """
    Coefficient for MoE routing jitter. Jitter is
    not used if set to None
    """

    enable_expert_tensor_parallelism: bool = False
    """
    Enable expert tensor parallelism
    """