Add new SentenceTransformer model.
Browse files- 1_Pooling/config.json +10 -0
- README.md +465 -0
- config.json +58 -0
- config_sentence_transformers.json +10 -0
- configuration_hf_nomic_bert.py +56 -0
- model.safetensors +3 -0
- modeling_hf_nomic_bert.py +1234 -0
- modules.json +14 -0
- sentence_bert_config.json +4 -0
- special_tokens_map.json +37 -0
- tokenizer.json +0 -0
- tokenizer_config.json +62 -0
- vocab.txt +0 -0
1_Pooling/config.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"word_embedding_dimension": 768,
|
3 |
+
"pooling_mode_cls_token": false,
|
4 |
+
"pooling_mode_mean_tokens": true,
|
5 |
+
"pooling_mode_max_tokens": false,
|
6 |
+
"pooling_mode_mean_sqrt_len_tokens": false,
|
7 |
+
"pooling_mode_weightedmean_tokens": false,
|
8 |
+
"pooling_mode_lasttoken": false,
|
9 |
+
"include_prompt": true
|
10 |
+
}
|
README.md
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language: []
|
3 |
+
library_name: sentence-transformers
|
4 |
+
tags:
|
5 |
+
- sentence-transformers
|
6 |
+
- sentence-similarity
|
7 |
+
- feature-extraction
|
8 |
+
- dataset_size:1K<n<10K
|
9 |
+
- loss:CachedGISTEmbedLoss
|
10 |
+
base_model: nomic-ai/nomic-embed-text-v1.5
|
11 |
+
metrics:
|
12 |
+
- cosine_accuracy
|
13 |
+
- dot_accuracy
|
14 |
+
- manhattan_accuracy
|
15 |
+
- euclidean_accuracy
|
16 |
+
- max_accuracy
|
17 |
+
widget:
|
18 |
+
- source_sentence: Pilot
|
19 |
+
sentences:
|
20 |
+
- Episode Two
|
21 |
+
- dog dinosaur bone
|
22 |
+
- 10' x 12' gazebo
|
23 |
+
- source_sentence: skull
|
24 |
+
sentences:
|
25 |
+
- cool head s
|
26 |
+
- trunk bike rack 4
|
27 |
+
- bread without gluten
|
28 |
+
- source_sentence: pipes
|
29 |
+
sentences:
|
30 |
+
- chillum pipe
|
31 |
+
- Deckle Edge Ruler
|
32 |
+
- dog collar for boxer
|
33 |
+
- source_sentence: ddj400
|
34 |
+
sentences:
|
35 |
+
- lc27h711qenxza
|
36 |
+
- bed frame for full
|
37 |
+
- chicago bears gifts
|
38 |
+
- source_sentence: primes
|
39 |
+
sentences:
|
40 |
+
- Newton
|
41 |
+
- big boys sneakers
|
42 |
+
- large dog clothes
|
43 |
+
pipeline_tag: sentence-similarity
|
44 |
+
model-index:
|
45 |
+
- name: SentenceTransformer based on nomic-ai/nomic-embed-text-v1.5
|
46 |
+
results:
|
47 |
+
- task:
|
48 |
+
type: triplet
|
49 |
+
name: Triplet
|
50 |
+
dataset:
|
51 |
+
name: esci dev
|
52 |
+
type: esci-dev
|
53 |
+
metrics:
|
54 |
+
- type: cosine_accuracy
|
55 |
+
value: 0.6414052697616061
|
56 |
+
name: Cosine Accuracy
|
57 |
+
- type: dot_accuracy
|
58 |
+
value: 0.36637390213299875
|
59 |
+
name: Dot Accuracy
|
60 |
+
- type: manhattan_accuracy
|
61 |
+
value: 0.6404015056461732
|
62 |
+
name: Manhattan Accuracy
|
63 |
+
- type: euclidean_accuracy
|
64 |
+
value: 0.6406524466750314
|
65 |
+
name: Euclidean Accuracy
|
66 |
+
- type: max_accuracy
|
67 |
+
value: 0.6414052697616061
|
68 |
+
name: Max Accuracy
|
69 |
+
---
|
70 |
+
|
71 |
+
# SentenceTransformer based on nomic-ai/nomic-embed-text-v1.5
|
72 |
+
|
73 |
+
This is a [sentence-transformers](https://www.SBERT.net) model finetuned from [nomic-ai/nomic-embed-text-v1.5](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5). It maps sentences & paragraphs to a 768-dimensional dense vector space and can be used for semantic textual similarity, semantic search, paraphrase mining, text classification, clustering, and more.
|
74 |
+
|
75 |
+
## Model Details
|
76 |
+
|
77 |
+
### Model Description
|
78 |
+
- **Model Type:** Sentence Transformer
|
79 |
+
- **Base model:** [nomic-ai/nomic-embed-text-v1.5](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5) <!-- at revision 91d2d6bfdddf0b0da840f901b533e99bae30d757 -->
|
80 |
+
- **Maximum Sequence Length:** 8192 tokens
|
81 |
+
- **Output Dimensionality:** 768 tokens
|
82 |
+
- **Similarity Function:** Cosine Similarity
|
83 |
+
<!-- - **Training Dataset:** Unknown -->
|
84 |
+
<!-- - **Language:** Unknown -->
|
85 |
+
<!-- - **License:** Unknown -->
|
86 |
+
|
87 |
+
### Model Sources
|
88 |
+
|
89 |
+
- **Documentation:** [Sentence Transformers Documentation](https://sbert.net)
|
90 |
+
- **Repository:** [Sentence Transformers on GitHub](https://github.com/UKPLab/sentence-transformers)
|
91 |
+
- **Hugging Face:** [Sentence Transformers on Hugging Face](https://huggingface.co/models?library=sentence-transformers)
|
92 |
+
|
93 |
+
### Full Model Architecture
|
94 |
+
|
95 |
+
```
|
96 |
+
SentenceTransformer(
|
97 |
+
(0): Transformer({'max_seq_length': 8192, 'do_lower_case': False}) with Transformer model: NomicBertModel
|
98 |
+
(1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
|
99 |
+
)
|
100 |
+
```
|
101 |
+
|
102 |
+
## Usage
|
103 |
+
|
104 |
+
### Direct Usage (Sentence Transformers)
|
105 |
+
|
106 |
+
First install the Sentence Transformers library:
|
107 |
+
|
108 |
+
```bash
|
109 |
+
pip install -U sentence-transformers
|
110 |
+
```
|
111 |
+
|
112 |
+
Then you can load this model and run inference.
|
113 |
+
```python
|
114 |
+
from sentence_transformers import SentenceTransformer
|
115 |
+
|
116 |
+
# Download from the 🤗 Hub
|
117 |
+
model = SentenceTransformer("sentence_transformers_model_id")
|
118 |
+
# Run inference
|
119 |
+
sentences = [
|
120 |
+
'primes',
|
121 |
+
'Newton',
|
122 |
+
'big boys sneakers',
|
123 |
+
]
|
124 |
+
embeddings = model.encode(sentences)
|
125 |
+
print(embeddings.shape)
|
126 |
+
# [3, 768]
|
127 |
+
|
128 |
+
# Get the similarity scores for the embeddings
|
129 |
+
similarities = model.similarity(embeddings, embeddings)
|
130 |
+
print(similarities.shape)
|
131 |
+
# [3, 3]
|
132 |
+
```
|
133 |
+
|
134 |
+
<!--
|
135 |
+
### Direct Usage (Transformers)
|
136 |
+
|
137 |
+
<details><summary>Click to see the direct usage in Transformers</summary>
|
138 |
+
|
139 |
+
</details>
|
140 |
+
-->
|
141 |
+
|
142 |
+
<!--
|
143 |
+
### Downstream Usage (Sentence Transformers)
|
144 |
+
|
145 |
+
You can finetune this model on your own dataset.
|
146 |
+
|
147 |
+
<details><summary>Click to expand</summary>
|
148 |
+
|
149 |
+
</details>
|
150 |
+
-->
|
151 |
+
|
152 |
+
<!--
|
153 |
+
### Out-of-Scope Use
|
154 |
+
|
155 |
+
*List how the model may foreseeably be misused and address what users ought not to do with the model.*
|
156 |
+
-->
|
157 |
+
|
158 |
+
## Evaluation
|
159 |
+
|
160 |
+
### Metrics
|
161 |
+
|
162 |
+
#### Triplet
|
163 |
+
* Dataset: `esci-dev`
|
164 |
+
* Evaluated with [<code>TripletEvaluator</code>](https://sbert.net/docs/package_reference/sentence_transformer/evaluation.html#sentence_transformers.evaluation.TripletEvaluator)
|
165 |
+
|
166 |
+
| Metric | Value |
|
167 |
+
|:-------------------|:-----------|
|
168 |
+
| cosine_accuracy | 0.6414 |
|
169 |
+
| dot_accuracy | 0.3664 |
|
170 |
+
| manhattan_accuracy | 0.6404 |
|
171 |
+
| euclidean_accuracy | 0.6407 |
|
172 |
+
| **max_accuracy** | **0.6414** |
|
173 |
+
|
174 |
+
<!--
|
175 |
+
## Bias, Risks and Limitations
|
176 |
+
|
177 |
+
*What are the known or foreseeable issues stemming from this model? You could also flag here known failure cases or weaknesses of the model.*
|
178 |
+
-->
|
179 |
+
|
180 |
+
<!--
|
181 |
+
### Recommendations
|
182 |
+
|
183 |
+
*What are recommendations with respect to the foreseeable issues? For example, filtering explicit content.*
|
184 |
+
-->
|
185 |
+
|
186 |
+
## Training Details
|
187 |
+
|
188 |
+
### Training Dataset
|
189 |
+
|
190 |
+
#### Unnamed Dataset
|
191 |
+
|
192 |
+
|
193 |
+
* Size: 9,090 training samples
|
194 |
+
* Columns: <code>query</code>, <code>pos</code>, and <code>neg</code>
|
195 |
+
* Approximate statistics based on the first 1000 samples:
|
196 |
+
| | query | pos | neg |
|
197 |
+
|:--------|:---------------------------------------------------------------------------------|:----------------------------------------------------------------------------------|:---------------------------------------------------------------------------------|
|
198 |
+
| type | string | string | string |
|
199 |
+
| details | <ul><li>min: 3 tokens</li><li>mean: 7.42 tokens</li><li>max: 30 tokens</li></ul> | <ul><li>min: 3 tokens</li><li>mean: 29.27 tokens</li><li>max: 87 tokens</li></ul> | <ul><li>min: 4 tokens</li><li>mean: 29.8 tokens</li><li>max: 82 tokens</li></ul> |
|
200 |
+
* Samples:
|
201 |
+
| query | pos | neg |
|
202 |
+
|:--------------------------------------------------------------|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
203 |
+
| <code>1 3/4 inch tooled belt strap without belt buckle</code> | <code>BS3501 Solid Brass Leaf Belt Buckle Fits 1-3/4"(45mm) Wide Belt</code> | <code>Nocona Men's Hired Brown Floral Eagle, 40</code> |
|
204 |
+
| <code>7edge phone case peacock</code> | <code>Galaxy S7 Edge Case for Girls Women Clear with Flowers Design Shockproof Protective Cell Phone Cases for Samsung Galaxy S7 Edge 5.5 Inch Cute Floral Pattern Print Flexible Slim Fit Bumper Rubber Cover</code> | <code>Galaxy S7 Case, Galaxy S7 Phone Case with HD Screen Protector for Girls Women, Gritup Cute Clear Gradient Glitter Liquid TPU Slim Phone Case for Samsung Galaxy S7 Teal/Purple</code> |
|
205 |
+
| <code>girls white shoes</code> | <code>adidas Women's Coast Star Shoes, ftwr White/Silver Met./ core Black, 6 M US</code> | <code>Converse Optical White M7650 - HI TOP Size 6 M US Women / 4 M US Men</code> |
|
206 |
+
* Loss: [<code>CachedGISTEmbedLoss</code>](https://sbert.net/docs/package_reference/sentence_transformer/losses.html#cachedgistembedloss) with these parameters:
|
207 |
+
```json
|
208 |
+
{'guide': SentenceTransformer(
|
209 |
+
(0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: BertModel
|
210 |
+
(1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
|
211 |
+
(2): Normalize()
|
212 |
+
), 'temperature': 0.01}
|
213 |
+
```
|
214 |
+
|
215 |
+
### Evaluation Dataset
|
216 |
+
|
217 |
+
#### Unnamed Dataset
|
218 |
+
|
219 |
+
|
220 |
+
* Size: 3,985 evaluation samples
|
221 |
+
* Columns: <code>query</code>, <code>pos</code>, and <code>neg</code>
|
222 |
+
* Approximate statistics based on the first 1000 samples:
|
223 |
+
| | query | pos | neg |
|
224 |
+
|:--------|:---------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------|:----------------------------------------------------------------------------------|
|
225 |
+
| type | string | string | string |
|
226 |
+
| details | <ul><li>min: 3 tokens</li><li>mean: 7.28 tokens</li><li>max: 25 tokens</li></ul> | <ul><li>min: 3 tokens</li><li>mean: 28.58 tokens</li><li>max: 116 tokens</li></ul> | <ul><li>min: 3 tokens</li><li>mean: 29.26 tokens</li><li>max: 79 tokens</li></ul> |
|
227 |
+
* Samples:
|
228 |
+
| query | pos | neg |
|
229 |
+
|:--------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
230 |
+
| <code>colors for dining room</code> | <code>AOOS CUSTOM Dimmable LED Neon Signs for Home Bedroom Salon Dining Room Wall Decor (Customization: Texts, Designs, Logos, Languages, Colors, Sizes, Fonts, Color-Changing) (24" / 1 Line Text)</code> | <code>Jetec 5 Pieces EAT Sign Kitchen Wood Rustic Sign Arrow Wall Decor EAT Farmhouse Decoration Hanging Arrow Wooden Sign for Kitchen Wall Home Dining Room (Charming Color)</code> |
|
231 |
+
| <code>mix no 6 heels for women</code> | <code>DREAM PAIRS Women's Hi-Chunk Gold Glitter High Heel Pump Sandals - 6 M US</code> | <code>Fashare Womens High Heels Pointed Toe Bowtie Back Ankle Buckle Strap Wedding Evening Party Dress Pumps Shoes</code> |
|
232 |
+
| <code>goxlrmini</code> | <code>Singing Machine SMM-205 Unidirectional Dynamic Microphone with 10 Ft. Cord,Black, one size</code> | <code>Behringer U-Phoria Studio Pro Complete Recording Bundle with UMC202HD USB Audio Interface - With 20' 6mm Rubber XLR Microphone Cable, On-Stage MBS5000 Broadcast/Webcast Boom Arm with XLR Cable</code> |
|
233 |
+
* Loss: [<code>CachedGISTEmbedLoss</code>](https://sbert.net/docs/package_reference/sentence_transformer/losses.html#cachedgistembedloss) with these parameters:
|
234 |
+
```json
|
235 |
+
{'guide': SentenceTransformer(
|
236 |
+
(0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: BertModel
|
237 |
+
(1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
|
238 |
+
(2): Normalize()
|
239 |
+
), 'temperature': 0.01}
|
240 |
+
```
|
241 |
+
|
242 |
+
### Training Hyperparameters
|
243 |
+
#### Non-Default Hyperparameters
|
244 |
+
|
245 |
+
- `per_device_train_batch_size`: 16
|
246 |
+
- `per_device_eval_batch_size`: 16
|
247 |
+
- `num_train_epochs`: 10
|
248 |
+
- `warmup_ratio`: 0.1
|
249 |
+
- `fp16`: True
|
250 |
+
- `batch_sampler`: no_duplicates
|
251 |
+
|
252 |
+
#### All Hyperparameters
|
253 |
+
<details><summary>Click to expand</summary>
|
254 |
+
|
255 |
+
- `overwrite_output_dir`: False
|
256 |
+
- `do_predict`: False
|
257 |
+
- `prediction_loss_only`: True
|
258 |
+
- `per_device_train_batch_size`: 16
|
259 |
+
- `per_device_eval_batch_size`: 16
|
260 |
+
- `per_gpu_train_batch_size`: None
|
261 |
+
- `per_gpu_eval_batch_size`: None
|
262 |
+
- `gradient_accumulation_steps`: 1
|
263 |
+
- `eval_accumulation_steps`: None
|
264 |
+
- `learning_rate`: 5e-05
|
265 |
+
- `weight_decay`: 0.0
|
266 |
+
- `adam_beta1`: 0.9
|
267 |
+
- `adam_beta2`: 0.999
|
268 |
+
- `adam_epsilon`: 1e-08
|
269 |
+
- `max_grad_norm`: 1.0
|
270 |
+
- `num_train_epochs`: 10
|
271 |
+
- `max_steps`: -1
|
272 |
+
- `lr_scheduler_type`: linear
|
273 |
+
- `lr_scheduler_kwargs`: {}
|
274 |
+
- `warmup_ratio`: 0.1
|
275 |
+
- `warmup_steps`: 0
|
276 |
+
- `log_level`: passive
|
277 |
+
- `log_level_replica`: warning
|
278 |
+
- `log_on_each_node`: True
|
279 |
+
- `logging_nan_inf_filter`: True
|
280 |
+
- `save_safetensors`: True
|
281 |
+
- `save_on_each_node`: False
|
282 |
+
- `save_only_model`: False
|
283 |
+
- `no_cuda`: False
|
284 |
+
- `use_cpu`: False
|
285 |
+
- `use_mps_device`: False
|
286 |
+
- `seed`: 42
|
287 |
+
- `data_seed`: None
|
288 |
+
- `jit_mode_eval`: False
|
289 |
+
- `use_ipex`: False
|
290 |
+
- `bf16`: False
|
291 |
+
- `fp16`: True
|
292 |
+
- `fp16_opt_level`: O1
|
293 |
+
- `half_precision_backend`: auto
|
294 |
+
- `bf16_full_eval`: False
|
295 |
+
- `fp16_full_eval`: False
|
296 |
+
- `tf32`: None
|
297 |
+
- `local_rank`: 0
|
298 |
+
- `ddp_backend`: None
|
299 |
+
- `tpu_num_cores`: None
|
300 |
+
- `tpu_metrics_debug`: False
|
301 |
+
- `debug`: []
|
302 |
+
- `dataloader_drop_last`: False
|
303 |
+
- `dataloader_num_workers`: 0
|
304 |
+
- `dataloader_prefetch_factor`: None
|
305 |
+
- `past_index`: -1
|
306 |
+
- `disable_tqdm`: False
|
307 |
+
- `remove_unused_columns`: True
|
308 |
+
- `label_names`: None
|
309 |
+
- `load_best_model_at_end`: False
|
310 |
+
- `ignore_data_skip`: False
|
311 |
+
- `fsdp`: []
|
312 |
+
- `fsdp_min_num_params`: 0
|
313 |
+
- `fsdp_config`: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}
|
314 |
+
- `fsdp_transformer_layer_cls_to_wrap`: None
|
315 |
+
- `accelerator_config`: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True}
|
316 |
+
- `deepspeed`: None
|
317 |
+
- `label_smoothing_factor`: 0.0
|
318 |
+
- `optim`: adamw_torch
|
319 |
+
- `optim_args`: None
|
320 |
+
- `adafactor`: False
|
321 |
+
- `group_by_length`: False
|
322 |
+
- `length_column_name`: length
|
323 |
+
- `ddp_find_unused_parameters`: None
|
324 |
+
- `ddp_bucket_cap_mb`: None
|
325 |
+
- `ddp_broadcast_buffers`: False
|
326 |
+
- `dataloader_pin_memory`: True
|
327 |
+
- `dataloader_persistent_workers`: False
|
328 |
+
- `skip_memory_metrics`: True
|
329 |
+
- `use_legacy_prediction_loop`: False
|
330 |
+
- `push_to_hub`: False
|
331 |
+
- `resume_from_checkpoint`: None
|
332 |
+
- `hub_model_id`: None
|
333 |
+
- `hub_strategy`: every_save
|
334 |
+
- `hub_private_repo`: False
|
335 |
+
- `hub_always_push`: False
|
336 |
+
- `gradient_checkpointing`: False
|
337 |
+
- `gradient_checkpointing_kwargs`: None
|
338 |
+
- `include_inputs_for_metrics`: False
|
339 |
+
- `fp16_backend`: auto
|
340 |
+
- `push_to_hub_model_id`: None
|
341 |
+
- `push_to_hub_organization`: None
|
342 |
+
- `mp_parameters`:
|
343 |
+
- `auto_find_batch_size`: False
|
344 |
+
- `full_determinism`: False
|
345 |
+
- `torchdynamo`: None
|
346 |
+
- `ray_scope`: last
|
347 |
+
- `ddp_timeout`: 1800
|
348 |
+
- `torch_compile`: False
|
349 |
+
- `torch_compile_backend`: None
|
350 |
+
- `torch_compile_mode`: None
|
351 |
+
- `dispatch_batches`: None
|
352 |
+
- `split_batches`: None
|
353 |
+
- `include_tokens_per_second`: False
|
354 |
+
- `include_num_input_tokens_seen`: False
|
355 |
+
- `neftune_noise_alpha`: None
|
356 |
+
- `batch_sampler`: no_duplicates
|
357 |
+
- `multi_dataset_batch_sampler`: proportional
|
358 |
+
|
359 |
+
</details>
|
360 |
+
|
361 |
+
### Training Logs
|
362 |
+
| Epoch | Step | Training Loss | esci-dev_max_accuracy |
|
363 |
+
|:------:|:----:|:-------------:|:---------------------:|
|
364 |
+
| 0 | 0 | - | 0.6414 |
|
365 |
+
| 0.1757 | 100 | 0.8875 | - |
|
366 |
+
| 0.3515 | 200 | 0.5281 | - |
|
367 |
+
| 0.5272 | 300 | 0.4621 | - |
|
368 |
+
| 0.7030 | 400 | 0.4669 | - |
|
369 |
+
| 0.8787 | 500 | 0.4501 | - |
|
370 |
+
| 1.0545 | 600 | 0.5379 | - |
|
371 |
+
| 1.2302 | 700 | 0.4288 | - |
|
372 |
+
| 1.4060 | 800 | 0.2112 | - |
|
373 |
+
| 1.5817 | 900 | 0.1508 | - |
|
374 |
+
| 1.7575 | 1000 | 0.1133 | - |
|
375 |
+
| 1.9332 | 1100 | 0.1312 | - |
|
376 |
+
| 2.1090 | 1200 | 0.0784 | - |
|
377 |
+
| 2.2847 | 1300 | 0.0983 | - |
|
378 |
+
| 2.4605 | 1400 | 0.106 | - |
|
379 |
+
| 2.6362 | 1500 | 0.1058 | - |
|
380 |
+
| 2.8120 | 1600 | 0.0673 | - |
|
381 |
+
| 2.9877 | 1700 | 0.0355 | - |
|
382 |
+
| 3.1634 | 1800 | 0.0175 | - |
|
383 |
+
| 3.3392 | 1900 | 0.0366 | - |
|
384 |
+
| 3.5149 | 2000 | 0.0332 | - |
|
385 |
+
| 3.6907 | 2100 | 0.0682 | - |
|
386 |
+
| 3.8664 | 2200 | 0.0378 | - |
|
387 |
+
| 4.0422 | 2300 | 0.0239 | - |
|
388 |
+
| 4.2179 | 2400 | 0.0282 | - |
|
389 |
+
| 4.3937 | 2500 | 0.0401 | - |
|
390 |
+
| 4.5694 | 2600 | 0.0268 | - |
|
391 |
+
| 4.7452 | 2700 | 0.0208 | - |
|
392 |
+
| 4.9209 | 2800 | 0.0117 | - |
|
393 |
+
| 5.0967 | 2900 | 0.0045 | - |
|
394 |
+
| 5.2724 | 3000 | 0.0145 | - |
|
395 |
+
| 5.4482 | 3100 | 0.029 | - |
|
396 |
+
| 5.6239 | 3200 | 0.0009 | - |
|
397 |
+
| 5.7996 | 3300 | 0.0033 | - |
|
398 |
+
| 5.9754 | 3400 | 0.0088 | - |
|
399 |
+
| 6.1511 | 3500 | 0.0014 | - |
|
400 |
+
| 6.3269 | 3600 | 0.0027 | - |
|
401 |
+
| 6.5026 | 3700 | 0.0021 | - |
|
402 |
+
| 6.6784 | 3800 | 0.0001 | - |
|
403 |
+
| 6.8541 | 3900 | 0.0025 | - |
|
404 |
+
| 7.0299 | 4000 | 0.0059 | - |
|
405 |
+
| 7.2056 | 4100 | 0.0025 | - |
|
406 |
+
| 7.3814 | 4200 | 0.0029 | - |
|
407 |
+
| 7.5571 | 4300 | 0.0007 | - |
|
408 |
+
| 7.7329 | 4400 | 0.0018 | - |
|
409 |
+
| 7.9086 | 4500 | 0.0032 | - |
|
410 |
+
| 8.0844 | 4600 | 0.0007 | - |
|
411 |
+
| 8.2601 | 4700 | 0.0027 | - |
|
412 |
+
| 8.4359 | 4800 | 0.0027 | - |
|
413 |
+
| 8.6116 | 4900 | 0.0 | - |
|
414 |
+
| 8.7873 | 5000 | 0.0025 | - |
|
415 |
+
| 8.9631 | 5100 | 0.0025 | - |
|
416 |
+
| 9.1388 | 5200 | 0.0014 | - |
|
417 |
+
| 9.3146 | 5300 | 0.0027 | - |
|
418 |
+
| 9.4903 | 5400 | 0.0021 | - |
|
419 |
+
| 9.6661 | 5500 | 0.0 | - |
|
420 |
+
| 9.8418 | 5600 | 0.0025 | - |
|
421 |
+
|
422 |
+
|
423 |
+
### Framework Versions
|
424 |
+
- Python: 3.10.12
|
425 |
+
- Sentence Transformers: 3.0.0
|
426 |
+
- Transformers: 4.38.2
|
427 |
+
- PyTorch: 2.1.2+cu121
|
428 |
+
- Accelerate: 0.27.2
|
429 |
+
- Datasets: 2.19.1
|
430 |
+
- Tokenizers: 0.15.2
|
431 |
+
|
432 |
+
## Citation
|
433 |
+
|
434 |
+
### BibTeX
|
435 |
+
|
436 |
+
#### Sentence Transformers
|
437 |
+
```bibtex
|
438 |
+
@inproceedings{reimers-2019-sentence-bert,
|
439 |
+
title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
|
440 |
+
author = "Reimers, Nils and Gurevych, Iryna",
|
441 |
+
booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
|
442 |
+
month = "11",
|
443 |
+
year = "2019",
|
444 |
+
publisher = "Association for Computational Linguistics",
|
445 |
+
url = "https://arxiv.org/abs/1908.10084",
|
446 |
+
}
|
447 |
+
```
|
448 |
+
|
449 |
+
<!--
|
450 |
+
## Glossary
|
451 |
+
|
452 |
+
*Clearly define terms in order to be accessible across audiences.*
|
453 |
+
-->
|
454 |
+
|
455 |
+
<!--
|
456 |
+
## Model Card Authors
|
457 |
+
|
458 |
+
*Lists the people who create the model card, providing recognition and accountability for the detailed work that goes into its construction.*
|
459 |
+
-->
|
460 |
+
|
461 |
+
<!--
|
462 |
+
## Model Card Contact
|
463 |
+
|
464 |
+
*Provides a way for people who have updates to the Model Card, suggestions, or questions, to contact the Model Card authors.*
|
465 |
+
-->
|
config.json
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "models/nomic-embed-text-esci/checkpoint-5600",
|
3 |
+
"activation_function": "swiglu",
|
4 |
+
"architectures": [
|
5 |
+
"NomicBertModel"
|
6 |
+
],
|
7 |
+
"attn_pdrop": 0.0,
|
8 |
+
"auto_map": {
|
9 |
+
"AutoConfig": "configuration_hf_nomic_bert.NomicBertConfig",
|
10 |
+
"AutoModel": "modeling_hf_nomic_bert.NomicBertModel",
|
11 |
+
"AutoModelForMaskedLM": "nomic-ai/nomic-bert-2048--modeling_hf_nomic_bert.NomicBertForPreTraining"
|
12 |
+
},
|
13 |
+
"bos_token_id": null,
|
14 |
+
"causal": false,
|
15 |
+
"dense_seq_output": true,
|
16 |
+
"embd_pdrop": 0.0,
|
17 |
+
"eos_token_id": null,
|
18 |
+
"fused_bias_fc": true,
|
19 |
+
"fused_dropout_add_ln": true,
|
20 |
+
"initializer_range": 0.02,
|
21 |
+
"layer_norm_epsilon": 1e-12,
|
22 |
+
"max_trained_positions": 2048,
|
23 |
+
"mlp_fc1_bias": false,
|
24 |
+
"mlp_fc2_bias": false,
|
25 |
+
"model_type": "nomic_bert",
|
26 |
+
"n_embd": 768,
|
27 |
+
"n_head": 12,
|
28 |
+
"n_inner": 3072,
|
29 |
+
"n_layer": 12,
|
30 |
+
"n_positions": 8192,
|
31 |
+
"pad_vocab_size_multiple": 64,
|
32 |
+
"parallel_block": false,
|
33 |
+
"parallel_block_tied_norm": false,
|
34 |
+
"prenorm": false,
|
35 |
+
"qkv_proj_bias": false,
|
36 |
+
"reorder_and_upcast_attn": false,
|
37 |
+
"resid_pdrop": 0.0,
|
38 |
+
"rotary_emb_base": 1000,
|
39 |
+
"rotary_emb_fraction": 1.0,
|
40 |
+
"rotary_emb_interleaved": false,
|
41 |
+
"rotary_emb_scale_base": null,
|
42 |
+
"rotary_scaling_factor": null,
|
43 |
+
"scale_attn_by_inverse_layer_idx": false,
|
44 |
+
"scale_attn_weights": true,
|
45 |
+
"summary_activation": null,
|
46 |
+
"summary_first_dropout": 0.0,
|
47 |
+
"summary_proj_to_labels": true,
|
48 |
+
"summary_type": "cls_index",
|
49 |
+
"summary_use_proj": true,
|
50 |
+
"torch_dtype": "float32",
|
51 |
+
"transformers_version": "4.38.2",
|
52 |
+
"type_vocab_size": 2,
|
53 |
+
"use_cache": true,
|
54 |
+
"use_flash_attn": true,
|
55 |
+
"use_rms_norm": false,
|
56 |
+
"use_xentropy": true,
|
57 |
+
"vocab_size": 30528
|
58 |
+
}
|
config_sentence_transformers.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"__version__": {
|
3 |
+
"sentence_transformers": "2.4.0.dev0",
|
4 |
+
"transformers": "4.37.2",
|
5 |
+
"pytorch": "2.1.0+cu121"
|
6 |
+
},
|
7 |
+
"prompts": {},
|
8 |
+
"default_prompt_name": null,
|
9 |
+
"similarity_fn_name": null
|
10 |
+
}
|
configuration_hf_nomic_bert.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import GPT2Config
|
2 |
+
|
3 |
+
|
4 |
+
class NomicBertConfig(GPT2Config):
|
5 |
+
model_type = "nomic_bert"
|
6 |
+
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
prenorm=False,
|
10 |
+
parallel_block=False,
|
11 |
+
parallel_block_tied_norm=False,
|
12 |
+
rotary_emb_fraction=0.0,
|
13 |
+
fused_dropout_add_ln=False,
|
14 |
+
fused_bias_fc=False,
|
15 |
+
use_flash_attn=False,
|
16 |
+
use_xentropy=False,
|
17 |
+
qkv_proj_bias=True,
|
18 |
+
rotary_emb_base=10_000,
|
19 |
+
rotary_emb_scale_base=None,
|
20 |
+
rotary_emb_interleaved=False,
|
21 |
+
mlp_fc1_bias=True,
|
22 |
+
mlp_fc2_bias=True,
|
23 |
+
use_rms_norm=False,
|
24 |
+
causal=False,
|
25 |
+
type_vocab_size=2,
|
26 |
+
dense_seq_output=True,
|
27 |
+
pad_vocab_size_multiple=1,
|
28 |
+
tie_word_embeddings=True,
|
29 |
+
rotary_scaling_factor=None,
|
30 |
+
max_trained_positions=2048,
|
31 |
+
**kwargs,
|
32 |
+
):
|
33 |
+
self.prenorm = prenorm
|
34 |
+
self.parallel_block = parallel_block
|
35 |
+
self.parallel_block_tied_norm = parallel_block_tied_norm
|
36 |
+
self.rotary_emb_fraction = rotary_emb_fraction
|
37 |
+
self.tie_word_embeddings = tie_word_embeddings
|
38 |
+
self.fused_dropout_add_ln = fused_dropout_add_ln
|
39 |
+
self.fused_bias_fc = fused_bias_fc
|
40 |
+
self.use_flash_attn = use_flash_attn
|
41 |
+
self.use_xentropy = use_xentropy
|
42 |
+
self.qkv_proj_bias = qkv_proj_bias
|
43 |
+
self.rotary_emb_base = rotary_emb_base
|
44 |
+
self.rotary_emb_scale_base = rotary_emb_scale_base
|
45 |
+
self.rotary_emb_interleaved = rotary_emb_interleaved
|
46 |
+
self.mlp_fc1_bias = mlp_fc1_bias
|
47 |
+
self.mlp_fc2_bias = mlp_fc2_bias
|
48 |
+
self.use_rms_norm = use_rms_norm
|
49 |
+
self.causal = causal
|
50 |
+
self.type_vocab_size = type_vocab_size
|
51 |
+
self.dense_seq_output = dense_seq_output
|
52 |
+
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
53 |
+
self.rotary_scaling_factor = rotary_scaling_factor
|
54 |
+
self.max_trained_positions = max_trained_positions
|
55 |
+
|
56 |
+
super().__init__(**kwargs)
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:85c0ee2699043dfafce03c1ee919a4bf7ef793421919aeb9892d0f41a6b7de62
|
3 |
+
size 546938168
|
modeling_hf_nomic_bert.py
ADDED
@@ -0,0 +1,1234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, Tri Dao.
|
2 |
+
# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
|
3 |
+
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
|
4 |
+
# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
|
5 |
+
|
6 |
+
import logging
|
7 |
+
|
8 |
+
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
|
9 |
+
import os
|
10 |
+
import re
|
11 |
+
from collections import OrderedDict
|
12 |
+
from functools import partial
|
13 |
+
from typing import List, Optional, Tuple, Union
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from einops import rearrange, repeat
|
19 |
+
from safetensors.torch import load_file as safe_load_file
|
20 |
+
from transformers import GPT2Config, PreTrainedModel
|
21 |
+
from transformers.models.bert.modeling_bert import (
|
22 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
23 |
+
MaskedLMOutput,
|
24 |
+
SequenceClassifierOutput,
|
25 |
+
)
|
26 |
+
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
|
27 |
+
from transformers.utils.hub import cached_file, get_checkpoint_shard_files
|
28 |
+
|
29 |
+
from .configuration_hf_nomic_bert import NomicBertConfig
|
30 |
+
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
# adapted from flash attention, added safe serialization option for hf models
|
35 |
+
def state_dict_from_pretrained(model_name, safe_serialization=False, device=None, dtype=None):
|
36 |
+
# If not fp32, then we don't want to load directly to the GPU
|
37 |
+
mapped_device = "cpu" if dtype not in [torch.float32, None] else device
|
38 |
+
is_sharded = False
|
39 |
+
load_safe = False
|
40 |
+
resolved_archive_file = None
|
41 |
+
|
42 |
+
weights_path = os.path.join(model_name, WEIGHTS_NAME)
|
43 |
+
weights_index_path = os.path.join(model_name, WEIGHTS_INDEX_NAME)
|
44 |
+
safe_weights_path = os.path.join(model_name, SAFE_WEIGHTS_NAME)
|
45 |
+
safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME)
|
46 |
+
|
47 |
+
if os.path.isfile(weights_path):
|
48 |
+
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
|
49 |
+
elif os.path.isfile(weights_index_path):
|
50 |
+
resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False)
|
51 |
+
is_sharded = True
|
52 |
+
elif os.path.isfile(safe_weights_path):
|
53 |
+
resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
|
54 |
+
load_safe = True
|
55 |
+
elif os.path.isfile(safe_weights_index_path):
|
56 |
+
resolved_archive_file = cached_file(
|
57 |
+
model_name, SAFE_WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False
|
58 |
+
)
|
59 |
+
is_sharded = True
|
60 |
+
load_safe = True
|
61 |
+
else: # Try loading from HF hub instead of from local files
|
62 |
+
resolved_archive_file = None
|
63 |
+
for weight_name in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
|
64 |
+
resolved_archive_file = cached_file(
|
65 |
+
model_name, weight_name, _raise_exceptions_for_missing_entries=False
|
66 |
+
)
|
67 |
+
if resolved_archive_file is not None:
|
68 |
+
if weight_name in [SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME]:
|
69 |
+
load_safe = True
|
70 |
+
if weight_name in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
|
71 |
+
is_sharded = True
|
72 |
+
break
|
73 |
+
|
74 |
+
if resolved_archive_file is None:
|
75 |
+
raise EnvironmentError(f"Model name {model_name} was not found.")
|
76 |
+
|
77 |
+
if load_safe:
|
78 |
+
loader = partial(safe_load_file, device=mapped_device)
|
79 |
+
else:
|
80 |
+
loader = partial(torch.load, map_location=mapped_device)
|
81 |
+
|
82 |
+
if is_sharded:
|
83 |
+
# resolved_archive_file becomes a list of files that point to the different
|
84 |
+
# checkpoint shards in this case.
|
85 |
+
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(model_name, resolved_archive_file)
|
86 |
+
state_dict = {}
|
87 |
+
for sharded_file in resolved_archive_file:
|
88 |
+
state_dict.update(loader(sharded_file))
|
89 |
+
else:
|
90 |
+
state_dict = loader(resolved_archive_file)
|
91 |
+
# Convert dtype before moving to GPU to save memory
|
92 |
+
if dtype is not None:
|
93 |
+
state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
|
94 |
+
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
|
95 |
+
return state_dict
|
96 |
+
|
97 |
+
|
98 |
+
def filter_shapes(state_dict, model):
|
99 |
+
"""
|
100 |
+
Filters the state dict to match the current model shape.
|
101 |
+
"""
|
102 |
+
filtered_state_dict = {}
|
103 |
+
for key, value in state_dict.items():
|
104 |
+
if key in model.state_dict():
|
105 |
+
if value.shape == model.state_dict()[key].shape:
|
106 |
+
filtered_state_dict[key] = value
|
107 |
+
return filtered_state_dict
|
108 |
+
|
109 |
+
|
110 |
+
def remap_bert_state_dict(
|
111 |
+
state_dict,
|
112 |
+
config,
|
113 |
+
remove_bert=False,
|
114 |
+
remove_cls_weights=False,
|
115 |
+
add_pooling_layer=False,
|
116 |
+
):
|
117 |
+
"""
|
118 |
+
Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
|
119 |
+
"""
|
120 |
+
|
121 |
+
def add_bert_prefix(key):
|
122 |
+
# prepend bert. to the key
|
123 |
+
if key.startswith("bert.") or key.startswith("cls."):
|
124 |
+
return key
|
125 |
+
return f"bert.{key}"
|
126 |
+
|
127 |
+
state_dict = OrderedDict((add_bert_prefix(k), v) for k, v in state_dict.items())
|
128 |
+
|
129 |
+
# LayerNorm
|
130 |
+
def key_mapping_ln_gamma_beta(key):
|
131 |
+
key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
|
132 |
+
key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
|
133 |
+
return key
|
134 |
+
|
135 |
+
state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
|
136 |
+
|
137 |
+
# Layers
|
138 |
+
def key_mapping_layers(key):
|
139 |
+
return re.sub(r"^bert.encoder.layer\.", "bert.encoder.layers.", key)
|
140 |
+
|
141 |
+
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
142 |
+
|
143 |
+
# LayerNorm
|
144 |
+
def key_mapping_ln(key):
|
145 |
+
key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key)
|
146 |
+
key = re.sub(
|
147 |
+
r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
|
148 |
+
r"bert.encoder.layers.\1.norm1.\2",
|
149 |
+
key,
|
150 |
+
)
|
151 |
+
key = re.sub(
|
152 |
+
r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
|
153 |
+
r"bert.encoder.layers.\1.norm2.\2",
|
154 |
+
key,
|
155 |
+
)
|
156 |
+
key = re.sub(
|
157 |
+
r"^cls.predictions.transform.LayerNorm.(weight|bias)",
|
158 |
+
r"cls.predictions.transform.layer_norm.\1",
|
159 |
+
key,
|
160 |
+
)
|
161 |
+
return key
|
162 |
+
|
163 |
+
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
164 |
+
|
165 |
+
# MLP
|
166 |
+
def key_mapping_mlp(key):
|
167 |
+
key = re.sub(
|
168 |
+
r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
|
169 |
+
r"bert.encoder.layers.\1.mlp.fc1.\2",
|
170 |
+
key,
|
171 |
+
)
|
172 |
+
key = re.sub(
|
173 |
+
r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)",
|
174 |
+
r"bert.encoder.layers.\1.mlp.fc2.\2",
|
175 |
+
key,
|
176 |
+
)
|
177 |
+
return key
|
178 |
+
|
179 |
+
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
180 |
+
|
181 |
+
# Attention
|
182 |
+
last_layer_subset = getattr(config, "last_layer_subset", False)
|
183 |
+
for d in range(config.num_hidden_layers):
|
184 |
+
if f"bert.encoder.layers.{d}.attention.self.query.weight" not in state_dict:
|
185 |
+
continue
|
186 |
+
Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight")
|
187 |
+
Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight")
|
188 |
+
Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight")
|
189 |
+
bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias")
|
190 |
+
bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
|
191 |
+
bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
|
192 |
+
if not (last_layer_subset and d == config.num_hidden_layers - 1):
|
193 |
+
state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
|
194 |
+
state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
|
195 |
+
else:
|
196 |
+
state_dict[f"bert.encoder.layers.{d}.attn.Wq.weight"] = Wq
|
197 |
+
state_dict[f"bert.encoder.layers.{d}.attn.Wkv.weight"] = torch.cat([Wk, Wv], dim=0)
|
198 |
+
state_dict[f"bert.encoder.layers.{d}.attn.Wq.bias"] = bq
|
199 |
+
state_dict[f"bert.encoder.layers.{d}.attn.Wkv.bias"] = torch.cat([bk, bv], dim=0)
|
200 |
+
|
201 |
+
def key_mapping_attn(key):
|
202 |
+
return re.sub(
|
203 |
+
r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
|
204 |
+
r"bert.encoder.layers.\1.attn.out_proj.\2",
|
205 |
+
key,
|
206 |
+
)
|
207 |
+
|
208 |
+
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
209 |
+
|
210 |
+
def key_mapping_decoder_bias(key):
|
211 |
+
return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
|
212 |
+
|
213 |
+
# remove nsp weights, we don't use
|
214 |
+
state_dict.pop("cls.seq_relationship.weight", None)
|
215 |
+
state_dict.pop("cls.seq_relationship.bias", None)
|
216 |
+
state_dict.pop("bert.embeddings.position_ids", None)
|
217 |
+
|
218 |
+
state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
|
219 |
+
|
220 |
+
if remove_cls_weights:
|
221 |
+
cls_weights = [
|
222 |
+
"cls.predictions.decoder.bias",
|
223 |
+
"cls.predictions.transform.dense.weight",
|
224 |
+
"cls.predictions.transform.dense.bias",
|
225 |
+
"cls.predictions.transform.layer_norm.weight",
|
226 |
+
"cls.predictions.transform.layer_norm.bias",
|
227 |
+
"cls.predictions.decoder.weight",
|
228 |
+
]
|
229 |
+
for weight in cls_weights:
|
230 |
+
state_dict.pop(weight, None)
|
231 |
+
|
232 |
+
# Word embedding
|
233 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
234 |
+
if pad_vocab_size_multiple > 1:
|
235 |
+
word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
|
236 |
+
state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
|
237 |
+
word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
|
238 |
+
)
|
239 |
+
if not remove_cls_weights:
|
240 |
+
decoder_weight = state_dict["cls.predictions.decoder.weight"]
|
241 |
+
state_dict["cls.predictions.decoder.weight"] = F.pad(
|
242 |
+
decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
|
243 |
+
)
|
244 |
+
# If the vocab was padded, we want to set the decoder bias for those padded indices to be
|
245 |
+
# strongly negative (i.e. the decoder shouldn't predict those indices).
|
246 |
+
# TD [2022-05-09]: I don't think it affects the MLPerf training.
|
247 |
+
if "cls.predictions.decoder.bias" in state_dict:
|
248 |
+
decoder_bias = state_dict["cls.predictions.decoder.bias"]
|
249 |
+
state_dict["cls.predictions.decoder.bias"] = F.pad(
|
250 |
+
decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
|
251 |
+
)
|
252 |
+
|
253 |
+
if add_pooling_layer is False:
|
254 |
+
pooler_weights = [
|
255 |
+
"bert.pooler.dense.weight",
|
256 |
+
"bert.pooler.dense.bias",
|
257 |
+
]
|
258 |
+
for key in pooler_weights:
|
259 |
+
state_dict.pop(key, None)
|
260 |
+
|
261 |
+
if remove_bert:
|
262 |
+
|
263 |
+
def remove_bert_prefix(key):
|
264 |
+
key = re.sub(r"^bert.", "", key)
|
265 |
+
return key
|
266 |
+
|
267 |
+
state_dict = OrderedDict((remove_bert_prefix(k), v) for k, v in state_dict.items())
|
268 |
+
|
269 |
+
return state_dict
|
270 |
+
|
271 |
+
|
272 |
+
class NomicBertPreTrainedModel(PreTrainedModel):
|
273 |
+
"""An abstract class to handle weights initialization and
|
274 |
+
a simple interface for dowloading and loading pretrained models.
|
275 |
+
"""
|
276 |
+
|
277 |
+
config_class = NomicBertConfig
|
278 |
+
base_model_prefix = "model"
|
279 |
+
supports_gradient_checkpointing = True
|
280 |
+
_no_split_modules = ["Block"]
|
281 |
+
_skip_keys_device_placement = "past_key_values"
|
282 |
+
|
283 |
+
def __init__(self, config, *inputs, **kwargs):
|
284 |
+
super().__init__(config)
|
285 |
+
if not isinstance(config, GPT2Config):
|
286 |
+
raise ValueError(
|
287 |
+
"Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
|
288 |
+
"To create a model from a Google pretrained model use "
|
289 |
+
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
290 |
+
self.__class__.__name__, self.__class__.__name__
|
291 |
+
)
|
292 |
+
)
|
293 |
+
self.config = config
|
294 |
+
|
295 |
+
@classmethod
|
296 |
+
def from_pretrained(cls, model_name, config=None, *inputs, **kwargs):
|
297 |
+
"""
|
298 |
+
Instantiate a NomicBertPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
299 |
+
Download and cache the pre-trained model file if needed.
|
300 |
+
|
301 |
+
Params:
|
302 |
+
pretrained_model_name_or_path: either:
|
303 |
+
- a path or url to a pretrained model archive containing:
|
304 |
+
. `bert_config.json` a configuration file for the model
|
305 |
+
. `pytorch_model.bin` a PyTorch dump of a NomicBertForPretraining instance
|
306 |
+
- a path or url to a pretrained model archive containing:
|
307 |
+
. `bert_config.json` a configuration file for the model
|
308 |
+
. `model.chkpt` a TensorFlow checkpoint
|
309 |
+
*inputs, **kwargs: additional input for the specific NomicBert class
|
310 |
+
(ex: num_labels for NomicBertForSequenceClassification)
|
311 |
+
"""
|
312 |
+
# Instantiate model.
|
313 |
+
if config is None:
|
314 |
+
config = cls.config_class.from_pretrained(model_name)
|
315 |
+
remove_cls = cls != NomicBertForPreTraining
|
316 |
+
remove_bert_prefix = cls != NomicBertForPreTraining and cls != NomicBertForSequenceClassification
|
317 |
+
ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
|
318 |
+
num_labels = kwargs.pop("num_labels", None)
|
319 |
+
rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
|
320 |
+
strict = kwargs.pop("strict", True)
|
321 |
+
if rotary_scaling_factor:
|
322 |
+
config.rotary_scaling_factor = rotary_scaling_factor
|
323 |
+
|
324 |
+
if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
|
325 |
+
config.n_positions = 2048
|
326 |
+
if num_labels:
|
327 |
+
config.num_labels = num_labels
|
328 |
+
|
329 |
+
if "add_pooling_layer" in kwargs:
|
330 |
+
model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
|
331 |
+
else:
|
332 |
+
if cls == NomicBertModel:
|
333 |
+
model = cls(config, *inputs, add_pooling_layer=False)
|
334 |
+
else:
|
335 |
+
model = cls(config, *inputs)
|
336 |
+
# TODO: fix this
|
337 |
+
# Assuming we know what we're doing when loading from disk
|
338 |
+
# Prob a bad assumption but i'm tired and want to train this asap
|
339 |
+
if os.path.exists(model_name):
|
340 |
+
model_path = f"{model_name}/pytorch_model.bin"
|
341 |
+
if os.path.exists(model_path):
|
342 |
+
state_dict = torch.load(f"{model_name}/pytorch_model.bin")
|
343 |
+
else:
|
344 |
+
model_path = f"{model_name}/model.safetensors"
|
345 |
+
if not os.path.exists(model_path):
|
346 |
+
raise ValueError(f"Model path {model_path} not found")
|
347 |
+
state_dict = safe_load_file(model_path)
|
348 |
+
|
349 |
+
if ignore_mismatched_shapes:
|
350 |
+
state_dict = filter_shapes(state_dict, model)
|
351 |
+
load_return = model.load_state_dict(state_dict, strict=False)
|
352 |
+
else:
|
353 |
+
# TODO: can probably check config class and see if we need to remap from a bert model
|
354 |
+
state_dict = state_dict_from_pretrained(model_name)
|
355 |
+
state_dict = remap_bert_state_dict(
|
356 |
+
state_dict,
|
357 |
+
config,
|
358 |
+
remove_bert=remove_bert_prefix,
|
359 |
+
remove_cls_weights=remove_cls,
|
360 |
+
add_pooling_layer=getattr(config, "add_pooling_layer", False),
|
361 |
+
)
|
362 |
+
if ignore_mismatched_shapes:
|
363 |
+
state_dict = filter_shapes(state_dict, model)
|
364 |
+
|
365 |
+
load_return = model.load_state_dict(state_dict, strict=strict)
|
366 |
+
logger.warning(load_return)
|
367 |
+
return model
|
368 |
+
|
369 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
370 |
+
if isinstance(module, NomicBertEncoder):
|
371 |
+
module.gradient_checkpointing = value
|
372 |
+
|
373 |
+
|
374 |
+
# https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
|
375 |
+
def _init_weights(module, initializer_range=0.02):
|
376 |
+
if isinstance(module, nn.Linear):
|
377 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
378 |
+
if module.bias is not None:
|
379 |
+
nn.init.zeros_(module.bias)
|
380 |
+
elif isinstance(module, nn.Embedding):
|
381 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
382 |
+
if module.padding_idx is not None:
|
383 |
+
nn.init.zeros_(module.weight[module.padding_idx])
|
384 |
+
|
385 |
+
|
386 |
+
class NomicBertEmbeddings(nn.Module):
|
387 |
+
def __init__(self, config):
|
388 |
+
"""
|
389 |
+
If max_position_embeddings <= 0, there's no position embeddings
|
390 |
+
If type_vocab_size <= 0, there's no token type embeddings
|
391 |
+
"""
|
392 |
+
super().__init__()
|
393 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
394 |
+
self.max_position_embeddings = config.max_position_embeddings if config.rotary_emb_fraction <= 0 else 0
|
395 |
+
self.type_vocab_size = config.type_vocab_size
|
396 |
+
if self.max_position_embeddings > 0 and config.rotary_emb_fraction <= 0:
|
397 |
+
self.position_embeddings = nn.Embedding(
|
398 |
+
config.max_position_embeddings,
|
399 |
+
config.hidden_size,
|
400 |
+
)
|
401 |
+
if self.type_vocab_size > 0:
|
402 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
403 |
+
|
404 |
+
def forward(self, input_ids, position_ids=None, token_type_ids=None):
|
405 |
+
"""
|
406 |
+
input_ids: (batch, seqlen)
|
407 |
+
position_ids: (batch, seqlen)
|
408 |
+
token_type_ids: (batch, seqlen)
|
409 |
+
"""
|
410 |
+
batch_size, seqlen = input_ids.shape
|
411 |
+
embeddings = self.word_embeddings(input_ids)
|
412 |
+
|
413 |
+
if self.type_vocab_size > 0:
|
414 |
+
if token_type_ids is None:
|
415 |
+
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
|
416 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
417 |
+
embeddings = embeddings + token_type_embeddings
|
418 |
+
|
419 |
+
if self.max_position_embeddings > 0:
|
420 |
+
if position_ids is None:
|
421 |
+
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
|
422 |
+
position_embeddings = self.position_embeddings(position_ids)
|
423 |
+
embeddings = embeddings + position_embeddings
|
424 |
+
return embeddings
|
425 |
+
|
426 |
+
|
427 |
+
class NomicBertMLP(nn.Module):
|
428 |
+
def __init__(
|
429 |
+
self,
|
430 |
+
in_features,
|
431 |
+
hidden_features=None,
|
432 |
+
out_features=None,
|
433 |
+
activation=F.gelu,
|
434 |
+
bias1=True,
|
435 |
+
bias2=True,
|
436 |
+
return_residual=False,
|
437 |
+
fused_bias_fc=False,
|
438 |
+
):
|
439 |
+
super().__init__()
|
440 |
+
out_features = out_features if out_features is not None else in_features
|
441 |
+
hidden_features = hidden_features if hidden_features is not None else in_features * 4
|
442 |
+
self.return_residual = return_residual
|
443 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1)
|
444 |
+
approximate = "tanh" if activation in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none"
|
445 |
+
self.activation = nn.GELU(approximate=approximate) if activation == "gelu" else activation
|
446 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
|
447 |
+
|
448 |
+
def forward(self, x):
|
449 |
+
y = self.fc1(x)
|
450 |
+
y = self.activation(y)
|
451 |
+
y = self.fc2(y)
|
452 |
+
return y if not self.return_residual else (y, x)
|
453 |
+
|
454 |
+
|
455 |
+
class NomciBertGatedMLP(nn.Module):
|
456 |
+
def __init__(
|
457 |
+
self,
|
458 |
+
in_features,
|
459 |
+
hidden_features=None,
|
460 |
+
out_features=None,
|
461 |
+
activation=F.sigmoid,
|
462 |
+
bias1=True,
|
463 |
+
bias2=True,
|
464 |
+
multiple_of=256,
|
465 |
+
return_residual=False,
|
466 |
+
fused_bias_fc=True,
|
467 |
+
device=None,
|
468 |
+
dtype=None,
|
469 |
+
):
|
470 |
+
super().__init__()
|
471 |
+
out_features = out_features if out_features is not None else in_features
|
472 |
+
hidden_features = hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
473 |
+
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
474 |
+
self.return_residual = return_residual
|
475 |
+
|
476 |
+
self.fc11 = nn.Linear(in_features, hidden_features, bias=bias1)
|
477 |
+
self.fc12 = nn.Linear(in_features, hidden_features, bias=bias1)
|
478 |
+
self.activation = activation
|
479 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
|
480 |
+
|
481 |
+
def forward(self, x):
|
482 |
+
y = self.fc11(x)
|
483 |
+
gate = self.fc12(x)
|
484 |
+
if self.activation == F.sigmoid: # Special case for GLU
|
485 |
+
y = F.glu(torch.cat([y, gate], dim=-1), dim=-1)
|
486 |
+
else:
|
487 |
+
y = y * self.activation(gate)
|
488 |
+
y = self.fc2(y)
|
489 |
+
return y if not self.return_residual else (y, x)
|
490 |
+
|
491 |
+
|
492 |
+
def rotate_half(x, interleaved=False):
|
493 |
+
if not interleaved:
|
494 |
+
x1, x2 = x.chunk(2, dim=-1)
|
495 |
+
return torch.cat((-x2, x1), dim=-1)
|
496 |
+
else:
|
497 |
+
x1, x2 = x[..., ::2], x[..., 1::2]
|
498 |
+
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
|
499 |
+
|
500 |
+
|
501 |
+
def apply_rotary_emb(x, cos, sin, offset=0, interleaved=False):
|
502 |
+
"""
|
503 |
+
x: (batch_size, seqlen, nheads, headdim)
|
504 |
+
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
505 |
+
"""
|
506 |
+
ro_dim = cos.shape[-1] * 2
|
507 |
+
assert ro_dim <= x.shape[-1]
|
508 |
+
cos, sin = (
|
509 |
+
cos[offset : offset + x.shape[1]],
|
510 |
+
sin[offset : offset + x.shape[1]],
|
511 |
+
)
|
512 |
+
cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
513 |
+
sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
514 |
+
return torch.cat(
|
515 |
+
[x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
|
516 |
+
dim=-1,
|
517 |
+
)
|
518 |
+
|
519 |
+
|
520 |
+
class NomicBertRotaryEmbedding(nn.Module):
|
521 |
+
def __init__(
|
522 |
+
self,
|
523 |
+
dim: int,
|
524 |
+
base=10000.0,
|
525 |
+
interleaved=False,
|
526 |
+
scale_base=None,
|
527 |
+
pos_idx_in_fp32=True,
|
528 |
+
device=None,
|
529 |
+
):
|
530 |
+
"""
|
531 |
+
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
532 |
+
of 1st half and 2nd half (GPT-NeoX style).
|
533 |
+
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
|
534 |
+
otherwise they might be in lower precision.
|
535 |
+
This option was added because previously (before 2023-07-02), when we construct
|
536 |
+
the position indices, we use the dtype of self.inv_freq. In most cases this would
|
537 |
+
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
|
538 |
+
self.inv_freq would be bf16, and the position indices are also in bf16.
|
539 |
+
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
|
540 |
+
embeddings for some positions will coincide.
|
541 |
+
To maintain compatibility with models previously trained in pure bf16,
|
542 |
+
we add this option.
|
543 |
+
"""
|
544 |
+
super().__init__()
|
545 |
+
self.dim = dim
|
546 |
+
self.base = float(base)
|
547 |
+
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
548 |
+
# Generate and save the inverse frequency buffer (non trainable)
|
549 |
+
inv_freq = self._compute_inv_freq(device)
|
550 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
551 |
+
self.interleaved = interleaved
|
552 |
+
self.scale_base = scale_base
|
553 |
+
scale = (
|
554 |
+
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
555 |
+
if scale_base is not None
|
556 |
+
else None
|
557 |
+
)
|
558 |
+
self.register_buffer("scale", scale, persistent=False)
|
559 |
+
|
560 |
+
self._seq_len_cached = 0
|
561 |
+
self._cos_cached = None
|
562 |
+
self._sin_cached = None
|
563 |
+
self._cos_k_cached = None
|
564 |
+
self._sin_k_cached = None
|
565 |
+
|
566 |
+
def _compute_inv_freq(self, device=None):
|
567 |
+
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
|
568 |
+
|
569 |
+
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
570 |
+
# Reset the tables if the sequence length has changed,
|
571 |
+
# if we're on a new device (possibly due to tracing for instance),
|
572 |
+
# or if we're switching from inference mode to training
|
573 |
+
if (
|
574 |
+
seqlen > self._seq_len_cached
|
575 |
+
or self._cos_cached is None
|
576 |
+
or self._cos_cached.device != device
|
577 |
+
or self._cos_cached.dtype != dtype
|
578 |
+
or (self.training and self._cos_cached.is_inference())
|
579 |
+
):
|
580 |
+
self._seq_len_cached = seqlen
|
581 |
+
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
|
582 |
+
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
|
583 |
+
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
|
584 |
+
if self.pos_idx_in_fp32:
|
585 |
+
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
586 |
+
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
|
587 |
+
# will be large. Having it in bf16 will lose a lot of precision and cause the
|
588 |
+
# cos & sin output to change significantly.
|
589 |
+
# We want to recompute self.inv_freq if it was not loaded in fp32
|
590 |
+
if self.inv_freq.dtype != torch.float32:
|
591 |
+
inv_freq = self._compute_inv_freq(device=device)
|
592 |
+
else:
|
593 |
+
inv_freq = self.inv_freq
|
594 |
+
else:
|
595 |
+
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
596 |
+
inv_freq = self.inv_freq
|
597 |
+
# Don't do einsum, it converts fp32 to fp16 under AMP
|
598 |
+
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
599 |
+
freqs = torch.outer(t, inv_freq)
|
600 |
+
self._cos_cached = torch.cos(freqs).to(dtype)
|
601 |
+
self._sin_cached = torch.sin(freqs).to(dtype)
|
602 |
+
|
603 |
+
def forward(
|
604 |
+
self,
|
605 |
+
qkv: torch.Tensor,
|
606 |
+
kv: Optional[torch.Tensor] = None,
|
607 |
+
seqlen_offset: Union[int, torch.Tensor] = 0,
|
608 |
+
max_seqlen: Optional[int] = None,
|
609 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
610 |
+
"""
|
611 |
+
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
|
612 |
+
else it's just q of shape (batch, seqlen, nheads, headdim)
|
613 |
+
kv: (batch, seqlen, 2, nheads, headdim)
|
614 |
+
seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
|
615 |
+
Most commonly used in inference when we have KV cache.
|
616 |
+
If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
|
617 |
+
should pass in max_seqlen, which will update the cos / sin cache up to that length.
|
618 |
+
Apply rotary embedding *inplace* to qkv and / or kv.
|
619 |
+
"""
|
620 |
+
seqlen = qkv.shape[1]
|
621 |
+
if seqlen > self._seq_len_cached:
|
622 |
+
self._update_cos_sin_cache(seqlen, device=qkv.device, dtype=qkv.dtype)
|
623 |
+
elif max_seqlen is not None:
|
624 |
+
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
625 |
+
elif isinstance(seqlen_offset, int):
|
626 |
+
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
627 |
+
|
628 |
+
q_rot = apply_rotary_emb(qkv[:, :, 0], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved)
|
629 |
+
k_rot = apply_rotary_emb(qkv[:, :, 1], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved)
|
630 |
+
return torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
|
631 |
+
|
632 |
+
|
633 |
+
class NomicBertDynamicNTKRotaryEmbedding(NomicBertRotaryEmbedding):
|
634 |
+
def __init__(self, rotary_scaling_factor, max_position_embeddings, **kwargs):
|
635 |
+
super().__init__(**kwargs)
|
636 |
+
self.rotary_scaling_factor = rotary_scaling_factor
|
637 |
+
self.max_position_embeddings = max_position_embeddings
|
638 |
+
|
639 |
+
def _compute_inv_freq(self, base=None, device=None):
|
640 |
+
if base is None:
|
641 |
+
base = self.base
|
642 |
+
return 1.0 / (base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
|
643 |
+
|
644 |
+
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
645 |
+
# Reset the tables if the sequence length has changed,
|
646 |
+
# if we're on a new device (possibly due to tracing for instance),
|
647 |
+
# or if we're switching from inference mode to training
|
648 |
+
if seqlen > self.max_position_embeddings:
|
649 |
+
base = self.base * (
|
650 |
+
(self.rotary_scaling_factor * seqlen / self.max_position_embeddings) - (self.rotary_scaling_factor - 1)
|
651 |
+
) ** (self.dim / (self.dim - 2))
|
652 |
+
inv_freq = self._compute_inv_freq(base=base, device=device)
|
653 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
654 |
+
|
655 |
+
if (
|
656 |
+
seqlen > self._seq_len_cached
|
657 |
+
or self._cos_cached is None
|
658 |
+
or self._cos_cached.device != device
|
659 |
+
or self._cos_cached.dtype != dtype
|
660 |
+
or (self.training and self._cos_cached.is_inference())
|
661 |
+
):
|
662 |
+
self._seq_len_cached = seqlen
|
663 |
+
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
|
664 |
+
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
|
665 |
+
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
|
666 |
+
if self.pos_idx_in_fp32:
|
667 |
+
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
668 |
+
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
|
669 |
+
# will be large. Having it in bf16 will lose a lot of precision and cause the
|
670 |
+
# cos & sin output to change significantly.
|
671 |
+
# We want to recompute self.inv_freq if it was not loaded in fp32
|
672 |
+
if self.inv_freq.dtype != torch.float32:
|
673 |
+
if seqlen > self.max_position_embeddings:
|
674 |
+
base = self.base * (
|
675 |
+
(self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1)
|
676 |
+
) ** (self.dim / (self.dim - 2))
|
677 |
+
else:
|
678 |
+
base = self.base
|
679 |
+
inv_freq = self._compute_inv_freq(device=device, base=base)
|
680 |
+
else:
|
681 |
+
inv_freq = self.inv_freq
|
682 |
+
else:
|
683 |
+
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
684 |
+
inv_freq = self.inv_freq
|
685 |
+
# Don't do einsum, it converts fp32 to fp16 under AMP
|
686 |
+
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
687 |
+
freqs = torch.outer(t, inv_freq)
|
688 |
+
if self.scale is None:
|
689 |
+
self._cos_cached = torch.cos(freqs).to(dtype)
|
690 |
+
self._sin_cached = torch.sin(freqs).to(dtype)
|
691 |
+
else:
|
692 |
+
power = (
|
693 |
+
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
|
694 |
+
) / self.scale_base
|
695 |
+
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
696 |
+
# We want the multiplication by scale to happen in fp32
|
697 |
+
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
698 |
+
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
699 |
+
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
700 |
+
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
701 |
+
|
702 |
+
|
703 |
+
class NomicBertAttention(nn.Module):
|
704 |
+
"""Multi-head self-attention and cross-attention"""
|
705 |
+
|
706 |
+
def __init__(
|
707 |
+
self,
|
708 |
+
config,
|
709 |
+
) -> None:
|
710 |
+
"""
|
711 |
+
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
|
712 |
+
return_residual: whether to return the input x along with the output. This is for
|
713 |
+
performance reason: for post-norm architecture, returning the input allows us
|
714 |
+
to fuse the backward of nn.Linear with the residual connection.
|
715 |
+
"""
|
716 |
+
super().__init__()
|
717 |
+
self.embed_dim = config.n_embd
|
718 |
+
self.use_flash_attn = config.use_flash_attn
|
719 |
+
self.fused_bias_fc = config.fused_bias_fc
|
720 |
+
|
721 |
+
self.num_heads = config.n_head
|
722 |
+
self.num_heads_kv = config.num_heads_kv if getattr(config, "num_heads_kv", None) is not None else self.num_heads
|
723 |
+
assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
|
724 |
+
self.head_dim = self.embed_dim // self.num_heads
|
725 |
+
# we don't really support mqa / gqa for now
|
726 |
+
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
727 |
+
|
728 |
+
self.register_buffer(
|
729 |
+
"norm_factor",
|
730 |
+
torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
|
731 |
+
persistent=False,
|
732 |
+
)
|
733 |
+
|
734 |
+
self.rotary_emb_dim = self.head_dim * config.rotary_emb_fraction
|
735 |
+
if self.rotary_emb_dim > 0:
|
736 |
+
if getattr(config, "rotary_scaling_factor", None):
|
737 |
+
self.rotary_emb = NomicBertDynamicNTKRotaryEmbedding(
|
738 |
+
dim=self.rotary_emb_dim,
|
739 |
+
base=config.rotary_emb_base,
|
740 |
+
scale_base=config.rotary_emb_scale_base,
|
741 |
+
interleaved=config.rotary_emb_interleaved,
|
742 |
+
rotary_scaling_factor=config.rotary_scaling_factor,
|
743 |
+
max_position_embeddings=config.max_trained_positions,
|
744 |
+
)
|
745 |
+
else:
|
746 |
+
self.rotary_emb = NomicBertRotaryEmbedding(
|
747 |
+
dim=self.rotary_emb_dim,
|
748 |
+
base=config.rotary_emb_base,
|
749 |
+
scale_base=config.rotary_emb_scale_base,
|
750 |
+
interleaved=config.rotary_emb_interleaved,
|
751 |
+
)
|
752 |
+
# bug in xformers: https://github.com/facebookresearch/xformers/issues/841
|
753 |
+
# uses the head dimension instead of the sequence dimension
|
754 |
+
self.rotary_head_dim = getattr(config, "rotary_head_dim", False)
|
755 |
+
|
756 |
+
self.Wqkv = nn.Linear(self.embed_dim, qkv_dim, bias=config.qkv_proj_bias)
|
757 |
+
|
758 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
|
759 |
+
self.causal = config.causal
|
760 |
+
self.drop = nn.Dropout(config.attn_pdrop)
|
761 |
+
|
762 |
+
def forward(
|
763 |
+
self,
|
764 |
+
hidden_states: torch.Tensor,
|
765 |
+
attention_mask: Optional[torch.Tensor] = None,
|
766 |
+
position_ids: Optional[torch.LongTensor] = None,
|
767 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
768 |
+
output_attentions: bool = False,
|
769 |
+
use_cache: bool = False,
|
770 |
+
is_padded_inputs: Optional[bool] = True,
|
771 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
772 |
+
max_seq_len: Optional[int] = None,
|
773 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
774 |
+
|
775 |
+
has_layer_past = past_key_value is not None
|
776 |
+
|
777 |
+
if has_layer_past:
|
778 |
+
past_key_value = past_key_value[0]
|
779 |
+
past_len = past_key_value[1]
|
780 |
+
else:
|
781 |
+
past_len = 0
|
782 |
+
|
783 |
+
qkv = self.Wqkv(hidden_states)
|
784 |
+
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
|
785 |
+
|
786 |
+
past_key_value = (past_key_value, past_len + qkv.size(1)) if use_cache else None
|
787 |
+
|
788 |
+
if self.rotary_emb_dim > 0:
|
789 |
+
if self.rotary_head_dim:
|
790 |
+
qkv = rearrange(qkv, "b s three h d -> b h three s d")
|
791 |
+
qkv = self.rotary_emb(qkv, seqlen_offset=past_len)
|
792 |
+
|
793 |
+
if self.rotary_head_dim:
|
794 |
+
qkv = rearrange(qkv, "b h three s d -> b s three h d")
|
795 |
+
|
796 |
+
query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
|
797 |
+
|
798 |
+
query = query.permute(0, 2, 1, 3)
|
799 |
+
key = key.permute(0, 2, 1, 3)
|
800 |
+
value = value.permute(0, 2, 1, 3)
|
801 |
+
|
802 |
+
attention_scores = torch.matmul(query, key.transpose(-1, -2)) / self.norm_factor
|
803 |
+
if attention_mask is not None:
|
804 |
+
attention_scores = attention_scores + attention_mask
|
805 |
+
|
806 |
+
attentions_probs = F.softmax(attention_scores, dim=-1)
|
807 |
+
attentions_probs = self.drop(attentions_probs)
|
808 |
+
|
809 |
+
attn_output = torch.matmul(attentions_probs, value)
|
810 |
+
attn_output = rearrange(attn_output.permute(0, 2, 1, 3), "... h d -> ... (h d)")
|
811 |
+
|
812 |
+
attn_output = self.out_proj(attn_output)
|
813 |
+
|
814 |
+
return attn_output
|
815 |
+
|
816 |
+
|
817 |
+
class NomicBertBlock(NomicBertPreTrainedModel):
|
818 |
+
def __init__(
|
819 |
+
self,
|
820 |
+
config,
|
821 |
+
):
|
822 |
+
super().__init__(config=config)
|
823 |
+
self.prenorm = config.prenorm
|
824 |
+
self.fused_dropout_add_ln = config.fused_dropout_add_ln
|
825 |
+
|
826 |
+
self.attn = NomicBertAttention(config)
|
827 |
+
activation = (
|
828 |
+
F.sigmoid
|
829 |
+
if config.activation_function == "glu"
|
830 |
+
else (F.silu if config.activation_function == "swiglu" else F.gelu)
|
831 |
+
)
|
832 |
+
if config.activation_function in ["glu", "swiglu", "geglu"]:
|
833 |
+
self.mlp = NomciBertGatedMLP(
|
834 |
+
config.n_embd,
|
835 |
+
hidden_features=config.n_inner,
|
836 |
+
bias1=config.mlp_fc1_bias,
|
837 |
+
bias2=config.mlp_fc2_bias,
|
838 |
+
activation=activation,
|
839 |
+
fused_bias_fc=config.fused_bias_fc,
|
840 |
+
)
|
841 |
+
else:
|
842 |
+
self.mlp = NomicBertMLP(
|
843 |
+
config.n_embd,
|
844 |
+
hidden_features=config.n_inner,
|
845 |
+
bias1=config.mlp_fc1_bias,
|
846 |
+
bias2=config.mlp_fc2_bias,
|
847 |
+
activation=activation,
|
848 |
+
fused_bias_fc=config.fused_bias_fc,
|
849 |
+
)
|
850 |
+
|
851 |
+
self.dropout1 = nn.Dropout(config.resid_pdrop)
|
852 |
+
self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
853 |
+
self.norm2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
854 |
+
self.dropout2 = nn.Dropout(config.resid_pdrop)
|
855 |
+
|
856 |
+
def forward(
|
857 |
+
self,
|
858 |
+
hidden_states: torch.Tensor,
|
859 |
+
hidden_states2: torch.Tensor,
|
860 |
+
residual: Optional[torch.Tensor] = None,
|
861 |
+
attention_mask: Optional[torch.Tensor] = None,
|
862 |
+
position_ids: Optional[torch.LongTensor] = None,
|
863 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
864 |
+
is_padded_inputs: Optional[bool] = True,
|
865 |
+
output_attentions: Optional[bool] = False,
|
866 |
+
use_cache: Optional[bool] = False,
|
867 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
868 |
+
max_seq_len: Optional[int] = None,
|
869 |
+
):
|
870 |
+
r"""Pass the input through the encoder layer.
|
871 |
+
|
872 |
+
Args:
|
873 |
+
hidden_states: the sequence to the encoder layer (required).
|
874 |
+
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
|
875 |
+
mixer_subset: for cross-attention only. If not None, will take a subset of x
|
876 |
+
before applying the query projection. Useful for e.g., ViT where we only care
|
877 |
+
about the CLS token in the last layer.
|
878 |
+
"""
|
879 |
+
if self.prenorm:
|
880 |
+
dropped = self.dropout1(hidden_states)
|
881 |
+
residual = (dropped + residual) if residual is not None else dropped
|
882 |
+
hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
|
883 |
+
hidden_states = self.attn(
|
884 |
+
hidden_states,
|
885 |
+
attention_mask=attention_mask,
|
886 |
+
is_padded_inputs=is_padded_inputs,
|
887 |
+
cu_seqlens=cu_seqlens,
|
888 |
+
max_seq_len=max_seq_len,
|
889 |
+
)
|
890 |
+
|
891 |
+
dropped = self.dropout2(hidden_states)
|
892 |
+
residual = (dropped + residual) if residual is not None else dropped
|
893 |
+
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
894 |
+
hidden_states = self.mlp(hidden_states)
|
895 |
+
|
896 |
+
return hidden_states, None, residual
|
897 |
+
else:
|
898 |
+
assert residual is None
|
899 |
+
attn_outputs = self.attn(
|
900 |
+
hidden_states,
|
901 |
+
attention_mask=attention_mask,
|
902 |
+
is_padded_inputs=is_padded_inputs,
|
903 |
+
cu_seqlens=cu_seqlens,
|
904 |
+
max_seq_len=max_seq_len,
|
905 |
+
)
|
906 |
+
hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype))
|
907 |
+
mlp_out = self.mlp(hidden_states)
|
908 |
+
|
909 |
+
hidden_states = self.norm2((self.dropout2(mlp_out) + hidden_states).to(dtype=self.norm2.weight.dtype))
|
910 |
+
return hidden_states, None, None
|
911 |
+
|
912 |
+
|
913 |
+
class NomicBertEncoder(nn.Module):
|
914 |
+
def __init__(self, config: GPT2Config):
|
915 |
+
super().__init__()
|
916 |
+
self.layers = nn.ModuleList([NomicBertBlock(config) for _ in range(config.n_layer)])
|
917 |
+
self.gradient_checkpointing = False
|
918 |
+
self.config = config
|
919 |
+
|
920 |
+
def forward(
|
921 |
+
self,
|
922 |
+
hidden_states: torch.LongTensor = None,
|
923 |
+
attention_mask: Optional[torch.Tensor] = None,
|
924 |
+
position_ids: Optional[torch.LongTensor] = None,
|
925 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
926 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
927 |
+
use_cache: Optional[bool] = None,
|
928 |
+
output_attentions: Optional[bool] = None,
|
929 |
+
output_hidden_states: Optional[bool] = None,
|
930 |
+
return_dict: Optional[bool] = None,
|
931 |
+
is_padded_inputs: Optional[bool] = True,
|
932 |
+
):
|
933 |
+
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
934 |
+
This means that we only compute the last layer output for these tokens.
|
935 |
+
subset_mask: (batch, seqlen), dtype=torch.bool
|
936 |
+
"""
|
937 |
+
hidden_states2 = None
|
938 |
+
residual = None
|
939 |
+
|
940 |
+
for _, layer in enumerate(self.layers):
|
941 |
+
if self.gradient_checkpointing and self.training:
|
942 |
+
|
943 |
+
def create_custom_forward(module):
|
944 |
+
def custom_forward(*inputs):
|
945 |
+
# None for past_key_value
|
946 |
+
return module(*inputs)
|
947 |
+
|
948 |
+
return custom_forward
|
949 |
+
|
950 |
+
hidden_states, hidden_states2, residual = torch.utils.checkpoint.checkpoint(
|
951 |
+
create_custom_forward(layer),
|
952 |
+
hidden_states,
|
953 |
+
hidden_states2,
|
954 |
+
residual,
|
955 |
+
attention_mask,
|
956 |
+
None,
|
957 |
+
None,
|
958 |
+
is_padded_inputs,
|
959 |
+
# if you freeze ANY layers, you need `use_reentrant=False`
|
960 |
+
# https://github.com/huggingface/transformers/issues/21381
|
961 |
+
# https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/7
|
962 |
+
use_reentrant=False,
|
963 |
+
)
|
964 |
+
|
965 |
+
else:
|
966 |
+
hidden_states, hidden_states2, residual = layer(
|
967 |
+
hidden_states,
|
968 |
+
hidden_states2,
|
969 |
+
residual,
|
970 |
+
attention_mask,
|
971 |
+
position_ids,
|
972 |
+
None,
|
973 |
+
is_padded_inputs,
|
974 |
+
output_attentions,
|
975 |
+
use_cache,
|
976 |
+
)
|
977 |
+
return hidden_states
|
978 |
+
|
979 |
+
|
980 |
+
class NomicBertPooler(nn.Module):
|
981 |
+
def __init__(self, config):
|
982 |
+
super().__init__()
|
983 |
+
self.dense = nn.Linear(config.n_embd, config.n_embd)
|
984 |
+
self.activation = nn.Tanh()
|
985 |
+
|
986 |
+
def forward(self, hidden_states, pool=True):
|
987 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
988 |
+
# to the first token.
|
989 |
+
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
990 |
+
pooled_output = self.dense(first_token_tensor)
|
991 |
+
pooled_output = self.activation(pooled_output)
|
992 |
+
return pooled_output
|
993 |
+
|
994 |
+
|
995 |
+
class NomicBertPredictionHeadTransform(nn.Module):
|
996 |
+
def __init__(self, config):
|
997 |
+
super().__init__()
|
998 |
+
self.dense = nn.Linear(config.n_embd, config.n_embd, bias=config.mlp_fc1_bias)
|
999 |
+
approximate = "tanh" if config.activation_function in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none"
|
1000 |
+
if config.activation_function == "swiglu":
|
1001 |
+
self.transform_act_fn = F.silu
|
1002 |
+
else:
|
1003 |
+
self.transform_act_fn = nn.GELU(approximate=approximate)
|
1004 |
+
|
1005 |
+
self.layer_norm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
1006 |
+
|
1007 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
1008 |
+
hidden_states = self.dense(hidden_states)
|
1009 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
1010 |
+
hidden_states = self.layer_norm(hidden_states)
|
1011 |
+
|
1012 |
+
return hidden_states
|
1013 |
+
|
1014 |
+
|
1015 |
+
class NomicBertLMPredictionHead(nn.Module):
|
1016 |
+
def __init__(self, config):
|
1017 |
+
super().__init__()
|
1018 |
+
|
1019 |
+
self.transform = NomicBertPredictionHeadTransform(config)
|
1020 |
+
|
1021 |
+
self.decoder = nn.Linear(config.n_embd, config.vocab_size, bias=config.mlp_fc1_bias)
|
1022 |
+
|
1023 |
+
def forward(self, hidden_states):
|
1024 |
+
hidden_states = self.transform(hidden_states)
|
1025 |
+
hidden_states = self.decoder(hidden_states)
|
1026 |
+
return hidden_states
|
1027 |
+
|
1028 |
+
|
1029 |
+
class NomicBertPreTrainingHeads(nn.Module):
|
1030 |
+
def __init__(self, config):
|
1031 |
+
super().__init__()
|
1032 |
+
self.predictions = NomicBertLMPredictionHead(config)
|
1033 |
+
|
1034 |
+
def forward(self, sequence_output):
|
1035 |
+
prediction_scores = self.predictions(sequence_output)
|
1036 |
+
return prediction_scores
|
1037 |
+
|
1038 |
+
|
1039 |
+
class NomicBertModel(NomicBertPreTrainedModel):
|
1040 |
+
def __init__(self, config: GPT2Config, add_pooling_layer=True):
|
1041 |
+
super().__init__(config)
|
1042 |
+
self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
1043 |
+
if config.vocab_size % self.pad_vocab_size_multiple != 0:
|
1044 |
+
config.vocab_size += self.pad_vocab_size_multiple - (config.vocab_size % self.pad_vocab_size_multiple)
|
1045 |
+
|
1046 |
+
assert config.activation_function in [
|
1047 |
+
"gelu",
|
1048 |
+
"gelu_new",
|
1049 |
+
"gelu_fast",
|
1050 |
+
"gelu_pytorch_tanh",
|
1051 |
+
"swiglu",
|
1052 |
+
"geglu",
|
1053 |
+
"glu",
|
1054 |
+
]
|
1055 |
+
|
1056 |
+
self.embeddings = NomicBertEmbeddings(config)
|
1057 |
+
self.emb_drop = nn.Dropout(config.resid_pdrop)
|
1058 |
+
self.emb_ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
1059 |
+
self.encoder = NomicBertEncoder(config)
|
1060 |
+
self.pooler = NomicBertPooler(config) if add_pooling_layer else None
|
1061 |
+
|
1062 |
+
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
1063 |
+
|
1064 |
+
def forward(
|
1065 |
+
self,
|
1066 |
+
input_ids,
|
1067 |
+
attention_mask=None,
|
1068 |
+
position_ids=None,
|
1069 |
+
token_type_ids=None,
|
1070 |
+
return_dict=None,
|
1071 |
+
matryoshka_dim=None,
|
1072 |
+
):
|
1073 |
+
if token_type_ids is None:
|
1074 |
+
token_type_ids = torch.zeros_like(input_ids)
|
1075 |
+
hidden_states = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
1076 |
+
hidden_states = self.emb_ln(hidden_states)
|
1077 |
+
hidden_states = self.emb_drop(hidden_states)
|
1078 |
+
|
1079 |
+
attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.shape)
|
1080 |
+
sequence_output = self.encoder(hidden_states, attention_mask=attention_mask, return_dict=return_dict)
|
1081 |
+
|
1082 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
1083 |
+
|
1084 |
+
if matryoshka_dim:
|
1085 |
+
sequence_output = sequence_output[:, :matryoshka_dim]
|
1086 |
+
|
1087 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
1088 |
+
last_hidden_state=sequence_output,
|
1089 |
+
pooler_output=pooled_output,
|
1090 |
+
)
|
1091 |
+
|
1092 |
+
|
1093 |
+
class NomicBertForPreTraining(NomicBertPreTrainedModel):
|
1094 |
+
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
1095 |
+
|
1096 |
+
def __init__(self, config: GPT2Config):
|
1097 |
+
super().__init__(config)
|
1098 |
+
|
1099 |
+
self.bert = NomicBertModel(config, add_pooling_layer=getattr(config, "add_pooling_layer", False))
|
1100 |
+
self.cls = NomicBertPreTrainingHeads(config)
|
1101 |
+
self.mlm_loss = nn.CrossEntropyLoss()
|
1102 |
+
|
1103 |
+
# Initialize weights and apply final processing
|
1104 |
+
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
1105 |
+
self.tie_weights()
|
1106 |
+
|
1107 |
+
def tie_weights(self):
|
1108 |
+
self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
|
1109 |
+
|
1110 |
+
def forward(
|
1111 |
+
self,
|
1112 |
+
input_ids,
|
1113 |
+
position_ids=None,
|
1114 |
+
token_type_ids=None,
|
1115 |
+
attention_mask=None,
|
1116 |
+
labels=None,
|
1117 |
+
):
|
1118 |
+
"""
|
1119 |
+
If labels are provided, they must be -100 for masked out tokens (as specified in the attention
|
1120 |
+
mask).
|
1121 |
+
Outputs:
|
1122 |
+
if `labels` and `next_sentence_label` are not `None`:
|
1123 |
+
Outputs the total_loss which is the sum of the masked language modeling loss and the next
|
1124 |
+
sentence classification loss.
|
1125 |
+
if `labels` or `next_sentence_label` is `None`:
|
1126 |
+
Outputs a tuple comprising
|
1127 |
+
- the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
|
1128 |
+
- the next sentence classification logits of shape [batch_size, 2].
|
1129 |
+
|
1130 |
+
"""
|
1131 |
+
outputs = self.bert(
|
1132 |
+
input_ids,
|
1133 |
+
position_ids=position_ids,
|
1134 |
+
token_type_ids=token_type_ids,
|
1135 |
+
attention_mask=attention_mask.bool() if attention_mask is not None else None,
|
1136 |
+
)
|
1137 |
+
sequence_output, _ = outputs.last_hidden_state, outputs.pooler_output
|
1138 |
+
|
1139 |
+
prediction_scores = self.cls(sequence_output)
|
1140 |
+
|
1141 |
+
total_loss = None
|
1142 |
+
if labels is not None:
|
1143 |
+
masked_lm_loss = self.mlm_loss(
|
1144 |
+
rearrange(prediction_scores, "... v -> (...) v"),
|
1145 |
+
rearrange(labels, "... -> (...)"),
|
1146 |
+
)
|
1147 |
+
total_loss = masked_lm_loss.float()
|
1148 |
+
|
1149 |
+
return MaskedLMOutput(
|
1150 |
+
loss=total_loss,
|
1151 |
+
logits=prediction_scores,
|
1152 |
+
hidden_states=outputs.hidden_states,
|
1153 |
+
attentions=None,
|
1154 |
+
)
|
1155 |
+
|
1156 |
+
|
1157 |
+
class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
|
1158 |
+
def __init__(self, config):
|
1159 |
+
super().__init__(config)
|
1160 |
+
self.num_labels = config.num_labels
|
1161 |
+
self.config = config
|
1162 |
+
|
1163 |
+
self.bert = NomicBertModel(config)
|
1164 |
+
classifier_dropout = getattr(config, "classifier_dropout", config.embd_pdrop)
|
1165 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
1166 |
+
self.classifier = nn.Linear(config.n_embd, config.num_labels)
|
1167 |
+
|
1168 |
+
# Initialize weights and apply final processing
|
1169 |
+
self.post_init()
|
1170 |
+
|
1171 |
+
def forward(
|
1172 |
+
self,
|
1173 |
+
input_ids: Optional[torch.Tensor] = None,
|
1174 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1175 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
1176 |
+
position_ids: Optional[torch.Tensor] = None,
|
1177 |
+
head_mask: Optional[torch.Tensor] = None,
|
1178 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1179 |
+
labels: Optional[torch.Tensor] = None,
|
1180 |
+
output_attentions: Optional[bool] = None,
|
1181 |
+
output_hidden_states: Optional[bool] = None,
|
1182 |
+
return_dict: Optional[bool] = None,
|
1183 |
+
):
|
1184 |
+
r"""
|
1185 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1186 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1187 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1188 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1189 |
+
"""
|
1190 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1191 |
+
outputs = self.bert(
|
1192 |
+
input_ids,
|
1193 |
+
position_ids=position_ids,
|
1194 |
+
token_type_ids=token_type_ids,
|
1195 |
+
attention_mask=attention_mask.bool() if attention_mask is not None else None,
|
1196 |
+
)
|
1197 |
+
|
1198 |
+
pooled_output = outputs[1]
|
1199 |
+
|
1200 |
+
pooled_output = self.dropout(pooled_output)
|
1201 |
+
logits = self.classifier(pooled_output)
|
1202 |
+
|
1203 |
+
loss = None
|
1204 |
+
if labels is not None:
|
1205 |
+
if self.config.problem_type is None:
|
1206 |
+
if self.num_labels == 1:
|
1207 |
+
self.config.problem_type = "regression"
|
1208 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
1209 |
+
self.config.problem_type = "single_label_classification"
|
1210 |
+
else:
|
1211 |
+
self.config.problem_type = "multi_label_classification"
|
1212 |
+
|
1213 |
+
if self.config.problem_type == "regression":
|
1214 |
+
loss_fct = nn.MSELoss()
|
1215 |
+
if self.num_labels == 1:
|
1216 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
1217 |
+
else:
|
1218 |
+
loss = loss_fct(logits, labels)
|
1219 |
+
elif self.config.problem_type == "single_label_classification":
|
1220 |
+
loss_fct = nn.CrossEntropyLoss()
|
1221 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
1222 |
+
elif self.config.problem_type == "multi_label_classification":
|
1223 |
+
loss_fct = nn.BCEWithLogitsLoss()
|
1224 |
+
loss = loss_fct(logits, labels)
|
1225 |
+
if not return_dict:
|
1226 |
+
output = (logits,) + outputs[2:]
|
1227 |
+
return ((loss,) + output) if loss is not None else output
|
1228 |
+
|
1229 |
+
return SequenceClassifierOutput(
|
1230 |
+
loss=loss,
|
1231 |
+
logits=logits,
|
1232 |
+
hidden_states=outputs.hidden_states,
|
1233 |
+
attentions=outputs.attentions,
|
1234 |
+
)
|
modules.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"idx": 0,
|
4 |
+
"name": "0",
|
5 |
+
"path": "",
|
6 |
+
"type": "sentence_transformers.models.Transformer"
|
7 |
+
},
|
8 |
+
{
|
9 |
+
"idx": 1,
|
10 |
+
"name": "1",
|
11 |
+
"path": "1_Pooling",
|
12 |
+
"type": "sentence_transformers.models.Pooling"
|
13 |
+
}
|
14 |
+
]
|
sentence_bert_config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"max_seq_length": 8192,
|
3 |
+
"do_lower_case": false
|
4 |
+
}
|
special_tokens_map.json
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": {
|
3 |
+
"content": "[CLS]",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"mask_token": {
|
10 |
+
"content": "[MASK]",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"pad_token": {
|
17 |
+
"content": "[PAD]",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": false,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
},
|
23 |
+
"sep_token": {
|
24 |
+
"content": "[SEP]",
|
25 |
+
"lstrip": false,
|
26 |
+
"normalized": false,
|
27 |
+
"rstrip": false,
|
28 |
+
"single_word": false
|
29 |
+
},
|
30 |
+
"unk_token": {
|
31 |
+
"content": "[UNK]",
|
32 |
+
"lstrip": false,
|
33 |
+
"normalized": false,
|
34 |
+
"rstrip": false,
|
35 |
+
"single_word": false
|
36 |
+
}
|
37 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {
|
3 |
+
"0": {
|
4 |
+
"content": "[PAD]",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false,
|
9 |
+
"special": true
|
10 |
+
},
|
11 |
+
"100": {
|
12 |
+
"content": "[UNK]",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false,
|
17 |
+
"special": true
|
18 |
+
},
|
19 |
+
"101": {
|
20 |
+
"content": "[CLS]",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false,
|
25 |
+
"special": true
|
26 |
+
},
|
27 |
+
"102": {
|
28 |
+
"content": "[SEP]",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": false,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false,
|
33 |
+
"special": true
|
34 |
+
},
|
35 |
+
"103": {
|
36 |
+
"content": "[MASK]",
|
37 |
+
"lstrip": false,
|
38 |
+
"normalized": false,
|
39 |
+
"rstrip": false,
|
40 |
+
"single_word": false,
|
41 |
+
"special": true
|
42 |
+
}
|
43 |
+
},
|
44 |
+
"clean_up_tokenization_spaces": true,
|
45 |
+
"cls_token": "[CLS]",
|
46 |
+
"do_lower_case": true,
|
47 |
+
"mask_token": "[MASK]",
|
48 |
+
"max_length": 8192,
|
49 |
+
"model_max_length": 8192,
|
50 |
+
"pad_to_multiple_of": null,
|
51 |
+
"pad_token": "[PAD]",
|
52 |
+
"pad_token_type_id": 0,
|
53 |
+
"padding_side": "right",
|
54 |
+
"sep_token": "[SEP]",
|
55 |
+
"stride": 0,
|
56 |
+
"strip_accents": null,
|
57 |
+
"tokenize_chinese_chars": true,
|
58 |
+
"tokenizer_class": "BertTokenizer",
|
59 |
+
"truncation_side": "right",
|
60 |
+
"truncation_strategy": "longest_first",
|
61 |
+
"unk_token": "[UNK]"
|
62 |
+
}
|
vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|