dwzhu commited on
Commit
8f238f4
·
verified ·
1 Parent(s): c33a9ec

Upload folder using huggingface_hub

Browse files
added_tokens.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "[CLS]": 101,
3
+ "[MASK]": 103,
4
+ "[PAD]": 0,
5
+ "[SEP]": 102,
6
+ "[UNK]": 100
7
+ }
config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/home/v-daweizhu/teamdrive/longembed/models/nomic-ai/nomic-bert-2048",
3
+ "activation_function": "swiglu",
4
+ "architectures": [
5
+ "NomicBertModel"
6
+ ],
7
+ "attn_pdrop": 0.0,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_hf_nomic_bert.NomicBertConfig",
10
+ "AutoModel": "modeling_hf_nomic_bert.NomicBertModel",
11
+ "AutoModelForMaskedLM": "modeling_hf_nomic_bert.NomicBertForPreTraining",
12
+ "AutoModelForSequenceClassification": "modeling_hf_nomic_bert.NomicBertForSequenceClassification"
13
+ },
14
+ "bos_token_id": null,
15
+ "causal": false,
16
+ "dense_seq_output": true,
17
+ "embd_pdrop": 0.1,
18
+ "eos_token_id": null,
19
+ "fused_bias_fc": true,
20
+ "fused_dropout_add_ln": true,
21
+ "initializer_range": 0.02,
22
+ "layer_norm_epsilon": 1e-12,
23
+ "mlp_fc1_bias": false,
24
+ "mlp_fc2_bias": false,
25
+ "model_type": "nomic_bert",
26
+ "n_embd": 768,
27
+ "n_head": 12,
28
+ "n_inner": 3072,
29
+ "n_layer": 12,
30
+ "n_positions": 2048,
31
+ "pad_vocab_size_multiple": 64,
32
+ "parallel_block": false,
33
+ "parallel_block_tied_norm": false,
34
+ "prenorm": false,
35
+ "qkv_proj_bias": false,
36
+ "reorder_and_upcast_attn": false,
37
+ "resid_pdrop": 0.1,
38
+ "rotary_emb_base": 1000,
39
+ "rotary_emb_fraction": 1.0,
40
+ "rotary_emb_interleaved": false,
41
+ "rotary_emb_scale_base": null,
42
+ "rotary_scaling_factor": null,
43
+ "scale_attn_by_inverse_layer_idx": false,
44
+ "scale_attn_weights": true,
45
+ "summary_activation": null,
46
+ "summary_first_dropout": 0.1,
47
+ "summary_proj_to_labels": true,
48
+ "summary_type": "cls_index",
49
+ "summary_use_proj": true,
50
+ "torch_dtype": "float32",
51
+ "transformers_version": "4.34.0",
52
+ "type_vocab_size": 2,
53
+ "use_cache": true,
54
+ "use_flash_attn": true,
55
+ "use_rms_norm": false,
56
+ "use_xentropy": true,
57
+ "vocab_size": 30528
58
+ }
configuration_hf_nomic_bert.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2Config
2
+
3
+
4
+ class NomicBertConfig(GPT2Config):
5
+ model_type = "nomic_bert"
6
+
7
+ def __init__(self,
8
+ prenorm=False,
9
+ parallel_block=False,
10
+ parallel_block_tied_norm=False,
11
+ rotary_emb_fraction=0.0,
12
+ fused_dropout_add_ln=False,
13
+ fused_bias_fc=False,
14
+ use_flash_attn=False,
15
+ use_xentropy=False,
16
+ qkv_proj_bias=True,
17
+ rotary_emb_base=1000,
18
+ rotary_emb_scale_base=None,
19
+ rotary_emb_interleaved=False,
20
+ mlp_fc1_bias=True,
21
+ mlp_fc2_bias=True,
22
+ use_rms_norm=False,
23
+ causal=False,
24
+ type_vocab_size=2,
25
+ dense_seq_output=True,
26
+ pad_vocab_size_multiple=1,
27
+ tie_word_embeddings=True,
28
+ **kwargs,
29
+ ):
30
+ self.prenorm = prenorm
31
+ self.parallel_block = parallel_block
32
+ self.parallel_block_tied_norm = parallel_block_tied_norm
33
+ self.rotary_emb_fraction = rotary_emb_fraction
34
+ self.tie_word_embeddings = tie_word_embeddings
35
+ self.fused_dropout_add_ln = fused_dropout_add_ln
36
+ self.fused_bias_fc = fused_bias_fc
37
+ self.use_flash_attn = use_flash_attn
38
+ self.use_xentropy = use_xentropy
39
+ self.qkv_proj_bias = qkv_proj_bias
40
+ self.rotary_emb_base = rotary_emb_base
41
+ self.rotary_emb_scale_base = rotary_emb_scale_base
42
+ self.rotary_emb_interleaved = rotary_emb_interleaved
43
+ self.mlp_fc1_bias = mlp_fc1_bias
44
+ self.mlp_fc2_bias = mlp_fc2_bias
45
+ self.use_rms_norm = use_rms_norm
46
+ self.causal = causal
47
+ self.type_vocab_size = type_vocab_size
48
+ self.dense_seq_output = dense_seq_output
49
+ self.pad_vocab_size_multiple = pad_vocab_size_multiple
50
+
51
+ super().__init__(**kwargs)
modeling_hf_nomic_bert.py ADDED
@@ -0,0 +1,1280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Tri Dao.
2
+ # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
3
+ # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
4
+ # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
5
+
6
+ import logging
7
+
8
+ # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
9
+ import os
10
+ import re
11
+ from collections import OrderedDict
12
+ from functools import partial
13
+ from typing import List, Optional, Tuple, Union
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from einops import rearrange, repeat
19
+ from safetensors.torch import load_file as safe_load_file
20
+ from transformers import GPT2Config, PreTrainedModel
21
+ from transformers.models.bert.modeling_bert import (
22
+ BaseModelOutputWithPoolingAndCrossAttentions,
23
+ MaskedLMOutput,
24
+ SequenceClassifierOutput,
25
+ )
26
+ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
27
+ from transformers.utils.hub import cached_file, get_checkpoint_shard_files
28
+
29
+ from .configuration_hf_nomic_bert import NomicBertConfig
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ # adapted from flash attention, added safe serialization option for hf models
35
+ def state_dict_from_pretrained(model_name, safe_serialization=False, device=None, dtype=None):
36
+ # If not fp32, then we don't want to load directly to the GPU
37
+ mapped_device = "cpu" if dtype not in [torch.float32, None] else device
38
+ is_sharded = False
39
+ load_safe = False
40
+ resolved_archive_file = None
41
+
42
+ weights_path = os.path.join(model_name, WEIGHTS_NAME)
43
+ weights_index_path = os.path.join(model_name, WEIGHTS_INDEX_NAME)
44
+ safe_weights_path = os.path.join(model_name, SAFE_WEIGHTS_NAME)
45
+ safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME)
46
+
47
+ if os.path.isfile(weights_path):
48
+ resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
49
+ elif os.path.isfile(weights_index_path):
50
+ resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False)
51
+ is_sharded = True
52
+ elif os.path.isfile(safe_weights_path):
53
+ resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
54
+ load_safe = True
55
+ elif os.path.isfile(safe_weights_index_path):
56
+ resolved_archive_file = cached_file(
57
+ model_name, SAFE_WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False
58
+ )
59
+ is_sharded = True
60
+ load_safe = True
61
+ else: # Try loading from HF hub instead of from local files
62
+ weight_name = WEIGHTS_NAME if not safe_serialization else SAFE_WEIGHTS_NAME
63
+ resolved_archive_file = cached_file(model_name, weight_name, _raise_exceptions_for_missing_entries=False)
64
+ if resolved_archive_file is None:
65
+ weight_index = WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME
66
+ resolved_archive_file = cached_file(model_name, weight_index, _raise_exceptions_for_missing_entries=False)
67
+ if resolved_archive_file is not None:
68
+ is_sharded = True
69
+
70
+ load_safe = safe_serialization
71
+
72
+ if resolved_archive_file is None:
73
+ raise EnvironmentError(f"Model name {model_name} was not found.")
74
+
75
+ if load_safe:
76
+ loader = partial(safe_load_file, device=mapped_device)
77
+ else:
78
+ loader = partial(torch.load, map_location=mapped_device)
79
+
80
+ if is_sharded:
81
+ # resolved_archive_file becomes a list of files that point to the different
82
+ # checkpoint shards in this case.
83
+ resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(model_name, resolved_archive_file)
84
+ state_dict = {}
85
+ for sharded_file in resolved_archive_file:
86
+ state_dict.update(loader(sharded_file))
87
+ else:
88
+ state_dict = loader(resolved_archive_file)
89
+ # Convert dtype before moving to GPU to save memory
90
+ if dtype is not None:
91
+ state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
92
+ state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
93
+ return state_dict
94
+
95
+
96
+ def filter_shapes(state_dict, model):
97
+ """
98
+ Filters the state dict to match the current model shape.
99
+ """
100
+ filtered_state_dict = {}
101
+ for key, value in state_dict.items():
102
+ if key in model.state_dict():
103
+ if value.shape == model.state_dict()[key].shape:
104
+ filtered_state_dict[key] = value
105
+ return filtered_state_dict
106
+
107
+
108
+ def remap_bert_state_dict(
109
+ state_dict,
110
+ config,
111
+ remove_bert=False,
112
+ remove_cls_weights=False,
113
+ add_pooling_layer=False,
114
+ ):
115
+ """
116
+ Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
117
+ """
118
+
119
+ def add_bert_prefix(key):
120
+ # prepend bert. to the key
121
+ if key.startswith("bert.") or key.startswith("cls."):
122
+ return key
123
+ return f"bert.{key}"
124
+
125
+ state_dict = OrderedDict((add_bert_prefix(k), v) for k, v in state_dict.items())
126
+
127
+ # LayerNorm
128
+ def key_mapping_ln_gamma_beta(key):
129
+ key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
130
+ key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
131
+ return key
132
+
133
+ state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
134
+
135
+ # Layers
136
+ def key_mapping_layers(key):
137
+ return re.sub(r"^bert.encoder.layer\.", "bert.encoder.layers.", key)
138
+
139
+ state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
140
+
141
+ # LayerNorm
142
+ def key_mapping_ln(key):
143
+ key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key)
144
+ key = re.sub(
145
+ r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
146
+ r"bert.encoder.layers.\1.norm1.\2",
147
+ key,
148
+ )
149
+ key = re.sub(
150
+ r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
151
+ r"bert.encoder.layers.\1.norm2.\2",
152
+ key,
153
+ )
154
+ key = re.sub(
155
+ r"^cls.predictions.transform.LayerNorm.(weight|bias)",
156
+ r"cls.predictions.transform.layer_norm.\1",
157
+ key,
158
+ )
159
+ return key
160
+
161
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
162
+
163
+ # MLP
164
+ def key_mapping_mlp(key):
165
+ key = re.sub(
166
+ r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
167
+ r"bert.encoder.layers.\1.mlp.fc1.\2",
168
+ key,
169
+ )
170
+ key = re.sub(
171
+ r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)",
172
+ r"bert.encoder.layers.\1.mlp.fc2.\2",
173
+ key,
174
+ )
175
+ return key
176
+
177
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
178
+
179
+ # Attention
180
+ last_layer_subset = getattr(config, "last_layer_subset", False)
181
+ for d in range(config.num_hidden_layers):
182
+ if f"bert.encoder.layers.{d}.attention.self.query.weight" not in state_dict:
183
+ continue
184
+ Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight")
185
+ Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight")
186
+ Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight")
187
+ bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias")
188
+ bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
189
+ bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
190
+ if not (last_layer_subset and d == config.num_hidden_layers - 1):
191
+ state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
192
+ state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
193
+ else:
194
+ state_dict[f"bert.encoder.layers.{d}.attn.Wq.weight"] = Wq
195
+ state_dict[f"bert.encoder.layers.{d}.attn.Wkv.weight"] = torch.cat([Wk, Wv], dim=0)
196
+ state_dict[f"bert.encoder.layers.{d}.attn.Wq.bias"] = bq
197
+ state_dict[f"bert.encoder.layers.{d}.attn.Wkv.bias"] = torch.cat([bk, bv], dim=0)
198
+
199
+ def key_mapping_attn(key):
200
+ return re.sub(
201
+ r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
202
+ r"bert.encoder.layers.\1.attn.out_proj.\2",
203
+ key,
204
+ )
205
+
206
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
207
+
208
+ def key_mapping_decoder_bias(key):
209
+ return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
210
+
211
+ # remove nsp weights, we don't use
212
+ state_dict.pop("cls.seq_relationship.weight", None)
213
+ state_dict.pop("cls.seq_relationship.bias", None)
214
+ state_dict.pop("bert.embeddings.position_ids", None)
215
+
216
+ state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
217
+
218
+ if remove_cls_weights:
219
+ cls_weights = [
220
+ "cls.predictions.decoder.bias",
221
+ "cls.predictions.transform.dense.weight",
222
+ "cls.predictions.transform.dense.bias",
223
+ "cls.predictions.transform.layer_norm.weight",
224
+ "cls.predictions.transform.layer_norm.bias",
225
+ "cls.predictions.decoder.weight",
226
+ ]
227
+ for weight in cls_weights:
228
+ state_dict.pop(weight, None)
229
+
230
+ # Word embedding
231
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
232
+ if pad_vocab_size_multiple > 1:
233
+ word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
234
+ state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
235
+ word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
236
+ )
237
+ if not remove_cls_weights:
238
+ decoder_weight = state_dict["cls.predictions.decoder.weight"]
239
+ state_dict["cls.predictions.decoder.weight"] = F.pad(
240
+ decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
241
+ )
242
+ # If the vocab was padded, we want to set the decoder bias for those padded indices to be
243
+ # strongly negative (i.e. the decoder shouldn't predict those indices).
244
+ # TD [2022-05-09]: I don't think it affects the MLPerf training.
245
+ if "cls.predictions.decoder.bias" in state_dict:
246
+ decoder_bias = state_dict["cls.predictions.decoder.bias"]
247
+ state_dict["cls.predictions.decoder.bias"] = F.pad(
248
+ decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
249
+ )
250
+
251
+ if add_pooling_layer is False:
252
+ pooler_weights = [
253
+ "bert.pooler.dense.weight",
254
+ "bert.pooler.dense.bias",
255
+ ]
256
+ for key in pooler_weights:
257
+ state_dict.pop(key, None)
258
+
259
+ if remove_bert:
260
+
261
+ def remove_bert_prefix(key):
262
+ key = re.sub(r"^bert.", "", key)
263
+ return key
264
+
265
+ state_dict = OrderedDict((remove_bert_prefix(k), v) for k, v in state_dict.items())
266
+
267
+ return state_dict
268
+
269
+
270
+ class NomicBertPreTrainedModel(PreTrainedModel):
271
+ """An abstract class to handle weights initialization and
272
+ a simple interface for dowloading and loading pretrained models.
273
+ """
274
+
275
+ config_class = NomicBertConfig
276
+ base_model_prefix = "model"
277
+ supports_gradient_checkpointing = True
278
+ _no_split_modules = ["Block"]
279
+ _skip_keys_device_placement = "past_key_values"
280
+
281
+ def __init__(self, config, *inputs, **kwargs):
282
+ super().__init__(config)
283
+ if not isinstance(config, GPT2Config):
284
+ raise ValueError(
285
+ "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
286
+ "To create a model from a Google pretrained model use "
287
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
288
+ self.__class__.__name__, self.__class__.__name__
289
+ )
290
+ )
291
+ self.config = config
292
+
293
+ @classmethod
294
+ def from_pretrained(cls, model_name, config=None, *inputs, **kwargs):
295
+ """
296
+ Instantiate a NomicBertPreTrainedModel from a pre-trained model file or a pytorch state dict.
297
+ Download and cache the pre-trained model file if needed.
298
+
299
+ Params:
300
+ pretrained_model_name_or_path: either:
301
+ - a path or url to a pretrained model archive containing:
302
+ . `bert_config.json` a configuration file for the model
303
+ . `pytorch_model.bin` a PyTorch dump of a NomicBertForPretraining instance
304
+ - a path or url to a pretrained model archive containing:
305
+ . `bert_config.json` a configuration file for the model
306
+ . `model.chkpt` a TensorFlow checkpoint
307
+ *inputs, **kwargs: additional input for the specific NomicBert class
308
+ (ex: num_labels for NomicBertForSequenceClassification)
309
+ """
310
+ # Instantiate model.
311
+ if config is None:
312
+ config = cls.config_class.from_pretrained(model_name)
313
+ remove_cls = cls != NomicBertForPreTraining
314
+ remove_bert_prefix = cls != NomicBertForPreTraining and cls != NomicBertForSequenceClassification
315
+ ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
316
+ num_labels = kwargs.pop("num_labels", None)
317
+ rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
318
+ strict = kwargs.pop("strict", True)
319
+ config.rotary_scaling_factor = rotary_scaling_factor
320
+ if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
321
+ config.n_positions = 2048
322
+ if num_labels:
323
+ config.num_labels = num_labels
324
+
325
+ if "add_pooling_layer" in kwargs:
326
+ model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
327
+ else:
328
+ model = cls(config, *inputs)
329
+ # TODO: fix this
330
+ # Assuming we know what we're doing when loading from disk
331
+ # Prob a bad assumption but i'm tired and want to train this asap
332
+ if os.path.exists(model_name):
333
+ model_path = f"{model_name}/pytorch_model.bin"
334
+ if os.path.exists(model_path):
335
+ state_dict = torch.load(f"{model_name}/pytorch_model.bin")
336
+ else:
337
+ model_path = f"{model_name}/model.safetensors"
338
+ if not os.path.exists(model_path):
339
+ raise ValueError(f"Model path {model_path} not found")
340
+ state_dict = safe_load_file(model_path)
341
+
342
+ if ignore_mismatched_shapes:
343
+ state_dict = filter_shapes(state_dict, model)
344
+ load_return = model.load_state_dict(state_dict, strict=False)
345
+ else:
346
+ # TODO: can probably check config class and see if we need to remap from a bert model
347
+ state_dict = state_dict_from_pretrained(model_name)
348
+ state_dict = remap_bert_state_dict(
349
+ state_dict,
350
+ config,
351
+ remove_bert=remove_bert_prefix,
352
+ remove_cls_weights=remove_cls,
353
+ add_pooling_layer=getattr(config, "add_pooling_layer", False),
354
+ )
355
+ if ignore_mismatched_shapes:
356
+ state_dict = filter_shapes(state_dict, model)
357
+
358
+ load_return = model.load_state_dict(state_dict, strict=strict)
359
+ logger.warning(load_return)
360
+ return model
361
+
362
+ def _set_gradient_checkpointing(self, module, value=False):
363
+ if isinstance(module, NomicBertEncoder):
364
+ module.gradient_checkpointing = value
365
+
366
+
367
+ # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
368
+ def _init_weights(module, initializer_range=0.02):
369
+ if isinstance(module, nn.Linear):
370
+ nn.init.normal_(module.weight, std=initializer_range)
371
+ if module.bias is not None:
372
+ nn.init.zeros_(module.bias)
373
+ elif isinstance(module, nn.Embedding):
374
+ nn.init.normal_(module.weight, std=initializer_range)
375
+ if module.padding_idx is not None:
376
+ nn.init.zeros_(module.weight[module.padding_idx])
377
+
378
+
379
+ class NomicBertEmbeddings(nn.Module):
380
+ def __init__(self, config):
381
+ """
382
+ If max_position_embeddings <= 0, there's no position embeddings
383
+ If type_vocab_size <= 0, there's no token type embeddings
384
+ """
385
+ super().__init__()
386
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
387
+ self.max_position_embeddings = config.max_position_embeddings if config.rotary_emb_fraction <= 0 else 0
388
+ self.type_vocab_size = config.type_vocab_size
389
+ if self.max_position_embeddings > 0 and config.rotary_emb_fraction <= 0:
390
+ self.position_embeddings = nn.Embedding(
391
+ config.max_position_embeddings,
392
+ config.hidden_size,
393
+ )
394
+ if self.type_vocab_size > 0:
395
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
396
+
397
+ def forward(self, input_ids, position_ids=None, token_type_ids=None):
398
+ """
399
+ input_ids: (batch, seqlen)
400
+ position_ids: (batch, seqlen)
401
+ token_type_ids: (batch, seqlen)
402
+ """
403
+ batch_size, seqlen = input_ids.shape
404
+ embeddings = self.word_embeddings(input_ids)
405
+
406
+ if self.type_vocab_size > 0:
407
+ if token_type_ids is None:
408
+ token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
409
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
410
+ embeddings = embeddings + token_type_embeddings
411
+
412
+ if self.max_position_embeddings > 0:
413
+ if position_ids is None:
414
+ position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
415
+ position_embeddings = self.position_embeddings(position_ids)
416
+ embeddings = embeddings + position_embeddings
417
+ return embeddings
418
+
419
+
420
+ class NomicBertMLP(nn.Module):
421
+ def __init__(
422
+ self,
423
+ in_features,
424
+ hidden_features=None,
425
+ out_features=None,
426
+ activation=F.gelu,
427
+ bias1=True,
428
+ bias2=True,
429
+ return_residual=False,
430
+ fused_bias_fc=False,
431
+ ):
432
+ super().__init__()
433
+ out_features = out_features if out_features is not None else in_features
434
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
435
+ self.return_residual = return_residual
436
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1)
437
+ approximate = "tanh" if activation in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none"
438
+ self.activation = nn.GELU(approximate=approximate) if activation == "gelu" else activation
439
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
440
+
441
+ def forward(self, x):
442
+ y = self.fc1(x)
443
+ y = self.activation(y)
444
+ y = self.fc2(y)
445
+ return y if not self.return_residual else (y, x)
446
+
447
+
448
+ class NomciBertGatedMLP(nn.Module):
449
+ def __init__(
450
+ self,
451
+ in_features,
452
+ hidden_features=None,
453
+ out_features=None,
454
+ activation=F.sigmoid,
455
+ bias1=True,
456
+ bias2=True,
457
+ multiple_of=256,
458
+ return_residual=False,
459
+ fused_bias_fc=True,
460
+ device=None,
461
+ dtype=None,
462
+ ):
463
+ super().__init__()
464
+ out_features = out_features if out_features is not None else in_features
465
+ hidden_features = hidden_features if hidden_features is not None else int(8 * in_features / 3)
466
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
467
+ self.return_residual = return_residual
468
+
469
+ self.fc11 = nn.Linear(in_features, hidden_features, bias=bias1)
470
+ self.fc12 = nn.Linear(in_features, hidden_features, bias=bias1)
471
+ self.activation = activation
472
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
473
+
474
+ def forward(self, x):
475
+ y = self.fc11(x)
476
+ gate = self.fc12(x)
477
+ if self.activation == F.sigmoid: # Special case for GLU
478
+ y = F.glu(torch.cat([y, gate], dim=-1), dim=-1)
479
+ else:
480
+ y = y * self.activation(gate)
481
+ y = self.fc2(y)
482
+ return y if not self.return_residual else (y, x)
483
+
484
+
485
+ def rotate_half(x, interleaved=False):
486
+ if not interleaved:
487
+ x1, x2 = x.chunk(2, dim=-1)
488
+ return torch.cat((-x2, x1), dim=-1)
489
+ else:
490
+ x1, x2 = x[..., ::2], x[..., 1::2]
491
+ return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
492
+
493
+
494
+ def apply_rotary_emb(x, cos, sin, offset=0, interleaved=False):
495
+ """
496
+ x: (batch_size, seqlen, nheads, headdim)
497
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
498
+ """
499
+ ro_dim = cos.shape[-1] * 2
500
+ assert ro_dim <= x.shape[-1]
501
+ cos, sin = (
502
+ cos[offset : offset + x.shape[1]],
503
+ sin[offset : offset + x.shape[1]],
504
+ )
505
+ cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
506
+ sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
507
+ return torch.cat(
508
+ [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
509
+ dim=-1,
510
+ )
511
+
512
+
513
+ class NomicBertRotaryEmbedding(nn.Module):
514
+ def __init__(
515
+ self,
516
+ dim: int,
517
+ base=10000.0,
518
+ interleaved=False,
519
+ scale_base=None,
520
+ pos_idx_in_fp32=True,
521
+ device=None,
522
+ ):
523
+ """
524
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
525
+ of 1st half and 2nd half (GPT-NeoX style).
526
+ pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
527
+ otherwise they might be in lower precision.
528
+ This option was added because previously (before 2023-07-02), when we construct
529
+ the position indices, we use the dtype of self.inv_freq. In most cases this would
530
+ be fp32, but if the model is trained in pure bf16 (not mixed precision), then
531
+ self.inv_freq would be bf16, and the position indices are also in bf16.
532
+ Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
533
+ embeddings for some positions will coincide.
534
+ To maintain compatibility with models previously trained in pure bf16,
535
+ we add this option.
536
+ """
537
+ super().__init__()
538
+ self.dim = dim
539
+ self.base = float(base)
540
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
541
+ # Generate and save the inverse frequency buffer (non trainable)
542
+ inv_freq = self._compute_inv_freq(device)
543
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
544
+ self.interleaved = interleaved
545
+ self.scale_base = scale_base
546
+ scale = (
547
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
548
+ if scale_base is not None
549
+ else None
550
+ )
551
+ self.register_buffer("scale", scale, persistent=False)
552
+
553
+ self._seq_len_cached = 0
554
+ self._cos_cached = None
555
+ self._sin_cached = None
556
+ self._cos_k_cached = None
557
+ self._sin_k_cached = None
558
+
559
+ def _compute_inv_freq(self, device=None):
560
+ return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
561
+
562
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
563
+ # Reset the tables if the sequence length has changed,
564
+ # if we're on a new device (possibly due to tracing for instance),
565
+ # or if we're switching from inference mode to training
566
+ if (
567
+ seqlen > self._seq_len_cached
568
+ or self._cos_cached is None
569
+ or self._cos_cached.device != device
570
+ or self._cos_cached.dtype != dtype
571
+ or (self.training and self._cos_cached.is_inference())
572
+ ):
573
+ self._seq_len_cached = seqlen
574
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
575
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
576
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
577
+ if self.pos_idx_in_fp32:
578
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
579
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
580
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
581
+ # cos & sin output to change significantly.
582
+ # We want to recompute self.inv_freq if it was not loaded in fp32
583
+ if self.inv_freq.dtype != torch.float32:
584
+ inv_freq = self._compute_inv_freq(device=device)
585
+ else:
586
+ inv_freq = self.inv_freq
587
+ else:
588
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
589
+ inv_freq = self.inv_freq
590
+ # Don't do einsum, it converts fp32 to fp16 under AMP
591
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
592
+ freqs = torch.outer(t, inv_freq)
593
+ self._cos_cached = torch.cos(freqs).to(dtype)
594
+ self._sin_cached = torch.sin(freqs).to(dtype)
595
+
596
+ def forward(
597
+ self,
598
+ qkv: torch.Tensor,
599
+ kv: Optional[torch.Tensor] = None,
600
+ seqlen_offset: Union[int, torch.Tensor] = 0,
601
+ max_seqlen: Optional[int] = None,
602
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
603
+ """
604
+ qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
605
+ else it's just q of shape (batch, seqlen, nheads, headdim)
606
+ kv: (batch, seqlen, 2, nheads, headdim)
607
+ seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
608
+ Most commonly used in inference when we have KV cache.
609
+ If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
610
+ should pass in max_seqlen, which will update the cos / sin cache up to that length.
611
+ Apply rotary embedding *inplace* to qkv and / or kv.
612
+ """
613
+ seqlen = qkv.shape[1]
614
+ if seqlen > self._seq_len_cached:
615
+ self._update_cos_sin_cache(seqlen, device=qkv.device, dtype=qkv.dtype)
616
+ elif max_seqlen is not None:
617
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
618
+ elif isinstance(seqlen_offset, int):
619
+ self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
620
+
621
+ q_rot = apply_rotary_emb(qkv[:, :, 0], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved)
622
+ k_rot = apply_rotary_emb(qkv[:, :, 1], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved)
623
+ return torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
624
+
625
+
626
+ class NomicBertDynamicNTKRotaryEmbedding(NomicBertRotaryEmbedding):
627
+ def __init__(self, rotary_scaling_factor, max_position_embeddings, **kwargs):
628
+ super().__init__(**kwargs)
629
+ self.rotary_scaling_factor = rotary_scaling_factor
630
+ self.max_position_embeddings = max_position_embeddings
631
+
632
+ def _compute_inv_freq(self, base=None, device=None):
633
+ if base is None:
634
+ base = self.base
635
+ return 1.0 / (base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
636
+
637
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
638
+ # Reset the tables if the sequence length has changed,
639
+ # if we're on a new device (possibly due to tracing for instance),
640
+ # or if we're switching from inference mode to training
641
+ if seqlen > self.max_position_embeddings:
642
+ base = self.base * (
643
+ (self.rotary_scaling_factor * seqlen / self.max_position_embeddings) - (self.rotary_scaling_factor - 1)
644
+ ) ** (self.dim / (self.dim - 2))
645
+ inv_freq = self._compute_inv_freq(base=base, device=device)
646
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
647
+
648
+ if (
649
+ seqlen > self._seq_len_cached
650
+ or self._cos_cached is None
651
+ or self._cos_cached.device != device
652
+ or self._cos_cached.dtype != dtype
653
+ or (self.training and self._cos_cached.is_inference())
654
+ ):
655
+ self._seq_len_cached = seqlen
656
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
657
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
658
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
659
+ if self.pos_idx_in_fp32:
660
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
661
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
662
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
663
+ # cos & sin output to change significantly.
664
+ # We want to recompute self.inv_freq if it was not loaded in fp32
665
+ if self.inv_freq.dtype != torch.float32:
666
+ if seqlen > self.max_position_embeddings:
667
+ base = self.base * (
668
+ (self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1)
669
+ ) ** (self.dim / (self.dim - 2))
670
+ else:
671
+ base = self.base
672
+ inv_freq = self._compute_inv_freq(device=device, base=base)
673
+ else:
674
+ inv_freq = self.inv_freq
675
+ else:
676
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
677
+ inv_freq = self.inv_freq
678
+ # Don't do einsum, it converts fp32 to fp16 under AMP
679
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
680
+ freqs = torch.outer(t, inv_freq)
681
+ if self.scale is None:
682
+ self._cos_cached = torch.cos(freqs).to(dtype)
683
+ self._sin_cached = torch.sin(freqs).to(dtype)
684
+ else:
685
+ power = (
686
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
687
+ ) / self.scale_base
688
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
689
+ # We want the multiplication by scale to happen in fp32
690
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
691
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
692
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
693
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
694
+
695
+
696
+ class NomicBertAttention(nn.Module):
697
+ """Multi-head self-attention and cross-attention"""
698
+
699
+ def __init__(
700
+ self,
701
+ config,
702
+ ) -> None:
703
+ """
704
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
705
+ return_residual: whether to return the input x along with the output. This is for
706
+ performance reason: for post-norm architecture, returning the input allows us
707
+ to fuse the backward of nn.Linear with the residual connection.
708
+ """
709
+ super().__init__()
710
+ self.embed_dim = config.n_embd
711
+ self.use_flash_attn = config.use_flash_attn
712
+ self.fused_bias_fc = config.fused_bias_fc
713
+
714
+ self.num_heads = config.n_head
715
+ self.num_heads_kv = config.num_heads_kv if getattr(config, "num_heads_kv", None) is not None else self.num_heads
716
+ assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
717
+ self.head_dim = self.embed_dim // self.num_heads
718
+ # we don't really support mqa / gqa for now
719
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
720
+
721
+ self.register_buffer(
722
+ "norm_factor",
723
+ torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
724
+ persistent=False,
725
+ )
726
+
727
+ self.rotary_emb_dim = self.head_dim * config.rotary_emb_fraction
728
+ if self.rotary_emb_dim > 0:
729
+ if getattr(config, "rotary_scaling_factor", None):
730
+ self.rotary_emb = NomicBertDynamicNTKRotaryEmbedding(
731
+ dim=self.rotary_emb_dim,
732
+ base=config.rotary_emb_base,
733
+ scale_base=config.rotary_emb_scale_base,
734
+ interleaved=config.rotary_emb_interleaved,
735
+ rotary_scaling_factor=config.rotary_scaling_factor,
736
+ max_position_embeddings=config.max_trained_positions,
737
+ )
738
+ else:
739
+ self.rotary_emb = NomicBertRotaryEmbedding(
740
+ dim=self.rotary_emb_dim,
741
+ base=config.rotary_emb_base,
742
+ scale_base=config.rotary_emb_scale_base,
743
+ interleaved=config.rotary_emb_interleaved,
744
+ )
745
+ # bug in xformers: https://github.com/facebookresearch/xformers/issues/841
746
+ # uses the head dimension instead of the sequence dimension
747
+ self.rotary_head_dim = getattr(config, "rotary_head_dim", False)
748
+
749
+ self.Wqkv = nn.Linear(self.embed_dim, qkv_dim, bias=config.qkv_proj_bias)
750
+
751
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
752
+ self.causal = config.causal
753
+ self.drop = nn.Dropout(config.attn_pdrop)
754
+
755
+ def forward(
756
+ self,
757
+ hidden_states: torch.Tensor,
758
+ attention_mask: Optional[torch.Tensor] = None,
759
+ position_ids: Optional[torch.LongTensor] = None,
760
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
761
+ output_attentions: bool = False,
762
+ use_cache: bool = False,
763
+ is_padded_inputs: Optional[bool] = True,
764
+ cu_seqlens: Optional[torch.Tensor] = None,
765
+ max_seq_len: Optional[int] = None,
766
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
767
+
768
+ has_layer_past = past_key_value is not None
769
+
770
+ if has_layer_past:
771
+ past_key_value = past_key_value[0]
772
+ past_len = past_key_value[1]
773
+ else:
774
+ past_len = 0
775
+
776
+ qkv = self.Wqkv(hidden_states)
777
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
778
+
779
+ past_key_value = (past_key_value, past_len + qkv.size(1)) if use_cache else None
780
+
781
+ if self.rotary_emb_dim > 0:
782
+ if self.rotary_head_dim:
783
+ qkv = rearrange(qkv, "b s three h d -> b h three s d")
784
+ qkv = self.rotary_emb(qkv, seqlen_offset=past_len)
785
+
786
+ if self.rotary_head_dim:
787
+ qkv = rearrange(qkv, "b h three s d -> b s three h d")
788
+
789
+ query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
790
+
791
+ query = query.permute(0, 2, 1, 3)
792
+ key = key.permute(0, 2, 1, 3)
793
+ value = value.permute(0, 2, 1, 3)
794
+
795
+ bsz, n_heads, seq_len, head_dim = query.shape
796
+ attention_mask = attention_mask.expand(bsz, n_heads, seq_len, seq_len).type_as(query)
797
+
798
+ import xformers.ops as xops
799
+ attn_output = xops.memory_efficient_attention(
800
+ query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2),
801
+ attn_bias=attention_mask, p=0
802
+ ).reshape(bsz, seq_len, n_heads * head_dim)
803
+
804
+ attn_output = self.out_proj(attn_output)
805
+
806
+ return attn_output
807
+
808
+ def bak_forward(
809
+ self,
810
+ hidden_states: torch.Tensor,
811
+ attention_mask: Optional[torch.Tensor] = None,
812
+ position_ids: Optional[torch.LongTensor] = None,
813
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
814
+ output_attentions: bool = False,
815
+ use_cache: bool = False,
816
+ is_padded_inputs: Optional[bool] = True,
817
+ cu_seqlens: Optional[torch.Tensor] = None,
818
+ max_seq_len: Optional[int] = None,
819
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
820
+
821
+ has_layer_past = past_key_value is not None
822
+
823
+ if has_layer_past:
824
+ past_key_value = past_key_value[0]
825
+ past_len = past_key_value[1]
826
+ else:
827
+ past_len = 0
828
+
829
+ qkv = self.Wqkv(hidden_states)
830
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
831
+
832
+ past_key_value = (past_key_value, past_len + qkv.size(1)) if use_cache else None
833
+
834
+ if self.rotary_emb_dim > 0:
835
+ if self.rotary_head_dim:
836
+ qkv = rearrange(qkv, "b s three h d -> b h three s d")
837
+ qkv = self.rotary_emb(qkv, seqlen_offset=past_len)
838
+
839
+ if self.rotary_head_dim:
840
+ qkv = rearrange(qkv, "b h three s d -> b s three h d")
841
+
842
+ query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
843
+
844
+ query = query.permute(0, 2, 1, 3)
845
+ key = key.permute(0, 2, 1, 3)
846
+ value = value.permute(0, 2, 1, 3)
847
+
848
+ attention_scores = torch.matmul(query, key.transpose(-1, -2)) / self.norm_factor
849
+ if attention_mask is not None:
850
+ attention_scores = attention_scores + attention_mask
851
+
852
+ attentions_probs = F.softmax(attention_scores, dim=-1)
853
+ attentions_probs = self.drop(attentions_probs)
854
+
855
+ attn_output = torch.matmul(attentions_probs, value)
856
+ attn_output = rearrange(attn_output.permute(0, 2, 1, 3), "... h d -> ... (h d)")
857
+
858
+ attn_output = self.out_proj(attn_output)
859
+
860
+ return attn_output
861
+
862
+
863
+ class NomicBertBlock(nn.Module):
864
+ def __init__(
865
+ self,
866
+ config,
867
+ ):
868
+ super().__init__()
869
+ self.prenorm = config.prenorm
870
+ self.fused_dropout_add_ln = config.fused_dropout_add_ln
871
+
872
+ self.attn = NomicBertAttention(config)
873
+ activation = (
874
+ F.sigmoid
875
+ if config.activation_function == "glu"
876
+ else (F.silu if config.activation_function == "swiglu" else F.gelu)
877
+ )
878
+ if config.activation_function in ["glu", "swiglu", "geglu"]:
879
+ self.mlp = NomciBertGatedMLP(
880
+ config.n_embd,
881
+ hidden_features=config.n_inner,
882
+ bias1=config.mlp_fc1_bias,
883
+ bias2=config.mlp_fc2_bias,
884
+ activation=activation,
885
+ fused_bias_fc=config.fused_bias_fc,
886
+ )
887
+ else:
888
+ self.mlp = NomicBertMLP(
889
+ config.n_embd,
890
+ hidden_features=config.n_inner,
891
+ bias1=config.mlp_fc1_bias,
892
+ bias2=config.mlp_fc2_bias,
893
+ activation=activation,
894
+ fused_bias_fc=config.fused_bias_fc,
895
+ )
896
+
897
+ self.dropout1 = nn.Dropout(config.resid_pdrop)
898
+ self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
899
+ self.norm2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
900
+ self.dropout2 = nn.Dropout(config.resid_pdrop)
901
+
902
+ def forward(
903
+ self,
904
+ hidden_states: torch.Tensor,
905
+ hidden_states2: torch.Tensor,
906
+ residual: Optional[torch.Tensor] = None,
907
+ attention_mask: Optional[torch.Tensor] = None,
908
+ position_ids: Optional[torch.LongTensor] = None,
909
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
910
+ is_padded_inputs: Optional[bool] = True,
911
+ output_attentions: Optional[bool] = False,
912
+ use_cache: Optional[bool] = False,
913
+ cu_seqlens: Optional[torch.Tensor] = None,
914
+ max_seq_len: Optional[int] = None,
915
+ ):
916
+ r"""Pass the input through the encoder layer.
917
+
918
+ Args:
919
+ hidden_states: the sequence to the encoder layer (required).
920
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
921
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
922
+ before applying the query projection. Useful for e.g., ViT where we only care
923
+ about the CLS token in the last layer.
924
+ """
925
+ if self.prenorm:
926
+ dropped = self.dropout1(hidden_states)
927
+ residual = (dropped + residual) if residual is not None else dropped
928
+ hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
929
+ hidden_states = self.attn(
930
+ hidden_states,
931
+ attention_mask=attention_mask,
932
+ is_padded_inputs=is_padded_inputs,
933
+ cu_seqlens=cu_seqlens,
934
+ max_seq_len=max_seq_len,
935
+ )
936
+
937
+ dropped = self.dropout2(hidden_states)
938
+ residual = (dropped + residual) if residual is not None else dropped
939
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
940
+ hidden_states = self.mlp(hidden_states)
941
+
942
+ return hidden_states, None, residual
943
+ else:
944
+ assert residual is None
945
+ attn_outputs = self.attn(
946
+ hidden_states,
947
+ attention_mask=attention_mask,
948
+ is_padded_inputs=is_padded_inputs,
949
+ cu_seqlens=cu_seqlens,
950
+ max_seq_len=max_seq_len,
951
+ )
952
+ hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype))
953
+ mlp_out = self.mlp(hidden_states)
954
+
955
+ hidden_states = self.norm2((self.dropout2(mlp_out) + hidden_states).to(dtype=self.norm2.weight.dtype))
956
+ return hidden_states, None, None
957
+
958
+
959
+ class NomicBertEncoder(nn.Module):
960
+ def __init__(self, config: GPT2Config):
961
+ super().__init__()
962
+ self.layers = nn.ModuleList([NomicBertBlock(config) for _ in range(config.n_layer)])
963
+ self.gradient_checkpointing = False
964
+ self.config = config
965
+
966
+ def forward(
967
+ self,
968
+ hidden_states: torch.LongTensor = None,
969
+ attention_mask: Optional[torch.Tensor] = None,
970
+ position_ids: Optional[torch.LongTensor] = None,
971
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
972
+ inputs_embeds: Optional[torch.FloatTensor] = None,
973
+ use_cache: Optional[bool] = None,
974
+ output_attentions: Optional[bool] = None,
975
+ output_hidden_states: Optional[bool] = None,
976
+ return_dict: Optional[bool] = None,
977
+ is_padded_inputs: Optional[bool] = True,
978
+ ):
979
+ """If subset_mask is not None, we only want output for the subset of the sequence.
980
+ This means that we only compute the last layer output for these tokens.
981
+ subset_mask: (batch, seqlen), dtype=torch.bool
982
+ """
983
+ hidden_states2 = None
984
+ residual = None
985
+
986
+ for _, layer in enumerate(self.layers):
987
+ if self.gradient_checkpointing and self.training:
988
+
989
+ def create_custom_forward(module):
990
+ def custom_forward(*inputs):
991
+ # None for past_key_value
992
+ return module(*inputs)
993
+
994
+ return custom_forward
995
+
996
+ hidden_states, hidden_states2, residual = torch.utils.checkpoint.checkpoint(
997
+ create_custom_forward(layer),
998
+ hidden_states,
999
+ hidden_states2,
1000
+ residual,
1001
+ attention_mask,
1002
+ None,
1003
+ None,
1004
+ is_padded_inputs,
1005
+ # if you freeze ANY layers, you need `use_reentrant=False`
1006
+ # https://github.com/huggingface/transformers/issues/21381
1007
+ # https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/7
1008
+ use_reentrant=False,
1009
+ )
1010
+
1011
+ else:
1012
+ hidden_states, hidden_states2, residual = layer(
1013
+ hidden_states,
1014
+ hidden_states2,
1015
+ residual,
1016
+ attention_mask,
1017
+ position_ids,
1018
+ None,
1019
+ is_padded_inputs,
1020
+ output_attentions,
1021
+ use_cache,
1022
+ )
1023
+ return hidden_states
1024
+
1025
+
1026
+ class NomicBertPooler(nn.Module):
1027
+ def __init__(self, config):
1028
+ super().__init__()
1029
+ self.dense = nn.Linear(config.n_embd, config.n_embd)
1030
+ self.activation = nn.Tanh()
1031
+
1032
+ def forward(self, hidden_states, pool=True):
1033
+ # We "pool" the model by simply taking the hidden state corresponding
1034
+ # to the first token.
1035
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
1036
+ pooled_output = self.dense(first_token_tensor)
1037
+ pooled_output = self.activation(pooled_output)
1038
+ return pooled_output
1039
+
1040
+
1041
+ class NomicBertPredictionHeadTransform(nn.Module):
1042
+ def __init__(self, config):
1043
+ super().__init__()
1044
+ self.dense = nn.Linear(config.n_embd, config.n_embd, bias=config.mlp_fc1_bias)
1045
+ approximate = "tanh" if config.activation_function in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none"
1046
+ if config.activation_function == "swiglu":
1047
+ self.transform_act_fn = F.silu
1048
+ else:
1049
+ self.transform_act_fn = nn.GELU(approximate=approximate)
1050
+
1051
+ self.layer_norm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1052
+
1053
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1054
+ hidden_states = self.dense(hidden_states)
1055
+ hidden_states = self.transform_act_fn(hidden_states)
1056
+ hidden_states = self.layer_norm(hidden_states)
1057
+
1058
+ return hidden_states
1059
+
1060
+
1061
+ class NomicBertLMPredictionHead(nn.Module):
1062
+ def __init__(self, config):
1063
+ super().__init__()
1064
+
1065
+ self.transform = NomicBertPredictionHeadTransform(config)
1066
+
1067
+ self.decoder = nn.Linear(config.n_embd, config.vocab_size, bias=config.mlp_fc1_bias)
1068
+
1069
+ def forward(self, hidden_states):
1070
+ hidden_states = self.transform(hidden_states)
1071
+ hidden_states = self.decoder(hidden_states)
1072
+ return hidden_states
1073
+
1074
+
1075
+ class NomicBertPreTrainingHeads(nn.Module):
1076
+ def __init__(self, config):
1077
+ super().__init__()
1078
+ self.predictions = NomicBertLMPredictionHead(config)
1079
+
1080
+ def forward(self, sequence_output):
1081
+ prediction_scores = self.predictions(sequence_output)
1082
+ return prediction_scores
1083
+
1084
+
1085
+ class NomicBertModel(NomicBertPreTrainedModel):
1086
+ def __init__(self, config: GPT2Config, add_pooling_layer=True):
1087
+ super().__init__(config)
1088
+ self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
1089
+ if config.vocab_size % self.pad_vocab_size_multiple != 0:
1090
+ config.vocab_size += self.pad_vocab_size_multiple - (config.vocab_size % self.pad_vocab_size_multiple)
1091
+
1092
+ assert config.activation_function in [
1093
+ "gelu",
1094
+ "gelu_new",
1095
+ "gelu_fast",
1096
+ "gelu_pytorch_tanh",
1097
+ "swiglu",
1098
+ "geglu",
1099
+ "glu",
1100
+ ]
1101
+
1102
+ self.embeddings = NomicBertEmbeddings(config)
1103
+ self.emb_drop = nn.Dropout(config.resid_pdrop)
1104
+ self.emb_ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1105
+ self.encoder = NomicBertEncoder(config)
1106
+ self.pooler = NomicBertPooler(config) if add_pooling_layer else None
1107
+
1108
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
1109
+
1110
+ def forward(
1111
+ self,
1112
+ input_ids,
1113
+ position_ids=None,
1114
+ token_type_ids=None,
1115
+ attention_mask=None,
1116
+ return_dict=None,
1117
+ matryoshka_dim=None,
1118
+ ):
1119
+ if token_type_ids is None:
1120
+ token_type_ids = torch.zeros_like(input_ids)
1121
+ hidden_states = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
1122
+ hidden_states = self.emb_ln(hidden_states)
1123
+ hidden_states = self.emb_drop(hidden_states)
1124
+
1125
+ attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.shape)
1126
+ sequence_output = self.encoder(hidden_states, attention_mask=attention_mask, return_dict=return_dict)
1127
+
1128
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1129
+
1130
+ if matryoshka_dim:
1131
+ sequence_output = sequence_output[:, :matryoshka_dim]
1132
+
1133
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1134
+ last_hidden_state=sequence_output,
1135
+ pooler_output=pooled_output,
1136
+ )
1137
+
1138
+
1139
+ class NomicBertForPreTraining(NomicBertPreTrainedModel):
1140
+ _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
1141
+
1142
+ def __init__(self, config: GPT2Config):
1143
+ super().__init__(config)
1144
+
1145
+ self.bert = NomicBertModel(config, add_pooling_layer=getattr(config, "add_pooling_layer", False))
1146
+ self.cls = NomicBertPreTrainingHeads(config)
1147
+ self.mlm_loss = nn.CrossEntropyLoss()
1148
+
1149
+ # Initialize weights and apply final processing
1150
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
1151
+ self.tie_weights()
1152
+
1153
+ def tie_weights(self):
1154
+ self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
1155
+
1156
+ def forward(
1157
+ self,
1158
+ input_ids,
1159
+ position_ids=None,
1160
+ token_type_ids=None,
1161
+ attention_mask=None,
1162
+ labels=None,
1163
+ ):
1164
+ """
1165
+ If labels are provided, they must be -100 for masked out tokens (as specified in the attention
1166
+ mask).
1167
+ Outputs:
1168
+ if `labels` and `next_sentence_label` are not `None`:
1169
+ Outputs the total_loss which is the sum of the masked language modeling loss and the next
1170
+ sentence classification loss.
1171
+ if `labels` or `next_sentence_label` is `None`:
1172
+ Outputs a tuple comprising
1173
+ - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
1174
+ - the next sentence classification logits of shape [batch_size, 2].
1175
+
1176
+ """
1177
+ outputs = self.bert(
1178
+ input_ids,
1179
+ position_ids=position_ids,
1180
+ token_type_ids=token_type_ids,
1181
+ attention_mask=attention_mask.bool() if attention_mask is not None else None,
1182
+ )
1183
+ sequence_output, _ = outputs.last_hidden_state, outputs.pooler_output
1184
+
1185
+ prediction_scores = self.cls(sequence_output)
1186
+
1187
+ total_loss = None
1188
+ if labels is not None:
1189
+ masked_lm_loss = self.mlm_loss(
1190
+ rearrange(prediction_scores, "... v -> (...) v"),
1191
+ rearrange(labels, "... -> (...)"),
1192
+ )
1193
+ total_loss = masked_lm_loss.float()
1194
+
1195
+ return MaskedLMOutput(
1196
+ loss=total_loss,
1197
+ logits=prediction_scores,
1198
+ hidden_states=outputs.hidden_states,
1199
+ attentions=None,
1200
+ )
1201
+
1202
+
1203
+ class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
1204
+ def __init__(self, config):
1205
+ super().__init__(config)
1206
+ self.num_labels = config.num_labels
1207
+ self.config = config
1208
+
1209
+ self.bert = NomicBertModel(config)
1210
+ classifier_dropout = getattr(config, "classifier_dropout", config.embd_pdrop)
1211
+ self.dropout = nn.Dropout(classifier_dropout)
1212
+ self.classifier = nn.Linear(config.n_embd, config.num_labels)
1213
+
1214
+ # Initialize weights and apply final processing
1215
+ self.post_init()
1216
+
1217
+ def forward(
1218
+ self,
1219
+ input_ids: Optional[torch.Tensor] = None,
1220
+ attention_mask: Optional[torch.Tensor] = None,
1221
+ token_type_ids: Optional[torch.Tensor] = None,
1222
+ position_ids: Optional[torch.Tensor] = None,
1223
+ head_mask: Optional[torch.Tensor] = None,
1224
+ inputs_embeds: Optional[torch.Tensor] = None,
1225
+ labels: Optional[torch.Tensor] = None,
1226
+ output_attentions: Optional[bool] = None,
1227
+ output_hidden_states: Optional[bool] = None,
1228
+ return_dict: Optional[bool] = None,
1229
+ ):
1230
+ r"""
1231
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1232
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1233
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1234
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1235
+ """
1236
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1237
+ outputs = self.bert(
1238
+ input_ids,
1239
+ position_ids=position_ids,
1240
+ token_type_ids=token_type_ids,
1241
+ attention_mask=attention_mask.bool() if attention_mask is not None else None,
1242
+ )
1243
+
1244
+ pooled_output = outputs[1]
1245
+
1246
+ pooled_output = self.dropout(pooled_output)
1247
+ logits = self.classifier(pooled_output)
1248
+
1249
+ loss = None
1250
+ if labels is not None:
1251
+ if self.config.problem_type is None:
1252
+ if self.num_labels == 1:
1253
+ self.config.problem_type = "regression"
1254
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1255
+ self.config.problem_type = "single_label_classification"
1256
+ else:
1257
+ self.config.problem_type = "multi_label_classification"
1258
+
1259
+ if self.config.problem_type == "regression":
1260
+ loss_fct = nn.MSELoss()
1261
+ if self.num_labels == 1:
1262
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1263
+ else:
1264
+ loss = loss_fct(logits, labels)
1265
+ elif self.config.problem_type == "single_label_classification":
1266
+ loss_fct = nn.CrossEntropyLoss()
1267
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1268
+ elif self.config.problem_type == "multi_label_classification":
1269
+ loss_fct = nn.BCEWithLogitsLoss()
1270
+ loss = loss_fct(logits, labels)
1271
+ if not return_dict:
1272
+ output = (logits,) + outputs[2:]
1273
+ return ((loss,) + output) if loss is not None else output
1274
+
1275
+ return SequenceClassifierOutput(
1276
+ loss=loss,
1277
+ logits=logits,
1278
+ hidden_states=outputs.hidden_states,
1279
+ attentions=outputs.attentions,
1280
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79606f054675f2f6f3ea58c0c5727f16914b91d6590cff0e1a78c78c67d67b5e
3
+ size 546961421
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "additional_special_tokens": [],
45
+ "clean_up_tokenization_spaces": true,
46
+ "cls_token": "[CLS]",
47
+ "do_lower_case": true,
48
+ "mask_token": "[MASK]",
49
+ "model_max_length": 512,
50
+ "pad_token": "[PAD]",
51
+ "sep_token": "[SEP]",
52
+ "strip_accents": null,
53
+ "tokenize_chinese_chars": true,
54
+ "tokenizer_class": "BertTokenizer",
55
+ "unk_token": "[UNK]"
56
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff