SeyedAli commited on
Commit
10e889f
·
1 Parent(s): c8484f6

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -1,35 +1,17 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
2
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
 
 
 
4
  *.h5 filter=lfs diff=lfs merge=lfs -text
5
+ *.tflite filter=lfs diff=lfs merge=lfs -text
6
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.ot filter=lfs diff=lfs merge=lfs -text
8
+ *.onnx filter=lfs diff=lfs merge=lfs -text
9
+ *.arrow filter=lfs diff=lfs merge=lfs -text
10
+ *.ftz filter=lfs diff=lfs merge=lfs -text
11
  *.joblib filter=lfs diff=lfs merge=lfs -text
 
 
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
14
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
15
  *.pt filter=lfs diff=lfs merge=lfs -text
16
  *.pth filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
17
  *tfevents* filter=lfs diff=lfs merge=lfs -text
1_Pooling/config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 384,
3
+ "pooling_mode_cls_token": false,
4
+ "pooling_mode_mean_tokens": true,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false
7
+ }
README.md CHANGED
@@ -1,3 +1,249 @@
1
  ---
2
- license: mit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ pipeline_tag: sentence-similarity
3
+ tags:
4
+ - sentence-transformers
5
+ - feature-extraction
6
+ - sentence-similarity
7
+ datasets:
8
+ - flax-sentence-embeddings/stackexchange_xml
9
+ - ms_marco
10
+ - gooaq
11
+ - yahoo_answers_topics
12
+ - search_qa
13
+ - eli5
14
+ - natural_questions
15
+ - trivia_qa
16
+ - embedding-data/QQP
17
+ - embedding-data/PAQ_pairs
18
+ - embedding-data/Amazon-QA
19
+ - embedding-data/WikiAnswers
20
  ---
21
+
22
+ # multi-qa-MiniLM-L6-cos-v1
23
+ This is a [sentence-transformers](https://www.SBERT.net) model: It maps sentences & paragraphs to a 384 dimensional dense vector space and was designed for **semantic search**. It has been trained on 215M (question, answer) pairs from diverse sources. For an introduction to semantic search, have a look at: [SBERT.net - Semantic Search](https://www.sbert.net/examples/applications/semantic-search/README.html)
24
+
25
+
26
+ ## Usage (Sentence-Transformers)
27
+ Using this model becomes easy when you have [sentence-transformers](https://www.SBERT.net) installed:
28
+
29
+ ```
30
+ pip install -U sentence-transformers
31
+ ```
32
+
33
+ Then you can use the model like this:
34
+ ```python
35
+ from sentence_transformers import SentenceTransformer, util
36
+
37
+ query = "How many people live in London?"
38
+ docs = ["Around 9 Million people live in London", "London is known for its financial district"]
39
+
40
+ #Load the model
41
+ model = SentenceTransformer('sentence-transformers/multi-qa-MiniLM-L6-cos-v1')
42
+
43
+ #Encode query and documents
44
+ query_emb = model.encode(query)
45
+ doc_emb = model.encode(docs)
46
+
47
+ #Compute dot score between query and all document embeddings
48
+ scores = util.dot_score(query_emb, doc_emb)[0].cpu().tolist()
49
+
50
+ #Combine docs & scores
51
+ doc_score_pairs = list(zip(docs, scores))
52
+
53
+ #Sort by decreasing score
54
+ doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
55
+
56
+ #Output passages & scores
57
+ for doc, score in doc_score_pairs:
58
+ print(score, doc)
59
+ ```
60
+
61
+
62
+ ## PyTorch Usage (HuggingFace Transformers)
63
+ Without [sentence-transformers](https://www.SBERT.net), you can use the model like this: First, you pass your input through the transformer model, then you have to apply the correct pooling-operation on-top of the contextualized word embeddings.
64
+
65
+ ```python
66
+ from transformers import AutoTokenizer, AutoModel
67
+ import torch
68
+ import torch.nn.functional as F
69
+
70
+ #Mean Pooling - Take average of all tokens
71
+ def mean_pooling(model_output, attention_mask):
72
+ token_embeddings = model_output.last_hidden_state
73
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
74
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
75
+
76
+
77
+ #Encode text
78
+ def encode(texts):
79
+ # Tokenize sentences
80
+ encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
81
+
82
+ # Compute token embeddings
83
+ with torch.no_grad():
84
+ model_output = model(**encoded_input, return_dict=True)
85
+
86
+ # Perform pooling
87
+ embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
88
+
89
+ # Normalize embeddings
90
+ embeddings = F.normalize(embeddings, p=2, dim=1)
91
+
92
+ return embeddings
93
+
94
+
95
+ # Sentences we want sentence embeddings for
96
+ query = "How many people live in London?"
97
+ docs = ["Around 9 Million people live in London", "London is known for its financial district"]
98
+
99
+ # Load model from HuggingFace Hub
100
+ tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
101
+ model = AutoModel.from_pretrained("sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
102
+
103
+ #Encode query and docs
104
+ query_emb = encode(query)
105
+ doc_emb = encode(docs)
106
+
107
+ #Compute dot score between query and all document embeddings
108
+ scores = torch.mm(query_emb, doc_emb.transpose(0, 1))[0].cpu().tolist()
109
+
110
+ #Combine docs & scores
111
+ doc_score_pairs = list(zip(docs, scores))
112
+
113
+ #Sort by decreasing score
114
+ doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
115
+
116
+ #Output passages & scores
117
+ for doc, score in doc_score_pairs:
118
+ print(score, doc)
119
+ ```
120
+
121
+ ## TensorFlow Usage (HuggingFace Transformers)
122
+ Similarly to the PyTorch example above, to use the model with TensorFlow you pass your input through the transformer model, then you have to apply the correct pooling-operation on-top of the contextualized word embeddings.
123
+
124
+ ```python
125
+ from transformers import AutoTokenizer, TFAutoModel
126
+ import tensorflow as tf
127
+
128
+ #Mean Pooling - Take attention mask into account for correct averaging
129
+ def mean_pooling(model_output, attention_mask):
130
+ token_embeddings = model_output.last_hidden_state
131
+ input_mask_expanded = tf.cast(tf.tile(tf.expand_dims(attention_mask, -1), [1, 1, token_embeddings.shape[-1]]), tf.float32)
132
+ return tf.math.reduce_sum(token_embeddings * input_mask_expanded, 1) / tf.math.maximum(tf.math.reduce_sum(input_mask_expanded, 1), 1e-9)
133
+
134
+
135
+ #Encode text
136
+ def encode(texts):
137
+ # Tokenize sentences
138
+ encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='tf')
139
+
140
+ # Compute token embeddings
141
+ model_output = model(**encoded_input, return_dict=True)
142
+
143
+ # Perform pooling
144
+ embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
145
+
146
+ # Normalize embeddings
147
+ embeddings = tf.math.l2_normalize(embeddings, axis=1)
148
+
149
+ return embeddings
150
+
151
+
152
+ # Sentences we want sentence embeddings for
153
+ query = "How many people live in London?"
154
+ docs = ["Around 9 Million people live in London", "London is known for its financial district"]
155
+
156
+ # Load model from HuggingFace Hub
157
+ tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
158
+ model = TFAutoModel.from_pretrained("sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
159
+
160
+ #Encode query and docs
161
+ query_emb = encode(query)
162
+ doc_emb = encode(docs)
163
+
164
+ #Compute dot score between query and all document embeddings
165
+ scores = (query_emb @ tf.transpose(doc_emb))[0].numpy().tolist()
166
+
167
+ #Combine docs & scores
168
+ doc_score_pairs = list(zip(docs, scores))
169
+
170
+ #Sort by decreasing score
171
+ doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
172
+
173
+ #Output passages & scores
174
+ for doc, score in doc_score_pairs:
175
+ print(score, doc)
176
+ ```
177
+
178
+ ## Technical Details
179
+
180
+ In the following some technical details how this model must be used:
181
+
182
+ | Setting | Value |
183
+ | --- | :---: |
184
+ | Dimensions | 384 |
185
+ | Produces normalized embeddings | Yes |
186
+ | Pooling-Method | Mean pooling |
187
+ | Suitable score functions | dot-product (`util.dot_score`), cosine-similarity (`util.cos_sim`), or euclidean distance |
188
+
189
+ Note: When loaded with `sentence-transformers`, this model produces normalized embeddings with length 1. In that case, dot-product and cosine-similarity are equivalent. dot-product is preferred as it is faster. Euclidean distance is proportional to dot-product and can also be used.
190
+
191
+ ----
192
+
193
+
194
+ ## Background
195
+
196
+ The project aims to train sentence embedding models on very large sentence level datasets using a self-supervised
197
+ contrastive learning objective. We use a contrastive learning objective: given a sentence from the pair, the model should predict which out of a set of randomly sampled other sentences, was actually paired with it in our dataset.
198
+
199
+ We developped this model during the
200
+ [Community week using JAX/Flax for NLP & CV](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104),
201
+ organized by Hugging Face. We developped this model as part of the project:
202
+ [Train the Best Sentence Embedding Model Ever with 1B Training Pairs](https://discuss.huggingface.co/t/train-the-best-sentence-embedding-model-ever-with-1b-training-pairs/7354). We benefited from efficient hardware infrastructure to run the project: 7 TPUs v3-8, as well as intervention from Googles Flax, JAX, and Cloud team member about efficient deep learning frameworks.
203
+
204
+ ## Intended uses
205
+
206
+ Our model is intented to be used for semantic search: It encodes queries / questions and text paragraphs in a dense vector space. It finds relevant documents for the given passages.
207
+
208
+ Note that there is a limit of 512 word pieces: Text longer than that will be truncated. Further note that the model was just trained on input text up to 250 word pieces. It might not work well for longer text.
209
+
210
+
211
+
212
+ ## Training procedure
213
+
214
+ The full training script is accessible in this current repository: `train_script.py`.
215
+
216
+ ### Pre-training
217
+
218
+ We use the pretrained [`nreimers/MiniLM-L6-H384-uncased`](https://huggingface.co/nreimers/MiniLM-L6-H384-uncased) model. Please refer to the model card for more detailed information about the pre-training procedure.
219
+
220
+ #### Training
221
+
222
+ We use the concatenation from multiple datasets to fine-tune our model. In total we have about 215M (question, answer) pairs.
223
+ We sampled each dataset given a weighted probability which configuration is detailed in the `data_config.json` file.
224
+
225
+ The model was trained with [MultipleNegativesRankingLoss](https://www.sbert.net/docs/package_reference/losses.html#multiplenegativesrankingloss) using Mean-pooling, cosine-similarity as similarity function, and a scale of 20.
226
+
227
+
228
+
229
+
230
+ | Dataset | Number of training tuples |
231
+ |--------------------------------------------------------|:--------------------------:|
232
+ | [WikiAnswers](https://github.com/afader/oqa#wikianswers-corpus) Duplicate question pairs from WikiAnswers | 77,427,422 |
233
+ | [PAQ](https://github.com/facebookresearch/PAQ) Automatically generated (Question, Paragraph) pairs for each paragraph in Wikipedia | 64,371,441 |
234
+ | [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) (Title, Body) pairs from all StackExchanges | 25,316,456 |
235
+ | [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) (Title, Answer) pairs from all StackExchanges | 21,396,559 |
236
+ | [MS MARCO](https://microsoft.github.io/msmarco/) Triplets (query, answer, hard_negative) for 500k queries from Bing search engine | 17,579,773 |
237
+ | [GOOAQ: Open Question Answering with Diverse Answer Types](https://github.com/allenai/gooaq) (query, answer) pairs for 3M Google queries and Google featured snippet | 3,012,496 |
238
+ | [Amazon-QA](http://jmcauley.ucsd.edu/data/amazon/qa/) (Question, Answer) pairs from Amazon product pages | 2,448,839
239
+ | [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Title, Answer) pairs from Yahoo Answers | 1,198,260 |
240
+ | [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Question, Answer) pairs from Yahoo Answers | 681,164 |
241
+ | [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Title, Question) pairs from Yahoo Answers | 659,896 |
242
+ | [SearchQA](https://huggingface.co/datasets/search_qa) (Question, Answer) pairs for 140k questions, each with Top5 Google snippets on that question | 582,261 |
243
+ | [ELI5](https://huggingface.co/datasets/eli5) (Question, Answer) pairs from Reddit ELI5 (explainlikeimfive) | 325,475 |
244
+ | [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) Duplicate questions pairs (titles) | 304,525 |
245
+ | [Quora Question Triplets](https://quoradata.quora.com/First-Quora-Dataset-Release-Question-Pairs) (Question, Duplicate_Question, Hard_Negative) triplets for Quora Questions Pairs dataset | 103,663 |
246
+ | [Natural Questions (NQ)](https://ai.google.com/research/NaturalQuestions) (Question, Paragraph) pairs for 100k real Google queries with relevant Wikipedia paragraph | 100,231 |
247
+ | [SQuAD2.0](https://rajpurkar.github.io/SQuAD-explorer/) (Question, Paragraph) pairs from SQuAD2.0 dataset | 87,599 |
248
+ | [TriviaQA](https://huggingface.co/datasets/trivia_qa) (Question, Evidence) pairs | 73,346 |
249
+ | **Total** | **214,988,242** |
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "nreimers/MiniLM-L6-H384-uncased",
3
+ "architectures": [
4
+ "BertModel"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "gradient_checkpointing": false,
8
+ "hidden_act": "gelu",
9
+ "hidden_dropout_prob": 0.1,
10
+ "hidden_size": 384,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 1536,
13
+ "layer_norm_eps": 1e-12,
14
+ "max_position_embeddings": 512,
15
+ "model_type": "bert",
16
+ "num_attention_heads": 12,
17
+ "num_hidden_layers": 6,
18
+ "pad_token_id": 0,
19
+ "position_embedding_type": "absolute",
20
+ "transformers_version": "4.8.2",
21
+ "type_vocab_size": 2,
22
+ "use_cache": true,
23
+ "vocab_size": 30522
24
+ }
config_sentence_transformers.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "sentence_transformers": "2.0.0",
4
+ "transformers": "4.6.1",
5
+ "pytorch": "1.8.1"
6
+ }
7
+ }
data_config.json ADDED
@@ -0,0 +1,942 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "name": "stackexchange_title_body/skeptics.stackexchange.com.jsonl.gz",
4
+ "lines": 10009,
5
+ "weight": 3
6
+ },
7
+ {
8
+ "name": "stackexchange_Title_Answer/islam.stackexchange.com.jsonl.gz",
9
+ "lines": 10052,
10
+ "weight": 3
11
+ },
12
+ {
13
+ "name": "stackexchange_Title_Answer/anime.stackexchange.com.jsonl.gz",
14
+ "lines": 10131,
15
+ "weight": 3
16
+ },
17
+ {
18
+ "name": "stackexchange_title_body/writers.stackexchange.com.jsonl.gz",
19
+ "lines": 10157,
20
+ "weight": 3
21
+ },
22
+ {
23
+ "name": "stackexchange_title_body/astronomy.stackexchange.com.jsonl.gz",
24
+ "lines": 10462,
25
+ "weight": 3
26
+ },
27
+ {
28
+ "name": "stackexchange_title_body/vi.stackexchange.com.jsonl.gz",
29
+ "lines": 10551,
30
+ "weight": 3
31
+ },
32
+ {
33
+ "name": "stackexchange_Title_Answer/french.stackexchange.com.jsonl.gz",
34
+ "lines": 10578,
35
+ "weight": 3
36
+ },
37
+ {
38
+ "name": "stackexchange_title_body/cstheory.stackexchange.com.jsonl.gz",
39
+ "lines": 10642,
40
+ "weight": 3
41
+ },
42
+ {
43
+ "name": "stackexchange_Title_Answer/civicrm.stackexchange.com.jsonl.gz",
44
+ "lines": 10648,
45
+ "weight": 3
46
+ },
47
+ {
48
+ "name": "stackexchange_Title_Answer/expressionengine.stackexchange.com.jsonl.gz",
49
+ "lines": 10742,
50
+ "weight": 3
51
+ },
52
+ {
53
+ "name": "stackexchange_title_body/engineering.stackexchange.com.jsonl.gz",
54
+ "lines": 10753,
55
+ "weight": 3
56
+ },
57
+ {
58
+ "name": "stackexchange_Title_Answer/history.stackexchange.com.jsonl.gz",
59
+ "lines": 10766,
60
+ "weight": 3
61
+ },
62
+ {
63
+ "name": "stackexchange_title_body/french.stackexchange.com.jsonl.gz",
64
+ "lines": 10794,
65
+ "weight": 3
66
+ },
67
+ {
68
+ "name": "stackexchange_Title_Answer/politics.stackexchange.com.jsonl.gz",
69
+ "lines": 11047,
70
+ "weight": 3
71
+ },
72
+ {
73
+ "name": "stackexchange_title_body/economics.stackexchange.com.jsonl.gz",
74
+ "lines": 11115,
75
+ "weight": 3
76
+ },
77
+ {
78
+ "name": "stackexchange_Title_Answer/craftcms.stackexchange.com.jsonl.gz",
79
+ "lines": 11236,
80
+ "weight": 3
81
+ },
82
+ {
83
+ "name": "stackexchange_title_body/anime.stackexchange.com.jsonl.gz",
84
+ "lines": 11444,
85
+ "weight": 3
86
+ },
87
+ {
88
+ "name": "stackexchange_Title_Answer/christianity.stackexchange.com.jsonl.gz",
89
+ "lines": 11498,
90
+ "weight": 3
91
+ },
92
+ {
93
+ "name": "stackexchange_Title_Answer/softwarerecs.stackexchange.com.jsonl.gz",
94
+ "lines": 11761,
95
+ "weight": 3
96
+ },
97
+ {
98
+ "name": "stackexchange_Title_Answer/boardgames.stackexchange.com.jsonl.gz",
99
+ "lines": 11805,
100
+ "weight": 3
101
+ },
102
+ {
103
+ "name": "stackexchange_title_body/islam.stackexchange.com.jsonl.gz",
104
+ "lines": 11853,
105
+ "weight": 3
106
+ },
107
+ {
108
+ "name": "stackexchange_title_body/expressionengine.stackexchange.com.jsonl.gz",
109
+ "lines": 11866,
110
+ "weight": 3
111
+ },
112
+ {
113
+ "name": "stackexchange_title_body/politics.stackexchange.com.jsonl.gz",
114
+ "lines": 11894,
115
+ "weight": 3
116
+ },
117
+ {
118
+ "name": "stackexchange_title_body/history.stackexchange.com.jsonl.gz",
119
+ "lines": 12021,
120
+ "weight": 3
121
+ },
122
+ {
123
+ "name": "stackexchange_title_body/christianity.stackexchange.com.jsonl.gz",
124
+ "lines": 12108,
125
+ "weight": 3
126
+ },
127
+ {
128
+ "name": "stackexchange_title_body/boardgames.stackexchange.com.jsonl.gz",
129
+ "lines": 12149,
130
+ "weight": 3
131
+ },
132
+ {
133
+ "name": "stackexchange_title_body/civicrm.stackexchange.com.jsonl.gz",
134
+ "lines": 12543,
135
+ "weight": 3
136
+ },
137
+ {
138
+ "name": "stackexchange_title_body/craftcms.stackexchange.com.jsonl.gz",
139
+ "lines": 12574,
140
+ "weight": 3
141
+ },
142
+ {
143
+ "name": "stackexchange_Title_Answer/networkengineering.stackexchange.com.jsonl.gz",
144
+ "lines": 12590,
145
+ "weight": 3
146
+ },
147
+ {
148
+ "name": "stackexchange_Title_Answer/space.stackexchange.com.jsonl.gz",
149
+ "lines": 12893,
150
+ "weight": 3
151
+ },
152
+ {
153
+ "name": "stackexchange_Title_Answer/quant.stackexchange.com.jsonl.gz",
154
+ "lines": 12933,
155
+ "weight": 3
156
+ },
157
+ {
158
+ "name": "stackexchange_Title_Answer/philosophy.stackexchange.com.jsonl.gz",
159
+ "lines": 13114,
160
+ "weight": 3
161
+ },
162
+ {
163
+ "name": "stackexchange_Title_Answer/gardening.stackexchange.com.jsonl.gz",
164
+ "lines": 13246,
165
+ "weight": 3
166
+ },
167
+ {
168
+ "name": "stackexchange_title_body/hinduism.stackexchange.com.jsonl.gz",
169
+ "lines": 13450,
170
+ "weight": 4
171
+ },
172
+ {
173
+ "name": "stackexchange_title_body/networkengineering.stackexchange.com.jsonl.gz",
174
+ "lines": 13454,
175
+ "weight": 4
176
+ },
177
+ {
178
+ "name": "stackexchange_Title_Answer/german.stackexchange.com.jsonl.gz",
179
+ "lines": 13733,
180
+ "weight": 4
181
+ },
182
+ {
183
+ "name": "stackexchange_title_body/german.stackexchange.com.jsonl.gz",
184
+ "lines": 13950,
185
+ "weight": 4
186
+ },
187
+ {
188
+ "name": "stackexchange_title_body/philosophy.stackexchange.com.jsonl.gz",
189
+ "lines": 14829,
190
+ "weight": 4
191
+ },
192
+ {
193
+ "name": "stackexchange_title_body/gardening.stackexchange.com.jsonl.gz",
194
+ "lines": 15136,
195
+ "weight": 4
196
+ },
197
+ {
198
+ "name": "stackexchange_title_body/space.stackexchange.com.jsonl.gz",
199
+ "lines": 15142,
200
+ "weight": 4
201
+ },
202
+ {
203
+ "name": "stackexchange_Title_Answer/bicycles.stackexchange.com.jsonl.gz",
204
+ "lines": 15708,
205
+ "weight": 4
206
+ },
207
+ {
208
+ "name": "stackexchange_Title_Answer/law.stackexchange.com.jsonl.gz",
209
+ "lines": 16133,
210
+ "weight": 4
211
+ },
212
+ {
213
+ "name": "stackexchange_Title_Answer/arduino.stackexchange.com.jsonl.gz",
214
+ "lines": 16281,
215
+ "weight": 4
216
+ },
217
+ {
218
+ "name": "stackexchange_title_body/bicycles.stackexchange.com.jsonl.gz",
219
+ "lines": 16353,
220
+ "weight": 4
221
+ },
222
+ {
223
+ "name": "stackexchange_Title_Answer/emacs.stackexchange.com.jsonl.gz",
224
+ "lines": 16830,
225
+ "weight": 4
226
+ },
227
+ {
228
+ "name": "stackexchange_title_body/quant.stackexchange.com.jsonl.gz",
229
+ "lines": 17261,
230
+ "weight": 4
231
+ },
232
+ {
233
+ "name": "stackexchange_Title_Answer/dsp.stackexchange.com.jsonl.gz",
234
+ "lines": 17430,
235
+ "weight": 4
236
+ },
237
+ {
238
+ "name": "stackexchange_Title_Answer/puzzling.stackexchange.com.jsonl.gz",
239
+ "lines": 17448,
240
+ "weight": 4
241
+ },
242
+ {
243
+ "name": "stackexchange_title_body/puzzling.stackexchange.com.jsonl.gz",
244
+ "lines": 17851,
245
+ "weight": 5
246
+ },
247
+ {
248
+ "name": "stackexchange_title_body/law.stackexchange.com.jsonl.gz",
249
+ "lines": 17941,
250
+ "weight": 5
251
+ },
252
+ {
253
+ "name": "stackexchange_Title_Answer/movies.stackexchange.com.jsonl.gz",
254
+ "lines": 18243,
255
+ "weight": 5
256
+ },
257
+ {
258
+ "name": "stackexchange_Title_Answer/mechanics.stackexchange.com.jsonl.gz",
259
+ "lines": 18613,
260
+ "weight": 5
261
+ },
262
+ {
263
+ "name": "stackexchange_Title_Answer/aviation.stackexchange.com.jsonl.gz",
264
+ "lines": 18755,
265
+ "weight": 5
266
+ },
267
+ {
268
+ "name": "stackexchange_Title_Answer/biology.stackexchange.com.jsonl.gz",
269
+ "lines": 19277,
270
+ "weight": 5
271
+ },
272
+ {
273
+ "name": "stackexchange_Title_Answer/crypto.stackexchange.com.jsonl.gz",
274
+ "lines": 19404,
275
+ "weight": 5
276
+ },
277
+ {
278
+ "name": "stackexchange_title_body/arduino.stackexchange.com.jsonl.gz",
279
+ "lines": 19553,
280
+ "weight": 5
281
+ },
282
+ {
283
+ "name": "stackexchange_Title_Answer/music.stackexchange.com.jsonl.gz",
284
+ "lines": 19936,
285
+ "weight": 5
286
+ },
287
+ {
288
+ "name": "stackexchange_title_body/aviation.stackexchange.com.jsonl.gz",
289
+ "lines": 20139,
290
+ "weight": 5
291
+ },
292
+ {
293
+ "name": "stackexchange_title_body/softwarerecs.stackexchange.com.jsonl.gz",
294
+ "lines": 20142,
295
+ "weight": 5
296
+ },
297
+ {
298
+ "name": "stackexchange_title_body/movies.stackexchange.com.jsonl.gz",
299
+ "lines": 20181,
300
+ "weight": 5
301
+ },
302
+ {
303
+ "name": "stackexchange_Title_Answer/datascience.stackexchange.com.jsonl.gz",
304
+ "lines": 20503,
305
+ "weight": 5
306
+ },
307
+ {
308
+ "name": "stackexchange_title_body/music.stackexchange.com.jsonl.gz",
309
+ "lines": 20636,
310
+ "weight": 5
311
+ },
312
+ {
313
+ "name": "stackexchange_Title_Answer/japanese.stackexchange.com.jsonl.gz",
314
+ "lines": 20948,
315
+ "weight": 5
316
+ },
317
+ {
318
+ "name": "stackexchange_title_body/emacs.stackexchange.com.jsonl.gz",
319
+ "lines": 21055,
320
+ "weight": 5
321
+ },
322
+ {
323
+ "name": "stackexchange_title_body/dsp.stackexchange.com.jsonl.gz",
324
+ "lines": 21252,
325
+ "weight": 5
326
+ },
327
+ {
328
+ "name": "stackexchange_title_body/japanese.stackexchange.com.jsonl.gz",
329
+ "lines": 22056,
330
+ "weight": 5
331
+ },
332
+ {
333
+ "name": "stackexchange_Title_Answer/bitcoin.stackexchange.com.jsonl.gz",
334
+ "lines": 22474,
335
+ "weight": 6
336
+ },
337
+ {
338
+ "name": "stackexchange_Title_Answer/cooking.stackexchange.com.jsonl.gz",
339
+ "lines": 22641,
340
+ "weight": 6
341
+ },
342
+ {
343
+ "name": "stackexchange_title_body/mechanics.stackexchange.com.jsonl.gz",
344
+ "lines": 22868,
345
+ "weight": 6
346
+ },
347
+ {
348
+ "name": "stackexchange_Title_Answer/photo.stackexchange.com.jsonl.gz",
349
+ "lines": 23204,
350
+ "weight": 6
351
+ },
352
+ {
353
+ "name": "stackexchange_title_body/crypto.stackexchange.com.jsonl.gz",
354
+ "lines": 23231,
355
+ "weight": 6
356
+ },
357
+ {
358
+ "name": "stackexchange_title_body/cooking.stackexchange.com.jsonl.gz",
359
+ "lines": 23705,
360
+ "weight": 6
361
+ },
362
+ {
363
+ "name": "stackexchange_title_body/photo.stackexchange.com.jsonl.gz",
364
+ "lines": 23753,
365
+ "weight": 6
366
+ },
367
+ {
368
+ "name": "stackexchange_Title_Answer/workplace.stackexchange.com.jsonl.gz",
369
+ "lines": 24012,
370
+ "weight": 6
371
+ },
372
+ {
373
+ "name": "stackexchange_Title_Answer/meta.stackoverflow.com.jsonl.gz",
374
+ "lines": 24044,
375
+ "weight": 6
376
+ },
377
+ {
378
+ "name": "stackexchange_Title_Answer/raspberrypi.stackexchange.com.jsonl.gz",
379
+ "lines": 24143,
380
+ "weight": 6
381
+ },
382
+ {
383
+ "name": "stackexchange_title_body/workplace.stackexchange.com.jsonl.gz",
384
+ "lines": 24189,
385
+ "weight": 6
386
+ },
387
+ {
388
+ "name": "stackexchange_title_body/biology.stackexchange.com.jsonl.gz",
389
+ "lines": 24447,
390
+ "weight": 6
391
+ },
392
+ {
393
+ "name": "stackexchange_Title_Answer/webapps.stackexchange.com.jsonl.gz",
394
+ "lines": 24867,
395
+ "weight": 6
396
+ },
397
+ {
398
+ "name": "stackexchange_title_body/bitcoin.stackexchange.com.jsonl.gz",
399
+ "lines": 25374,
400
+ "weight": 6
401
+ },
402
+ {
403
+ "name": "stackexchange_Title_Answer/judaism.stackexchange.com.jsonl.gz",
404
+ "lines": 26085,
405
+ "weight": 6
406
+ },
407
+ {
408
+ "name": "stackexchange_Title_Answer/ethereum.stackexchange.com.jsonl.gz",
409
+ "lines": 26124,
410
+ "weight": 6
411
+ },
412
+ {
413
+ "name": "stackexchange_Title_Answer/worldbuilding.stackexchange.com.jsonl.gz",
414
+ "lines": 26210,
415
+ "weight": 6
416
+ },
417
+ {
418
+ "name": "stackexchange_title_body/worldbuilding.stackexchange.com.jsonl.gz",
419
+ "lines": 26763,
420
+ "weight": 7
421
+ },
422
+ {
423
+ "name": "stackexchange_Title_Answer/chemistry.stackexchange.com.jsonl.gz",
424
+ "lines": 27061,
425
+ "weight": 7
426
+ },
427
+ {
428
+ "name": "stackexchange_title_body/datascience.stackexchange.com.jsonl.gz",
429
+ "lines": 27397,
430
+ "weight": 7
431
+ },
432
+ {
433
+ "name": "stackexchange_Title_Answer/graphicdesign.stackexchange.com.jsonl.gz",
434
+ "lines": 28083,
435
+ "weight": 7
436
+ },
437
+ {
438
+ "name": "stackexchange_Title_Answer/ux.stackexchange.com.jsonl.gz",
439
+ "lines": 28901,
440
+ "weight": 7
441
+ },
442
+ {
443
+ "name": "stackexchange_title_body/ux.stackexchange.com.jsonl.gz",
444
+ "lines": 29403,
445
+ "weight": 7
446
+ },
447
+ {
448
+ "name": "stackexchange_Title_Answer/money.stackexchange.com.jsonl.gz",
449
+ "lines": 29404,
450
+ "weight": 7
451
+ },
452
+ {
453
+ "name": "stackexchange_title_body/webapps.stackexchange.com.jsonl.gz",
454
+ "lines": 29697,
455
+ "weight": 7
456
+ },
457
+ {
458
+ "name": "stackexchange_Title_Answer/cs.stackexchange.com.jsonl.gz",
459
+ "lines": 30010,
460
+ "weight": 7
461
+ },
462
+ {
463
+ "name": "stackexchange_title_body/graphicdesign.stackexchange.com.jsonl.gz",
464
+ "lines": 30233,
465
+ "weight": 7
466
+ },
467
+ {
468
+ "name": "stackexchange_Title_Answer/webmasters.stackexchange.com.jsonl.gz",
469
+ "lines": 30370,
470
+ "weight": 7
471
+ },
472
+ {
473
+ "name": "stackexchange_title_body/raspberrypi.stackexchange.com.jsonl.gz",
474
+ "lines": 30625,
475
+ "weight": 7
476
+ },
477
+ {
478
+ "name": "stackexchange_title_body/money.stackexchange.com.jsonl.gz",
479
+ "lines": 32021,
480
+ "weight": 8
481
+ },
482
+ {
483
+ "name": "stackexchange_title_body/judaism.stackexchange.com.jsonl.gz",
484
+ "lines": 32028,
485
+ "weight": 8
486
+ },
487
+ {
488
+ "name": "stackexchange_Title_Answer/academia.stackexchange.com.jsonl.gz",
489
+ "lines": 32137,
490
+ "weight": 8
491
+ },
492
+ {
493
+ "name": "stackexchange_title_body/ethereum.stackexchange.com.jsonl.gz",
494
+ "lines": 32760,
495
+ "weight": 8
496
+ },
497
+ {
498
+ "name": "stackexchange_title_body/academia.stackexchange.com.jsonl.gz",
499
+ "lines": 34331,
500
+ "weight": 8
501
+ },
502
+ {
503
+ "name": "stackexchange_title_body/chemistry.stackexchange.com.jsonl.gz",
504
+ "lines": 34506,
505
+ "weight": 8
506
+ },
507
+ {
508
+ "name": "stackexchange_title_body/webmasters.stackexchange.com.jsonl.gz",
509
+ "lines": 34559,
510
+ "weight": 8
511
+ },
512
+ {
513
+ "name": "stackexchange_title_body/meta.stackoverflow.com.jsonl.gz",
514
+ "lines": 36456,
515
+ "weight": 9
516
+ },
517
+ {
518
+ "name": "stackexchange_Title_Answer/travel.stackexchange.com.jsonl.gz",
519
+ "lines": 36533,
520
+ "weight": 9
521
+ },
522
+ {
523
+ "name": "stackexchange_Title_Answer/android.stackexchange.com.jsonl.gz",
524
+ "lines": 38077,
525
+ "weight": 9
526
+ },
527
+ {
528
+ "name": "stackexchange_title_body/cs.stackexchange.com.jsonl.gz",
529
+ "lines": 38314,
530
+ "weight": 9
531
+ },
532
+ {
533
+ "name": "stackexchange_Title_Answer/gamedev.stackexchange.com.jsonl.gz",
534
+ "lines": 40154,
535
+ "weight": 10
536
+ },
537
+ {
538
+ "name": "stackexchange_Title_Answer/rpg.stackexchange.com.jsonl.gz",
539
+ "lines": 40435,
540
+ "weight": 10
541
+ },
542
+ {
543
+ "name": "stackexchange_title_body/travel.stackexchange.com.jsonl.gz",
544
+ "lines": 41227,
545
+ "weight": 10
546
+ },
547
+ {
548
+ "name": "stackexchange_Title_Answer/codereview.stackexchange.com.jsonl.gz",
549
+ "lines": 41748,
550
+ "weight": 10
551
+ },
552
+ {
553
+ "name": "stackexchange_title_body/rpg.stackexchange.com.jsonl.gz",
554
+ "lines": 42303,
555
+ "weight": 10
556
+ },
557
+ {
558
+ "name": "stackexchange_title_body/codereview.stackexchange.com.jsonl.gz",
559
+ "lines": 45765,
560
+ "weight": 11
561
+ },
562
+ {
563
+ "name": "stackexchange_title_body/gamedev.stackexchange.com.jsonl.gz",
564
+ "lines": 46485,
565
+ "weight": 11
566
+ },
567
+ {
568
+ "name": "stackexchange_Title_Answer/softwareengineering.stackexchange.com.jsonl.gz",
569
+ "lines": 51326,
570
+ "weight": 12
571
+ },
572
+ {
573
+ "name": "stackexchange_Title_Answer/security.stackexchange.com.jsonl.gz",
574
+ "lines": 51355,
575
+ "weight": 12
576
+ },
577
+ {
578
+ "name": "stackexchange_title_body/android.stackexchange.com.jsonl.gz",
579
+ "lines": 51608,
580
+ "weight": 12
581
+ },
582
+ {
583
+ "name": "stackexchange_Title_Answer/diy.stackexchange.com.jsonl.gz",
584
+ "lines": 52896,
585
+ "weight": 12
586
+ },
587
+ {
588
+ "name": "stackexchange_title_body/softwareengineering.stackexchange.com.jsonl.gz",
589
+ "lines": 53942,
590
+ "weight": 13
591
+ },
592
+ {
593
+ "name": "stackexchange_Title_Answer/blender.stackexchange.com.jsonl.gz",
594
+ "lines": 54153,
595
+ "weight": 13
596
+ },
597
+ {
598
+ "name": "stackexchange_Title_Answer/scifi.stackexchange.com.jsonl.gz",
599
+ "lines": 54805,
600
+ "weight": 13
601
+ },
602
+ {
603
+ "name": "stackexchange_title_body/security.stackexchange.com.jsonl.gz",
604
+ "lines": 58000,
605
+ "weight": 14
606
+ },
607
+ {
608
+ "name": "stackexchange_Title_Answer/mathematica.stackexchange.com.jsonl.gz",
609
+ "lines": 59895,
610
+ "weight": 14
611
+ },
612
+ {
613
+ "name": "stackexchange_title_body/diy.stackexchange.com.jsonl.gz",
614
+ "lines": 60083,
615
+ "weight": 14
616
+ },
617
+ {
618
+ "name": "stackexchange_Title_Answer/meta.stackexchange.com.jsonl.gz",
619
+ "lines": 60744,
620
+ "weight": 14
621
+ },
622
+ {
623
+ "name": "stackexchange_title_body/scifi.stackexchange.com.jsonl.gz",
624
+ "lines": 61528,
625
+ "weight": 14
626
+ },
627
+ {
628
+ "name": "stackexchange_Title_Answer/drupal.stackexchange.com.jsonl.gz",
629
+ "lines": 67817,
630
+ "weight": 16
631
+ },
632
+ {
633
+ "name": "stackexchange_Title_Answer/dba.stackexchange.com.jsonl.gz",
634
+ "lines": 71449,
635
+ "weight": 17
636
+ },
637
+ {
638
+ "name": "stackexchange_title_body/mathematica.stackexchange.com.jsonl.gz",
639
+ "lines": 73131,
640
+ "weight": 17
641
+ },
642
+ {
643
+ "name": "stackexchange_Title_Answer/ell.stackexchange.com.jsonl.gz",
644
+ "lines": 77892,
645
+ "weight": 18
646
+ },
647
+ {
648
+ "name": "stackexchange_Title_Answer/magento.stackexchange.com.jsonl.gz",
649
+ "lines": 79241,
650
+ "weight": 18
651
+ },
652
+ {
653
+ "name": "stackexchange_title_body/drupal.stackexchange.com.jsonl.gz",
654
+ "lines": 79717,
655
+ "weight": 18
656
+ },
657
+ {
658
+ "name": "stackexchange_Title_Answer/sharepoint.stackexchange.com.jsonl.gz",
659
+ "lines": 80420,
660
+ "weight": 19
661
+ },
662
+ {
663
+ "name": "stackexchange_title_body/blender.stackexchange.com.jsonl.gz",
664
+ "lines": 80766,
665
+ "weight": 19
666
+ },
667
+ {
668
+ "name": "stackexchange_title_body/dba.stackexchange.com.jsonl.gz",
669
+ "lines": 81871,
670
+ "weight": 19
671
+ },
672
+ {
673
+ "name": "stackexchange_Title_Answer/gaming.stackexchange.com.jsonl.gz",
674
+ "lines": 82887,
675
+ "weight": 19
676
+ },
677
+ {
678
+ "name": "stackexchange_title_body/ell.stackexchange.com.jsonl.gz",
679
+ "lines": 83271,
680
+ "weight": 19
681
+ },
682
+ {
683
+ "name": "stackexchange_title_body/meta.stackexchange.com.jsonl.gz",
684
+ "lines": 83510,
685
+ "weight": 19
686
+ },
687
+ {
688
+ "name": "stackexchange_Title_Answer/wordpress.stackexchange.com.jsonl.gz",
689
+ "lines": 83621,
690
+ "weight": 19
691
+ },
692
+ {
693
+ "name": "stackexchange_Title_Answer/mathoverflow.net.jsonl.gz",
694
+ "lines": 85289,
695
+ "weight": 20
696
+ },
697
+ {
698
+ "name": "stackexchange_Title_Answer/salesforce.stackexchange.com.jsonl.gz",
699
+ "lines": 87272,
700
+ "weight": 20
701
+ },
702
+ {
703
+ "name": "stackexchange_title_body/gaming.stackexchange.com.jsonl.gz",
704
+ "lines": 88912,
705
+ "weight": 21
706
+ },
707
+ {
708
+ "name": "stackexchange_Title_Answer/apple.stackexchange.com.jsonl.gz",
709
+ "lines": 92487,
710
+ "weight": 21
711
+ },
712
+ {
713
+ "name": "stackexchange_title_body/sharepoint.stackexchange.com.jsonl.gz",
714
+ "lines": 94011,
715
+ "weight": 22
716
+ },
717
+ {
718
+ "name": "stackexchange_title_body/magento.stackexchange.com.jsonl.gz",
719
+ "lines": 99991,
720
+ "weight": 23
721
+ },
722
+ {
723
+ "name": "stackexchange_Title_Answer/gis.stackexchange.com.jsonl.gz",
724
+ "lines": 100254,
725
+ "weight": 23
726
+ },
727
+ {
728
+ "name": "stackexchange_title_body/wordpress.stackexchange.com.jsonl.gz",
729
+ "lines": 100474,
730
+ "weight": 23
731
+ },
732
+ {
733
+ "name": "stackexchange_Title_Answer/english.stackexchange.com.jsonl.gz",
734
+ "lines": 100640,
735
+ "weight": 23
736
+ },
737
+ {
738
+ "name": "stackexchange_title_body/salesforce.stackexchange.com.jsonl.gz",
739
+ "lines": 105260,
740
+ "weight": 24
741
+ },
742
+ {
743
+ "name": "stackexchange_title_body/english.stackexchange.com.jsonl.gz",
744
+ "lines": 109522,
745
+ "weight": 25
746
+ },
747
+ {
748
+ "name": "stackexchange_title_body/apple.stackexchange.com.jsonl.gz",
749
+ "lines": 110622,
750
+ "weight": 25
751
+ },
752
+ {
753
+ "name": "stackexchange_Title_Answer/stats.stackexchange.com.jsonl.gz",
754
+ "lines": 115679,
755
+ "weight": 27
756
+ },
757
+ {
758
+ "name": "stackexchange_title_body/mathoverflow.net.jsonl.gz",
759
+ "lines": 120851,
760
+ "weight": 28
761
+ },
762
+ {
763
+ "name": "stackexchange_Title_Answer/electronics.stackexchange.com.jsonl.gz",
764
+ "lines": 129494,
765
+ "weight": 30
766
+ },
767
+ {
768
+ "name": "stackexchange_title_body/gis.stackexchange.com.jsonl.gz",
769
+ "lines": 131000,
770
+ "weight": 30
771
+ },
772
+ {
773
+ "name": "stackexchange_Title_Answer/physics.stackexchange.com.jsonl.gz",
774
+ "lines": 141230,
775
+ "weight": 32
776
+ },
777
+ {
778
+ "name": "stackexchange_title_body/electronics.stackexchange.com.jsonl.gz",
779
+ "lines": 143582,
780
+ "weight": 33
781
+ },
782
+ {
783
+ "name": "TriviaQA_pairs.jsonl.gz",
784
+ "lines": 73346,
785
+ "weight": 34
786
+ },
787
+ {
788
+ "name": "stackexchange_Title_Answer/unix.stackexchange.com.jsonl.gz",
789
+ "lines": 155414,
790
+ "weight": 36
791
+ },
792
+ {
793
+ "name": "stackexchange_Title_Answer/tex.stackexchange.com.jsonl.gz",
794
+ "lines": 171628,
795
+ "weight": 39
796
+ },
797
+ {
798
+ "name": "squad_pairs.jsonl.gz",
799
+ "lines": 87599,
800
+ "weight": 40
801
+ },
802
+ {
803
+ "name": "stackexchange_title_body/physics.stackexchange.com.jsonl.gz",
804
+ "lines": 173307,
805
+ "weight": 40
806
+ },
807
+ {
808
+ "name": "stackexchange_title_body/stats.stackexchange.com.jsonl.gz",
809
+ "lines": 173466,
810
+ "weight": 40
811
+ },
812
+ {
813
+ "name": "stackexchange_title_body/unix.stackexchange.com.jsonl.gz",
814
+ "lines": 185997,
815
+ "weight": 42
816
+ },
817
+ {
818
+ "name": "NQ-train_pairs.jsonl.gz",
819
+ "lines": 100231,
820
+ "weight": 46
821
+ },
822
+ {
823
+ "name": "stackexchange_title_body/tex.stackexchange.com.jsonl.gz",
824
+ "lines": 202954,
825
+ "weight": 46
826
+ },
827
+ {
828
+ "name": "quora_duplicates_triplets.jsonl.gz",
829
+ "lines": 103663,
830
+ "weight": 47
831
+ },
832
+ {
833
+ "name": "stackexchange_Title_Answer/serverfault.com.jsonl.gz",
834
+ "lines": 238507,
835
+ "weight": 54
836
+ },
837
+ {
838
+ "name": "stackexchange_Title_Answer/askubuntu.com.jsonl.gz",
839
+ "lines": 267135,
840
+ "weight": 61
841
+ },
842
+ {
843
+ "name": "stackexchange_title_body/serverfault.com.jsonl.gz",
844
+ "lines": 270904,
845
+ "weight": 62
846
+ },
847
+ {
848
+ "name": "stackexchange_duplicate_questions_title_title.jsonl.gz",
849
+ "lines": 304525,
850
+ "weight": 69
851
+ },
852
+ {
853
+ "name": "stackexchange_title_body/askubuntu.com.jsonl.gz",
854
+ "lines": 347925,
855
+ "weight": 79
856
+ },
857
+ {
858
+ "name": "stackexchange_Title_Answer/superuser.com.jsonl.gz",
859
+ "lines": 352610,
860
+ "weight": 80
861
+ },
862
+ {
863
+ "name": "stackexchange_title_body/superuser.com.jsonl.gz",
864
+ "lines": 435463,
865
+ "weight": 99
866
+ },
867
+ {
868
+ "name": "stackexchange_title_body/small_stackexchanges.jsonl.gz",
869
+ "lines": 448146,
870
+ "weight": 102
871
+ },
872
+ {
873
+ "name": "stackexchange_Title_Answer/small_stackexchanges.jsonl.gz",
874
+ "lines": 460256,
875
+ "weight": 104
876
+ },
877
+ {
878
+ "name": "eli5_question_answer.jsonl.gz",
879
+ "lines": 325475,
880
+ "weight": 147
881
+ },
882
+ {
883
+ "name": "yahoo_answers_title_question.jsonl.gz",
884
+ "lines": 659896,
885
+ "weight": 149
886
+ },
887
+ {
888
+ "name": "PAQ_pairs.jsonl.gz",
889
+ "lines": 64371441,
890
+ "weight": 150
891
+ },
892
+ {
893
+ "name": "WikiAnswers_pairs.jsonl.gz",
894
+ "lines": 77427422,
895
+ "weight": 150
896
+ },
897
+ {
898
+ "name": "stackexchange_Title_Answer/math.stackexchange.com.jsonl.gz",
899
+ "lines": 1100953,
900
+ "weight": 226
901
+ },
902
+ {
903
+ "name": "yahoo_answers_title_answer.jsonl.gz",
904
+ "lines": 1198260,
905
+ "weight": 226
906
+ },
907
+ {
908
+ "name": "stackexchange_title_body/math.stackexchange.com.jsonl.gz",
909
+ "lines": 1338443,
910
+ "weight": 226
911
+ },
912
+ {
913
+ "name": "stackexchange_Title_Answer/stackoverflow.com-Posts.jsonl.gz",
914
+ "lines": 15768211,
915
+ "weight": 226
916
+ },
917
+ {
918
+ "name": "stackexchange_title_body/stackoverflow.com-Posts.jsonl.gz",
919
+ "lines": 18562443,
920
+ "weight": 226
921
+ },
922
+ {
923
+ "name": "searchQA_question_top5_snippets_merged.jsonl.gz",
924
+ "lines": 582261,
925
+ "weight": 263
926
+ },
927
+ {
928
+ "name": "amazon-qa-train-pairs.jsonl.gz",
929
+ "lines": 2448839,
930
+ "weight": 451
931
+ },
932
+ {
933
+ "name": "gooaq_pairs.jsonl.gz",
934
+ "lines": 3012496,
935
+ "weight": 451
936
+ },
937
+ {
938
+ "name": "msmarco-query_passage_negative_v2.jsonl.gz",
939
+ "lines": 17579773,
940
+ "weight": 1000
941
+ }
942
+ ]
modules.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.models.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ },
14
+ {
15
+ "idx": 2,
16
+ "name": "2",
17
+ "path": "2_Normalize",
18
+ "type": "sentence_transformers.models.Normalize"
19
+ }
20
+ ]
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df507ec1743de52aa0a1b401183c5fd8ca18e846689421ea6c94cae014d9b26b
3
+ size 90888945
sentence_bert_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "max_seq_length": 512,
3
+ "do_lower_case": false
4
+ }
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
tf_model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d356e3a2867336eb0c469be0151e1edf79666e6399dd050a4c382e3c407432a2
3
+ size 91005696
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "special_tokens_map_file": null, "name_or_path": "nreimers/MiniLM-L6-H384-uncased", "do_basic_tokenize": true, "never_split": null, "tokenizer_class": "BertTokenizer", "model_max_length": 512}
train_script.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train script for a single file
3
+
4
+ Need to set the TPU address first:
5
+ export XRT_TPU_CONFIG="localservice;0;localhost:51011"
6
+ """
7
+
8
+ import torch.multiprocessing as mp
9
+ import threading
10
+ import time
11
+ import random
12
+ import sys
13
+ import argparse
14
+ import gzip
15
+ import json
16
+ import logging
17
+ import tqdm
18
+ import torch
19
+ from torch import nn
20
+ from torch.utils.data import DataLoader
21
+ import torch
22
+ import torch_xla
23
+ import torch_xla.core
24
+ import torch_xla.core.functions
25
+ import torch_xla.core.xla_model as xm
26
+ import torch_xla.distributed.xla_multiprocessing as xmp
27
+ import torch_xla.distributed.parallel_loader as pl
28
+ import os
29
+ from shutil import copyfile
30
+
31
+
32
+ from transformers import (
33
+ AdamW,
34
+ AutoModel,
35
+ AutoTokenizer,
36
+ get_linear_schedule_with_warmup,
37
+ set_seed,
38
+ )
39
+
40
+ class AutoModelForSentenceEmbedding(nn.Module):
41
+ def __init__(self, model_name, tokenizer, args):
42
+ super(AutoModelForSentenceEmbedding, self).__init__()
43
+
44
+ assert args.pooling in ['mean', 'cls']
45
+
46
+ self.model = AutoModel.from_pretrained(model_name)
47
+ self.normalize = not args.no_normalize
48
+ self.tokenizer = tokenizer
49
+ self.pooling = args.pooling
50
+
51
+ def forward(self, **kwargs):
52
+ model_output = self.model(**kwargs)
53
+ if self.pooling == 'mean':
54
+ embeddings = self.mean_pooling(model_output, kwargs['attention_mask'])
55
+ elif self.pooling == 'cls':
56
+ embeddings = self.cls_pooling(model_output, kwargs['attention_mask'])
57
+
58
+ if self.normalize:
59
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
60
+
61
+ return embeddings
62
+
63
+ def mean_pooling(self, model_output, attention_mask):
64
+ token_embeddings = model_output[0] # First element of model_output contains all token embeddings
65
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
66
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
67
+
68
+ def cls_pooling(self, model_output, attention_mask):
69
+ return model_output[0][:,0]
70
+
71
+ def save_pretrained(self, output_path):
72
+ if xm.is_master_ordinal():
73
+ self.tokenizer.save_pretrained(output_path)
74
+ self.model.config.save_pretrained(output_path)
75
+
76
+ xm.save(self.model.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
77
+
78
+
79
+
80
+
81
+ def train_function(index, args, queue):
82
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
83
+ model = AutoModelForSentenceEmbedding(args.model, tokenizer, args)
84
+
85
+
86
+ ### Train Loop
87
+ device = xm.xla_device()
88
+ model = model.to(device)
89
+
90
+ # Instantiate optimizer
91
+ optimizer = AdamW(params=model.parameters(), lr=2e-5, correct_bias=True)
92
+
93
+ lr_scheduler = get_linear_schedule_with_warmup(
94
+ optimizer=optimizer,
95
+ num_warmup_steps=500,
96
+ num_training_steps=args.steps,
97
+ )
98
+
99
+ # Now we train the model
100
+ cross_entropy_loss = nn.CrossEntropyLoss()
101
+ max_grad_norm = 1
102
+
103
+ model.train()
104
+
105
+ for global_step in tqdm.trange(args.steps, disable=not xm.is_master_ordinal()):
106
+ #### Get the batch data
107
+ batch = queue.get()
108
+ #print(index, "batch {}x{}".format(len(batch), ",".join([str(len(b)) for b in batch])))
109
+
110
+
111
+ if len(batch[0]) == 2: #(anchor, positive)
112
+ text1 = tokenizer([b[0] for b in batch], return_tensors="pt", max_length=args.max_length_a, truncation=True, padding="max_length")
113
+ text2 = tokenizer([b[1] for b in batch], return_tensors="pt", max_length=args.max_length_b, truncation=True, padding="max_length")
114
+
115
+ ### Compute embeddings
116
+ embeddings_a = model(**text1.to(device))
117
+ embeddings_b = model(**text2.to(device))
118
+
119
+ ### Gather all embedings
120
+ embeddings_a = torch_xla.core.functions.all_gather(embeddings_a)
121
+ embeddings_b = torch_xla.core.functions.all_gather(embeddings_b)
122
+
123
+ ### Compute similarity scores 512 x 512
124
+ scores = torch.mm(embeddings_a, embeddings_b.transpose(0, 1)) * args.scale
125
+
126
+ ### Compute cross-entropy loss
127
+ labels = torch.tensor(range(len(scores)), dtype=torch.long, device=embeddings_a.device) # Example a[i] should match with b[i]
128
+
129
+ ## Symmetric loss as in CLIP
130
+ loss = (cross_entropy_loss(scores, labels) + cross_entropy_loss(scores.transpose(0, 1), labels)) / 2
131
+
132
+ else: #(anchor, positive, negative)
133
+ text1 = tokenizer([b[0] for b in batch], return_tensors="pt", max_length=args.max_length_a, truncation=True, padding="max_length")
134
+ text2 = tokenizer([b[1] for b in batch], return_tensors="pt", max_length=args.max_length_b, truncation=True, padding="max_length")
135
+ text3 = tokenizer([b[2] for b in batch], return_tensors="pt", max_length=args.max_length_b, truncation=True, padding="max_length")
136
+
137
+ embeddings_a = model(**text1.to(device))
138
+ embeddings_b1 = model(**text2.to(device))
139
+ embeddings_b2 = model(**text3.to(device))
140
+
141
+ embeddings_a = torch_xla.core.functions.all_gather(embeddings_a)
142
+ embeddings_b1 = torch_xla.core.functions.all_gather(embeddings_b1)
143
+ embeddings_b2 = torch_xla.core.functions.all_gather(embeddings_b2)
144
+
145
+ embeddings_b = torch.cat([embeddings_b1, embeddings_b2])
146
+
147
+ ### Compute similarity scores 512 x 1024
148
+ scores = torch.mm(embeddings_a, embeddings_b.transpose(0, 1)) * args.scale
149
+
150
+ ### Compute cross-entropy loss
151
+ labels = torch.tensor(range(len(scores)), dtype=torch.long, device=embeddings_a.device) # Example a[i] should match with b[i]
152
+
153
+ ## One-way loss
154
+ loss = cross_entropy_loss(scores, labels)
155
+
156
+
157
+ # Backward pass
158
+ optimizer.zero_grad()
159
+ loss.backward()
160
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
161
+
162
+ xm.optimizer_step(optimizer, barrier=True)
163
+ lr_scheduler.step()
164
+
165
+
166
+ #Save model
167
+ if (global_step+1) % args.save_steps == 0:
168
+ output_path = os.path.join(args.output, str(global_step+1))
169
+ xm.master_print("save model: "+output_path)
170
+ model.save_pretrained(output_path)
171
+
172
+
173
+ output_path = os.path.join(args.output, "final")
174
+ xm.master_print("save model final: "+ output_path)
175
+ model.save_pretrained(output_path)
176
+
177
+
178
+ def produce_data(args, queue, filepaths, dataset_indices):
179
+ global_batch_size = args.batch_size*args.nprocs #Global batch size
180
+ num_same_dataset = int(args.nprocs / args.datasets_per_batch)
181
+ print("producer", "global_batch_size", global_batch_size)
182
+ print("producer", "num_same_dataset", num_same_dataset)
183
+
184
+ datasets = []
185
+ for filepath in filepaths:
186
+ if "reddit_" in filepath: #Special dataset class for Reddit files
187
+ data_obj = RedditDataset(filepath)
188
+ else:
189
+ data_obj = Dataset(filepath)
190
+ datasets.append(iter(data_obj))
191
+
192
+ # Store if dataset is in a 2 col or 3 col format
193
+ num_cols = {idx: len(next(dataset)) for idx, dataset in enumerate(datasets)}
194
+
195
+ while True:
196
+ texts_in_batch = set()
197
+ batch_format = None #2 vs 3 col format for this batch
198
+
199
+ #Add data from several sub datasets
200
+ for _ in range(args.datasets_per_batch):
201
+ valid_dataset = False #Check that datasets have the same 2/3 col format
202
+ while not valid_dataset:
203
+ data_idx = random.choice(dataset_indices)
204
+ if batch_format is None:
205
+ batch_format = num_cols[data_idx]
206
+ valid_dataset = True
207
+ else: #Check that this dataset has the same format
208
+ valid_dataset = (batch_format == num_cols[data_idx])
209
+
210
+ #Get data from this dataset
211
+ dataset = datasets[data_idx]
212
+ local_batch_size = args.batch_size
213
+ if batch_format == 3 and args.batch_size_triplets is not None:
214
+ local_batch_size = args.batch_size_triplets
215
+
216
+ for _ in range(num_same_dataset):
217
+ for _ in range(args.nprocs):
218
+ batch_device = [] #A batch for one device
219
+ while len(batch_device) < local_batch_size:
220
+ sample = next(dataset)
221
+ in_batch = False
222
+ for text in sample:
223
+ if text in texts_in_batch:
224
+ in_batch = True
225
+ break
226
+
227
+ if not in_batch:
228
+ for text in sample:
229
+ texts_in_batch.add(text)
230
+ batch_device.append(sample)
231
+
232
+ queue.put(batch_device)
233
+
234
+
235
+ class RedditDataset:
236
+ """
237
+ A class that handles the reddit data files
238
+ """
239
+ def __init__(self, filepath):
240
+ self.filepath = filepath
241
+
242
+ def __iter__(self):
243
+ while True:
244
+ with gzip.open(self.filepath, "rt") as fIn:
245
+ for line in fIn:
246
+ data = json.loads(line)
247
+
248
+ if "response" in data and "context" in data:
249
+ yield [data["response"], data["context"]]
250
+
251
+ class Dataset:
252
+ """
253
+ A class that handles one dataset
254
+ """
255
+ def __init__(self, filepath):
256
+ self.filepath = filepath
257
+
258
+ def __iter__(self):
259
+ max_dataset_size = 20*1000*1000 #Cache small datasets in memory
260
+ dataset = []
261
+ data_format = None
262
+
263
+ while dataset is None or len(dataset) == 0:
264
+ with gzip.open(self.filepath, "rt") as fIn:
265
+ for line in fIn:
266
+ data = json.loads(line)
267
+ if isinstance(data, dict):
268
+ data = data['texts']
269
+
270
+ if data_format is None:
271
+ data_format = len(data)
272
+
273
+ #Ensure that all entries are of the same 2/3 col format
274
+ assert len(data) == data_format
275
+
276
+ if dataset is not None:
277
+ dataset.append(data)
278
+ if len(dataset) >= max_dataset_size:
279
+ dataset = None
280
+
281
+ yield data
282
+
283
+ # Data loaded. Now stream to the queue
284
+ # Shuffle for each epoch
285
+ while True:
286
+ random.shuffle(dataset)
287
+ for data in dataset:
288
+ yield data
289
+
290
+
291
+
292
+ if __name__ == "__main__":
293
+ parser = argparse.ArgumentParser()
294
+ parser.add_argument('--model', default='nreimers/MiniLM-L6-H384-uncased')
295
+ parser.add_argument('--steps', type=int, default=2000)
296
+ parser.add_argument('--save_steps', type=int, default=10000)
297
+ parser.add_argument('--batch_size', type=int, default=64)
298
+ parser.add_argument('--batch_size_triplets', type=int, default=None)
299
+ parser.add_argument('--max_length_a', type=int, default=128)
300
+ parser.add_argument('--max_length_b', type=int, default=128)
301
+ parser.add_argument('--nprocs', type=int, default=8)
302
+ parser.add_argument('--datasets_per_batch', type=int, default=2, help="Number of datasets per batch")
303
+ parser.add_argument('--scale', type=float, default=20, help="Use 20 for cossim, and 1 when you work with unnormalized embeddings with dot product")
304
+ parser.add_argument('--no_normalize', action="store_true", default=False, help="If set: Embeddings are not normalized")
305
+ parser.add_argument('--pooling', default='mean')
306
+ parser.add_argument('--data_folder', default="/data", help="Folder with your dataset files")
307
+ parser.add_argument('data_config', help="A data_config.json file")
308
+ parser.add_argument('output')
309
+ args = parser.parse_args()
310
+
311
+ # Ensure num proc is devisible by datasets_per_batch
312
+ assert (args.nprocs % args.datasets_per_batch) == 0
313
+
314
+
315
+ logging.info("Output: "+args.output)
316
+ if os.path.exists(args.output):
317
+ print("Output folder already exists.")
318
+ input("Continue?")
319
+
320
+ # Write train script to output path
321
+ os.makedirs(args.output, exist_ok=True)
322
+
323
+ data_config_path = os.path.join(args.output, 'data_config.json')
324
+ copyfile(args.data_config, data_config_path)
325
+
326
+ train_script_path = os.path.join(args.output, 'train_script.py')
327
+ copyfile(__file__, train_script_path)
328
+ with open(train_script_path, 'a') as fOut:
329
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
330
+
331
+
332
+
333
+ #Load data config
334
+ with open(args.data_config) as fIn:
335
+ data_config = json.load(fIn)
336
+
337
+ queue = mp.Queue(maxsize=100*args.nprocs)
338
+
339
+ filepaths = []
340
+ dataset_indices = []
341
+ for idx, data in enumerate(data_config):
342
+ filepaths.append(os.path.join(os.path.expanduser(args.data_folder), data['name']))
343
+ dataset_indices.extend([idx]*data['weight'])
344
+
345
+ # Start producer
346
+ p = mp.Process(target=produce_data, args=(args, queue, filepaths, dataset_indices))
347
+ p.start()
348
+
349
+ # Run training
350
+ print("Start processes:", args.nprocs)
351
+ xmp.spawn(train_function, args=(args, queue), nprocs=args.nprocs, start_method='fork')
352
+ print("Training done")
353
+ print("It might be that not all processes exit automatically. In that case you must manually kill this process.")
354
+ print("With 'pkill python' you can kill all remaining python processes")
355
+ p.kill()
356
+ exit()
357
+
358
+
359
+
360
+ # Script was called via:
361
+ #python train_many_data_files_v2.py --steps 200000 --batch_size 128 --model nreimers/MiniLM-L6-H384-uncased --max_length_a 64 --max_length_b 250 train_data_configs/multi-qa_v1.json output/multi-qa_v1-MiniLM-L6-mean_cos
vocab.txt ADDED
The diff for this file is too large to render. See raw diff