Safetensors
Korean
new
reranker
korean
custom_code
sigridjineth commited on
Commit
1539793
·
verified ·
1 Parent(s): fdefb2d

Upload folder using huggingface_hub

Browse files
checkpoint-200/config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/workspace/sigrid/kozistr-ko-sentence-embeddings/reranker-data/gte-korean-reranker-base-241210-c",
3
+ "architectures": [
4
+ "NewForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration.NewConfig",
9
+ "AutoModel": "Alibaba-NLP/new-impl--modeling.NewModel",
10
+ "AutoModelForMaskedLM": "Alibaba-NLP/new-impl--modeling.NewForMaskedLM",
11
+ "AutoModelForMultipleChoice": "Alibaba-NLP/new-impl--modeling.NewForMultipleChoice",
12
+ "AutoModelForQuestionAnswering": "Alibaba-NLP/new-impl--modeling.NewForQuestionAnswering",
13
+ "AutoModelForSequenceClassification": "modeling.NewForSequenceClassification",
14
+ "AutoModelForTokenClassification": "Alibaba-NLP/new-impl--modeling.NewForTokenClassification"
15
+ },
16
+ "classifier_dropout": 0.0,
17
+ "hidden_act": "gelu",
18
+ "hidden_dropout_prob": 0.1,
19
+ "hidden_size": 768,
20
+ "id2label": {
21
+ "0": "LABEL_0"
22
+ },
23
+ "initializer_range": 0.02,
24
+ "intermediate_size": 3072,
25
+ "label2id": {
26
+ "LABEL_0": 0
27
+ },
28
+ "layer_norm_eps": 1e-12,
29
+ "layer_norm_type": "layer_norm",
30
+ "logn_attention_clip1": false,
31
+ "logn_attention_scale": false,
32
+ "max_position_embeddings": 8192,
33
+ "model_type": "new",
34
+ "num_attention_heads": 12,
35
+ "num_hidden_layers": 12,
36
+ "pack_qkv": true,
37
+ "pad_token_id": 1,
38
+ "position_embedding_type": "rope",
39
+ "rope_scaling": {
40
+ "factor": 8.0,
41
+ "type": "ntk"
42
+ },
43
+ "rope_theta": 20000,
44
+ "torch_dtype": "float16",
45
+ "transformers_version": "4.44.2",
46
+ "type_vocab_size": 1,
47
+ "unpad_inputs": false,
48
+ "use_memory_efficient_attention": false,
49
+ "vocab_size": 250048
50
+ }
checkpoint-200/configuration.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The GTE Team Authors and Alibaba Group.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ NEW model configuration"""
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.utils import logging
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+
23
+ class NewConfig(PretrainedConfig):
24
+ r"""
25
+ This is the configuration class to store the configuration of a [`NewModel`] or a [`TFNewModel`]. It is used to
26
+ instantiate a NEW model according to the specified arguments, defining the model architecture. Instantiating a
27
+ configuration with the defaults will yield a similar configuration to that of the NEW
28
+ [izhx/new-base-en](https://huggingface.co/izhx/new-base-en) architecture.
29
+
30
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
+ documentation from [`PretrainedConfig`] for more information.
32
+
33
+
34
+ Args:
35
+ vocab_size (`int`, *optional*, defaults to 30522):
36
+ Vocabulary size of the NEW model. Defines the number of different tokens that can be represented by the
37
+ `inputs_ids` passed when calling [`NewModel`] or [`TFNewModel`].
38
+ hidden_size (`int`, *optional*, defaults to 768):
39
+ Dimensionality of the encoder layers and the pooler layer.
40
+ num_hidden_layers (`int`, *optional*, defaults to 12):
41
+ Number of hidden layers in the Transformer encoder.
42
+ num_attention_heads (`int`, *optional*, defaults to 12):
43
+ Number of attention heads for each attention layer in the Transformer encoder.
44
+ intermediate_size (`int`, *optional*, defaults to 3072):
45
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
46
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
47
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
48
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
49
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
50
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
51
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
52
+ The dropout ratio for the attention probabilities.
53
+ max_position_embeddings (`int`, *optional*, defaults to 512):
54
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
55
+ just in case (e.g., 512 or 1024 or 2048).
56
+ type_vocab_size (`int`, *optional*, defaults to 2):
57
+ The vocabulary size of the `token_type_ids` passed when calling [`NewModel`] or [`TFNewModel`].
58
+ initializer_range (`float`, *optional*, defaults to 0.02):
59
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
60
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
61
+ The epsilon used by the layer normalization layers.
62
+ position_embedding_type (`str`, *optional*, defaults to `"rope"`):
63
+ Type of position embedding. Choose one of `"absolute"`, `"rope"`.
64
+ rope_theta (`float`, *optional*, defaults to 10000.0):
65
+ The base period of the RoPE embeddings.
66
+ rope_scaling (`Dict`, *optional*):
67
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
68
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
69
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
70
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
71
+ these scaling strategies behave:
72
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
73
+ experimental feature, subject to breaking API changes in future versions.
74
+ classifier_dropout (`float`, *optional*):
75
+ The dropout ratio for the classification head.
76
+
77
+ Examples:
78
+
79
+ ```python
80
+ >>> from transformers import NewConfig, NewModel
81
+
82
+ >>> # Initializing a NEW izhx/new-base-en style configuration
83
+ >>> configuration = NewConfig()
84
+
85
+ >>> # Initializing a model (with random weights) from the izhx/new-base-en style configuration
86
+ >>> model = NewModel(configuration)
87
+
88
+ >>> # Accessing the model configuration
89
+ >>> configuration = model.config
90
+ ```"""
91
+
92
+ model_type = "new"
93
+
94
+ def __init__(
95
+ self,
96
+ vocab_size=30528,
97
+ hidden_size=768,
98
+ num_hidden_layers=12,
99
+ num_attention_heads=12,
100
+ intermediate_size=3072,
101
+ hidden_act="gelu",
102
+ hidden_dropout_prob=0.1,
103
+ attention_probs_dropout_prob=0.0,
104
+ max_position_embeddings=2048,
105
+ type_vocab_size=1,
106
+ initializer_range=0.02,
107
+ layer_norm_type='layer_norm',
108
+ layer_norm_eps=1e-12,
109
+ # pad_token_id=0,
110
+ position_embedding_type="rope",
111
+ rope_theta=10000.0,
112
+ rope_scaling=None,
113
+ classifier_dropout=None,
114
+ pack_qkv=True,
115
+ unpad_inputs=False,
116
+ use_memory_efficient_attention=False,
117
+ logn_attention_scale=False,
118
+ logn_attention_clip1=False,
119
+ **kwargs,
120
+ ):
121
+ super().__init__(**kwargs)
122
+
123
+ self.vocab_size = vocab_size
124
+ self.hidden_size = hidden_size
125
+ self.num_hidden_layers = num_hidden_layers
126
+ self.num_attention_heads = num_attention_heads
127
+ self.hidden_act = hidden_act
128
+ self.intermediate_size = intermediate_size
129
+ self.hidden_dropout_prob = hidden_dropout_prob
130
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
131
+ self.max_position_embeddings = max_position_embeddings
132
+ self.type_vocab_size = type_vocab_size
133
+ self.initializer_range = initializer_range
134
+ self.layer_norm_type = layer_norm_type
135
+ self.layer_norm_eps = layer_norm_eps
136
+ self.position_embedding_type = position_embedding_type
137
+ self.rope_theta = rope_theta
138
+ self.rope_scaling = rope_scaling
139
+ self.classifier_dropout = classifier_dropout
140
+
141
+ self.pack_qkv = pack_qkv
142
+ self.unpad_inputs = unpad_inputs
143
+ self.use_memory_efficient_attention = use_memory_efficient_attention
144
+ self.logn_attention_scale = logn_attention_scale
145
+ self.logn_attention_clip1 = logn_attention_clip1
checkpoint-200/global_step200/mp_rank_00_model_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a065a800810f8bf3a86a107772dca7e31bcdf93e42c41ff15e8bc84e57082939
3
+ size 4283468624
checkpoint-200/latest ADDED
@@ -0,0 +1 @@
 
 
1
+ global_step200
checkpoint-200/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efccdc583e449d5ad7b1ab1beda845d50c68d48c9acd72a1eef1ece00b1ac8b1
3
+ size 611934706
checkpoint-200/modeling.py ADDED
@@ -0,0 +1,1418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The GTE Team Authors and Alibaba Group.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch NEW model."""
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+
26
+ from transformers.activations import ACT2FN
27
+ from transformers.modeling_outputs import (
28
+ BaseModelOutput,
29
+ BaseModelOutputWithPooling,
30
+ MaskedLMOutput,
31
+ MultipleChoiceModelOutput,
32
+ QuestionAnsweringModelOutput,
33
+ SequenceClassifierOutput,
34
+ ModelOutput,
35
+ )
36
+ from transformers.modeling_utils import PreTrainedModel
37
+ from transformers.utils import logging
38
+
39
+ try:
40
+ import xformers.ops as xops
41
+ except ImportError as e:
42
+ xops = None
43
+
44
+ from .configuration import NewConfig
45
+
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+
50
+ # Adapted from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
51
+ # Which was adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
52
+ class IndexFirstAxis(torch.autograd.Function):
53
+ @staticmethod
54
+ def forward(ctx, input, indices):
55
+ ctx.save_for_backward(indices)
56
+ assert input.ndim >= 2
57
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
58
+ second_dim = other_shape.numel()
59
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
60
+ # return input[indices]
61
+ # return torch.gather(
62
+ # rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
63
+ # ).reshape(-1, *other_shape)
64
+ return torch.gather(
65
+ input.view(ctx.first_axis_dim, second_dim),
66
+ 0,
67
+ indices.unsqueeze(-1).expand(indices.size(0), second_dim)
68
+ ).reshape(-1, *other_shape)
69
+
70
+ @staticmethod
71
+ def backward(ctx, grad_output):
72
+ (indices,) = ctx.saved_tensors
73
+ assert grad_output.ndim >= 2
74
+ other_shape = grad_output.shape[1:]
75
+ # grad_output = rearrange(grad_output, "b ... -> b (...)")
76
+ grad_output = grad_output.view(grad_output.size(0), other_shape.numel())
77
+ grad_input = torch.zeros(
78
+ [ctx.first_axis_dim, grad_output.shape[1]],
79
+ device=grad_output.device,
80
+ dtype=grad_output.dtype,
81
+ )
82
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
83
+ # grad_input[indices] = grad_output
84
+ # grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
85
+ grad_input.scatter_(
86
+ 0, indices.unsqueeze(-1).expand(indices.size(0), grad_output.size(1)), grad_output
87
+ )
88
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
89
+
90
+
91
+ index_first_axis = IndexFirstAxis.apply
92
+
93
+
94
+ def unpad_input(hidden_states, attention_mask=None, indices=None):
95
+ """
96
+ Arguments:
97
+ hidden_states: (batch, seqlen, ...)
98
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
99
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
100
+ Return:
101
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
102
+ """
103
+ if indices is None:
104
+ assert attention_mask is not None
105
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
106
+
107
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
108
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
109
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
110
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
111
+ # so we write custom forward and backward to make it a bit faster.
112
+ hidden_states = hidden_states.view(-1, *hidden_states.shape[2:])
113
+ return index_first_axis(hidden_states, indices)
114
+
115
+
116
+ class IndexPutFirstAxis(torch.autograd.Function):
117
+ @staticmethod
118
+ def forward(
119
+ ctx,
120
+ values: torch.Tensor,
121
+ indices: torch.Tensor,
122
+ first_axis_dim
123
+ ) -> torch.Tensor:
124
+ ctx.save_for_backward(indices)
125
+ assert indices.ndim == 1
126
+ assert values.ndim >= 2
127
+ output = torch.zeros(
128
+ first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
129
+ )
130
+ output[indices] = values
131
+ return output
132
+
133
+ @staticmethod
134
+ def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
135
+ indices, = ctx.saved_tensors
136
+ grad_values = grad_output[indices]
137
+ return grad_values, None, None
138
+
139
+
140
+ index_put_first_axis = IndexPutFirstAxis.apply
141
+
142
+
143
+ def pad_input(inputs: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor:
144
+ """Add padding to sequences.
145
+
146
+ Arguments:
147
+ inputs: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
148
+ indices: (total_nnz), `indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()`
149
+ batch: int batch_size
150
+ seqlen: int max sequence length
151
+
152
+ Returns:
153
+ inputs: (batch, seqlen, ...)
154
+ """
155
+ output = index_put_first_axis(inputs, indices, batch * seqlen)
156
+ return output.view(batch, seqlen, *inputs.shape[1:])
157
+
158
+
159
+ def rotate_half(x):
160
+ """Rotates half the hidden dims of the input."""
161
+ x1 = x[..., : x.shape[-1] // 2]
162
+ x2 = x[..., x.shape[-1] // 2 :]
163
+ return torch.cat((-x2, x1), dim=-1)
164
+
165
+
166
+ def apply_rotary_pos_emb(q, k, cos, sin):
167
+ """Applies Rotary Position Embedding to the query and key tensors.
168
+
169
+ Args:
170
+ q (`torch.Tensor`): The query tensor.
171
+ k (`torch.Tensor`): The key tensor.
172
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
173
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
174
+ Returns:
175
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
176
+ """
177
+ cos, sin = cos.to(q.dtype), sin.to(q.dtype)
178
+ q_embed = (q * cos) + (rotate_half(q) * sin)
179
+ k_embed = (k * cos) + (rotate_half(k) * sin)
180
+ return q_embed, k_embed
181
+
182
+
183
+ class RotaryEmbedding(torch.nn.Module):
184
+ def __init__(self, dim, max_position_embeddings=512, base=10000.0, device=None):
185
+ super().__init__()
186
+
187
+ self.dim = dim
188
+ self.max_position_embeddings = max_position_embeddings
189
+ self.base = base
190
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
191
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
192
+
193
+ # Build here to make `torch.jit.trace` work.
194
+ self._set_cos_sin_cache(
195
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
196
+ )
197
+
198
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
199
+ self.max_seq_len_cached = seq_len
200
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
201
+
202
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
203
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
204
+ emb = torch.cat((freqs, freqs), dim=-1)
205
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
206
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
207
+
208
+ def forward(self, x, seq_len=None):
209
+ # x: [bs, num_attention_heads, seq_len, head_size]
210
+ if seq_len > self.max_seq_len_cached:
211
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
212
+
213
+ return (
214
+ self.cos_cached[:seq_len, ...].to(dtype=x.dtype),
215
+ self.sin_cached[:seq_len, ...].to(dtype=x.dtype),
216
+ )
217
+
218
+
219
+ class NTKScalingRotaryEmbedding(RotaryEmbedding):
220
+ """RotaryEmbedding extended with fixed and mixed NTK scaling. https://kexue.fm/archives/9706 """
221
+
222
+ def __init__(self, dim, max_position_embeddings=512, base=10000, device=None, scaling_factor=1.0, mixed_b=None):
223
+ self.scaling_factor = scaling_factor
224
+ self.mixed_b = mixed_b
225
+ super().__init__(dim, max_position_embeddings, base, device)
226
+ max_position_embeddings = max_position_embeddings * self.scaling_factor
227
+ self._set_cos_sin_cache(max_position_embeddings, self.inv_freq.device, torch.get_default_dtype())
228
+
229
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
230
+ self.max_seq_len_cached = seq_len
231
+
232
+ if seq_len > self.max_position_embeddings:
233
+ base = self.base * (self.scaling_factor if self.mixed_b is None else 1)
234
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
235
+
236
+ if self.mixed_b is None:
237
+ inv_freq = inv_freq / self.scaling_factor ** (2 / self.dim) # (6)
238
+ else:
239
+ a = torch.tensor(self.scaling_factor).log() / (self.dim / 2) ** self.mixed_b # (13)
240
+ lambda_1_m = (a * torch.arange(1, self.dim // 2 + 1).float().to(device) ** self.mixed_b).exp() # (12)
241
+ inv_freq = inv_freq / lambda_1_m # (10)
242
+
243
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
244
+
245
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
246
+
247
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
248
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
249
+ emb = torch.cat((freqs, freqs), dim=-1)
250
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
251
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
252
+
253
+
254
+ class RMSNorm(nn.Module):
255
+ def __init__(self, hidden_size, eps=1e-6):
256
+ """
257
+ RMSNorm is equivalent to T5LayerNorm
258
+ """
259
+ super().__init__()
260
+ self.weight = nn.Parameter(torch.ones(hidden_size))
261
+ self.variance_epsilon = eps
262
+
263
+ def forward(self, hidden_states):
264
+ input_dtype = hidden_states.dtype
265
+ hidden_states = hidden_states.to(torch.float32)
266
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
267
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
268
+ return self.weight * hidden_states.to(input_dtype)
269
+
270
+
271
+ LAYER_NORM = {
272
+ 'layer_norm': nn.LayerNorm,
273
+ 'rms_norm': RMSNorm
274
+ }
275
+
276
+
277
+ class NewEmbeddings(nn.Module):
278
+ """
279
+ Embedding and Unpadding.
280
+ """
281
+
282
+ def __init__(self, config: NewConfig):
283
+ super().__init__()
284
+ self.padding_idx = config.pad_token_id
285
+ self.word_embeddings = nn.Embedding(
286
+ config.vocab_size, config.hidden_size, padding_idx=self.padding_idx
287
+ )
288
+
289
+ self.position_embedding_type = config.position_embedding_type
290
+ if self.position_embedding_type == 'absolute':
291
+ self.position_embeddings = nn.Embedding(
292
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
293
+ )
294
+ elif self.position_embedding_type == 'rope':
295
+ self._init_rope(config)
296
+ else:
297
+ raise ValueError
298
+
299
+ self.type_vocab_size = config.type_vocab_size
300
+ if self.type_vocab_size > 0:
301
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
302
+
303
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
304
+ # any TensorFlow checkpoint file
305
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
306
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
307
+ # position_ids is contiguous in memory and excluded when serialized
308
+ self.register_buffer(
309
+ "position_ids", torch.arange(config.max_position_embeddings), persistent=False
310
+ )
311
+
312
+ def _init_rope(self, config):
313
+ kwargs = dict(
314
+ dim=int(config.hidden_size / config.num_attention_heads),
315
+ max_position_embeddings=config.max_position_embeddings,
316
+ base=config.rope_theta
317
+ )
318
+ if config.rope_scaling is None:
319
+ self.rotary_emb = RotaryEmbedding(**kwargs)
320
+ else:
321
+ kwargs.update(scaling_factor=config.rope_scaling["factor"])
322
+ scaling_type = config.rope_scaling["type"]
323
+ if scaling_type == 'ntk':
324
+ kwargs.update(mixed_b=config.rope_scaling.get('mixed_b', None))
325
+ self.rotary_emb = NTKScalingRotaryEmbedding(**kwargs)
326
+ # elif scaling_type == "linear":
327
+ # self.rotary_emb = LinearScalingRotaryEmbedding(**kwargs)
328
+ # elif scaling_type == "dynamic":
329
+ # self.rotary_emb = DynamicNTKScalingRotaryEmbedding(**kwargs)
330
+ else:
331
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
332
+
333
+ def forward(
334
+ self,
335
+ unpad_inputs: bool,
336
+ input_ids: Optional[torch.Tensor] = None,
337
+ attention_mask: Optional[torch.Tensor] = None,
338
+ length: Optional[List[int]] = None,
339
+ token_type_ids: Optional[torch.Tensor] = None,
340
+ position_ids: Optional[torch.Tensor] = None,
341
+ inputs_embeds: Optional[torch.Tensor] = None,
342
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[Tuple], Optional[List[int]]]:
343
+ """
344
+ """
345
+ if inputs_embeds is None:
346
+ device, input_shape = input_ids.device, input_ids.shape
347
+ else:
348
+ device, input_shape = inputs_embeds.device, inputs_embeds.shape[:2]
349
+ batch_size, seq_length = input_shape
350
+
351
+ # Set attention_mask if it's None
352
+ if attention_mask is None:
353
+ attention_mask = torch.ones(input_shape, device=device)
354
+ if length is not None:
355
+ for i, l in enumerate(length):
356
+ attention_mask[i, l:] = 0
357
+
358
+ # Set attention_mask_bool for unpadding
359
+ if unpad_inputs:
360
+ attention_mask_bool = attention_mask.bool()
361
+ if length is None:
362
+ length = attention_mask.sum(-1).tolist()
363
+
364
+ # Get word embeddings
365
+ if inputs_embeds is None:
366
+ if unpad_inputs:
367
+ input_ids = input_ids[attention_mask_bool].unsqueeze(0)
368
+ inputs_embeds = self.word_embeddings(input_ids)
369
+ else:
370
+ if unpad_inputs:
371
+ inputs_embeds = inputs_embeds[attention_mask_bool].unsqueeze(0)
372
+ embeddings = inputs_embeds
373
+
374
+ # Set and unpad position_ids
375
+ if position_ids is None:
376
+ if seq_length > self.position_ids.size(0):
377
+ self.register_buffer(
378
+ "position_ids", torch.arange(seq_length, device=embeddings.device), persistent=False
379
+ )
380
+ if unpad_inputs:
381
+ # [1, cumsum_seq_len]
382
+ position_ids = torch.cat([self.position_ids[:l] for l in length]).unsqueeze(0)
383
+ else:
384
+ # [bs, seq_len]
385
+ position_ids = self.position_ids[:seq_length].expand(batch_size, -1)
386
+ elif unpad_inputs:
387
+ position_ids = position_ids[attention_mask_bool].unsqueeze(0) # [1, cumsum_seq_len]
388
+
389
+ # Compute rotary embedding
390
+ if self.position_embedding_type == 'rope':
391
+ rope_cos, rope_sin = self.rotary_emb(inputs_embeds, seq_len=seq_length)
392
+ rope_cos = rope_cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
393
+ rope_sin = rope_sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
394
+ rope_embeds = rope_cos, rope_sin
395
+ else:
396
+ rope_embeds = None
397
+
398
+ if self.type_vocab_size > 0:
399
+ if token_type_ids is None:
400
+ token_type_ids = position_ids.mul(0)
401
+ else:
402
+ if self.type_vocab_size < 2:
403
+ token_type_ids.mul_(0)
404
+ if unpad_inputs:
405
+ token_type_ids = token_type_ids[attention_mask_bool].unsqueeze(0)
406
+
407
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
408
+ embeddings = embeddings + token_type_embeddings
409
+
410
+ # BERT position
411
+ if self.position_embedding_type == "absolute":
412
+ position_embeddings = self.position_embeddings(position_ids)
413
+ embeddings = embeddings + position_embeddings
414
+
415
+ embeddings = self.LayerNorm(embeddings)
416
+ embeddings = self.dropout(embeddings)
417
+
418
+ return embeddings, attention_mask, rope_embeds, length
419
+
420
+
421
+ class NewAttention(nn.Module):
422
+ def __init__(self, config: NewConfig, pack_qkv=None, use_memory_efficient_attention=None):
423
+ super().__init__()
424
+ self.config = config
425
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
426
+ raise ValueError(
427
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
428
+ f"heads ({config.num_attention_heads})"
429
+ )
430
+
431
+ self.hidden_size = config.hidden_size
432
+ self.num_attention_heads = config.num_attention_heads
433
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
434
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
435
+
436
+ if pack_qkv is None:
437
+ pack_qkv = config.pack_qkv
438
+ self.pack_qkv = pack_qkv
439
+
440
+ if self.pack_qkv:
441
+ self.qkv_proj = nn.Linear(config.hidden_size, self.all_head_size * 3, bias=True)
442
+ else:
443
+ self.q_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
444
+ self.k_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
445
+ self.v_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
446
+
447
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
448
+ self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
449
+
450
+ if use_memory_efficient_attention is None:
451
+ use_memory_efficient_attention = self.config.use_memory_efficient_attention
452
+ self.use_memory_efficient_attention = use_memory_efficient_attention
453
+ self.memory_efficient_attention = None if xops is None else xops.memory_efficient_attention
454
+ if self.use_memory_efficient_attention:
455
+ assert self.memory_efficient_attention is not None, 'please install xformers'
456
+
457
+ def forward(
458
+ self,
459
+ hidden_states: torch.Tensor,
460
+ attention_bias: torch.FloatTensor,
461
+ rope_embeds: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
462
+ padding_inputs: Optional[Tuple] = None, # indices, batch, seqlen
463
+ attention_scale: Optional[torch.FloatTensor] = None,
464
+ head_mask: Optional[torch.FloatTensor] = None,
465
+ output_attentions: Optional[bool] = False,
466
+ qkv_inputs: Optional[Tuple] = None, # For RetroMAE
467
+ ) -> Tuple[torch.Tensor, ...]:
468
+ shape_hd = (self.num_attention_heads, self.attention_head_size)
469
+ # qkv
470
+ if self.pack_qkv and qkv_inputs is None:
471
+ qkv_pack = self.qkv_proj(hidden_states).split(self.all_head_size, dim=-1)
472
+ else:
473
+ if qkv_inputs is None:
474
+ qkv_inputs = (hidden_states, hidden_states, hidden_states)
475
+ qkv_pack = [
476
+ getattr(self, n + '_proj')(s) for s, n in zip(qkv_inputs, 'qkv')
477
+ ]
478
+ query_states, key_states, value_states = [t.view(t.shape[:-1] + shape_hd) for t in qkv_pack]
479
+
480
+ if self.config.position_embedding_type == 'rope':
481
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, *rope_embeds)
482
+
483
+ dtype = query_states.dtype
484
+
485
+ if self.config.logn_attention_scale and attention_scale is not None:
486
+ # https://kexue.fm/archives/8823
487
+ query_states = query_states * attention_scale.to(dtype)
488
+
489
+ if padding_inputs is not None:
490
+ query_states = pad_input(query_states.squeeze(), *padding_inputs)
491
+ key_states = pad_input(key_states.squeeze(), *padding_inputs)
492
+ value_states = pad_input(value_states.squeeze(), *padding_inputs)
493
+
494
+ if self.use_memory_efficient_attention:
495
+ assert self.memory_efficient_attention is not None, "xformers is not loaded"
496
+ assert output_attentions is False, "memory_efficient_attention do not output attentions"
497
+ assert head_mask is None, "Not support yet"
498
+ attention_probs = None
499
+ if torch.is_tensor(attention_bias):
500
+ attention_bias = attention_bias.to(dtype)
501
+ context_layer = self.memory_efficient_attention(
502
+ query_states,
503
+ key_states,
504
+ value_states,
505
+ attn_bias=attention_bias,
506
+ p=self.dropout.p
507
+ )
508
+ else:
509
+ if output_attentions and isinstance(self, NewSdpaAttention):
510
+ raise RuntimeError("SDPA do not output attentions")
511
+ context_layer, attention_probs = self._attention(
512
+ query_states, key_states, value_states, attention_bias, head_mask
513
+ )
514
+
515
+ if padding_inputs is not None:
516
+ context_layer = unpad_input(context_layer, indices=padding_inputs[0])
517
+
518
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
519
+ context_layer = context_layer.view(new_context_layer_shape)
520
+
521
+ # output proj
522
+ attn_output = self.o_proj(context_layer)
523
+
524
+ # add attentions if we output them
525
+ outputs = (attn_output, attention_probs) if output_attentions else (attn_output,)
526
+ return outputs
527
+
528
+ def _attention(self, query_states, key_states, value_states, attention_bias, head_mask):
529
+ """
530
+ Args:
531
+ q/k/v: (B, L, n_head, head_dim),
532
+ Returns:
533
+ attn_output: (B L, n_head, head_dim)
534
+ """
535
+ query_states = query_states.transpose(1, 2)
536
+ key_states = key_states.transpose(1, 2)
537
+ value_states = value_states.transpose(1, 2)
538
+ # Take the dot product between "query" and "key" to get the raw attention scores.
539
+ attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
540
+
541
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
542
+ if attention_bias is not None:
543
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
544
+ attention_scores = attention_scores + attention_bias
545
+
546
+ # Normalize the attention scores to probabilities.
547
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
548
+
549
+ # This is actually dropping out entire tokens to attend to, which might
550
+ # seem a bit unusual, but is taken from the original Transformer paper.
551
+ if self.dropout.p > 0:
552
+ attention_probs = self.dropout(attention_probs)
553
+
554
+ # Mask heads if we want to
555
+ if head_mask is not None:
556
+ attention_probs = attention_probs * head_mask
557
+
558
+ context_layer = torch.matmul(attention_probs, value_states)
559
+
560
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
561
+ return context_layer, attention_probs
562
+
563
+
564
+ class NewSdpaAttention(NewAttention):
565
+ """
566
+ New attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
567
+ `NewAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
568
+ SDPA API.
569
+ """
570
+ def __init__(self, config: NewConfig, **kwargs):
571
+ super().__init__(config, **kwargs)
572
+ # torch.backends.cuda.enable_mem_efficient_sdp(False)
573
+ # logger.warning(
574
+ # "Disable memory efficient attention kernel for `NewSdpaAttention`, you can set "
575
+ # "`use_memory_efficient_attention=True` if it expected to use."
576
+ # )
577
+
578
+ def _attention(self, query_states, key_states, value_states, attention_bias, head_mask):
579
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
580
+ query_states.transpose(1, 2),
581
+ key_states.transpose(1, 2),
582
+ value_states.transpose(1, 2),
583
+ attn_mask=attention_bias,
584
+ dropout_p=self.dropout.p if self.training else 0.0,
585
+ )
586
+ attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
587
+ return attn_output, None
588
+
589
+
590
+ NEW_ATTENTION_CLASSES = {
591
+ "eager": NewAttention,
592
+ # "flash_attention_2": , # TODO
593
+ "sdpa": NewSdpaAttention,
594
+ }
595
+
596
+
597
+ class NewGatedMLP(nn.Module):
598
+ """
599
+ GLU Variants Improve Transformer.
600
+ """
601
+
602
+ def __init__(self, config: NewConfig):
603
+ super().__init__()
604
+ self.intermediate_size = config.intermediate_size
605
+ self.up_gate_proj = nn.Linear(config.hidden_size, self.intermediate_size * 2, bias=False)
606
+ self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=True)
607
+ self.act_fn = ACT2FN[config.hidden_act]
608
+ if config.hidden_dropout_prob > 0:
609
+ self.hidden_dropout = nn.Dropout(config.hidden_dropout_prob)
610
+ else:
611
+ self.hidden_dropout = None
612
+
613
+ def forward(self, hidden_states):
614
+ up_gate = self.up_gate_proj(hidden_states)
615
+ up_states, gate = torch.split(up_gate, self.intermediate_size, dim=-1)
616
+ gate = self.act_fn(gate)
617
+ gated_states = gate * up_states
618
+ if self.hidden_dropout is not None:
619
+ gated_states = self.hidden_dropout(gated_states)
620
+ down_states = self.down_proj(gated_states)
621
+ return down_states
622
+
623
+
624
+ class NewLayer(nn.Module):
625
+ def __init__(
626
+ self,
627
+ config: NewConfig,
628
+ pack_qkv=None,
629
+ use_memory_efficient_attention=None,
630
+ attn_implementation=None
631
+ ):
632
+ super().__init__()
633
+ if attn_implementation is None:
634
+ attn_implementation = config._attn_implementation
635
+ if use_memory_efficient_attention is None:
636
+ use_memory_efficient_attention = config.use_memory_efficient_attention
637
+ if use_memory_efficient_attention:
638
+ if attn_implementation != 'eager':
639
+ logger.warning_once(f"Override {attn_implementation=} to 'eager' as {use_memory_efficient_attention=}")
640
+ attn_implementation = 'eager' # Since it will be SDPA by default for torch>=2.1.1
641
+ self.attention = NEW_ATTENTION_CLASSES[attn_implementation](
642
+ config, pack_qkv=pack_qkv, use_memory_efficient_attention=use_memory_efficient_attention
643
+ )
644
+ self.mlp = NewGatedMLP(config)
645
+
646
+ ln_class = LAYER_NORM[config.layer_norm_type]
647
+ self.attn_ln = ln_class(config.hidden_size, eps=config.layer_norm_eps)
648
+ self.mlp_ln = ln_class(config.hidden_size, eps=config.layer_norm_eps)
649
+
650
+ if config.hidden_dropout_prob > 0:
651
+ self.hidden_dropout = nn.Dropout(config.hidden_dropout_prob)
652
+ else:
653
+ self.hidden_dropout = None
654
+
655
+ def forward(
656
+ self,
657
+ hidden_states: torch.Tensor,
658
+ attention_bias: torch.FloatTensor,
659
+ rope_embeds: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
660
+ padding_inputs: Optional[Tuple] = None, # indices, batch, seqlen
661
+ attention_scale: Optional[torch.FloatTensor] = None,
662
+ subset_indices: Optional[torch.LongTensor] = None,
663
+ head_mask: Optional[torch.FloatTensor] = None,
664
+ output_attentions: Optional[bool] = False,
665
+ qkv_inputs: Optional[Tuple] = None, # For RetroMAE
666
+ ) -> Tuple[torch.Tensor, ...]:
667
+ # Multi head self attention
668
+ residual = hidden_states if qkv_inputs is None else qkv_inputs[0]
669
+ attention_outputs = self.attention(
670
+ hidden_states,
671
+ attention_bias,
672
+ rope_embeds,
673
+ padding_inputs,
674
+ attention_scale,
675
+ head_mask,
676
+ output_attentions=output_attentions,
677
+ qkv_inputs=qkv_inputs,
678
+ )
679
+ hidden_states = attention_outputs[0]
680
+ if self.hidden_dropout is not None:
681
+ hidden_states = self.hidden_dropout(hidden_states)
682
+ hidden_states = residual + hidden_states
683
+
684
+ # In pretraining, after the attention of last layer, we only need the masked tokens.
685
+ if subset_indices is not None:
686
+ hidden_states = hidden_states[subset_indices]
687
+
688
+ hidden_states = self.attn_ln(hidden_states)
689
+
690
+ # Fully Connected
691
+ residual = hidden_states
692
+ hidden_states = self.mlp(hidden_states)
693
+ if self.hidden_dropout is not None:
694
+ hidden_states = self.hidden_dropout(hidden_states)
695
+ hidden_states = residual + hidden_states
696
+ hidden_states = self.mlp_ln(hidden_states)
697
+
698
+ # add self attentions if we output attention weights
699
+ outputs = (hidden_states,) + attention_outputs[1:]
700
+ return outputs
701
+
702
+
703
+ class NewEncoder(nn.Module):
704
+ def __init__(self, config):
705
+ super().__init__()
706
+ self.config = config
707
+ self.layer = nn.ModuleList([NewLayer(config) for _ in range(config.num_hidden_layers)])
708
+ self.gradient_checkpointing = False
709
+
710
+ def forward(
711
+ self,
712
+ hidden_states: torch.Tensor,
713
+ attention_bias: Optional[torch.FloatTensor] = None,
714
+ rope_embeds: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
715
+ padding_inputs: Optional[Tuple] = None, # indices, batch, seqlen
716
+ attention_scale: Optional[torch.FloatTensor] = None,
717
+ subset_indices: Optional[torch.LongTensor] = None,
718
+ head_mask: Optional[torch.FloatTensor] = None,
719
+ output_attentions: Optional[bool] = False,
720
+ output_hidden_states: Optional[bool] = False,
721
+ return_dict: Optional[bool] = True,
722
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
723
+ all_hidden_states = () if output_hidden_states else None
724
+ all_self_attentions = () if output_attentions else None
725
+
726
+ for i, layer_module in enumerate(self.layer):
727
+ if output_hidden_states:
728
+ all_hidden_states = all_hidden_states + (hidden_states,)
729
+
730
+ if i >= len(self.layer) - 1:
731
+ layer_subset_indices = subset_indices
732
+ else:
733
+ layer_subset_indices = None
734
+
735
+ layer_head_mask = head_mask[i] if head_mask is not None else None
736
+
737
+ if self.gradient_checkpointing and self.training:
738
+ layer_outputs = self._gradient_checkpointing_func(
739
+ layer_module.__call__,
740
+ hidden_states,
741
+ attention_bias,
742
+ rope_embeds,
743
+ padding_inputs,
744
+ attention_scale,
745
+ layer_subset_indices,
746
+ layer_head_mask,
747
+ )
748
+ else:
749
+ layer_outputs = layer_module(
750
+ hidden_states,
751
+ attention_bias,
752
+ rope_embeds,
753
+ padding_inputs,
754
+ attention_scale,
755
+ layer_subset_indices,
756
+ layer_head_mask,
757
+ output_attentions,
758
+ )
759
+
760
+ hidden_states = layer_outputs[0]
761
+ if output_attentions:
762
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
763
+
764
+ if output_hidden_states:
765
+ all_hidden_states = all_hidden_states + (hidden_states,)
766
+
767
+ if not return_dict:
768
+ return tuple(
769
+ v
770
+ for v in [
771
+ hidden_states,
772
+ all_hidden_states,
773
+ all_self_attentions,
774
+ ]
775
+ if v is not None
776
+ )
777
+ return BaseModelOutput(
778
+ last_hidden_state=hidden_states,
779
+ hidden_states=all_hidden_states,
780
+ attentions=all_self_attentions,
781
+ )
782
+
783
+
784
+ # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->New
785
+ class NewPooler(nn.Module):
786
+ def __init__(self, config):
787
+ super().__init__()
788
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
789
+ self.activation = nn.Tanh()
790
+
791
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
792
+ # We "pool" the model by simply taking the hidden state corresponding
793
+ # to the first token.
794
+ first_token_tensor = hidden_states[:, 0]
795
+ pooled_output = self.dense(first_token_tensor)
796
+ pooled_output = self.activation(pooled_output)
797
+ return pooled_output
798
+
799
+
800
+ class NewPreTrainedModel(PreTrainedModel):
801
+ """
802
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
803
+ models.
804
+ """
805
+
806
+ config_class = NewConfig
807
+ base_model_prefix = "new"
808
+ supports_gradient_checkpointing = True
809
+ _supports_sdpa = True
810
+
811
+ def _init_weights(self, module):
812
+ """Initialize the weights"""
813
+ if isinstance(module, nn.Linear):
814
+ # Slightly different from the TF version which uses truncated_normal for initialization
815
+ # cf https://github.com/pytorch/pytorch/pull/5617
816
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
817
+ if module.bias is not None:
818
+ module.bias.data.zero_()
819
+ elif isinstance(module, nn.Embedding):
820
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
821
+ if module.padding_idx is not None:
822
+ module.weight.data[module.padding_idx].zero_()
823
+ elif isinstance(module, nn.LayerNorm):
824
+ module.bias.data.zero_()
825
+ module.weight.data.fill_(1.0)
826
+
827
+
828
+ class NewModel(NewPreTrainedModel):
829
+ """
830
+ The bare New Model transformer outputting raw hidden-states without any specific head on top.
831
+ """
832
+
833
+ def __init__(self, config: NewConfig, add_pooling_layer=False):
834
+ super().__init__(config)
835
+ self.config = config
836
+
837
+ self.embeddings = NewEmbeddings(config)
838
+ self.encoder = NewEncoder(config)
839
+
840
+ self.pooler = NewPooler(config) if add_pooling_layer else None
841
+
842
+ # Initialize weights and apply final processing
843
+ self.post_init()
844
+
845
+ def get_input_embeddings(self):
846
+ return self.embeddings.word_embeddings
847
+
848
+ def set_input_embeddings(self, value):
849
+ self.embeddings.word_embeddings = value
850
+
851
+ def forward(
852
+ self,
853
+ input_ids: Optional[torch.Tensor] = None,
854
+ attention_mask: Optional[torch.Tensor] = None,
855
+ length: Optional[List[int]] = None,
856
+ subset_indices: Optional[torch.LongTensor] = None,
857
+ token_type_ids: Optional[torch.Tensor] = None,
858
+ position_ids: Optional[torch.Tensor] = None,
859
+ head_mask: Optional[torch.Tensor] = None,
860
+ inputs_embeds: Optional[torch.Tensor] = None,
861
+ output_attentions: Optional[bool] = None,
862
+ output_hidden_states: Optional[bool] = None,
863
+ return_dict: Optional[bool] = None,
864
+ unpad_inputs: Optional[bool] = None,
865
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]:
866
+ r"""
867
+ length (`list` of length `batch_size`, *optional*):
868
+ If is `None`, return padded `last_hidden_state`.
869
+ subset_indices ():
870
+ pass
871
+ unpad_inputs (`bool`, *optional*):
872
+ pass
873
+ """
874
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
875
+ output_hidden_states = (
876
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
877
+ )
878
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
879
+ unpad_inputs = unpad_inputs if unpad_inputs is not None else self.config.unpad_inputs
880
+ output_padded = length is None
881
+
882
+ if input_ids is not None and inputs_embeds is not None:
883
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
884
+ elif input_ids is not None:
885
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
886
+ input_shape = input_ids.size()
887
+ elif inputs_embeds is not None:
888
+ input_shape = inputs_embeds.size()[:-1]
889
+ else:
890
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
891
+
892
+ # TODO: not used
893
+ # # Prepare head mask if needed
894
+ # # 1.0 in head_mask indicate we keep the head
895
+ # # attention_probs has shape bsz x n_heads x N x N
896
+ # # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
897
+ # # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
898
+ # head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
899
+
900
+ # Get embeddings, may unpad them
901
+ (embedding_output, attention_mask, rope_embeds, length) = self.embeddings(
902
+ unpad_inputs,
903
+ input_ids=input_ids,
904
+ attention_mask=attention_mask,
905
+ length=length,
906
+ token_type_ids=token_type_ids,
907
+ position_ids=position_ids,
908
+ inputs_embeds=inputs_embeds
909
+ )
910
+
911
+ batch_size, seq_length = input_shape
912
+ if unpad_inputs and self.config.use_memory_efficient_attention:
913
+ attention_bias = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(length)
914
+ else:
915
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
916
+ # ourselves in which case we just need to make it broadcastable to all heads.
917
+ attention_bias = self.get_extended_attention_mask(attention_mask, input_shape)
918
+ if self.config.use_memory_efficient_attention:
919
+ # Invalid shape for attention bias: torch.Size([48, 1, 1, 512]) (expected (48, 12, 512, 512))
920
+ attention_bias = attention_bias.expand(-1, self.config.num_attention_heads, seq_length, -1)
921
+
922
+ padding_inputs = None
923
+ if unpad_inputs and (output_padded or not self.config.use_memory_efficient_attention):
924
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
925
+ if not self.config.use_memory_efficient_attention:
926
+ padding_inputs = (indices, *input_shape)
927
+
928
+ attention_scale = None
929
+ if self.config.logn_attention_scale:
930
+ logger.warning_once("TODO: logn_attention_scale")
931
+ # # attention scale log_512(input_len)
932
+ # attention_scale = attention_mask.sum(1).log() / torch.tensor(self.config.max_position_embeddings).log()
933
+ # # inference-time logn scale need clip 1
934
+ # if self.config.logn_attention_clip1:
935
+ # attention_scale.clip_(1)
936
+ # attention_scale = attention_scale[:, None, None, None]
937
+ # else:
938
+ # attention_scale = None
939
+
940
+ encoder_outputs = self.encoder(
941
+ embedding_output,
942
+ attention_bias=attention_bias,
943
+ rope_embeds=rope_embeds,
944
+ padding_inputs=padding_inputs,
945
+ attention_scale=attention_scale,
946
+ subset_indices=subset_indices,
947
+ head_mask=head_mask,
948
+ output_attentions=output_attentions,
949
+ output_hidden_states=output_hidden_states,
950
+ return_dict=return_dict,
951
+ )
952
+ sequence_output = encoder_outputs[0]
953
+ if unpad_inputs and output_padded:
954
+ sequence_output = pad_input(
955
+ sequence_output.squeeze(), indices, batch_size, seq_length
956
+ )
957
+
958
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
959
+
960
+ if not return_dict:
961
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
962
+
963
+ return BaseModelOutputWithPooling(
964
+ last_hidden_state=sequence_output,
965
+ pooler_output=pooled_output,
966
+ hidden_states=encoder_outputs.hidden_states,
967
+ attentions=encoder_outputs.attentions,
968
+ )
969
+
970
+
971
+ class NewLMPredictionHead(nn.Module):
972
+ def __init__(self, config):
973
+ super().__init__()
974
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
975
+ self.transform_act_fn = ACT2FN[config.hidden_act]
976
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
977
+
978
+ # The output weights are the same as the input embeddings, but there is
979
+ # an output-only bias for each token.
980
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
981
+
982
+ def forward(self, hidden_states):
983
+ hidden_states = self.dense(hidden_states)
984
+ hidden_states = self.transform_act_fn(hidden_states)
985
+ hidden_states = self.norm(hidden_states)
986
+ hidden_states = self.decoder(hidden_states)
987
+ return hidden_states
988
+
989
+
990
+ class NewForMaskedLM(NewPreTrainedModel):
991
+ _tied_weights_keys = ["lm_head.decoder.bias", "lm_head.decoder.weight"]
992
+
993
+ def __init__(self, config: NewConfig):
994
+ super().__init__(config)
995
+ self.new = NewModel(config, add_pooling_layer=False)
996
+ self.lm_head = NewLMPredictionHead(config)
997
+ self.loss_fct = nn.CrossEntropyLoss()
998
+
999
+ # Initialize weights and apply final processing
1000
+ self.post_init()
1001
+
1002
+ def get_output_embeddings(self):
1003
+ return self.lm_head.decoder
1004
+
1005
+ def set_output_embeddings(self, new_embeddings):
1006
+ self.lm_head.decoder = new_embeddings
1007
+
1008
+ def forward(
1009
+ self,
1010
+ input_ids: Optional[torch.Tensor] = None,
1011
+ attention_mask: Optional[torch.Tensor] = None,
1012
+ token_type_ids: Optional[torch.Tensor] = None,
1013
+ position_ids: Optional[torch.Tensor] = None,
1014
+ head_mask: Optional[torch.Tensor] = None,
1015
+ inputs_embeds: Optional[torch.Tensor] = None,
1016
+ labels: Optional[torch.Tensor] = None,
1017
+ output_attentions: Optional[bool] = None,
1018
+ output_hidden_states: Optional[bool] = None,
1019
+ return_dict: Optional[bool] = None,
1020
+ unpad_inputs: Optional[bool] = None,
1021
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
1022
+ r"""
1023
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1024
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1025
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1026
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1027
+ """
1028
+
1029
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1030
+
1031
+ if labels is None or not self.new.config.unpad_inputs:
1032
+ length = None
1033
+ subset_indices = None
1034
+ else:
1035
+ length = attention_mask.sum(-1).tolist()
1036
+ labels = labels[attention_mask.bool()].unsqueeze(0)
1037
+ subset_indices = labels > -100
1038
+
1039
+ outputs = self.new(
1040
+ input_ids,
1041
+ attention_mask=attention_mask,
1042
+ length=length,
1043
+ subset_indices=subset_indices,
1044
+ token_type_ids=token_type_ids,
1045
+ position_ids=position_ids,
1046
+ head_mask=head_mask,
1047
+ inputs_embeds=inputs_embeds,
1048
+ output_attentions=output_attentions,
1049
+ output_hidden_states=output_hidden_states,
1050
+ return_dict=return_dict,
1051
+ unpad_inputs=unpad_inputs,
1052
+ )
1053
+
1054
+ sequence_output = outputs[0]
1055
+ prediction_scores = self.lm_head(sequence_output)
1056
+
1057
+ masked_lm_loss = None
1058
+ if labels is not None:
1059
+ if subset_indices is None:
1060
+ mask = attention_mask.bool()
1061
+ prediction_scores = prediction_scores[mask]
1062
+ labels = labels[mask]
1063
+ else:
1064
+ labels = labels[subset_indices]
1065
+ masked_lm_loss = self.loss_fct(prediction_scores, labels)
1066
+
1067
+ if not return_dict:
1068
+ output = (prediction_scores,) + outputs[2:]
1069
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1070
+
1071
+ return MaskedLMOutput(
1072
+ loss=masked_lm_loss,
1073
+ logits=prediction_scores,
1074
+ hidden_states=outputs.hidden_states,
1075
+ attentions=outputs.attentions,
1076
+ )
1077
+
1078
+
1079
+ class NewForSequenceClassification(NewPreTrainedModel):
1080
+ def __init__(self, config):
1081
+ super().__init__(config)
1082
+ self.num_labels = config.num_labels
1083
+ self.config = config
1084
+
1085
+ self.new = NewModel(config, add_pooling_layer=True)
1086
+ classifier_dropout = (
1087
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1088
+ )
1089
+ self.dropout = nn.Dropout(classifier_dropout)
1090
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1091
+
1092
+ # Initialize weights and apply final processing
1093
+ self.post_init()
1094
+
1095
+ def forward(
1096
+ self,
1097
+ input_ids: Optional[torch.Tensor] = None,
1098
+ attention_mask: Optional[torch.Tensor] = None,
1099
+ token_type_ids: Optional[torch.Tensor] = None,
1100
+ position_ids: Optional[torch.Tensor] = None,
1101
+ head_mask: Optional[torch.Tensor] = None,
1102
+ inputs_embeds: Optional[torch.Tensor] = None,
1103
+ labels: Optional[torch.Tensor] = None,
1104
+ output_attentions: Optional[bool] = None,
1105
+ output_hidden_states: Optional[bool] = None,
1106
+ return_dict: Optional[bool] = None,
1107
+ unpad_inputs: Optional[bool] = None,
1108
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1109
+ r"""
1110
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1111
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1112
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1113
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1114
+ """
1115
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1116
+
1117
+ outputs = self.new(
1118
+ input_ids,
1119
+ attention_mask=attention_mask,
1120
+ token_type_ids=token_type_ids,
1121
+ position_ids=position_ids,
1122
+ head_mask=head_mask,
1123
+ inputs_embeds=inputs_embeds,
1124
+ output_attentions=output_attentions,
1125
+ output_hidden_states=output_hidden_states,
1126
+ return_dict=return_dict,
1127
+ unpad_inputs=unpad_inputs,
1128
+ )
1129
+
1130
+ pooled_output = outputs[1]
1131
+
1132
+ pooled_output = self.dropout(pooled_output)
1133
+ logits = self.classifier(pooled_output)
1134
+
1135
+ loss = None
1136
+ if labels is not None:
1137
+ if self.config.problem_type is None:
1138
+ if self.num_labels == 1:
1139
+ self.config.problem_type = "regression"
1140
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1141
+ self.config.problem_type = "single_label_classification"
1142
+ else:
1143
+ self.config.problem_type = "multi_label_classification"
1144
+
1145
+ if self.config.problem_type == "regression":
1146
+ loss_fct = nn.MSELoss()
1147
+ if self.num_labels == 1:
1148
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1149
+ else:
1150
+ loss = loss_fct(logits, labels)
1151
+ elif self.config.problem_type == "single_label_classification":
1152
+ loss_fct = nn.CrossEntropyLoss()
1153
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1154
+ elif self.config.problem_type == "multi_label_classification":
1155
+ loss_fct = nn.BCEWithLogitsLoss()
1156
+ loss = loss_fct(logits, labels)
1157
+
1158
+ if not return_dict:
1159
+ output = (logits,) + outputs[2:]
1160
+ return ((loss,) + output) if loss is not None else output
1161
+
1162
+ return SequenceClassifierOutput(
1163
+ loss=loss,
1164
+ logits=logits,
1165
+ hidden_states=outputs.hidden_states,
1166
+ attentions=outputs.attentions,
1167
+ )
1168
+
1169
+
1170
+ class NewForMultipleChoice(NewPreTrainedModel):
1171
+ def __init__(self, config):
1172
+ super().__init__(config)
1173
+
1174
+ self.new = NewModel(config, add_pooling_layer=True)
1175
+ classifier_dropout = (
1176
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1177
+ )
1178
+ self.dropout = nn.Dropout(classifier_dropout)
1179
+ self.classifier = nn.Linear(config.hidden_size, 1)
1180
+
1181
+ # Initialize weights and apply final processing
1182
+ self.post_init()
1183
+
1184
+ def forward(
1185
+ self,
1186
+ input_ids: Optional[torch.Tensor] = None,
1187
+ attention_mask: Optional[torch.Tensor] = None,
1188
+ token_type_ids: Optional[torch.Tensor] = None,
1189
+ position_ids: Optional[torch.Tensor] = None,
1190
+ head_mask: Optional[torch.Tensor] = None,
1191
+ inputs_embeds: Optional[torch.Tensor] = None,
1192
+ labels: Optional[torch.Tensor] = None,
1193
+ output_attentions: Optional[bool] = None,
1194
+ output_hidden_states: Optional[bool] = None,
1195
+ return_dict: Optional[bool] = None,
1196
+ unpad_inputs: Optional[bool] = None,
1197
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
1198
+ r"""
1199
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1200
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1201
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1202
+ `input_ids` above)
1203
+ """
1204
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1205
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1206
+
1207
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1208
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1209
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1210
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1211
+ inputs_embeds = (
1212
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1213
+ if inputs_embeds is not None
1214
+ else None
1215
+ )
1216
+
1217
+ outputs = self.new(
1218
+ input_ids,
1219
+ attention_mask=attention_mask,
1220
+ token_type_ids=token_type_ids,
1221
+ position_ids=position_ids,
1222
+ head_mask=head_mask,
1223
+ inputs_embeds=inputs_embeds,
1224
+ output_attentions=output_attentions,
1225
+ output_hidden_states=output_hidden_states,
1226
+ return_dict=return_dict,
1227
+ unpad_inputs=unpad_inputs,
1228
+ )
1229
+
1230
+ pooled_output = outputs[1]
1231
+
1232
+ pooled_output = self.dropout(pooled_output)
1233
+ logits = self.classifier(pooled_output)
1234
+ reshaped_logits = logits.view(-1, num_choices)
1235
+
1236
+ loss = None
1237
+ if labels is not None:
1238
+ loss_fct = nn.CrossEntropyLoss()
1239
+ loss = loss_fct(reshaped_logits, labels)
1240
+
1241
+ if not return_dict:
1242
+ output = (reshaped_logits,) + outputs[2:]
1243
+ return ((loss,) + output) if loss is not None else output
1244
+
1245
+ return MultipleChoiceModelOutput(
1246
+ loss=loss,
1247
+ logits=reshaped_logits,
1248
+ hidden_states=outputs.hidden_states,
1249
+ attentions=outputs.attentions,
1250
+ )
1251
+
1252
+
1253
+ @dataclass
1254
+ class NewTokenClassifierOutput(ModelOutput):
1255
+ loss: Optional[torch.FloatTensor] = None
1256
+ logits: torch.FloatTensor = None
1257
+ last_hidden_state: torch.FloatTensor = None
1258
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
1259
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
1260
+
1261
+
1262
+ class NewForTokenClassification(NewPreTrainedModel):
1263
+ def __init__(self, config):
1264
+ super().__init__(config)
1265
+ self.num_labels = config.num_labels
1266
+
1267
+ self.new = NewModel(config, add_pooling_layer=False)
1268
+ classifier_dropout = (
1269
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1270
+ )
1271
+ self.dropout = nn.Dropout(classifier_dropout)
1272
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1273
+
1274
+ # Initialize weights and apply final processing
1275
+ self.post_init()
1276
+
1277
+ def forward(
1278
+ self,
1279
+ input_ids: Optional[torch.Tensor] = None,
1280
+ attention_mask: Optional[torch.Tensor] = None,
1281
+ token_type_ids: Optional[torch.Tensor] = None,
1282
+ position_ids: Optional[torch.Tensor] = None,
1283
+ head_mask: Optional[torch.Tensor] = None,
1284
+ inputs_embeds: Optional[torch.Tensor] = None,
1285
+ labels: Optional[torch.Tensor] = None,
1286
+ output_attentions: Optional[bool] = None,
1287
+ output_hidden_states: Optional[bool] = None,
1288
+ return_dict: Optional[bool] = None,
1289
+ unpad_inputs: Optional[bool] = None,
1290
+ ) -> Union[Tuple[torch.Tensor], NewTokenClassifierOutput]:
1291
+ r"""
1292
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1293
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1294
+ """
1295
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1296
+
1297
+ outputs = self.new(
1298
+ input_ids,
1299
+ attention_mask=attention_mask,
1300
+ token_type_ids=token_type_ids,
1301
+ position_ids=position_ids,
1302
+ head_mask=head_mask,
1303
+ inputs_embeds=inputs_embeds,
1304
+ output_attentions=output_attentions,
1305
+ output_hidden_states=output_hidden_states,
1306
+ return_dict=return_dict,
1307
+ unpad_inputs=unpad_inputs,
1308
+ )
1309
+
1310
+ sequence_output = outputs[0]
1311
+
1312
+ sequence_output = self.dropout(sequence_output)
1313
+ logits = self.classifier(sequence_output)
1314
+
1315
+ loss = None
1316
+ if labels is not None:
1317
+ loss_fct = nn.CrossEntropyLoss()
1318
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1319
+
1320
+ if not return_dict:
1321
+ output = (logits,) + outputs[2:]
1322
+ return ((loss,) + output) if loss is not None else output
1323
+
1324
+ return NewTokenClassifierOutput(
1325
+ loss=loss,
1326
+ logits=logits,
1327
+ last_hidden_state=sequence_output,
1328
+ hidden_states=outputs.hidden_states,
1329
+ attentions=outputs.attentions,
1330
+ )
1331
+
1332
+
1333
+ class NewForQuestionAnswering(NewPreTrainedModel):
1334
+ def __init__(self, config):
1335
+ super().__init__(config)
1336
+ self.num_labels = config.num_labels
1337
+
1338
+ self.new = NewModel(config, add_pooling_layer=False)
1339
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1340
+
1341
+ # Initialize weights and apply final processing
1342
+ self.post_init()
1343
+
1344
+ def forward(
1345
+ self,
1346
+ input_ids: Optional[torch.Tensor] = None,
1347
+ attention_mask: Optional[torch.Tensor] = None,
1348
+ token_type_ids: Optional[torch.Tensor] = None,
1349
+ position_ids: Optional[torch.Tensor] = None,
1350
+ head_mask: Optional[torch.Tensor] = None,
1351
+ inputs_embeds: Optional[torch.Tensor] = None,
1352
+ start_positions: Optional[torch.Tensor] = None,
1353
+ end_positions: Optional[torch.Tensor] = None,
1354
+ output_attentions: Optional[bool] = None,
1355
+ output_hidden_states: Optional[bool] = None,
1356
+ return_dict: Optional[bool] = None,
1357
+ unpad_inputs: Optional[bool] = None,
1358
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1359
+ r"""
1360
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1361
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1362
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1363
+ are not taken into account for computing the loss.
1364
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1365
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1366
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1367
+ are not taken into account for computing the loss.
1368
+ """
1369
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1370
+
1371
+ outputs = self.new(
1372
+ input_ids,
1373
+ attention_mask=attention_mask,
1374
+ token_type_ids=token_type_ids,
1375
+ position_ids=position_ids,
1376
+ head_mask=head_mask,
1377
+ inputs_embeds=inputs_embeds,
1378
+ output_attentions=output_attentions,
1379
+ output_hidden_states=output_hidden_states,
1380
+ return_dict=return_dict,
1381
+ unpad_inputs=unpad_inputs,
1382
+ )
1383
+
1384
+ sequence_output = outputs[0]
1385
+
1386
+ logits = self.qa_outputs(sequence_output)
1387
+ start_logits, end_logits = logits.split(1, dim=-1)
1388
+ start_logits = start_logits.squeeze(-1).contiguous()
1389
+ end_logits = end_logits.squeeze(-1).contiguous()
1390
+
1391
+ total_loss = None
1392
+ if start_positions is not None and end_positions is not None:
1393
+ # If we are on multi-GPU, split add a dimension
1394
+ if len(start_positions.size()) > 1:
1395
+ start_positions = start_positions.squeeze(-1)
1396
+ if len(end_positions.size()) > 1:
1397
+ end_positions = end_positions.squeeze(-1)
1398
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1399
+ ignored_index = start_logits.size(1)
1400
+ start_positions = start_positions.clamp(0, ignored_index)
1401
+ end_positions = end_positions.clamp(0, ignored_index)
1402
+
1403
+ loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
1404
+ start_loss = loss_fct(start_logits, start_positions)
1405
+ end_loss = loss_fct(end_logits, end_positions)
1406
+ total_loss = (start_loss + end_loss) / 2
1407
+
1408
+ if not return_dict:
1409
+ output = (start_logits, end_logits) + outputs[2:]
1410
+ return ((total_loss,) + output) if total_loss is not None else output
1411
+
1412
+ return QuestionAnsweringModelOutput(
1413
+ loss=total_loss,
1414
+ start_logits=start_logits,
1415
+ end_logits=end_logits,
1416
+ hidden_states=outputs.hidden_states,
1417
+ attentions=outputs.attentions,
1418
+ )
checkpoint-200/rng_state_0.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:491b54b4d910a39ffed7cc8b9a1a7522a2d46cb24fde3b6e4ea9e59e2b769678
3
+ size 15984
checkpoint-200/rng_state_1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb57c59b1177680843a1b2932345a1bfdb11f77e64fae2e5464422d422b59fae
3
+ size 15984
checkpoint-200/rng_state_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13de24762cc88e04c2704b85d4e642bcf9031ef34e7817506fbb495c62e467f8
3
+ size 15984
checkpoint-200/rng_state_3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:888d6550bcdbac3947838d60366b2f5873011d3b7f300f87f38058d7845e6a2e
3
+ size 15920
checkpoint-200/rng_state_4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09dfc7f8fde7376ae2317747d57da83ba9ac4b7494b06444f753f2fbaddee07a
3
+ size 15920
checkpoint-200/rng_state_5.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ada59ce425718dd1feb844fb76799ec8199471c2f84b4055764e543aa59b6d03
3
+ size 15920
checkpoint-200/rng_state_6.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22e7d80046ff31c27340393abfc26100bdcecfb1edb86eeaf7c87625008ca77b
3
+ size 15984
checkpoint-200/rng_state_7.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e197de07e3bd0af6fddcdbccbafb25a8b6175280fbc1e2b113d3bc50508a9fe
3
+ size 15920
checkpoint-200/special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "<s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "<mask>",
25
+ "lstrip": true,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "<pad>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "</s>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "<unk>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
checkpoint-200/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b11d4b2e33aa47529535e2aa477b7bae26b5490e6bc956fa46422dfefd4b9366
3
+ size 17082833
checkpoint-200/tokenizer_config.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<s>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<pad>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "</s>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "250001": {
36
+ "content": "<mask>",
37
+ "lstrip": true,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "bos_token": "<s>",
45
+ "clean_up_tokenization_spaces": true,
46
+ "cls_token": "<s>",
47
+ "eos_token": "</s>",
48
+ "mask_token": "<mask>",
49
+ "max_length": 512,
50
+ "model_max_length": 32768,
51
+ "pad_to_multiple_of": null,
52
+ "pad_token": "<pad>",
53
+ "pad_token_type_id": 0,
54
+ "padding_side": "right",
55
+ "sep_token": "</s>",
56
+ "stride": 0,
57
+ "tokenizer_class": "XLMRobertaTokenizer",
58
+ "truncation_side": "right",
59
+ "truncation_strategy": "longest_first",
60
+ "unk_token": "<unk>"
61
+ }
checkpoint-200/trainer_state.json ADDED
@@ -0,0 +1,1433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": null,
3
+ "best_model_checkpoint": null,
4
+ "epoch": 19.96879875195008,
5
+ "eval_steps": 500,
6
+ "global_step": 200,
7
+ "is_hyper_param_search": false,
8
+ "is_local_process_zero": true,
9
+ "is_world_process_zero": true,
10
+ "log_history": [
11
+ {
12
+ "epoch": 0.0998439937597504,
13
+ "grad_norm": 0.4140388733942774,
14
+ "learning_rate": 0.0,
15
+ "loss": 1.9149,
16
+ "step": 1
17
+ },
18
+ {
19
+ "epoch": 0.1996879875195008,
20
+ "grad_norm": 4.125320422125423,
21
+ "learning_rate": 0.0001,
22
+ "loss": 2.0842,
23
+ "step": 2
24
+ },
25
+ {
26
+ "epoch": 0.2995319812792512,
27
+ "grad_norm": 4.114561571796027,
28
+ "learning_rate": 0.0001,
29
+ "loss": 2.0845,
30
+ "step": 3
31
+ },
32
+ {
33
+ "epoch": 0.3993759750390016,
34
+ "grad_norm": 7.814135082802521,
35
+ "learning_rate": 9.94949494949495e-05,
36
+ "loss": 2.2529,
37
+ "step": 4
38
+ },
39
+ {
40
+ "epoch": 0.49921996879875197,
41
+ "grad_norm": 4.959578007928057,
42
+ "learning_rate": 9.8989898989899e-05,
43
+ "loss": 2.1802,
44
+ "step": 5
45
+ },
46
+ {
47
+ "epoch": 0.5990639625585024,
48
+ "grad_norm": 4.020980884960991,
49
+ "learning_rate": 9.848484848484849e-05,
50
+ "loss": 2.0996,
51
+ "step": 6
52
+ },
53
+ {
54
+ "epoch": 0.6989079563182528,
55
+ "grad_norm": 2.913148956691379,
56
+ "learning_rate": 9.797979797979798e-05,
57
+ "loss": 2.0739,
58
+ "step": 7
59
+ },
60
+ {
61
+ "epoch": 0.7987519500780031,
62
+ "grad_norm": 1.5904378694728827,
63
+ "learning_rate": 9.747474747474747e-05,
64
+ "loss": 2.0438,
65
+ "step": 8
66
+ },
67
+ {
68
+ "epoch": 0.8985959438377535,
69
+ "grad_norm": 1.1120971271683253,
70
+ "learning_rate": 9.696969696969698e-05,
71
+ "loss": 2.023,
72
+ "step": 9
73
+ },
74
+ {
75
+ "epoch": 0.9984399375975039,
76
+ "grad_norm": 1.4857926676190432,
77
+ "learning_rate": 9.646464646464647e-05,
78
+ "loss": 2.0256,
79
+ "step": 10
80
+ },
81
+ {
82
+ "epoch": 1.0982839313572543,
83
+ "grad_norm": 1.7514837650567026,
84
+ "learning_rate": 9.595959595959596e-05,
85
+ "loss": 2.0297,
86
+ "step": 11
87
+ },
88
+ {
89
+ "epoch": 1.1981279251170047,
90
+ "grad_norm": 1.3684063083659384,
91
+ "learning_rate": 9.545454545454546e-05,
92
+ "loss": 1.9984,
93
+ "step": 12
94
+ },
95
+ {
96
+ "epoch": 1.2979719188767551,
97
+ "grad_norm": 0.8391319984906789,
98
+ "learning_rate": 9.494949494949495e-05,
99
+ "loss": 1.9938,
100
+ "step": 13
101
+ },
102
+ {
103
+ "epoch": 1.3978159126365055,
104
+ "grad_norm": 0.7680152250335479,
105
+ "learning_rate": 9.444444444444444e-05,
106
+ "loss": 1.9984,
107
+ "step": 14
108
+ },
109
+ {
110
+ "epoch": 1.497659906396256,
111
+ "grad_norm": 1.1427900590537006,
112
+ "learning_rate": 9.393939393939395e-05,
113
+ "loss": 2.0059,
114
+ "step": 15
115
+ },
116
+ {
117
+ "epoch": 1.5975039001560063,
118
+ "grad_norm": 1.2852588832884364,
119
+ "learning_rate": 9.343434343434344e-05,
120
+ "loss": 2.0057,
121
+ "step": 16
122
+ },
123
+ {
124
+ "epoch": 1.6973478939157567,
125
+ "grad_norm": 0.8509981577726656,
126
+ "learning_rate": 9.292929292929293e-05,
127
+ "loss": 1.97,
128
+ "step": 17
129
+ },
130
+ {
131
+ "epoch": 1.797191887675507,
132
+ "grad_norm": 0.4374249257660765,
133
+ "learning_rate": 9.242424242424242e-05,
134
+ "loss": 1.9852,
135
+ "step": 18
136
+ },
137
+ {
138
+ "epoch": 1.8970358814352575,
139
+ "grad_norm": 1.006945747433108,
140
+ "learning_rate": 9.191919191919192e-05,
141
+ "loss": 1.9736,
142
+ "step": 19
143
+ },
144
+ {
145
+ "epoch": 1.9968798751950079,
146
+ "grad_norm": 1.1714326150671521,
147
+ "learning_rate": 9.141414141414141e-05,
148
+ "loss": 1.9866,
149
+ "step": 20
150
+ },
151
+ {
152
+ "epoch": 2.0967238689547583,
153
+ "grad_norm": 0.6697915843016325,
154
+ "learning_rate": 9.090909090909092e-05,
155
+ "loss": 1.9669,
156
+ "step": 21
157
+ },
158
+ {
159
+ "epoch": 2.1965678627145087,
160
+ "grad_norm": 0.43542954442572934,
161
+ "learning_rate": 9.040404040404041e-05,
162
+ "loss": 1.9596,
163
+ "step": 22
164
+ },
165
+ {
166
+ "epoch": 2.296411856474259,
167
+ "grad_norm": 0.8895989581186896,
168
+ "learning_rate": 8.98989898989899e-05,
169
+ "loss": 1.9777,
170
+ "step": 23
171
+ },
172
+ {
173
+ "epoch": 2.3962558502340094,
174
+ "grad_norm": 0.748473401890919,
175
+ "learning_rate": 8.93939393939394e-05,
176
+ "loss": 1.9828,
177
+ "step": 24
178
+ },
179
+ {
180
+ "epoch": 2.49609984399376,
181
+ "grad_norm": 0.4762840239068188,
182
+ "learning_rate": 8.888888888888889e-05,
183
+ "loss": 1.9863,
184
+ "step": 25
185
+ },
186
+ {
187
+ "epoch": 2.5959438377535102,
188
+ "grad_norm": 0.4634914120924797,
189
+ "learning_rate": 8.83838383838384e-05,
190
+ "loss": 1.9728,
191
+ "step": 26
192
+ },
193
+ {
194
+ "epoch": 2.6957878315132606,
195
+ "grad_norm": 0.576721007312459,
196
+ "learning_rate": 8.787878787878789e-05,
197
+ "loss": 1.9813,
198
+ "step": 27
199
+ },
200
+ {
201
+ "epoch": 2.795631825273011,
202
+ "grad_norm": 0.4717088615288276,
203
+ "learning_rate": 8.737373737373738e-05,
204
+ "loss": 1.9709,
205
+ "step": 28
206
+ },
207
+ {
208
+ "epoch": 2.8954758190327614,
209
+ "grad_norm": 0.5243076095653101,
210
+ "learning_rate": 8.686868686868688e-05,
211
+ "loss": 1.9889,
212
+ "step": 29
213
+ },
214
+ {
215
+ "epoch": 2.995319812792512,
216
+ "grad_norm": 0.35563844256479116,
217
+ "learning_rate": 8.636363636363637e-05,
218
+ "loss": 1.969,
219
+ "step": 30
220
+ },
221
+ {
222
+ "epoch": 3.095163806552262,
223
+ "grad_norm": 0.5040313506272054,
224
+ "learning_rate": 8.585858585858586e-05,
225
+ "loss": 1.9701,
226
+ "step": 31
227
+ },
228
+ {
229
+ "epoch": 3.1950078003120126,
230
+ "grad_norm": 0.5293887294443628,
231
+ "learning_rate": 8.535353535353535e-05,
232
+ "loss": 1.9774,
233
+ "step": 32
234
+ },
235
+ {
236
+ "epoch": 3.294851794071763,
237
+ "grad_norm": 0.33336016676106733,
238
+ "learning_rate": 8.484848484848486e-05,
239
+ "loss": 1.9702,
240
+ "step": 33
241
+ },
242
+ {
243
+ "epoch": 3.3946957878315134,
244
+ "grad_norm": 0.5156182664373749,
245
+ "learning_rate": 8.434343434343435e-05,
246
+ "loss": 1.9552,
247
+ "step": 34
248
+ },
249
+ {
250
+ "epoch": 3.4945397815912638,
251
+ "grad_norm": 0.410792592829029,
252
+ "learning_rate": 8.383838383838384e-05,
253
+ "loss": 1.9642,
254
+ "step": 35
255
+ },
256
+ {
257
+ "epoch": 3.594383775351014,
258
+ "grad_norm": 0.40267682408922495,
259
+ "learning_rate": 8.333333333333334e-05,
260
+ "loss": 1.9688,
261
+ "step": 36
262
+ },
263
+ {
264
+ "epoch": 3.6942277691107646,
265
+ "grad_norm": 0.3869359148412346,
266
+ "learning_rate": 8.282828282828283e-05,
267
+ "loss": 1.9733,
268
+ "step": 37
269
+ },
270
+ {
271
+ "epoch": 3.794071762870515,
272
+ "grad_norm": 0.37728712869432585,
273
+ "learning_rate": 8.232323232323233e-05,
274
+ "loss": 1.9648,
275
+ "step": 38
276
+ },
277
+ {
278
+ "epoch": 3.8939157566302653,
279
+ "grad_norm": 0.3922418131207954,
280
+ "learning_rate": 8.181818181818183e-05,
281
+ "loss": 1.9689,
282
+ "step": 39
283
+ },
284
+ {
285
+ "epoch": 3.9937597503900157,
286
+ "grad_norm": 0.26353046722639645,
287
+ "learning_rate": 8.131313131313132e-05,
288
+ "loss": 1.9727,
289
+ "step": 40
290
+ },
291
+ {
292
+ "epoch": 4.093603744149766,
293
+ "grad_norm": 0.3911091474488452,
294
+ "learning_rate": 8.080808080808081e-05,
295
+ "loss": 1.9631,
296
+ "step": 41
297
+ },
298
+ {
299
+ "epoch": 4.1934477379095165,
300
+ "grad_norm": 0.33402240826623614,
301
+ "learning_rate": 8.03030303030303e-05,
302
+ "loss": 1.9665,
303
+ "step": 42
304
+ },
305
+ {
306
+ "epoch": 4.2932917316692665,
307
+ "grad_norm": 0.34654808232868395,
308
+ "learning_rate": 7.97979797979798e-05,
309
+ "loss": 1.9646,
310
+ "step": 43
311
+ },
312
+ {
313
+ "epoch": 4.393135725429017,
314
+ "grad_norm": 0.3031078864703629,
315
+ "learning_rate": 7.92929292929293e-05,
316
+ "loss": 1.9693,
317
+ "step": 44
318
+ },
319
+ {
320
+ "epoch": 4.492979719188767,
321
+ "grad_norm": 0.35342072957234116,
322
+ "learning_rate": 7.878787878787879e-05,
323
+ "loss": 1.9688,
324
+ "step": 45
325
+ },
326
+ {
327
+ "epoch": 4.592823712948518,
328
+ "grad_norm": 0.3918161921811716,
329
+ "learning_rate": 7.828282828282829e-05,
330
+ "loss": 1.9609,
331
+ "step": 46
332
+ },
333
+ {
334
+ "epoch": 4.692667706708268,
335
+ "grad_norm": 0.24995683506017796,
336
+ "learning_rate": 7.777777777777778e-05,
337
+ "loss": 1.9515,
338
+ "step": 47
339
+ },
340
+ {
341
+ "epoch": 4.792511700468019,
342
+ "grad_norm": 0.3308078104166398,
343
+ "learning_rate": 7.727272727272727e-05,
344
+ "loss": 1.9607,
345
+ "step": 48
346
+ },
347
+ {
348
+ "epoch": 4.892355694227769,
349
+ "grad_norm": 0.3130926472973521,
350
+ "learning_rate": 7.676767676767676e-05,
351
+ "loss": 1.9699,
352
+ "step": 49
353
+ },
354
+ {
355
+ "epoch": 4.99219968798752,
356
+ "grad_norm": 0.30892356920484393,
357
+ "learning_rate": 7.626262626262627e-05,
358
+ "loss": 1.9645,
359
+ "step": 50
360
+ },
361
+ {
362
+ "epoch": 5.09204368174727,
363
+ "grad_norm": 0.2804202715276883,
364
+ "learning_rate": 7.575757575757576e-05,
365
+ "loss": 1.9569,
366
+ "step": 51
367
+ },
368
+ {
369
+ "epoch": 5.1918876755070205,
370
+ "grad_norm": 0.2789049399636327,
371
+ "learning_rate": 7.525252525252525e-05,
372
+ "loss": 1.9585,
373
+ "step": 52
374
+ },
375
+ {
376
+ "epoch": 5.29173166926677,
377
+ "grad_norm": 0.2906929505804403,
378
+ "learning_rate": 7.474747474747475e-05,
379
+ "loss": 1.9565,
380
+ "step": 53
381
+ },
382
+ {
383
+ "epoch": 5.391575663026521,
384
+ "grad_norm": 0.2033727950080347,
385
+ "learning_rate": 7.424242424242424e-05,
386
+ "loss": 1.9755,
387
+ "step": 54
388
+ },
389
+ {
390
+ "epoch": 5.491419656786271,
391
+ "grad_norm": 0.31364461369416163,
392
+ "learning_rate": 7.373737373737373e-05,
393
+ "loss": 1.9647,
394
+ "step": 55
395
+ },
396
+ {
397
+ "epoch": 5.591263650546022,
398
+ "grad_norm": 0.2531087381531638,
399
+ "learning_rate": 7.323232323232324e-05,
400
+ "loss": 1.9578,
401
+ "step": 56
402
+ },
403
+ {
404
+ "epoch": 5.691107644305772,
405
+ "grad_norm": 0.23764498764830225,
406
+ "learning_rate": 7.272727272727273e-05,
407
+ "loss": 1.9617,
408
+ "step": 57
409
+ },
410
+ {
411
+ "epoch": 5.790951638065523,
412
+ "grad_norm": 0.24888591334854687,
413
+ "learning_rate": 7.222222222222222e-05,
414
+ "loss": 1.963,
415
+ "step": 58
416
+ },
417
+ {
418
+ "epoch": 5.890795631825273,
419
+ "grad_norm": 0.2647075657405339,
420
+ "learning_rate": 7.171717171717171e-05,
421
+ "loss": 1.9685,
422
+ "step": 59
423
+ },
424
+ {
425
+ "epoch": 5.990639625585024,
426
+ "grad_norm": 0.27820470985704615,
427
+ "learning_rate": 7.121212121212121e-05,
428
+ "loss": 1.9654,
429
+ "step": 60
430
+ },
431
+ {
432
+ "epoch": 6.090483619344774,
433
+ "grad_norm": 0.20068946885468097,
434
+ "learning_rate": 7.07070707070707e-05,
435
+ "loss": 1.9667,
436
+ "step": 61
437
+ },
438
+ {
439
+ "epoch": 6.190327613104524,
440
+ "grad_norm": 0.25026234630326394,
441
+ "learning_rate": 7.020202020202021e-05,
442
+ "loss": 1.9542,
443
+ "step": 62
444
+ },
445
+ {
446
+ "epoch": 6.290171606864274,
447
+ "grad_norm": 0.22856925269883635,
448
+ "learning_rate": 6.96969696969697e-05,
449
+ "loss": 1.9573,
450
+ "step": 63
451
+ },
452
+ {
453
+ "epoch": 6.390015600624025,
454
+ "grad_norm": 0.2392183076591563,
455
+ "learning_rate": 6.91919191919192e-05,
456
+ "loss": 1.9647,
457
+ "step": 64
458
+ },
459
+ {
460
+ "epoch": 6.489859594383775,
461
+ "grad_norm": 0.20384525102843132,
462
+ "learning_rate": 6.86868686868687e-05,
463
+ "loss": 1.9628,
464
+ "step": 65
465
+ },
466
+ {
467
+ "epoch": 6.589703588143526,
468
+ "grad_norm": 0.23941897200051984,
469
+ "learning_rate": 6.818181818181818e-05,
470
+ "loss": 1.9667,
471
+ "step": 66
472
+ },
473
+ {
474
+ "epoch": 6.689547581903276,
475
+ "grad_norm": 0.20375278551444306,
476
+ "learning_rate": 6.767676767676769e-05,
477
+ "loss": 1.9572,
478
+ "step": 67
479
+ },
480
+ {
481
+ "epoch": 6.789391575663027,
482
+ "grad_norm": 0.20727005267599333,
483
+ "learning_rate": 6.717171717171718e-05,
484
+ "loss": 1.9581,
485
+ "step": 68
486
+ },
487
+ {
488
+ "epoch": 6.889235569422777,
489
+ "grad_norm": 0.22300809533132504,
490
+ "learning_rate": 6.666666666666667e-05,
491
+ "loss": 1.9693,
492
+ "step": 69
493
+ },
494
+ {
495
+ "epoch": 6.9890795631825275,
496
+ "grad_norm": 0.21742318730398613,
497
+ "learning_rate": 6.616161616161617e-05,
498
+ "loss": 1.9656,
499
+ "step": 70
500
+ },
501
+ {
502
+ "epoch": 7.0889235569422775,
503
+ "grad_norm": 0.20343822391223,
504
+ "learning_rate": 6.565656565656566e-05,
505
+ "loss": 1.9656,
506
+ "step": 71
507
+ },
508
+ {
509
+ "epoch": 7.188767550702028,
510
+ "grad_norm": 0.2364200066637671,
511
+ "learning_rate": 6.515151515151516e-05,
512
+ "loss": 1.95,
513
+ "step": 72
514
+ },
515
+ {
516
+ "epoch": 7.288611544461778,
517
+ "grad_norm": 0.18261048615751524,
518
+ "learning_rate": 6.464646464646466e-05,
519
+ "loss": 1.9538,
520
+ "step": 73
521
+ },
522
+ {
523
+ "epoch": 7.388455538221529,
524
+ "grad_norm": 0.24533487813163474,
525
+ "learning_rate": 6.414141414141415e-05,
526
+ "loss": 1.952,
527
+ "step": 74
528
+ },
529
+ {
530
+ "epoch": 7.488299531981279,
531
+ "grad_norm": 0.2539612930496735,
532
+ "learning_rate": 6.363636363636364e-05,
533
+ "loss": 1.9598,
534
+ "step": 75
535
+ },
536
+ {
537
+ "epoch": 7.58814352574103,
538
+ "grad_norm": 0.2991457603613546,
539
+ "learning_rate": 6.313131313131313e-05,
540
+ "loss": 1.9733,
541
+ "step": 76
542
+ },
543
+ {
544
+ "epoch": 7.68798751950078,
545
+ "grad_norm": 0.2209105824741669,
546
+ "learning_rate": 6.262626262626264e-05,
547
+ "loss": 1.9531,
548
+ "step": 77
549
+ },
550
+ {
551
+ "epoch": 7.787831513260531,
552
+ "grad_norm": 0.31698310918964695,
553
+ "learning_rate": 6.212121212121213e-05,
554
+ "loss": 1.964,
555
+ "step": 78
556
+ },
557
+ {
558
+ "epoch": 7.887675507020281,
559
+ "grad_norm": 0.17584276182725114,
560
+ "learning_rate": 6.161616161616162e-05,
561
+ "loss": 1.9644,
562
+ "step": 79
563
+ },
564
+ {
565
+ "epoch": 7.9875195007800315,
566
+ "grad_norm": 0.29996919622824225,
567
+ "learning_rate": 6.111111111111112e-05,
568
+ "loss": 1.9555,
569
+ "step": 80
570
+ },
571
+ {
572
+ "epoch": 8.087363494539781,
573
+ "grad_norm": 0.2406502202676367,
574
+ "learning_rate": 6.060606060606061e-05,
575
+ "loss": 1.9554,
576
+ "step": 81
577
+ },
578
+ {
579
+ "epoch": 8.187207488299531,
580
+ "grad_norm": 0.32705142732170855,
581
+ "learning_rate": 6.01010101010101e-05,
582
+ "loss": 1.9517,
583
+ "step": 82
584
+ },
585
+ {
586
+ "epoch": 8.287051482059283,
587
+ "grad_norm": 0.27249925952338305,
588
+ "learning_rate": 5.959595959595959e-05,
589
+ "loss": 1.9474,
590
+ "step": 83
591
+ },
592
+ {
593
+ "epoch": 8.386895475819033,
594
+ "grad_norm": 0.29448831027669287,
595
+ "learning_rate": 5.90909090909091e-05,
596
+ "loss": 1.9459,
597
+ "step": 84
598
+ },
599
+ {
600
+ "epoch": 8.486739469578783,
601
+ "grad_norm": 0.29998154037028857,
602
+ "learning_rate": 5.858585858585859e-05,
603
+ "loss": 1.9606,
604
+ "step": 85
605
+ },
606
+ {
607
+ "epoch": 8.586583463338533,
608
+ "grad_norm": 0.23153724936859055,
609
+ "learning_rate": 5.808080808080808e-05,
610
+ "loss": 1.9598,
611
+ "step": 86
612
+ },
613
+ {
614
+ "epoch": 8.686427457098285,
615
+ "grad_norm": 0.22081595887056477,
616
+ "learning_rate": 5.757575757575758e-05,
617
+ "loss": 1.9586,
618
+ "step": 87
619
+ },
620
+ {
621
+ "epoch": 8.786271450858035,
622
+ "grad_norm": 0.19177670537863922,
623
+ "learning_rate": 5.707070707070707e-05,
624
+ "loss": 1.9715,
625
+ "step": 88
626
+ },
627
+ {
628
+ "epoch": 8.886115444617785,
629
+ "grad_norm": 0.25725928107907137,
630
+ "learning_rate": 5.6565656565656563e-05,
631
+ "loss": 1.9602,
632
+ "step": 89
633
+ },
634
+ {
635
+ "epoch": 8.985959438377535,
636
+ "grad_norm": 0.26044371305524344,
637
+ "learning_rate": 5.606060606060606e-05,
638
+ "loss": 1.9607,
639
+ "step": 90
640
+ },
641
+ {
642
+ "epoch": 9.085803432137286,
643
+ "grad_norm": 0.23728151561491595,
644
+ "learning_rate": 5.555555555555556e-05,
645
+ "loss": 1.9588,
646
+ "step": 91
647
+ },
648
+ {
649
+ "epoch": 9.185647425897036,
650
+ "grad_norm": 0.20354348868729488,
651
+ "learning_rate": 5.5050505050505056e-05,
652
+ "loss": 1.9492,
653
+ "step": 92
654
+ },
655
+ {
656
+ "epoch": 9.285491419656786,
657
+ "grad_norm": 0.18672087839741056,
658
+ "learning_rate": 5.4545454545454546e-05,
659
+ "loss": 1.9457,
660
+ "step": 93
661
+ },
662
+ {
663
+ "epoch": 9.385335413416536,
664
+ "grad_norm": 0.1939858201242329,
665
+ "learning_rate": 5.4040404040404044e-05,
666
+ "loss": 1.9453,
667
+ "step": 94
668
+ },
669
+ {
670
+ "epoch": 9.485179407176288,
671
+ "grad_norm": 0.19172060706771135,
672
+ "learning_rate": 5.353535353535354e-05,
673
+ "loss": 1.958,
674
+ "step": 95
675
+ },
676
+ {
677
+ "epoch": 9.585023400936038,
678
+ "grad_norm": 0.1837920882880991,
679
+ "learning_rate": 5.303030303030303e-05,
680
+ "loss": 1.9577,
681
+ "step": 96
682
+ },
683
+ {
684
+ "epoch": 9.684867394695788,
685
+ "grad_norm": 0.2162949878555464,
686
+ "learning_rate": 5.2525252525252536e-05,
687
+ "loss": 1.9622,
688
+ "step": 97
689
+ },
690
+ {
691
+ "epoch": 9.784711388455538,
692
+ "grad_norm": 0.19325381586186333,
693
+ "learning_rate": 5.2020202020202026e-05,
694
+ "loss": 1.9433,
695
+ "step": 98
696
+ },
697
+ {
698
+ "epoch": 9.88455538221529,
699
+ "grad_norm": 0.2018142831658023,
700
+ "learning_rate": 5.151515151515152e-05,
701
+ "loss": 1.9605,
702
+ "step": 99
703
+ },
704
+ {
705
+ "epoch": 9.98439937597504,
706
+ "grad_norm": 0.176671565601027,
707
+ "learning_rate": 5.101010101010101e-05,
708
+ "loss": 1.9578,
709
+ "step": 100
710
+ },
711
+ {
712
+ "epoch": 10.08424336973479,
713
+ "grad_norm": 0.2117788085352089,
714
+ "learning_rate": 5.050505050505051e-05,
715
+ "loss": 1.9478,
716
+ "step": 101
717
+ },
718
+ {
719
+ "epoch": 10.18408736349454,
720
+ "grad_norm": 0.1816135304249716,
721
+ "learning_rate": 5e-05,
722
+ "loss": 1.9423,
723
+ "step": 102
724
+ },
725
+ {
726
+ "epoch": 10.283931357254291,
727
+ "grad_norm": 0.2680310363226074,
728
+ "learning_rate": 4.94949494949495e-05,
729
+ "loss": 1.9519,
730
+ "step": 103
731
+ },
732
+ {
733
+ "epoch": 10.383775351014041,
734
+ "grad_norm": 0.17934299698555412,
735
+ "learning_rate": 4.898989898989899e-05,
736
+ "loss": 1.9625,
737
+ "step": 104
738
+ },
739
+ {
740
+ "epoch": 10.48361934477379,
741
+ "grad_norm": 0.19786074542682824,
742
+ "learning_rate": 4.848484848484849e-05,
743
+ "loss": 1.95,
744
+ "step": 105
745
+ },
746
+ {
747
+ "epoch": 10.58346333853354,
748
+ "grad_norm": 0.17490489580858018,
749
+ "learning_rate": 4.797979797979798e-05,
750
+ "loss": 1.9513,
751
+ "step": 106
752
+ },
753
+ {
754
+ "epoch": 10.683307332293293,
755
+ "grad_norm": 0.224513887757472,
756
+ "learning_rate": 4.7474747474747476e-05,
757
+ "loss": 1.9499,
758
+ "step": 107
759
+ },
760
+ {
761
+ "epoch": 10.783151326053042,
762
+ "grad_norm": 0.16993980203530532,
763
+ "learning_rate": 4.696969696969697e-05,
764
+ "loss": 1.944,
765
+ "step": 108
766
+ },
767
+ {
768
+ "epoch": 10.882995319812792,
769
+ "grad_norm": 0.18436224376063975,
770
+ "learning_rate": 4.6464646464646464e-05,
771
+ "loss": 1.9494,
772
+ "step": 109
773
+ },
774
+ {
775
+ "epoch": 10.982839313572542,
776
+ "grad_norm": 0.1858801208244774,
777
+ "learning_rate": 4.595959595959596e-05,
778
+ "loss": 1.9504,
779
+ "step": 110
780
+ },
781
+ {
782
+ "epoch": 11.082683307332294,
783
+ "grad_norm": 0.21397140437157122,
784
+ "learning_rate": 4.545454545454546e-05,
785
+ "loss": 1.9467,
786
+ "step": 111
787
+ },
788
+ {
789
+ "epoch": 11.182527301092044,
790
+ "grad_norm": 0.16934479906947233,
791
+ "learning_rate": 4.494949494949495e-05,
792
+ "loss": 1.9475,
793
+ "step": 112
794
+ },
795
+ {
796
+ "epoch": 11.282371294851794,
797
+ "grad_norm": 0.17710883760980722,
798
+ "learning_rate": 4.4444444444444447e-05,
799
+ "loss": 1.9418,
800
+ "step": 113
801
+ },
802
+ {
803
+ "epoch": 11.382215288611544,
804
+ "grad_norm": 0.2278025006688675,
805
+ "learning_rate": 4.3939393939393944e-05,
806
+ "loss": 1.9456,
807
+ "step": 114
808
+ },
809
+ {
810
+ "epoch": 11.482059282371296,
811
+ "grad_norm": 0.18727166458531316,
812
+ "learning_rate": 4.343434343434344e-05,
813
+ "loss": 1.9408,
814
+ "step": 115
815
+ },
816
+ {
817
+ "epoch": 11.581903276131046,
818
+ "grad_norm": 0.17348080665741175,
819
+ "learning_rate": 4.292929292929293e-05,
820
+ "loss": 1.9367,
821
+ "step": 116
822
+ },
823
+ {
824
+ "epoch": 11.681747269890796,
825
+ "grad_norm": 0.21559975863343248,
826
+ "learning_rate": 4.242424242424243e-05,
827
+ "loss": 1.9509,
828
+ "step": 117
829
+ },
830
+ {
831
+ "epoch": 11.781591263650546,
832
+ "grad_norm": 0.20515384184563593,
833
+ "learning_rate": 4.191919191919192e-05,
834
+ "loss": 1.9503,
835
+ "step": 118
836
+ },
837
+ {
838
+ "epoch": 11.881435257410295,
839
+ "grad_norm": 0.17579996751101729,
840
+ "learning_rate": 4.141414141414142e-05,
841
+ "loss": 1.9443,
842
+ "step": 119
843
+ },
844
+ {
845
+ "epoch": 11.981279251170047,
846
+ "grad_norm": 0.1870399234707776,
847
+ "learning_rate": 4.0909090909090915e-05,
848
+ "loss": 1.9507,
849
+ "step": 120
850
+ },
851
+ {
852
+ "epoch": 12.081123244929797,
853
+ "grad_norm": 0.2323975590399996,
854
+ "learning_rate": 4.0404040404040405e-05,
855
+ "loss": 1.9486,
856
+ "step": 121
857
+ },
858
+ {
859
+ "epoch": 12.180967238689547,
860
+ "grad_norm": 0.17332911391024705,
861
+ "learning_rate": 3.98989898989899e-05,
862
+ "loss": 1.9441,
863
+ "step": 122
864
+ },
865
+ {
866
+ "epoch": 12.280811232449299,
867
+ "grad_norm": 0.23886491083540215,
868
+ "learning_rate": 3.939393939393939e-05,
869
+ "loss": 1.9489,
870
+ "step": 123
871
+ },
872
+ {
873
+ "epoch": 12.380655226209049,
874
+ "grad_norm": 0.192192583869745,
875
+ "learning_rate": 3.888888888888889e-05,
876
+ "loss": 1.936,
877
+ "step": 124
878
+ },
879
+ {
880
+ "epoch": 12.480499219968799,
881
+ "grad_norm": 0.24070020033146947,
882
+ "learning_rate": 3.838383838383838e-05,
883
+ "loss": 1.9363,
884
+ "step": 125
885
+ },
886
+ {
887
+ "epoch": 12.580343213728549,
888
+ "grad_norm": 0.17061145664967614,
889
+ "learning_rate": 3.787878787878788e-05,
890
+ "loss": 1.947,
891
+ "step": 126
892
+ },
893
+ {
894
+ "epoch": 12.680187207488299,
895
+ "grad_norm": 0.20420044689274344,
896
+ "learning_rate": 3.7373737373737376e-05,
897
+ "loss": 1.9462,
898
+ "step": 127
899
+ },
900
+ {
901
+ "epoch": 12.78003120124805,
902
+ "grad_norm": 0.16640664781155742,
903
+ "learning_rate": 3.686868686868687e-05,
904
+ "loss": 1.9404,
905
+ "step": 128
906
+ },
907
+ {
908
+ "epoch": 12.8798751950078,
909
+ "grad_norm": 0.17534875646136103,
910
+ "learning_rate": 3.6363636363636364e-05,
911
+ "loss": 1.9441,
912
+ "step": 129
913
+ },
914
+ {
915
+ "epoch": 12.97971918876755,
916
+ "grad_norm": 0.1881647742956635,
917
+ "learning_rate": 3.5858585858585855e-05,
918
+ "loss": 1.9452,
919
+ "step": 130
920
+ },
921
+ {
922
+ "epoch": 13.0795631825273,
923
+ "grad_norm": 0.21130090774448568,
924
+ "learning_rate": 3.535353535353535e-05,
925
+ "loss": 1.938,
926
+ "step": 131
927
+ },
928
+ {
929
+ "epoch": 13.179407176287052,
930
+ "grad_norm": 0.19012207225624486,
931
+ "learning_rate": 3.484848484848485e-05,
932
+ "loss": 1.93,
933
+ "step": 132
934
+ },
935
+ {
936
+ "epoch": 13.279251170046802,
937
+ "grad_norm": 0.19535583015453165,
938
+ "learning_rate": 3.434343434343435e-05,
939
+ "loss": 1.9423,
940
+ "step": 133
941
+ },
942
+ {
943
+ "epoch": 13.379095163806552,
944
+ "grad_norm": 0.1972934873185412,
945
+ "learning_rate": 3.3838383838383844e-05,
946
+ "loss": 1.9449,
947
+ "step": 134
948
+ },
949
+ {
950
+ "epoch": 13.478939157566302,
951
+ "grad_norm": 0.21172258190614646,
952
+ "learning_rate": 3.3333333333333335e-05,
953
+ "loss": 1.9423,
954
+ "step": 135
955
+ },
956
+ {
957
+ "epoch": 13.578783151326054,
958
+ "grad_norm": 0.20243808248600392,
959
+ "learning_rate": 3.282828282828283e-05,
960
+ "loss": 1.9454,
961
+ "step": 136
962
+ },
963
+ {
964
+ "epoch": 13.678627145085803,
965
+ "grad_norm": 0.29468220957824104,
966
+ "learning_rate": 3.232323232323233e-05,
967
+ "loss": 1.9329,
968
+ "step": 137
969
+ },
970
+ {
971
+ "epoch": 13.778471138845553,
972
+ "grad_norm": 0.1852836649334086,
973
+ "learning_rate": 3.181818181818182e-05,
974
+ "loss": 1.9397,
975
+ "step": 138
976
+ },
977
+ {
978
+ "epoch": 13.878315132605305,
979
+ "grad_norm": 0.17635021846243693,
980
+ "learning_rate": 3.131313131313132e-05,
981
+ "loss": 1.9414,
982
+ "step": 139
983
+ },
984
+ {
985
+ "epoch": 13.978159126365055,
986
+ "grad_norm": 0.1837620268343685,
987
+ "learning_rate": 3.080808080808081e-05,
988
+ "loss": 1.9265,
989
+ "step": 140
990
+ },
991
+ {
992
+ "epoch": 14.078003120124805,
993
+ "grad_norm": 0.1851416429157977,
994
+ "learning_rate": 3.0303030303030306e-05,
995
+ "loss": 1.938,
996
+ "step": 141
997
+ },
998
+ {
999
+ "epoch": 14.177847113884555,
1000
+ "grad_norm": 0.18177436704033564,
1001
+ "learning_rate": 2.9797979797979796e-05,
1002
+ "loss": 1.9338,
1003
+ "step": 142
1004
+ },
1005
+ {
1006
+ "epoch": 14.277691107644305,
1007
+ "grad_norm": 0.20249599488147646,
1008
+ "learning_rate": 2.9292929292929294e-05,
1009
+ "loss": 1.943,
1010
+ "step": 143
1011
+ },
1012
+ {
1013
+ "epoch": 14.377535101404057,
1014
+ "grad_norm": 0.1914943764672633,
1015
+ "learning_rate": 2.878787878787879e-05,
1016
+ "loss": 1.9381,
1017
+ "step": 144
1018
+ },
1019
+ {
1020
+ "epoch": 14.477379095163807,
1021
+ "grad_norm": 0.18144339446500468,
1022
+ "learning_rate": 2.8282828282828282e-05,
1023
+ "loss": 1.9493,
1024
+ "step": 145
1025
+ },
1026
+ {
1027
+ "epoch": 14.577223088923557,
1028
+ "grad_norm": 0.22871591479507436,
1029
+ "learning_rate": 2.777777777777778e-05,
1030
+ "loss": 1.9394,
1031
+ "step": 146
1032
+ },
1033
+ {
1034
+ "epoch": 14.677067082683307,
1035
+ "grad_norm": 0.2409736531836878,
1036
+ "learning_rate": 2.7272727272727273e-05,
1037
+ "loss": 1.9363,
1038
+ "step": 147
1039
+ },
1040
+ {
1041
+ "epoch": 14.776911076443058,
1042
+ "grad_norm": 0.21702411701682794,
1043
+ "learning_rate": 2.676767676767677e-05,
1044
+ "loss": 1.9324,
1045
+ "step": 148
1046
+ },
1047
+ {
1048
+ "epoch": 14.876755070202808,
1049
+ "grad_norm": 0.186963824720383,
1050
+ "learning_rate": 2.6262626262626268e-05,
1051
+ "loss": 1.9254,
1052
+ "step": 149
1053
+ },
1054
+ {
1055
+ "epoch": 14.976599063962558,
1056
+ "grad_norm": 0.20551876684974787,
1057
+ "learning_rate": 2.575757575757576e-05,
1058
+ "loss": 1.9385,
1059
+ "step": 150
1060
+ },
1061
+ {
1062
+ "epoch": 15.076443057722308,
1063
+ "grad_norm": 0.17794734645273458,
1064
+ "learning_rate": 2.5252525252525256e-05,
1065
+ "loss": 1.935,
1066
+ "step": 151
1067
+ },
1068
+ {
1069
+ "epoch": 15.17628705148206,
1070
+ "grad_norm": 0.19787955354426204,
1071
+ "learning_rate": 2.474747474747475e-05,
1072
+ "loss": 1.9286,
1073
+ "step": 152
1074
+ },
1075
+ {
1076
+ "epoch": 15.27613104524181,
1077
+ "grad_norm": 0.21663975391838738,
1078
+ "learning_rate": 2.4242424242424244e-05,
1079
+ "loss": 1.9274,
1080
+ "step": 153
1081
+ },
1082
+ {
1083
+ "epoch": 15.37597503900156,
1084
+ "grad_norm": 0.19056508068402894,
1085
+ "learning_rate": 2.3737373737373738e-05,
1086
+ "loss": 1.9328,
1087
+ "step": 154
1088
+ },
1089
+ {
1090
+ "epoch": 15.47581903276131,
1091
+ "grad_norm": 0.20643529046597323,
1092
+ "learning_rate": 2.3232323232323232e-05,
1093
+ "loss": 1.9374,
1094
+ "step": 155
1095
+ },
1096
+ {
1097
+ "epoch": 15.575663026521061,
1098
+ "grad_norm": 0.17428582721990332,
1099
+ "learning_rate": 2.272727272727273e-05,
1100
+ "loss": 1.9435,
1101
+ "step": 156
1102
+ },
1103
+ {
1104
+ "epoch": 15.675507020280811,
1105
+ "grad_norm": 0.17915807350384474,
1106
+ "learning_rate": 2.2222222222222223e-05,
1107
+ "loss": 1.9342,
1108
+ "step": 157
1109
+ },
1110
+ {
1111
+ "epoch": 15.775351014040561,
1112
+ "grad_norm": 0.17934386940217817,
1113
+ "learning_rate": 2.171717171717172e-05,
1114
+ "loss": 1.9252,
1115
+ "step": 158
1116
+ },
1117
+ {
1118
+ "epoch": 15.875195007800311,
1119
+ "grad_norm": 0.16971494417624172,
1120
+ "learning_rate": 2.1212121212121215e-05,
1121
+ "loss": 1.9333,
1122
+ "step": 159
1123
+ },
1124
+ {
1125
+ "epoch": 15.975039001560063,
1126
+ "grad_norm": 0.1710725442382166,
1127
+ "learning_rate": 2.070707070707071e-05,
1128
+ "loss": 1.9397,
1129
+ "step": 160
1130
+ },
1131
+ {
1132
+ "epoch": 16.07488299531981,
1133
+ "grad_norm": 0.16048331708079347,
1134
+ "learning_rate": 2.0202020202020203e-05,
1135
+ "loss": 1.9354,
1136
+ "step": 161
1137
+ },
1138
+ {
1139
+ "epoch": 16.174726989079563,
1140
+ "grad_norm": 0.2209212482793433,
1141
+ "learning_rate": 1.9696969696969697e-05,
1142
+ "loss": 1.9572,
1143
+ "step": 162
1144
+ },
1145
+ {
1146
+ "epoch": 16.274570982839315,
1147
+ "grad_norm": 0.17292517371584637,
1148
+ "learning_rate": 1.919191919191919e-05,
1149
+ "loss": 1.9384,
1150
+ "step": 163
1151
+ },
1152
+ {
1153
+ "epoch": 16.374414976599063,
1154
+ "grad_norm": 0.1756696399704993,
1155
+ "learning_rate": 1.8686868686868688e-05,
1156
+ "loss": 1.9287,
1157
+ "step": 164
1158
+ },
1159
+ {
1160
+ "epoch": 16.474258970358814,
1161
+ "grad_norm": 0.193814973934712,
1162
+ "learning_rate": 1.8181818181818182e-05,
1163
+ "loss": 1.9285,
1164
+ "step": 165
1165
+ },
1166
+ {
1167
+ "epoch": 16.574102964118566,
1168
+ "grad_norm": 0.21108116449806094,
1169
+ "learning_rate": 1.7676767676767676e-05,
1170
+ "loss": 1.9249,
1171
+ "step": 166
1172
+ },
1173
+ {
1174
+ "epoch": 16.673946957878314,
1175
+ "grad_norm": 0.164152325154632,
1176
+ "learning_rate": 1.7171717171717173e-05,
1177
+ "loss": 1.9335,
1178
+ "step": 167
1179
+ },
1180
+ {
1181
+ "epoch": 16.773790951638066,
1182
+ "grad_norm": 0.1934976757474289,
1183
+ "learning_rate": 1.6666666666666667e-05,
1184
+ "loss": 1.9344,
1185
+ "step": 168
1186
+ },
1187
+ {
1188
+ "epoch": 16.873634945397814,
1189
+ "grad_norm": 0.17861559674997443,
1190
+ "learning_rate": 1.6161616161616165e-05,
1191
+ "loss": 1.9315,
1192
+ "step": 169
1193
+ },
1194
+ {
1195
+ "epoch": 16.973478939157566,
1196
+ "grad_norm": 0.16812713496720977,
1197
+ "learning_rate": 1.565656565656566e-05,
1198
+ "loss": 1.9278,
1199
+ "step": 170
1200
+ },
1201
+ {
1202
+ "epoch": 17.073322932917318,
1203
+ "grad_norm": 0.19243202935397094,
1204
+ "learning_rate": 1.5151515151515153e-05,
1205
+ "loss": 1.939,
1206
+ "step": 171
1207
+ },
1208
+ {
1209
+ "epoch": 17.173166926677066,
1210
+ "grad_norm": 0.16546322416204856,
1211
+ "learning_rate": 1.4646464646464647e-05,
1212
+ "loss": 1.9295,
1213
+ "step": 172
1214
+ },
1215
+ {
1216
+ "epoch": 17.273010920436818,
1217
+ "grad_norm": 0.19615095413628908,
1218
+ "learning_rate": 1.4141414141414141e-05,
1219
+ "loss": 1.9357,
1220
+ "step": 173
1221
+ },
1222
+ {
1223
+ "epoch": 17.37285491419657,
1224
+ "grad_norm": 0.16562858231287156,
1225
+ "learning_rate": 1.3636363636363637e-05,
1226
+ "loss": 1.9372,
1227
+ "step": 174
1228
+ },
1229
+ {
1230
+ "epoch": 17.472698907956318,
1231
+ "grad_norm": 0.1755423564949021,
1232
+ "learning_rate": 1.3131313131313134e-05,
1233
+ "loss": 1.9208,
1234
+ "step": 175
1235
+ },
1236
+ {
1237
+ "epoch": 17.57254290171607,
1238
+ "grad_norm": 0.16572591523274388,
1239
+ "learning_rate": 1.2626262626262628e-05,
1240
+ "loss": 1.9196,
1241
+ "step": 176
1242
+ },
1243
+ {
1244
+ "epoch": 17.672386895475817,
1245
+ "grad_norm": 0.16066050812369387,
1246
+ "learning_rate": 1.2121212121212122e-05,
1247
+ "loss": 1.9379,
1248
+ "step": 177
1249
+ },
1250
+ {
1251
+ "epoch": 17.77223088923557,
1252
+ "grad_norm": 0.18230307180057742,
1253
+ "learning_rate": 1.1616161616161616e-05,
1254
+ "loss": 1.9344,
1255
+ "step": 178
1256
+ },
1257
+ {
1258
+ "epoch": 17.87207488299532,
1259
+ "grad_norm": 0.16147840026521357,
1260
+ "learning_rate": 1.1111111111111112e-05,
1261
+ "loss": 1.9249,
1262
+ "step": 179
1263
+ },
1264
+ {
1265
+ "epoch": 17.97191887675507,
1266
+ "grad_norm": 0.17234298543336798,
1267
+ "learning_rate": 1.0606060606060607e-05,
1268
+ "loss": 1.9341,
1269
+ "step": 180
1270
+ },
1271
+ {
1272
+ "epoch": 18.07176287051482,
1273
+ "grad_norm": 0.16952419332241464,
1274
+ "learning_rate": 1.0101010101010101e-05,
1275
+ "loss": 1.9382,
1276
+ "step": 181
1277
+ },
1278
+ {
1279
+ "epoch": 18.171606864274573,
1280
+ "grad_norm": 0.17503197241676455,
1281
+ "learning_rate": 9.595959595959595e-06,
1282
+ "loss": 1.9277,
1283
+ "step": 182
1284
+ },
1285
+ {
1286
+ "epoch": 18.27145085803432,
1287
+ "grad_norm": 0.16018657280969506,
1288
+ "learning_rate": 9.090909090909091e-06,
1289
+ "loss": 1.9259,
1290
+ "step": 183
1291
+ },
1292
+ {
1293
+ "epoch": 18.371294851794072,
1294
+ "grad_norm": 0.16577134954028483,
1295
+ "learning_rate": 8.585858585858587e-06,
1296
+ "loss": 1.9391,
1297
+ "step": 184
1298
+ },
1299
+ {
1300
+ "epoch": 18.47113884555382,
1301
+ "grad_norm": 0.1758462044127833,
1302
+ "learning_rate": 8.080808080808082e-06,
1303
+ "loss": 1.9316,
1304
+ "step": 185
1305
+ },
1306
+ {
1307
+ "epoch": 18.570982839313572,
1308
+ "grad_norm": 0.16928715932805172,
1309
+ "learning_rate": 7.5757575757575764e-06,
1310
+ "loss": 1.9218,
1311
+ "step": 186
1312
+ },
1313
+ {
1314
+ "epoch": 18.670826833073324,
1315
+ "grad_norm": 0.16185874983512785,
1316
+ "learning_rate": 7.0707070707070704e-06,
1317
+ "loss": 1.9244,
1318
+ "step": 187
1319
+ },
1320
+ {
1321
+ "epoch": 18.770670826833072,
1322
+ "grad_norm": 0.16445906712178507,
1323
+ "learning_rate": 6.565656565656567e-06,
1324
+ "loss": 1.9425,
1325
+ "step": 188
1326
+ },
1327
+ {
1328
+ "epoch": 18.870514820592824,
1329
+ "grad_norm": 0.16313460189322437,
1330
+ "learning_rate": 6.060606060606061e-06,
1331
+ "loss": 1.9336,
1332
+ "step": 189
1333
+ },
1334
+ {
1335
+ "epoch": 18.970358814352576,
1336
+ "grad_norm": 0.15990081630753986,
1337
+ "learning_rate": 5.555555555555556e-06,
1338
+ "loss": 1.9178,
1339
+ "step": 190
1340
+ },
1341
+ {
1342
+ "epoch": 19.070202808112324,
1343
+ "grad_norm": 0.16547636636850527,
1344
+ "learning_rate": 5.050505050505051e-06,
1345
+ "loss": 1.9281,
1346
+ "step": 191
1347
+ },
1348
+ {
1349
+ "epoch": 19.170046801872076,
1350
+ "grad_norm": 0.1625270231867559,
1351
+ "learning_rate": 4.5454545454545455e-06,
1352
+ "loss": 1.9348,
1353
+ "step": 192
1354
+ },
1355
+ {
1356
+ "epoch": 19.269890795631824,
1357
+ "grad_norm": 0.16385675767663568,
1358
+ "learning_rate": 4.040404040404041e-06,
1359
+ "loss": 1.9305,
1360
+ "step": 193
1361
+ },
1362
+ {
1363
+ "epoch": 19.369734789391575,
1364
+ "grad_norm": 0.16718542619114216,
1365
+ "learning_rate": 3.5353535353535352e-06,
1366
+ "loss": 1.9376,
1367
+ "step": 194
1368
+ },
1369
+ {
1370
+ "epoch": 19.469578783151327,
1371
+ "grad_norm": 0.16595125072244407,
1372
+ "learning_rate": 3.0303030303030305e-06,
1373
+ "loss": 1.9264,
1374
+ "step": 195
1375
+ },
1376
+ {
1377
+ "epoch": 19.569422776911075,
1378
+ "grad_norm": 0.16912445317015737,
1379
+ "learning_rate": 2.5252525252525253e-06,
1380
+ "loss": 1.9252,
1381
+ "step": 196
1382
+ },
1383
+ {
1384
+ "epoch": 19.669266770670827,
1385
+ "grad_norm": 0.15257787442711698,
1386
+ "learning_rate": 2.0202020202020206e-06,
1387
+ "loss": 1.9312,
1388
+ "step": 197
1389
+ },
1390
+ {
1391
+ "epoch": 19.76911076443058,
1392
+ "grad_norm": 0.17270934725449602,
1393
+ "learning_rate": 1.5151515151515152e-06,
1394
+ "loss": 1.9241,
1395
+ "step": 198
1396
+ },
1397
+ {
1398
+ "epoch": 19.868954758190327,
1399
+ "grad_norm": 0.16771403116909167,
1400
+ "learning_rate": 1.0101010101010103e-06,
1401
+ "loss": 1.9342,
1402
+ "step": 199
1403
+ },
1404
+ {
1405
+ "epoch": 19.96879875195008,
1406
+ "grad_norm": 0.17132458008674775,
1407
+ "learning_rate": 5.050505050505052e-07,
1408
+ "loss": 1.9347,
1409
+ "step": 200
1410
+ }
1411
+ ],
1412
+ "logging_steps": 1.0,
1413
+ "max_steps": 200,
1414
+ "num_input_tokens_seen": 0,
1415
+ "num_train_epochs": 20,
1416
+ "save_steps": 1000,
1417
+ "stateful_callbacks": {
1418
+ "TrainerControl": {
1419
+ "args": {
1420
+ "should_epoch_stop": false,
1421
+ "should_evaluate": false,
1422
+ "should_log": false,
1423
+ "should_save": true,
1424
+ "should_training_stop": true
1425
+ },
1426
+ "attributes": {}
1427
+ }
1428
+ },
1429
+ "total_flos": 0.0,
1430
+ "train_batch_size": 64,
1431
+ "trial_name": null,
1432
+ "trial_params": null
1433
+ }
checkpoint-200/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e73c8433f9bd77a861884878bc3440468454cf7de3a2ce588dd20f479c1091e1
3
+ size 6840
config.json CHANGED
@@ -1,17 +1,17 @@
1
  {
2
- "_name_or_path": "gte-korean-reranker-base-241210-c",
3
  "architectures": [
4
  "NewForSequenceClassification"
5
  ],
6
  "attention_probs_dropout_prob": 0.0,
7
  "auto_map": {
8
  "AutoConfig": "configuration.NewConfig",
9
- "AutoModel": "new-impl--modeling.NewModel",
10
- "AutoModelForMaskedLM": "new-impl--modeling.NewForMaskedLM",
11
- "AutoModelForMultipleChoice": "new-impl--modeling.NewForMultipleChoice",
12
- "AutoModelForQuestionAnswering": "new-impl--modeling.NewForQuestionAnswering",
13
  "AutoModelForSequenceClassification": "modeling.NewForSequenceClassification",
14
- "AutoModelForTokenClassification": "new-impl--modeling.NewForTokenClassification"
15
  },
16
  "classifier_dropout": 0.0,
17
  "hidden_act": "gelu",
 
1
  {
2
+ "_name_or_path": "/workspace/sigrid/kozistr-ko-sentence-embeddings/reranker-data/gte-korean-reranker-base-241210-c",
3
  "architectures": [
4
  "NewForSequenceClassification"
5
  ],
6
  "attention_probs_dropout_prob": 0.0,
7
  "auto_map": {
8
  "AutoConfig": "configuration.NewConfig",
9
+ "AutoModel": "Alibaba-NLP/new-impl--modeling.NewModel",
10
+ "AutoModelForMaskedLM": "Alibaba-NLP/new-impl--modeling.NewForMaskedLM",
11
+ "AutoModelForMultipleChoice": "Alibaba-NLP/new-impl--modeling.NewForMultipleChoice",
12
+ "AutoModelForQuestionAnswering": "Alibaba-NLP/new-impl--modeling.NewForQuestionAnswering",
13
  "AutoModelForSequenceClassification": "modeling.NewForSequenceClassification",
14
+ "AutoModelForTokenClassification": "Alibaba-NLP/new-impl--modeling.NewForTokenClassification"
15
  },
16
  "classifier_dropout": 0.0,
17
  "hidden_act": "gelu",
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2b9dcb302bd0fdf62bcbd0a594656e39ef4e0343906372f7613d148b5bad281d
3
  size 611934706
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efccdc583e449d5ad7b1ab1beda845d50c68d48c9acd72a1eef1ece00b1ac8b1
3
  size 611934706
training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c211c72915ee94f355409ecc5f90293478e9296d3d5f570648edaa36903e639a
3
  size 6840
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e73c8433f9bd77a861884878bc3440468454cf7de3a2ce588dd20f479c1091e1
3
  size 6840