“SufurElite” commited on
Commit
eff5003
1 Parent(s): 47af08a
ELC_ParserBERT_10M_textonly_predictions.json.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62805b3ad647effbdd913914d89158d728eec3c044ba4daf136b4f245d85b8f4
3
+ size 1184499
LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,95 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+
5
+ # ELC-ParserBERT
6
+
7
+ This model is an adaptation of the [Every Layer Counts BERT model](<https://aclanthology.org/2023.conll-babylm.20/>), but it incorporates the `Parser Network` from the [StructFormer](<https://arxiv.org/abs/2012.00857>). It was trained for the [BabyLM 2024 challenge](https://babylm.github.io/index.html)'s Strict-Small track.
8
+
9
+ ## Dataset
10
+
11
+ The training data for the challenge can be accessed through OSF [here](https://osf.io/ad7qg/). This model was trained on the 10M token training dataset.
12
+
13
+ ### Order in Pretraining
14
+
15
+ After the segmentation of the data, the segements are ordered in increasing difficulty according to the flesch_reading_ease metric. This ordering can either be maintained by not including the shuffle flag when training or rejected (and allowing shuffling of the data to happen); this model did shuffle the data.
16
+
17
+ ## Hyperparameters
18
+
19
+ ### Base Model
20
+
21
+ | Hyperparameter | Value |
22
+ | -------------- | ----- |
23
+ | Initial learning rate | 5e-3 |
24
+ | Batch size | 256 |
25
+ | Steps | 13495 |
26
+ | shuffled | True |
27
+ |attention_probs_dropout_prob | 0.1 |
28
+ | classifier_dropout | 0.2 |
29
+ | hidden_dropout_prob | 0.1 |
30
+ | hidden_size | 384 |
31
+ | intermediate_size | 1024 |
32
+ | layer_norm_eps | 1e-07 |
33
+ | max_position_embeddings | 512 |
34
+ | num_attention_heads | 6 |
35
+ | num_hidden_layers | 12 |
36
+ | vocab_size | 16384 |
37
+ | n_parser_layers | 4 |
38
+ | parser_conv_size |9 |
39
+
40
+ ### Fine-tuning
41
+
42
+ The fine-tuning parameters were unchanged from the organizer outside of following the ELC-BERT model's patience approach for last year, in particular:
43
+
44
+ | Hyperparameter | Value |
45
+ | -------------- | ----- |
46
+ | Initial learning rate | 5e-5 |
47
+ | Batch size | 64 |
48
+ | Maximum epochs | 10 |
49
+ | Evaluate every (epochs) | 1 |
50
+ | Patience | 10 (for CoLA, MRPC, RTE, BoolQ, MultiRC, and WSC), 100 (for MNLI, MNLI-MM, QQP, QNLI, and SST-2) |
51
+ | Seed | 12 |
52
+
53
+ ## Credit
54
+
55
+ As mentioned above, this model is an adapatation of Every Layer Counts (ELC) BERT and StructFormer, the citations and code repositories for which can be found here
56
+
57
+ * StructFormer
58
+ * [StructFormer Github](<https://github.com/google-research/google-research/tree/master/structformer>)
59
+
60
+ * ```bibtex
61
+ @misc{shen2020structformer,
62
+ title={StructFormer: Joint Unsupervised Induction of Dependency and Constituency Structure from Masked Language Modeling},
63
+ author={Yikang Shen and Yi Tay and Che Zheng and Dara Bahri and Donald Metzler and Aaron Courville},
64
+ year={2020},
65
+ eprint={2012.00857},
66
+ archivePrefix={arXiv},
67
+ primaryClass={cs.CL}}```
68
+ * ELC-BERT:
69
+ * [ELC-BERT Github](<https://github.com/ltgoslo/elc-bert>)
70
+ * [ELC-BERT 10M Hugging Face](https://huggingface.co/lgcharpe/ELC_BERT_small_baby_10M)
71
+ * ```bibtex
72
+ @inproceedings{georges-gabriel-charpentier-samuel-2023-layers,
73
+ title = "Not all layers are equally as important: Every Layer Counts {BERT}",
74
+ author = "Georges Gabriel Charpentier, Lucas and
75
+ Samuel, David",
76
+ editor = "Warstadt, Alex and
77
+ Mueller, Aaron and
78
+ Choshen, Leshem and
79
+ Wilcox, Ethan and
80
+ Zhuang, Chengxu and
81
+ Ciro, Juan and
82
+ Mosquera, Rafael and
83
+ Paranjabe, Bhargavi and
84
+ Williams, Adina and
85
+ Linzen, Tal and
86
+ Cotterell, Ryan",
87
+ booktitle = "Proceedings of the BabyLM Challenge at the 27th Conference on Computational Natural Language Learning",
88
+ month = dec,
89
+ year = "2023",
90
+ address = "Singapore",
91
+ publisher = "Association for Computational Linguistics",
92
+ url = "https://aclanthology.org/2023.conll-babylm.20",
93
+ doi = "10.18653/v1/2023.conll-babylm.20",
94
+ pages = "238--252",
95
+ }```
__init__.py ADDED
File without changes
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LtgBertForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_ltgbert.LtgBertConfig",
8
+ "AutoModelForMaskedLM": "modeling_ltgbert.LtgBertForMaskedLM",
9
+ "AutoModelForSequenceClassification": "modeling_ltgbert.LtgBertForSequenceClassification"
10
+ },
11
+ "classifier_dropout": 0.2,
12
+ "hidden_dropout_prob": 0.1,
13
+ "hidden_size": 384,
14
+ "intermediate_size": 1024,
15
+ "layer_norm_eps": 1e-07,
16
+ "max_position_embeddings": 512,
17
+ "model_type": "ltgbert",
18
+ "num_attention_heads": 6,
19
+ "num_hidden_layers": 12,
20
+ "output_all_encoded_layers": true,
21
+ "pad_token_id": 3,
22
+ "position_bucket_size": 32,
23
+ "torch_dtype": "float32",
24
+ "transformers_version": "4.26.0",
25
+ "vocab_size": 16384
26
+ }
configuration_ltgbert.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Language Technology Group from University of Oslo and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ LTG-BERT configutation """
17
+
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+
21
+
22
+ LTG_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
23
+ "bnc-bert-span": "https://huggingface.co/ltg/bnc-bert-span",
24
+ "bnc-bert-span-2x": "https://huggingface.co/ltg/bnc-bert-span-2x",
25
+ "bnc-bert-span-0.5x": "https://huggingface.co/ltg/bnc-bert-span-0.5x",
26
+ "bnc-bert-span-0.25x": "https://huggingface.co/ltg/bnc-bert-span-0.25x",
27
+ "bnc-bert-span-order": "https://huggingface.co/ltg/bnc-bert-span-order",
28
+ "bnc-bert-span-document": "https://huggingface.co/ltg/bnc-bert-span-document",
29
+ "bnc-bert-span-word": "https://huggingface.co/ltg/bnc-bert-span-word",
30
+ "bnc-bert-span-subword": "https://huggingface.co/ltg/bnc-bert-span-subword",
31
+ "norbert3-xs": "https://huggingface.co/ltg/norbert3-xs/config.json",
32
+ "norbert3-small": "https://huggingface.co/ltg/norbert3-small/config.json",
33
+ "norbert3-base": "https://huggingface.co/ltg/norbert3-base/config.json",
34
+ "norbert3-large": "https://huggingface.co/ltg/norbert3-large/config.json",
35
+ "norbert3-oversampled-base": "https://huggingface.co/ltg/norbert3-oversampled-base/config.json",
36
+ "norbert3-ncc-base": "https://huggingface.co/ltg/norbert3-ncc-base/config.json",
37
+ "norbert3-nak-base": "https://huggingface.co/ltg/norbert3-nak-base/config.json",
38
+ "norbert3-nb-base": "https://huggingface.co/ltg/norbert3-nb-base/config.json",
39
+ "norbert3-wiki-base": "https://huggingface.co/ltg/norbert3-wiki-base/config.json",
40
+ "norbert3-c4-base": "https://huggingface.co/ltg/norbert3-c4-base/config.json",
41
+ }
42
+
43
+
44
+ class LtgBertConfig(PretrainedConfig):
45
+ r"""
46
+ This is the configuration class to store the configuration of a [`LtgBertModel`]. It is used to
47
+ instantiate an LTG-BERT model according to the specified arguments, defining the model architecture.
48
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
49
+ documentation from [`PretrainedConfig`] for more information.
50
+ Args:
51
+ vocab_size (`int`, *optional*, defaults to 16384):
52
+ Vocabulary size of the LTG-BERT model. Defines the number of different tokens that can be represented by the
53
+ `inputs_ids` passed when calling [`LtgBertModel`].
54
+ hidden_size (`int`, *optional*, defaults to 768):
55
+ Dimensionality of the encoder layers and the pooler layer.
56
+ num_hidden_layers (`int`, *optional*, defaults to 12):
57
+ Number of hidden layers in the Transformer encoder.
58
+ num_attention_heads (`int`, *optional*, defaults to 12):
59
+ Number of attention heads for each attention layer in the Transformer encoder.
60
+ intermediate_size (`int`, *optional*, defaults to 2048):
61
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
62
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
63
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
64
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
65
+ The dropout ratio for the attention probabilities.
66
+ max_position_embeddings (`int`, *optional*, defaults to 512):
67
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
68
+ just in case (e.g., 512 or 1024 or 2048).
69
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
70
+ The epsilon used by the layer normalization layers.
71
+ classifier_dropout (`float`, *optional*):
72
+ The dropout ratio for the classification head.
73
+ """
74
+ model_type = "ltgbert"
75
+
76
+ def __init__(
77
+ self,
78
+ vocab_size=16384,
79
+ attention_probs_dropout_prob=0.1,
80
+ hidden_dropout_prob=0.1,
81
+ hidden_size=768,
82
+ intermediate_size=2048,
83
+ max_position_embeddings=512,
84
+ position_bucket_size=32,
85
+ num_attention_heads=12,
86
+ num_hidden_layers=12,
87
+ layer_norm_eps=1.0e-7,
88
+ pad_token_id=4,
89
+ output_all_encoded_layers=True,
90
+ classifier_dropout=None,
91
+ **kwargs,
92
+ ):
93
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
94
+
95
+ self.vocab_size = vocab_size
96
+ self.hidden_size = hidden_size
97
+ self.num_hidden_layers = num_hidden_layers
98
+ self.num_attention_heads = num_attention_heads
99
+ self.intermediate_size = intermediate_size
100
+ self.hidden_dropout_prob = hidden_dropout_prob
101
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
102
+ self.max_position_embeddings = max_position_embeddings
103
+ self.output_all_encoded_layers = output_all_encoded_layers
104
+ self.position_bucket_size = position_bucket_size
105
+ self.layer_norm_eps = layer_norm_eps
106
+ self.classifier_dropout = classifier_dropout
modeling_ltgbert.py ADDED
@@ -0,0 +1,1294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Language Technology Group from University of Oslo and The HuggingFace Inc. team.
3
+ # And Copyright 2024 The Google Research Authors.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ # Base implementation of the LTG-BERT/ELC-BERT Model is from Language Technology Group from University of Oslo and The HuggingFace Inc., Team
18
+ # The StructFormer components is from The Google Research Authors - the authors were Yikang Shen and Yi Tay and Che Zheng and Dara Bahri and Donald Metzler and Aaron Courville
19
+ # (and the code can be from here: https://github.com/google-research/google-research/tree/master/structformer), both were using Apache license, Version 2.0
20
+
21
+ """ PyTorch LTG-(ELC)-ParserBERT model."""
22
+
23
+
24
+ import math
25
+ from typing import List, Optional, Tuple, Union
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+ from torch.utils import checkpoint
31
+
32
+ from .configuration_ltgbert import LtgBertConfig
33
+ from transformers.modeling_utils import PreTrainedModel
34
+ from transformers.activations import gelu_new
35
+ from transformers.modeling_outputs import (
36
+ MaskedLMOutput,
37
+ MultipleChoiceModelOutput,
38
+ QuestionAnsweringModelOutput,
39
+ SequenceClassifierOutput,
40
+ TokenClassifierOutput,
41
+ BaseModelOutput,
42
+ )
43
+ from transformers.pytorch_utils import softmax_backward_data
44
+ from transformers.utils import (
45
+ add_start_docstrings,
46
+ add_start_docstrings_to_model_forward,
47
+ )
48
+
49
+
50
+ _CHECKPOINT_FOR_DOC = "ltg/bnc-bert-span"
51
+ _CONFIG_FOR_DOC = "LtgBertConfig"
52
+
53
+
54
+ LTG_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
55
+ "bnc-bert-span",
56
+ "bnc-bert-span-2x",
57
+ "bnc-bert-span-0.5x",
58
+ "bnc-bert-span-0.25x",
59
+ "bnc-bert-span-order",
60
+ "bnc-bert-span-document",
61
+ "bnc-bert-span-word",
62
+ "bnc-bert-span-subword",
63
+ "norbert3-xs",
64
+ "norbert3-small",
65
+ "norbert3-base",
66
+ "norbert3-large",
67
+ "norbert3-oversampled-base",
68
+ "norbert3-ncc-base",
69
+ "norbert3-nak-base",
70
+ "norbert3-nb-base",
71
+ "norbert3-wiki-base",
72
+ "norbert3-c4-base",
73
+ ]
74
+
75
+
76
+ class Conv1d(nn.Module):
77
+ """1D convolution layer."""
78
+
79
+ def __init__(self, hidden_size, kernel_size, dilation=1):
80
+ """Initialization.
81
+
82
+ Args:
83
+ hidden_size: dimension of input embeddings
84
+ kernel_size: convolution kernel size
85
+ dilation: the spacing between the kernel points
86
+ """
87
+ super(Conv1d, self).__init__()
88
+
89
+ if kernel_size % 2 == 0:
90
+ padding = (kernel_size // 2) * dilation
91
+ self.shift = True
92
+ else:
93
+ padding = ((kernel_size - 1) // 2) * dilation
94
+ self.shift = False
95
+ self.conv = nn.Conv1d(
96
+ hidden_size, hidden_size, kernel_size, padding=padding, dilation=dilation
97
+ )
98
+
99
+ def forward(self, x):
100
+ """Compute convolution.
101
+
102
+ Args:
103
+ x: input embeddings
104
+ Returns:
105
+ conv_output: convolution results
106
+ """
107
+
108
+ if self.shift:
109
+ return self.conv(x.transpose(1, 2)).transpose(1, 2)[:, 1:]
110
+ else:
111
+ return self.conv(x.transpose(1, 2)).transpose(1, 2)
112
+
113
+
114
+ def cumprod(x, reverse=False, exclusive=False):
115
+ """cumulative product."""
116
+ if reverse:
117
+ x = x.flip([-1])
118
+
119
+ if exclusive:
120
+ x = F.pad(x[:, :, :-1], (1, 0), value=1)
121
+
122
+ cx = x.cumprod(-1)
123
+
124
+ if reverse:
125
+ cx = cx.flip([-1])
126
+ return cx
127
+
128
+
129
+ def cumsum(x, reverse=False, exclusive=False):
130
+ """cumulative sum."""
131
+ bsz, _, length = x.size()
132
+ device = x.device
133
+ if reverse:
134
+ if exclusive:
135
+ w = torch.ones([bsz, length, length], device=device).tril(-1)
136
+ else:
137
+ w = torch.ones([bsz, length, length], device=device).tril(0)
138
+ cx = torch.bmm(x, w)
139
+ else:
140
+ if exclusive:
141
+ w = torch.ones([bsz, length, length], device=device).triu(1)
142
+ else:
143
+ w = torch.ones([bsz, length, length], device=device).triu(0)
144
+ cx = torch.bmm(x, w)
145
+ return cx
146
+
147
+
148
+ def cummin(x, reverse=False, exclusive=False, max_value=1e4):
149
+ """cumulative min."""
150
+ if reverse:
151
+ if exclusive:
152
+ x = F.pad(x[:, :, 1:], (0, 1), value=max_value)
153
+ x = x.flip([-1]).cummin(-1)[0].flip([-1])
154
+ else:
155
+ if exclusive:
156
+ x = F.pad(x[:, :, :-1], (1, 0), value=max_value)
157
+ x = x.cummin(-1)[0]
158
+ return x
159
+
160
+
161
+ class ParserNetwork(nn.Module):
162
+ def __init__(
163
+ self,
164
+ config,
165
+ pad=0,
166
+ n_parser_layers=4,
167
+ conv_size=9,
168
+ relations=("head", "child"),
169
+ weight_act="softmax",
170
+ ):
171
+ """
172
+ hidden_size: dimension of input embeddings
173
+ nlayers: number of layers
174
+ ntokens: number of output categories
175
+ nhead: number of self-attention heads
176
+ dropout: dropout rate
177
+ pad: pad token index
178
+ n_parser_layers: number of parsing layers
179
+ conv_size: convolution kernel size for parser
180
+ relations: relations that are used to compute self attention
181
+ weight_act: relations distribution activation function
182
+ """
183
+ super(ParserNetwork, self).__init__()
184
+ self.hidden_size = config.hidden_size
185
+ self.num_hidden_layers = config.num_hidden_layers
186
+ self.num_attention_heads = config.num_attention_heads
187
+
188
+ self.parser_layers = nn.ModuleList(
189
+ [
190
+ nn.Sequential(
191
+ Conv1d(self.hidden_size, conv_size),
192
+ nn.LayerNorm(self.hidden_size, elementwise_affine=False),
193
+ nn.Tanh(),
194
+ )
195
+ for _ in range(n_parser_layers)
196
+ ]
197
+ )
198
+
199
+ self.distance_ff = nn.Sequential(
200
+ Conv1d(self.hidden_size, 2),
201
+ nn.LayerNorm(self.hidden_size, elementwise_affine=False),
202
+ nn.Tanh(),
203
+ nn.Linear(self.hidden_size, 1),
204
+ )
205
+
206
+ self.height_ff = nn.Sequential(
207
+ nn.Linear(self.hidden_size, self.hidden_size),
208
+ nn.LayerNorm(self.hidden_size, elementwise_affine=False),
209
+ nn.Tanh(),
210
+ nn.Linear(self.hidden_size, 1),
211
+ )
212
+
213
+ n_rel = len(relations)
214
+ self._rel_weight = nn.Parameter(
215
+ torch.zeros((self.num_hidden_layers, self.num_attention_heads, n_rel))
216
+ )
217
+ self._rel_weight.data.normal_(0, 0.1)
218
+
219
+ self._scaler = nn.Parameter(torch.zeros(2))
220
+
221
+ self.n_parse_layers = n_parser_layers
222
+ self.weight_act = weight_act
223
+ self.relations = relations
224
+ self.pad = pad
225
+
226
+ @property
227
+ def scaler(self):
228
+ return self._scaler.exp()
229
+
230
+ @property
231
+ def rel_weight(self):
232
+ if self.weight_act == "sigmoid":
233
+ return torch.sigmoid(self._rel_weight)
234
+ elif self.weight_act == "softmax":
235
+ return torch.softmax(self._rel_weight, dim=-1)
236
+
237
+ def parse(self, x, h):
238
+ """
239
+ Parse input sentence.
240
+ Args:
241
+ x: input tokens (required).
242
+ h: static embeddings
243
+ Returns:
244
+ distance: syntactic distance
245
+ height: syntactic height
246
+ """
247
+
248
+ mask = x != self.pad
249
+ mask_shifted = F.pad(mask[:, 1:], (0, 1), value=0)
250
+
251
+ for i in range(self.n_parse_layers):
252
+ h = h.masked_fill(~mask[:, :, None], 0)
253
+ h = self.parser_layers[i](h)
254
+
255
+ height = self.height_ff(h).squeeze(-1)
256
+ height.masked_fill_(~mask, -1e4)
257
+
258
+ distance = self.distance_ff(h).squeeze(-1)
259
+ distance.masked_fill_(~mask_shifted, 1e4)
260
+
261
+ # Calbrating the distance and height to the same level
262
+ length = distance.size(1)
263
+ height_max = height[:, None, :].expand(-1, length, -1)
264
+ height_max = torch.cummax(
265
+ height_max.triu(0) - torch.ones_like(height_max).tril(-1) * 1e4, dim=-1
266
+ )[0].triu(0)
267
+
268
+ margin_left = torch.relu(
269
+ F.pad(distance[:, :-1, None], (0, 0, 1, 0), value=1e4) - height_max
270
+ )
271
+ margin_right = torch.relu(distance[:, None, :] - height_max)
272
+ margin = torch.where(
273
+ margin_left > margin_right, margin_right, margin_left
274
+ ).triu(0)
275
+
276
+ margin_mask = torch.stack([mask_shifted] + [mask] * (length - 1), dim=1)
277
+ margin.masked_fill_(~margin_mask, 0)
278
+ margin = margin.max()
279
+
280
+ distance = distance - margin
281
+
282
+ return distance, height
283
+
284
+ def compute_block(self, distance, height):
285
+ """Compute constituents from distance and height."""
286
+
287
+ beta_logits = (distance[:, None, :] - height[:, :, None]) * self.scaler[0]
288
+
289
+ gamma = torch.sigmoid(-beta_logits)
290
+ ones = torch.ones_like(gamma)
291
+
292
+ block_mask_left = cummin(
293
+ gamma.tril(-1) + ones.triu(0), reverse=True, max_value=1
294
+ )
295
+ block_mask_left = block_mask_left - F.pad(
296
+ block_mask_left[:, :, :-1], (1, 0), value=0
297
+ )
298
+ block_mask_left.tril_(0)
299
+
300
+ block_mask_right = cummin(
301
+ gamma.triu(0) + ones.tril(-1), exclusive=True, max_value=1
302
+ )
303
+ block_mask_right = block_mask_right - F.pad(
304
+ block_mask_right[:, :, 1:], (0, 1), value=0
305
+ )
306
+ block_mask_right.triu_(0)
307
+
308
+ block_p = block_mask_left[:, :, :, None] * block_mask_right[:, :, None, :]
309
+ block = cumsum(block_mask_left).tril(0) + cumsum(
310
+ block_mask_right, reverse=True
311
+ ).triu(1)
312
+
313
+ return block_p, block
314
+
315
+ def compute_head(self, height):
316
+ """Estimate head for each constituent."""
317
+
318
+ _, length = height.size()
319
+ head_logits = height * self.scaler[1]
320
+ index = torch.arange(length, device=height.device)
321
+
322
+ mask = (index[:, None, None] <= index[None, None, :]) * (
323
+ index[None, None, :] <= index[None, :, None]
324
+ )
325
+ head_logits = head_logits[:, None, None, :].repeat(1, length, length, 1)
326
+ head_logits.masked_fill_(~mask[None, :, :, :], -1e4)
327
+
328
+ head_p = torch.softmax(head_logits, dim=-1)
329
+
330
+ return head_p
331
+
332
+ def generate_mask(self, x, distance, height):
333
+ """Compute head and cibling distribution for each token."""
334
+
335
+ batch_size, length = x.size()
336
+
337
+ eye = torch.eye(length, device=x.device, dtype=torch.bool)
338
+ eye = eye[None, :, :].expand((batch_size, -1, -1))
339
+
340
+ block_p, block = self.compute_block(distance, height)
341
+ head_p = self.compute_head(height)
342
+ head = torch.einsum("blij,bijh->blh", block_p, head_p)
343
+ head = head.masked_fill(eye, 0)
344
+ child = head.transpose(1, 2)
345
+ cibling = torch.bmm(head, child).masked_fill(eye, 0)
346
+
347
+ rel_list = []
348
+ if "head" in self.relations:
349
+ rel_list.append(head)
350
+ if "child" in self.relations:
351
+ rel_list.append(child)
352
+ if "cibling" in self.relations:
353
+ rel_list.append(cibling)
354
+
355
+ rel = torch.stack(rel_list, dim=1)
356
+
357
+ rel_weight = self.rel_weight
358
+
359
+ dep = torch.einsum("lhr,brij->lbhij", rel_weight, rel)
360
+ att_mask = dep.reshape(
361
+ self.num_hidden_layers, batch_size, self.num_attention_heads, length, length
362
+ )
363
+
364
+ return att_mask, cibling, head, block
365
+
366
+ def forward(self, x, embeddings):
367
+ """
368
+ Pass the x tokens through the parse network, get the syntactic height and distances
369
+ and compute the distribution for each token
370
+ """
371
+
372
+ x = torch.transpose(x, 0, 1)
373
+ embeddings = torch.transpose(embeddings, 0, 1)
374
+
375
+ distance, height = self.parse(x, embeddings)
376
+ att_mask, cibling, head, block = self.generate_mask(x, distance, height)
377
+ return att_mask, cibling, head, block
378
+
379
+
380
+ class Encoder(nn.Module):
381
+ def __init__(self, config, activation_checkpointing=False):
382
+ super().__init__()
383
+ self.layers = nn.ModuleList(
384
+ [EncoderLayer(config, i) for i in range(config.num_hidden_layers)]
385
+ )
386
+
387
+ for i, layer in enumerate(self.layers):
388
+ layer.mlp.mlp[1].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
389
+ layer.mlp.mlp[-2].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
390
+
391
+ self.activation_checkpointing = activation_checkpointing
392
+
393
+ def forward(self, hidden_states, attention_mask, relative_embedding):
394
+ hidden_states, attention_probs = [hidden_states], []
395
+
396
+ for i in range(len(self.layers)):
397
+ if self.activation_checkpointing:
398
+ hidden_state, attention_p = checkpoint.checkpoint(
399
+ self.layers[i], hidden_states, attention_mask, relative_embedding
400
+ )
401
+ else:
402
+ hidden_state, attention_p = self.layers[i](
403
+ hidden_states, attention_mask[i], relative_embedding
404
+ )
405
+
406
+ hidden_states.append(hidden_state)
407
+ attention_probs.append(attention_p)
408
+
409
+ return hidden_states, attention_probs
410
+
411
+
412
+ class MaskClassifier(nn.Module):
413
+ def __init__(self, config, subword_embedding):
414
+ super().__init__()
415
+ self.nonlinearity = nn.Sequential(
416
+ nn.LayerNorm(
417
+ config.hidden_size, config.layer_norm_eps, elementwise_affine=False
418
+ ),
419
+ nn.Linear(config.hidden_size, config.hidden_size),
420
+ nn.GELU(),
421
+ nn.LayerNorm(
422
+ config.hidden_size, config.layer_norm_eps, elementwise_affine=False
423
+ ),
424
+ nn.Dropout(config.hidden_dropout_prob),
425
+ nn.Linear(subword_embedding.size(1), subword_embedding.size(0)),
426
+ )
427
+ self.initialize(config.hidden_size, subword_embedding)
428
+
429
+ def initialize(self, hidden_size, embedding):
430
+ std = math.sqrt(2.0 / (5.0 * hidden_size))
431
+ nn.init.trunc_normal_(
432
+ self.nonlinearity[1].weight, mean=0.0, std=std, a=-2 * std, b=2 * std
433
+ )
434
+ self.nonlinearity[-1].weight = embedding
435
+ self.nonlinearity[1].bias.data.zero_()
436
+ self.nonlinearity[-1].bias.data.zero_()
437
+
438
+ def forward(self, x, masked_lm_labels=None):
439
+ if masked_lm_labels is not None:
440
+ x = torch.index_select(
441
+ x.flatten(0, 1),
442
+ 0,
443
+ torch.nonzero(masked_lm_labels.flatten() != -100).squeeze(),
444
+ )
445
+ x = self.nonlinearity(x)
446
+ return x
447
+
448
+
449
+ class EncoderLayer(nn.Module):
450
+ def __init__(self, config, layer_num):
451
+ super().__init__()
452
+ self.attention = Attention(config)
453
+ self.mlp = FeedForward(config)
454
+ temp = torch.zeros(layer_num + 1)
455
+ temp[-1] = 1
456
+ self.prev_layer_weights = nn.Parameter(temp)
457
+
458
+ def forward(self, hidden_states, padding_mask, relative_embedding):
459
+ prev_layer_weights = F.softmax(self.prev_layer_weights, dim=-1)
460
+ x = prev_layer_weights[0] * hidden_states[0]
461
+ for i, hidden_state in enumerate(hidden_states[1:]):
462
+ x = x + prev_layer_weights[i + 1] * hidden_state
463
+ attention_output, attention_probs = self.attention(
464
+ x, padding_mask, relative_embedding
465
+ )
466
+ x = attention_output
467
+ x = x + self.mlp(x)
468
+ return x, attention_probs
469
+
470
+
471
+ class GeGLU(nn.Module):
472
+ def forward(self, x):
473
+ x, gate = x.chunk(2, dim=-1)
474
+ x = x * gelu_new(gate)
475
+ return x
476
+
477
+
478
+ class FeedForward(nn.Module):
479
+ def __init__(self, config):
480
+ super().__init__()
481
+ self.mlp = nn.Sequential(
482
+ nn.LayerNorm(
483
+ config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False
484
+ ),
485
+ nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False),
486
+ GeGLU(),
487
+ nn.LayerNorm(
488
+ config.intermediate_size,
489
+ eps=config.layer_norm_eps,
490
+ elementwise_affine=False,
491
+ ),
492
+ nn.Linear(config.intermediate_size, config.hidden_size, bias=False),
493
+ nn.Dropout(config.hidden_dropout_prob),
494
+ )
495
+ self.initialize(config.hidden_size)
496
+
497
+ def initialize(self, hidden_size):
498
+ std = math.sqrt(2.0 / (5.0 * hidden_size))
499
+ nn.init.trunc_normal_(
500
+ self.mlp[1].weight, mean=0.0, std=std, a=-2 * std, b=2 * std
501
+ )
502
+ nn.init.trunc_normal_(
503
+ self.mlp[-2].weight, mean=0.0, std=std, a=-2 * std, b=2 * std
504
+ )
505
+
506
+ def forward(self, x):
507
+ return self.mlp(x)
508
+
509
+
510
+ class MaskedSoftmax(torch.autograd.Function):
511
+ @staticmethod
512
+ def forward(self, x, mask, dim):
513
+ self.dim = dim
514
+ x.masked_fill_(mask, float("-inf"))
515
+ x = torch.softmax(x, self.dim)
516
+ x.masked_fill_(mask, 0.0)
517
+ self.save_for_backward(x)
518
+ return x
519
+
520
+ @staticmethod
521
+ def backward(self, grad_output):
522
+ (output,) = self.saved_tensors
523
+ input_grad = softmax_backward_data(self, grad_output, output, self.dim, output)
524
+ return input_grad, None, None
525
+
526
+
527
+ class Attention(nn.Module):
528
+ def __init__(self, config):
529
+ super().__init__()
530
+
531
+ self.config = config
532
+
533
+ if config.hidden_size % config.num_attention_heads != 0:
534
+ raise ValueError(
535
+ f"The hidden size {config.hidden_size} is not a multiple of the number of attention heads {config.num_attention_heads}"
536
+ )
537
+
538
+ self.hidden_size = config.hidden_size
539
+ self.num_heads = config.num_attention_heads
540
+ self.head_size = config.hidden_size // config.num_attention_heads
541
+
542
+ self.in_proj_qk = nn.Linear(
543
+ config.hidden_size, 2 * config.hidden_size, bias=True
544
+ )
545
+ self.in_proj_v = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
546
+ self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
547
+
548
+ self.pre_layer_norm = nn.LayerNorm(
549
+ config.hidden_size, config.layer_norm_eps, elementwise_affine=False
550
+ )
551
+ self.post_layer_norm = nn.LayerNorm(
552
+ config.hidden_size, config.layer_norm_eps, elementwise_affine=True
553
+ )
554
+
555
+ position_indices = torch.arange(
556
+ config.max_position_embeddings, dtype=torch.long
557
+ ).unsqueeze(1) - torch.arange(
558
+ config.max_position_embeddings, dtype=torch.long
559
+ ).unsqueeze(
560
+ 0
561
+ )
562
+ position_indices = self.make_log_bucket_position(
563
+ position_indices,
564
+ config.position_bucket_size,
565
+ config.max_position_embeddings,
566
+ )
567
+ position_indices = config.position_bucket_size - 1 + position_indices
568
+ self.register_buffer("position_indices", position_indices, persistent=True)
569
+
570
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
571
+ self.scale = 1.0 / math.sqrt(3 * self.head_size)
572
+ self.initialize()
573
+
574
+ def make_log_bucket_position(self, relative_pos, bucket_size, max_position):
575
+ sign = torch.sign(relative_pos)
576
+ mid = bucket_size // 2
577
+ abs_pos = torch.where(
578
+ (relative_pos < mid) & (relative_pos > -mid),
579
+ mid - 1,
580
+ torch.abs(relative_pos).clamp(max=max_position - 1),
581
+ )
582
+ log_pos = (
583
+ torch.ceil(
584
+ torch.log(abs_pos / mid)
585
+ / math.log((max_position - 1) / mid)
586
+ * (mid - 1)
587
+ ).int()
588
+ + mid
589
+ )
590
+ bucket_pos = torch.where(abs_pos <= mid, relative_pos, log_pos * sign).long()
591
+ return bucket_pos
592
+
593
+ def initialize(self):
594
+ std = math.sqrt(2.0 / (5.0 * self.hidden_size))
595
+ nn.init.trunc_normal_(
596
+ self.in_proj_qk.weight, mean=0.0, std=std, a=-2 * std, b=2 * std
597
+ )
598
+ nn.init.trunc_normal_(
599
+ self.in_proj_v.weight, mean=0.0, std=std, a=-2 * std, b=2 * std
600
+ )
601
+ nn.init.trunc_normal_(
602
+ self.out_proj.weight, mean=0.0, std=std, a=-2 * std, b=2 * std
603
+ )
604
+ self.in_proj_qk.bias.data.zero_()
605
+ self.in_proj_v.bias.data.zero_()
606
+ self.out_proj.bias.data.zero_()
607
+
608
+ def compute_attention_scores(self, hidden_states, relative_embedding):
609
+ key_len, batch_size, _ = hidden_states.size()
610
+ query_len = key_len
611
+
612
+ if self.position_indices.size(0) < query_len:
613
+ position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(
614
+ 1
615
+ ) - torch.arange(query_len, dtype=torch.long).unsqueeze(0)
616
+ position_indices = self.make_log_bucket_position(
617
+ position_indices, self.position_bucket_size, 512
618
+ )
619
+ position_indices = self.position_bucket_size - 1 + position_indices
620
+ self.position_indices = position_indices.to(hidden_states.device)
621
+
622
+ hidden_states = self.pre_layer_norm(hidden_states)
623
+
624
+ query, key = self.in_proj_qk(hidden_states).chunk(2, dim=2) # shape: [T, B, D]
625
+ value = self.in_proj_v(hidden_states) # shape: [T, B, D]
626
+
627
+ query = query.reshape(
628
+ query_len, batch_size * self.num_heads, self.head_size
629
+ ).transpose(0, 1)
630
+ key = key.reshape(
631
+ key_len, batch_size * self.num_heads, self.head_size
632
+ ).transpose(0, 1)
633
+ value = value.view(
634
+ key_len, batch_size * self.num_heads, self.head_size
635
+ ).transpose(0, 1)
636
+
637
+ attention_scores = torch.bmm(query, key.transpose(1, 2) * self.scale)
638
+
639
+ query_pos, key_pos = self.in_proj_qk(self.dropout(relative_embedding)).chunk(
640
+ 2, dim=-1
641
+ ) # shape: [2T-1, D]
642
+ query_pos = query_pos.view(
643
+ -1, self.num_heads, self.head_size
644
+ ) # shape: [2T-1, H, D]
645
+ key_pos = key_pos.view(
646
+ -1, self.num_heads, self.head_size
647
+ ) # shape: [2T-1, H, D]
648
+
649
+ query = query.view(batch_size, self.num_heads, query_len, self.head_size)
650
+ key = key.view(batch_size, self.num_heads, query_len, self.head_size)
651
+
652
+ attention_c_p = torch.einsum(
653
+ "bhqd,khd->bhqk", query, key_pos.squeeze(1) * self.scale
654
+ )
655
+ attention_p_c = torch.einsum(
656
+ "bhkd,qhd->bhqk", key * self.scale, query_pos.squeeze(1)
657
+ )
658
+
659
+ position_indices = self.position_indices[:query_len, :key_len].expand(
660
+ batch_size, self.num_heads, -1, -1
661
+ )
662
+ attention_c_p = attention_c_p.gather(3, position_indices)
663
+ attention_p_c = attention_p_c.gather(2, position_indices)
664
+
665
+ attention_scores = attention_scores.view(
666
+ batch_size, self.num_heads, query_len, key_len
667
+ )
668
+ attention_scores.add_(attention_c_p)
669
+ attention_scores.add_(attention_p_c)
670
+
671
+ return attention_scores, value
672
+
673
+ def compute_output(self, attention_probs, value):
674
+ attention_probs = self.dropout(attention_probs)
675
+ context = torch.bmm(attention_probs.flatten(0, 1), value) # shape: [B*H, Q, D]
676
+ context = context.transpose(0, 1).reshape(
677
+ context.size(1), -1, self.hidden_size
678
+ ) # shape: [Q, B, H*D]
679
+ context = self.out_proj(context)
680
+ context = self.post_layer_norm(context)
681
+ context = self.dropout(context)
682
+ return context
683
+
684
+ def forward(self, hidden_states, attention_mask, relative_embedding):
685
+ attention_scores, value = self.compute_attention_scores(
686
+ hidden_states, relative_embedding
687
+ )
688
+ attention_probs = torch.sigmoid(attention_scores) * attention_mask
689
+ return self.compute_output(attention_probs, value), attention_probs.detach()
690
+
691
+
692
+ class Embedding(nn.Module):
693
+ def __init__(self, config):
694
+ super().__init__()
695
+ self.hidden_size = config.hidden_size
696
+
697
+ self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
698
+ self.word_layer_norm = nn.LayerNorm(
699
+ config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False
700
+ )
701
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
702
+
703
+ self.relative_embedding = nn.Parameter(
704
+ torch.empty(2 * config.position_bucket_size - 1, config.hidden_size)
705
+ )
706
+ self.relative_layer_norm = nn.LayerNorm(
707
+ config.hidden_size, eps=config.layer_norm_eps
708
+ )
709
+
710
+ self.initialize()
711
+
712
+ def initialize(self):
713
+ std = math.sqrt(2.0 / (5.0 * self.hidden_size))
714
+ nn.init.trunc_normal_(
715
+ self.relative_embedding, mean=0.0, std=std, a=-2 * std, b=2 * std
716
+ )
717
+ nn.init.trunc_normal_(
718
+ self.word_embedding.weight, mean=0.0, std=std, a=-2 * std, b=2 * std
719
+ )
720
+
721
+ def forward(self, input_ids):
722
+ word_embedding = self.dropout(
723
+ self.word_layer_norm(self.word_embedding(input_ids))
724
+ )
725
+ relative_embeddings = self.relative_layer_norm(self.relative_embedding)
726
+ return word_embedding, relative_embeddings
727
+
728
+
729
+ #
730
+ # HuggingFace wrappers
731
+ #
732
+
733
+
734
+ class LtgBertPreTrainedModel(PreTrainedModel):
735
+ """
736
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
737
+ models.
738
+ """
739
+
740
+ config_class = LtgBertConfig
741
+ base_model_prefix = "bnc-bert"
742
+ supports_gradient_checkpointing = True
743
+
744
+ def _set_gradient_checkpointing(self, module, value=False):
745
+ if isinstance(module, Encoder):
746
+ module.activation_checkpointing = value
747
+
748
+ def _init_weights(self, _):
749
+ pass # everything is already initialized
750
+
751
+
752
+ LTG_BERT_START_DOCSTRING = r"""
753
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
754
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
755
+ etc.)
756
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
757
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
758
+ and behavior.
759
+ Parameters:
760
+ config ([`LtgBertConfig`]): Model configuration class with all the parameters of the model.
761
+ Initializing with a config file does not load the weights associated with the model, only the
762
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
763
+ """
764
+
765
+ LTG_BERT_INPUTS_DOCSTRING = r"""
766
+ Args:
767
+ input_ids (`torch.LongTensor` of shape `({0})`):
768
+ Indices of input sequence tokens in the vocabulary.
769
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
770
+ [`PreTrainedTokenizer.__call__`] for details.
771
+ [What are input IDs?](../glossary#input-ids)
772
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
773
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
774
+ - 1 for tokens that are **not masked**,
775
+ - 0 for tokens that are **masked**.
776
+ [What are attention masks?](../glossary#attention-mask)
777
+ output_hidden_states (`bool`, *optional*):
778
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
779
+ more detail.
780
+ output_attentions (`bool`, *optional*):
781
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
782
+ tensors for more detail.
783
+ return_dict (`bool`, *optional*):
784
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
785
+ """
786
+
787
+
788
+ @add_start_docstrings(
789
+ "The bare LTG-BERT transformer outputting raw hidden-states without any specific head on top.",
790
+ LTG_BERT_START_DOCSTRING,
791
+ )
792
+ class LtgBertModel(LtgBertPreTrainedModel):
793
+ def __init__(self, config, add_mlm_layer=False):
794
+ super().__init__(config)
795
+ self.config = config
796
+
797
+ self.embedding = Embedding(config)
798
+ self.parser_network = ParserNetwork(config, pad=config.pad_token_id)
799
+ self.transformer = Encoder(config, activation_checkpointing=False)
800
+ self.classifier = (
801
+ MaskClassifier(config, self.embedding.word_embedding.weight)
802
+ if add_mlm_layer
803
+ else None
804
+ )
805
+
806
+ def get_input_embeddings(self):
807
+ return self.embedding.word_embedding
808
+
809
+ def set_input_embeddings(self, value):
810
+ self.embedding.word_embedding = value
811
+
812
+ def get_contextualized_embeddings(
813
+ self,
814
+ input_ids: Optional[torch.Tensor] = None,
815
+ attention_mask: Optional[torch.Tensor] = None,
816
+ ) -> List[torch.Tensor]:
817
+ if input_ids is not None:
818
+ input_shape = input_ids.size()
819
+ else:
820
+ raise ValueError("You have to specify input_ids")
821
+
822
+ batch_size, seq_length = input_shape
823
+ device = input_ids.device
824
+
825
+ static_embeddings, relative_embedding = self.embedding(input_ids.t())
826
+ att_mask, cibling, head, block = self.parser_network(
827
+ input_ids.t(), static_embeddings
828
+ )
829
+ contextualized_embeddings, attention_probs = self.transformer(
830
+ static_embeddings, att_mask, relative_embedding
831
+ )
832
+ contextualized_embeddings = [
833
+ e.transpose(0, 1) for e in contextualized_embeddings
834
+ ]
835
+ last_layer = contextualized_embeddings[-1]
836
+ contextualized_embeddings = [contextualized_embeddings[0]] + [
837
+ contextualized_embeddings[i] - contextualized_embeddings[i - 1]
838
+ for i in range(1, len(contextualized_embeddings))
839
+ ]
840
+ return last_layer, contextualized_embeddings, attention_probs
841
+
842
+ @add_start_docstrings_to_model_forward(
843
+ LTG_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
844
+ )
845
+ def forward(
846
+ self,
847
+ input_ids: Optional[torch.Tensor] = None,
848
+ attention_mask: Optional[torch.Tensor] = None,
849
+ output_hidden_states: Optional[bool] = None,
850
+ output_attentions: Optional[bool] = None,
851
+ return_dict: Optional[bool] = None,
852
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
853
+ output_attentions = (
854
+ output_attentions
855
+ if output_attentions is not None
856
+ else self.config.output_attentions
857
+ )
858
+ output_hidden_states = (
859
+ output_hidden_states
860
+ if output_hidden_states is not None
861
+ else self.config.output_hidden_states
862
+ )
863
+ return_dict = (
864
+ return_dict if return_dict is not None else self.config.use_return_dict
865
+ )
866
+
867
+ (
868
+ sequence_output,
869
+ contextualized_embeddings,
870
+ attention_probs,
871
+ ) = self.get_contextualized_embeddings(input_ids, attention_mask)
872
+
873
+ if not return_dict:
874
+ return (
875
+ sequence_output,
876
+ *([contextualized_embeddings] if output_hidden_states else []),
877
+ *([attention_probs] if output_attentions else []),
878
+ )
879
+
880
+ return BaseModelOutput(
881
+ last_hidden_state=sequence_output,
882
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
883
+ attentions=attention_probs if output_attentions else None,
884
+ )
885
+
886
+
887
+ @add_start_docstrings(
888
+ """LTG-BERT model with a `language modeling` head on top.""",
889
+ LTG_BERT_START_DOCSTRING,
890
+ )
891
+ class LtgBertForMaskedLM(LtgBertModel):
892
+ _keys_to_ignore_on_load_unexpected = ["head"]
893
+
894
+ def __init__(self, config):
895
+ super().__init__(config, add_mlm_layer=True)
896
+
897
+ def get_output_embeddings(self):
898
+ return self.classifier.nonlinearity[-1].weight
899
+
900
+ def set_output_embeddings(self, new_embeddings):
901
+ self.classifier.nonlinearity[-1].weight = new_embeddings
902
+
903
+ @add_start_docstrings_to_model_forward(
904
+ LTG_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
905
+ )
906
+ def forward(
907
+ self,
908
+ input_ids: Optional[torch.Tensor] = None,
909
+ attention_mask: Optional[torch.Tensor] = None,
910
+ output_hidden_states: Optional[bool] = None,
911
+ output_attentions: Optional[bool] = None,
912
+ return_dict: Optional[bool] = None,
913
+ labels: Optional[torch.LongTensor] = None,
914
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
915
+ r"""
916
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
917
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
918
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
919
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
920
+ """
921
+ return_dict = (
922
+ return_dict if return_dict is not None else self.config.use_return_dict
923
+ )
924
+
925
+ (
926
+ sequence_output,
927
+ contextualized_embeddings,
928
+ attention_probs,
929
+ ) = self.get_contextualized_embeddings(input_ids, attention_mask)
930
+ subword_prediction = self.classifier(sequence_output)
931
+
932
+ masked_lm_loss = None
933
+ if labels is not None:
934
+ masked_lm_loss = F.cross_entropy(
935
+ subword_prediction.flatten(0, 1), labels.flatten()
936
+ )
937
+
938
+ if not return_dict:
939
+ output = (
940
+ subword_prediction,
941
+ *([contextualized_embeddings] if output_hidden_states else []),
942
+ *([attention_probs] if output_attentions else []),
943
+ )
944
+ return (
945
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
946
+ )
947
+
948
+ return MaskedLMOutput(
949
+ loss=masked_lm_loss,
950
+ logits=subword_prediction,
951
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
952
+ attentions=attention_probs if output_attentions else None,
953
+ )
954
+
955
+
956
+ class Classifier(nn.Module):
957
+ def __init__(self, config, num_labels: int):
958
+ super().__init__()
959
+
960
+ drop_out = getattr(config, "classifier_dropout", config.hidden_dropout_prob)
961
+
962
+ self.nonlinearity = nn.Sequential(
963
+ nn.LayerNorm(
964
+ config.hidden_size, config.layer_norm_eps, elementwise_affine=False
965
+ ),
966
+ nn.Linear(config.hidden_size, config.hidden_size),
967
+ nn.GELU(),
968
+ nn.LayerNorm(
969
+ config.hidden_size, config.layer_norm_eps, elementwise_affine=False
970
+ ),
971
+ nn.Dropout(drop_out),
972
+ nn.Linear(config.hidden_size, num_labels),
973
+ )
974
+ self.initialize(config.hidden_size)
975
+
976
+ def initialize(self, hidden_size):
977
+ std = math.sqrt(2.0 / (5.0 * hidden_size))
978
+ nn.init.trunc_normal_(
979
+ self.nonlinearity[1].weight, mean=0.0, std=std, a=-2 * std, b=2 * std
980
+ )
981
+ nn.init.trunc_normal_(
982
+ self.nonlinearity[-1].weight, mean=0.0, std=std, a=-2 * std, b=2 * std
983
+ )
984
+ self.nonlinearity[1].bias.data.zero_()
985
+ self.nonlinearity[-1].bias.data.zero_()
986
+
987
+ def forward(self, x):
988
+ x = self.nonlinearity(x)
989
+ return x
990
+
991
+
992
+ @add_start_docstrings(
993
+ """
994
+ LTG-BERT model with a sequence classification/regression head on top (a linear layer on top of the pooled
995
+ output) e.g. for GLUE tasks.
996
+ """,
997
+ LTG_BERT_START_DOCSTRING,
998
+ )
999
+ class LtgBertForSequenceClassification(LtgBertModel):
1000
+ _keys_to_ignore_on_load_unexpected = ["classifier"]
1001
+ _keys_to_ignore_on_load_missing = ["head"]
1002
+
1003
+ def __init__(self, config):
1004
+ super().__init__(config, add_mlm_layer=False)
1005
+
1006
+ self.num_labels = config.num_labels
1007
+ self.head = Classifier(config, self.num_labels)
1008
+
1009
+ @add_start_docstrings_to_model_forward(
1010
+ LTG_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
1011
+ )
1012
+ def forward(
1013
+ self,
1014
+ input_ids: Optional[torch.Tensor] = None,
1015
+ attention_mask: Optional[torch.Tensor] = None,
1016
+ output_attentions: Optional[bool] = None,
1017
+ output_hidden_states: Optional[bool] = None,
1018
+ return_dict: Optional[bool] = None,
1019
+ labels: Optional[torch.LongTensor] = None,
1020
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1021
+ r"""
1022
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1023
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1024
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1025
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1026
+ """
1027
+ return_dict = (
1028
+ return_dict if return_dict is not None else self.config.use_return_dict
1029
+ )
1030
+
1031
+ (
1032
+ sequence_output,
1033
+ contextualized_embeddings,
1034
+ attention_probs,
1035
+ ) = self.get_contextualized_embeddings(input_ids, attention_mask)
1036
+ logits = self.head(sequence_output[:, 0, :])
1037
+
1038
+ loss = None
1039
+ if labels is not None:
1040
+ if self.config.problem_type is None:
1041
+ if self.num_labels == 1:
1042
+ self.config.problem_type = "regression"
1043
+ elif self.num_labels > 1 and (
1044
+ labels.dtype == torch.long or labels.dtype == torch.int
1045
+ ):
1046
+ self.config.problem_type = "single_label_classification"
1047
+ else:
1048
+ self.config.problem_type = "multi_label_classification"
1049
+
1050
+ if self.config.problem_type == "regression":
1051
+ loss_fct = nn.MSELoss()
1052
+ if self.num_labels == 1:
1053
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1054
+ else:
1055
+ loss = loss_fct(logits, labels)
1056
+ elif self.config.problem_type == "single_label_classification":
1057
+ loss_fct = nn.CrossEntropyLoss()
1058
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1059
+ elif self.config.problem_type == "multi_label_classification":
1060
+ loss_fct = nn.BCEWithLogitsLoss()
1061
+ loss = loss_fct(logits, labels)
1062
+
1063
+ if not return_dict:
1064
+ output = (
1065
+ logits,
1066
+ *([contextualized_embeddings] if output_hidden_states else []),
1067
+ *([attention_probs] if output_attentions else []),
1068
+ )
1069
+ return ((loss,) + output) if loss is not None else output
1070
+
1071
+ return SequenceClassifierOutput(
1072
+ loss=loss,
1073
+ logits=logits,
1074
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
1075
+ attentions=attention_probs if output_attentions else None,
1076
+ )
1077
+
1078
+
1079
+ @add_start_docstrings(
1080
+ """
1081
+ LTG-BERT model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1082
+ Named-Entity-Recognition (NER) tasks.
1083
+ """,
1084
+ LTG_BERT_START_DOCSTRING,
1085
+ )
1086
+ class LtgBertForTokenClassification(LtgBertModel):
1087
+ _keys_to_ignore_on_load_unexpected = ["classifier"]
1088
+ _keys_to_ignore_on_load_missing = ["head"]
1089
+
1090
+ def __init__(self, config):
1091
+ super().__init__(config, add_mlm_layer=False)
1092
+
1093
+ self.num_labels = config.num_labels
1094
+ self.head = Classifier(config, self.num_labels)
1095
+
1096
+ @add_start_docstrings_to_model_forward(
1097
+ LTG_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
1098
+ )
1099
+ def forward(
1100
+ self,
1101
+ input_ids: Optional[torch.Tensor] = None,
1102
+ attention_mask: Optional[torch.Tensor] = None,
1103
+ token_type_ids: Optional[torch.Tensor] = None,
1104
+ position_ids: Optional[torch.Tensor] = None,
1105
+ output_attentions: Optional[bool] = None,
1106
+ output_hidden_states: Optional[bool] = None,
1107
+ return_dict: Optional[bool] = None,
1108
+ labels: Optional[torch.LongTensor] = None,
1109
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1110
+ return_dict = (
1111
+ return_dict if return_dict is not None else self.config.use_return_dict
1112
+ )
1113
+
1114
+ (
1115
+ sequence_output,
1116
+ contextualized_embeddings,
1117
+ attention_probs,
1118
+ ) = self.get_contextualized_embeddings(input_ids, attention_mask)
1119
+ logits = self.head(sequence_output)
1120
+
1121
+ loss = None
1122
+ if labels is not None:
1123
+ loss_fct = nn.CrossEntropyLoss()
1124
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1125
+
1126
+ if not return_dict:
1127
+ output = (
1128
+ logits,
1129
+ *([contextualized_embeddings] if output_hidden_states else []),
1130
+ *([attention_probs] if output_attentions else []),
1131
+ )
1132
+ return ((loss,) + output) if loss is not None else output
1133
+
1134
+ return TokenClassifierOutput(
1135
+ loss=loss,
1136
+ logits=logits,
1137
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
1138
+ attentions=attention_probs if output_attentions else None,
1139
+ )
1140
+
1141
+
1142
+ @add_start_docstrings(
1143
+ """
1144
+ LTG-BERT model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1145
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1146
+ """,
1147
+ LTG_BERT_START_DOCSTRING,
1148
+ )
1149
+ class LtgBertForQuestionAnswering(LtgBertModel):
1150
+ _keys_to_ignore_on_load_unexpected = ["classifier"]
1151
+ _keys_to_ignore_on_load_missing = ["head"]
1152
+
1153
+ def __init__(self, config):
1154
+ super().__init__(config, add_mlm_layer=False)
1155
+
1156
+ self.num_labels = config.num_labels
1157
+ self.head = Classifier(config, self.num_labels)
1158
+
1159
+ @add_start_docstrings_to_model_forward(
1160
+ LTG_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
1161
+ )
1162
+ def forward(
1163
+ self,
1164
+ input_ids: Optional[torch.Tensor] = None,
1165
+ attention_mask: Optional[torch.Tensor] = None,
1166
+ token_type_ids: Optional[torch.Tensor] = None,
1167
+ position_ids: Optional[torch.Tensor] = None,
1168
+ output_attentions: Optional[bool] = None,
1169
+ output_hidden_states: Optional[bool] = None,
1170
+ return_dict: Optional[bool] = None,
1171
+ start_positions: Optional[torch.Tensor] = None,
1172
+ end_positions: Optional[torch.Tensor] = None,
1173
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1174
+ return_dict = (
1175
+ return_dict if return_dict is not None else self.config.use_return_dict
1176
+ )
1177
+
1178
+ (
1179
+ sequence_output,
1180
+ contextualized_embeddings,
1181
+ attention_probs,
1182
+ ) = self.get_contextualized_embeddings(input_ids, attention_mask)
1183
+ logits = self.head(sequence_output)
1184
+
1185
+ start_logits, end_logits = logits.split(1, dim=-1)
1186
+ start_logits = start_logits.squeeze(-1).contiguous()
1187
+ end_logits = end_logits.squeeze(-1).contiguous()
1188
+
1189
+ total_loss = None
1190
+ if start_positions is not None and end_positions is not None:
1191
+ # If we are on multi-GPU, split add a dimension
1192
+ if len(start_positions.size()) > 1:
1193
+ start_positions = start_positions.squeeze(-1)
1194
+ if len(end_positions.size()) > 1:
1195
+ end_positions = end_positions.squeeze(-1)
1196
+
1197
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1198
+ ignored_index = start_logits.size(1)
1199
+ start_positions = start_positions.clamp(0, ignored_index)
1200
+ end_positions = end_positions.clamp(0, ignored_index)
1201
+
1202
+ loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
1203
+ start_loss = loss_fct(start_logits, start_positions)
1204
+ end_loss = loss_fct(end_logits, end_positions)
1205
+ total_loss = (start_loss + end_loss) / 2
1206
+
1207
+ if not return_dict:
1208
+ output = (
1209
+ start_logits,
1210
+ end_logits,
1211
+ *([contextualized_embeddings] if output_hidden_states else []),
1212
+ *([attention_probs] if output_attentions else []),
1213
+ )
1214
+ return ((total_loss,) + output) if total_loss is not None else output
1215
+
1216
+ return QuestionAnsweringModelOutput(
1217
+ loss=total_loss,
1218
+ start_logits=start_logits,
1219
+ end_logits=end_logits,
1220
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
1221
+ attentions=attention_probs if output_attentions else None,
1222
+ )
1223
+
1224
+
1225
+ @add_start_docstrings(
1226
+ """
1227
+ LTG-BERT model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1228
+ softmax) e.g. for RocStories/SWAG tasks.
1229
+ """,
1230
+ LTG_BERT_START_DOCSTRING,
1231
+ )
1232
+ class LtgBertForMultipleChoice(LtgBertModel):
1233
+ _keys_to_ignore_on_load_unexpected = ["classifier"]
1234
+ _keys_to_ignore_on_load_missing = ["head"]
1235
+
1236
+ def __init__(self, config):
1237
+ super().__init__(config, add_mlm_layer=False)
1238
+
1239
+ self.num_labels = getattr(config, "num_labels", 2)
1240
+ self.head = Classifier(config, self.num_labels)
1241
+
1242
+ @add_start_docstrings_to_model_forward(
1243
+ LTG_BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
1244
+ )
1245
+ def forward(
1246
+ self,
1247
+ input_ids: Optional[torch.Tensor] = None,
1248
+ attention_mask: Optional[torch.Tensor] = None,
1249
+ token_type_ids: Optional[torch.Tensor] = None,
1250
+ position_ids: Optional[torch.Tensor] = None,
1251
+ labels: Optional[torch.Tensor] = None,
1252
+ output_attentions: Optional[bool] = None,
1253
+ output_hidden_states: Optional[bool] = None,
1254
+ return_dict: Optional[bool] = None,
1255
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
1256
+ return_dict = (
1257
+ return_dict if return_dict is not None else self.config.use_return_dict
1258
+ )
1259
+ num_choices = input_ids.shape[1]
1260
+
1261
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1))
1262
+ flat_attention_mask = (
1263
+ attention_mask.view(-1, attention_mask.size(-1))
1264
+ if attention_mask is not None
1265
+ else None
1266
+ )
1267
+
1268
+ (
1269
+ sequence_output,
1270
+ contextualized_embeddings,
1271
+ attention_probs,
1272
+ ) = self.get_contextualized_embeddings(flat_input_ids, flat_attention_mask)
1273
+ logits = self.head(sequence_output)
1274
+ reshaped_logits = logits.view(-1, num_choices)
1275
+
1276
+ loss = None
1277
+ if labels is not None:
1278
+ loss_fct = nn.CrossEntropyLoss()
1279
+ loss = loss_fct(reshaped_logits, labels)
1280
+
1281
+ if not return_dict:
1282
+ output = (
1283
+ reshaped_logits,
1284
+ *([contextualized_embeddings] if output_hidden_states else []),
1285
+ *([attention_probs] if output_attentions else []),
1286
+ )
1287
+ return ((loss,) + output) if loss is not None else output
1288
+
1289
+ return MultipleChoiceModelOutput(
1290
+ loss=loss,
1291
+ logits=reshaped_logits,
1292
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
1293
+ attentions=attention_probs if output_attentions else None,
1294
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b7628dc36c1c996847b8ac8e32cd959e09cfab55511e6dd935b904f2374f0b0
3
+ size 159209798
results.md ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Results
2
+
3
+ The results here are taken from running `score_predictions.py` from the [babylm evaluation pipeline](https://github.com/babylm/evaluation-pipeline-2024) on the `ELC_ParserBERT_10M_textonly_predictions.json.gz` file in this directory, which contains the predictions for the different evaluation tasks.
4
+
5
+ ## Overall Results
6
+
7
+ Here are the average results per section and the macroscore, compared with the baseline models:
8
+
9
+ | Model | BLiMP | BLiMP Supplement | EWoK | GLUE | *Macroaverage* |
10
+ | --- | --- | --- | --- | --- | --- |
11
+ | BabyLlama | 69.8 | 59.5 | 50.7 | 63.3 | 60.8 |
12
+ | LTG-BERT | 60.6 | 60.8 | 48.9 | 60.3 | 57.7 |
13
+ | ELC-ParserBERT | 59.6 | 57.7 | 63.1 | 44.5 | 56.2 |
14
+
15
+ ## The Breakdown Per Section
16
+
17
+ |glue subtask | Score |
18
+ |-------------- | ------- |
19
+ |cola (MCC) | 0.042 |
20
+ |sst2 | 0.502 |
21
+ |mrpc (F1) | 0.82 |
22
+ |qqp (F1) | 0 |
23
+ |mnli | 0.357 |
24
+ |mnli-mm | 0.355 |
25
+ |qnli | 0.491 |
26
+ |rte | 0.496 |
27
+ |boolq | 0.585 |
28
+ |multirc | 0.63 |
29
+ |wsc | 0.615 |
30
+ |*Average* | 0.445 |
31
+
32
+ | blimp subtask | Score |
33
+ | --------------------------------------------------- | ------- |
34
+ | adjunct_island | 0.712 |
35
+ | anaphor_gender_agreement | 0.593 |
36
+ | anaphor_number_agreement | 0.647 |
37
+ | animate_subject_passive | 0.594 |
38
+ | animate_subject_trans | 0.47 |
39
+ | causative | 0.726 |
40
+ | complex_NP_island | 0.447 |
41
+ | coordinate_structure_constraint_complex_left_branch | 0.39 |
42
+ | coordinate_structure_constraint_object_extraction | 0.806 |
43
+ | determiner_noun_agreement_1 | 0.793 |
44
+ | determiner_noun_agreement_2 | 0.936 |
45
+ | determiner_noun_agreement_irregular_1 | 0.467 |
46
+ | determiner_noun_agreement_irregular_2 | 0.394 |
47
+ | determiner_noun_agreement_with_adj_2 | 0.889 |
48
+ | determiner_noun_agreement_with_adj_irregular_1 | 0.834 |
49
+ | determiner_noun_agreement_with_adj_irregular_2 | 0.848 |
50
+ | determiner_noun_agreement_with_adjective_1 | 0.758 |
51
+ | distractor_agreement_relational_noun | 0.212 |
52
+ | distractor_agreement_relative_clause | 0.282 |
53
+ | drop_argument | 0.485 |
54
+ | ellipsis_n_bar_1 | 0.505 |
55
+ | ellipsis_n_bar_2 | 0.342 |
56
+ | existential_there_object_raising | 0.447 |
57
+ | existential_there_quantifiers_1 | 0.385 |
58
+ | existential_there_quantifiers_2 | 0.396 |
59
+ | existential_there_subject_raising | 0.476 |
60
+ | expletive_it_object_raising | 0.44 |
61
+ | inchoative | 0.527 |
62
+ | intransitive | 0.484 |
63
+ | irregular_past_participle_adjectives | 0.348 |
64
+ | irregular_past_participle_verbs | 0.594 |
65
+ | irregular_plural_subject_verb_agreement_1 | 0.634 |
66
+ | irregular_plural_subject_verb_agreement_2 | 0.687 |
67
+ | left_branch_island_echo_question | 0.634 |
68
+ | left_branch_island_simple_question | 0.615 |
69
+ | matrix_question_npi_licensor_present | 0.206 |
70
+ | npi_present_1 | 0.362 |
71
+ | npi_present_2 | 0.347 |
72
+ | only_npi_licensor_present | 0.964 |
73
+ | only_npi_scope | 0.89 |
74
+ | passive_1 | 0.514 |
75
+ | passive_2 | 0.482 |
76
+ | principle_A_c_command | 0.635 |
77
+ | principle_A_case_1 | 0.999 |
78
+ | principle_A_case_2 | 0.78 |
79
+ | principle_A_domain_1 | 0.893 |
80
+ | principle_A_domain_2 | 0.623 |
81
+ | principle_A_domain_3 | 0.556 |
82
+ | principle_A_reconstruction | 0.339 |
83
+ | regular_plural_subject_verb_agreement_1 | 0.628 |
84
+ | regular_plural_subject_verb_agreement_2 | 0.663 |
85
+ | sentential_negation_npi_licensor_present | 0.93 |
86
+ | sentential_negation_npi_scope | 0.722 |
87
+ | sentential_subject_island | 0.361 |
88
+ | superlative_quantifiers_1 | 0.702 |
89
+ | superlative_quantifiers_2 | 0.498 |
90
+ | tough_vs_raising_1 | 0.351 |
91
+ | tough_vs_raising_2 | 0.648 |
92
+ | transitive | 0.645 |
93
+ | wh_island | 0.719 |
94
+ | wh_questions_object_gap | 0.657 |
95
+ | wh_questions_subject_gap | 0.861 |
96
+ | wh_questions_subject_gap_long_distance | 0.937 |
97
+ | wh_vs_that_no_gap | 0.969 |
98
+ | wh_vs_that_no_gap_long_distance | 0.969 |
99
+ | wh_vs_that_with_gap | 0.222 |
100
+ | wh_vs_that_with_gap_long_distance | 0.063 |
101
+ | *Average* | 0.596 |
102
+
103
+ | blimp_supplement subtask | Score |
104
+ | -------------------------- | ------- |
105
+ | hypernym | 0.531 |
106
+ | qa_congruence_easy | 0.641 |
107
+ | qa_congruence_tricky | 0.521 |
108
+ | subject_aux_inversion | 0.614 |
109
+ | turn_taking | 0.579 |
110
+ | *Average* | 0.577 |
111
+
112
+ | ewok subtask | Score |
113
+ | ----------------------- | ------- |
114
+ | agent-properties | 0.738 |
115
+ | material-dynamics | 0.81 |
116
+ | material-properties | 0.6 |
117
+ | physical-dynamics | 0.383 |
118
+ | physical-interactions | 0.599 |
119
+ | physical-relations | 0.817 |
120
+ | quantitative-properties | 0.427 |
121
+ | social-interactions | 0.565 |
122
+ | social-properties | 0.561 |
123
+ | social-relations | 0.807 |
124
+ | spatial-relations | 0.635 |
125
+ | *Average* | 0.631 |
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "model_max_length": 1000000000000000019884624838656,
3
+ "tokenizer_class": "PreTrainedTokenizerFast"
4
+ }