upload
Browse files- README.md +69 -0
- config.json +29 -0
- data_config.json +51 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +1 -0
- spiece.model +3 -0
- tokenizer.json +0 -0
- tokenizer_config.json +1 -0
- train_script.py +253 -0
- trainer_state.json +0 -0
README.md
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language: en
|
3 |
+
datasets:
|
4 |
+
- sentence-transformers/reddit-title-body
|
5 |
+
- sentence-transformers/embedding-training-data
|
6 |
+
widget:
|
7 |
+
- text: "Python is an interpreted, high-level and general-purpose programming language. Python's design philosophy emphasizes code readability with its notable use of significant whitespace. Its language constructs and object-oriented approach aim to help programmers write clear, logical code for small and large-scale projects."
|
8 |
+
|
9 |
+
license: apache-2.0
|
10 |
+
---
|
11 |
+
|
12 |
+
# doc2query/all-t5-base-v1
|
13 |
+
|
14 |
+
This is a [doc2query](https://arxiv.org/abs/1904.08375) model based on T5 (also known as [docT5query](https://cs.uwaterloo.ca/~jimmylin/publications/Nogueira_Lin_2019_docTTTTTquery-v2.pdf)).
|
15 |
+
|
16 |
+
It can be used for:
|
17 |
+
- **Document expansion**: You generate for your paragraphs 20-40 queries and index the paragraphs and the generates queries in a standard BM25 index like Elasticsearch, OpenSearch, or Lucene. The generated queries help to close the lexical gap of lexical search, as the generate queries contain synonyms. Further, it re-weights words giving important words a higher weight even if they appear seldomn in a paragraph. In our [BEIR](https://arxiv.org/abs/2104.08663) paper we showed that BM25+docT5query is a powerful search engine. In the [BEIR repository](https://github.com/UKPLab/beir) we have an example how to use docT5query with Pyserini.
|
18 |
+
- **Domain Specific Training Data Generation**: It can be used to generate training data to learn an embedding model. On [SBERT.net](https://www.sbert.net/examples/unsupervised_learning/query_generation/README.html) we have an example how to use the model to generate (query, text) pairs for a given collection of unlabeled texts. These pairs can then be used to train powerful dense embedding models.
|
19 |
+
|
20 |
+
## Usage
|
21 |
+
```python
|
22 |
+
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
23 |
+
|
24 |
+
model_name = 'doc2query/all-t5-base-v1'
|
25 |
+
tokenizer = T5Tokenizer.from_pretrained(model_name)
|
26 |
+
model = T5ForConditionalGeneration.from_pretrained(model_name)
|
27 |
+
|
28 |
+
text = "Python is an interpreted, high-level and general-purpose programming language. Python's design philosophy emphasizes code readability with its notable use of significant whitespace. Its language constructs and object-oriented approach aim to help programmers write clear, logical code for small and large-scale projects."
|
29 |
+
|
30 |
+
|
31 |
+
input_ids = tokenizer.encode(text, max_length=384, truncation=True, return_tensors='pt')
|
32 |
+
outputs = model.generate(
|
33 |
+
input_ids=input_ids,
|
34 |
+
max_length=64,
|
35 |
+
do_sample=True,
|
36 |
+
top_p=0.95,
|
37 |
+
num_return_sequences=5)
|
38 |
+
|
39 |
+
print("Text:")
|
40 |
+
print(text)
|
41 |
+
|
42 |
+
print("\nGenerated Queries:")
|
43 |
+
for i in range(len(outputs)):
|
44 |
+
query = tokenizer.decode(outputs[i], skip_special_tokens=True)
|
45 |
+
print(f'{i + 1}: {query}')
|
46 |
+
```
|
47 |
+
|
48 |
+
**Note:** `model.generate()` is non-deterministic. It produces different queries each time you run it.
|
49 |
+
|
50 |
+
## Training
|
51 |
+
This model fine-tuned [google/t5-v1_1-base](https://huggingface.co/google/t5-v1_1-base) for 570k training steps. For the training script, see the `train_script.py` in this repository.
|
52 |
+
|
53 |
+
The input-text was truncated to 384 word pieces. Output text was generated up to 64 word pieces.
|
54 |
+
|
55 |
+
This model was trained on a large collection of datasets. For the exact datasets names and weights see the `data_config.json` in this repository. Most of the datasets are available at [https://huggingface.co/sentence-transformers](https://huggingface.co/sentence-transformers).
|
56 |
+
|
57 |
+
The datasets include besides others:
|
58 |
+
- (title, body) pairs from [Reddit](https://huggingface.co/datasets/sentence-transformers/reddit-title-body)
|
59 |
+
- (title, body) pairs and (title, answer) pairs from StackExchange and Yahoo Answers!
|
60 |
+
- (title, review) pairs from Amazon reviews
|
61 |
+
- (query, paragraph) pairs from MS MARCO, NQ, and GooAQ
|
62 |
+
- (question, duplicate_question) from Quora and WikiAnswers
|
63 |
+
- (title, abstract) pairs from S2ORC
|
64 |
+
|
65 |
+
## Prefix
|
66 |
+
|
67 |
+
This model was trained **without a prefix**. In contrast to [doc2query/all-with_prefix-t5-base-v1](https://huggingface.co/doc2query/all-with_prefix-t5-base-v1) you cannot specify what type of transformation (answer2question, review2title) etc. you will have. This can lead to a mixture of output values.
|
68 |
+
|
69 |
+
|
config.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "google/t5-v1_1-base",
|
3 |
+
"architectures": [
|
4 |
+
"T5ForConditionalGeneration"
|
5 |
+
],
|
6 |
+
"d_ff": 2048,
|
7 |
+
"d_kv": 64,
|
8 |
+
"d_model": 768,
|
9 |
+
"decoder_start_token_id": 0,
|
10 |
+
"dropout_rate": 0.1,
|
11 |
+
"eos_token_id": 1,
|
12 |
+
"feed_forward_proj": "gated-gelu",
|
13 |
+
"gradient_checkpointing": false,
|
14 |
+
"initializer_factor": 1.0,
|
15 |
+
"is_encoder_decoder": true,
|
16 |
+
"layer_norm_epsilon": 1e-06,
|
17 |
+
"model_type": "t5",
|
18 |
+
"num_decoder_layers": 12,
|
19 |
+
"num_heads": 12,
|
20 |
+
"num_layers": 12,
|
21 |
+
"output_past": true,
|
22 |
+
"pad_token_id": 0,
|
23 |
+
"relative_attention_num_buckets": 32,
|
24 |
+
"tie_word_embeddings": false,
|
25 |
+
"torch_dtype": "float32",
|
26 |
+
"transformers_version": "4.10.2",
|
27 |
+
"use_cache": true,
|
28 |
+
"vocab_size": 32128
|
29 |
+
}
|
data_config.json
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"text2reddit": {
|
3 |
+
"weight": 30,
|
4 |
+
"files": {"/home/reddit/reddit-title-body/reddit_title_text_2010.jsonl.gz": 1,
|
5 |
+
"/home/reddit/reddit-title-body/reddit_title_text_2011.jsonl.gz": 1,
|
6 |
+
"/home/reddit/reddit-title-body/reddit_title_text_2012.jsonl.gz": 3,
|
7 |
+
"/home/reddit/reddit-title-body/reddit_title_text_2013.jsonl.gz": 5,
|
8 |
+
"/home/reddit/reddit-title-body/reddit_title_text_2014.jsonl.gz": 8,
|
9 |
+
"/home/reddit/reddit-title-body/reddit_title_text_2015.jsonl.gz": 11,
|
10 |
+
"/home/reddit/reddit-title-body/reddit_title_text_2016.jsonl.gz": 12,
|
11 |
+
"/home/reddit/reddit-title-body/reddit_title_text_2017.jsonl.gz": 13,
|
12 |
+
"/home/reddit/reddit-title-body/reddit_title_text_2018.jsonl.gz": 15,
|
13 |
+
"/home/reddit/reddit-title-body/reddit_title_text_2019.jsonl.gz": 19,
|
14 |
+
"/home/reddit/reddit-title-body/reddit_title_text_2020.jsonl.gz": 23,
|
15 |
+
"/home/reddit/reddit-title-body/reddit_title_text_2021.jsonl.gz": 12
|
16 |
+
}
|
17 |
+
|
18 |
+
},
|
19 |
+
"question2title": {
|
20 |
+
"weight": 30,
|
21 |
+
"files": {"/home/sbert_pretrained_models/datasets/embedding-training-data/yahoo_answers_title_question.jsonl.gz": 40,
|
22 |
+
"/home/stackexchange_archive/jsonl/skeptics.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/writers.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/astronomy.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/vi.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/cstheory.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/engineering.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/french.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/economics.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/anime.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/islam.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/expressionengine.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/politics.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/history.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/christianity.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/boardgames.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/civicrm.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/craftcms.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/hinduism.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/networkengineering.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/german.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/philosophy.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/gardening.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/space.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/bicycles.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/quant.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/puzzling.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/law.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/arduino.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/aviation.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/softwarerecs.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/movies.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/music.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/emacs.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/dsp.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/japanese.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/mechanics.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/crypto.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/cooking.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/photo.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/workplace.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/biology.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/bitcoin.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/worldbuilding.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/datascience.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/ux.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/webapps.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/graphicdesign.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/raspberrypi.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/money.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/judaism.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/ethereum.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/academia.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/chemistry.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/webmasters.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/meta.stackoverflow.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/cs.stackexchange.com.jsonl.gz": 4, "/home/stackexchange_archive/jsonl/travel.stackexchange.com.jsonl.gz": 4, "/home/stackexchange_archive/jsonl/rpg.stackexchange.com.jsonl.gz": 4, "/home/stackexchange_archive/jsonl/codereview.stackexchange.com.jsonl.gz": 4, "/home/stackexchange_archive/jsonl/gamedev.stackexchange.com.jsonl.gz": 4, "/home/stackexchange_archive/jsonl/android.stackexchange.com.jsonl.gz": 5, "/home/stackexchange_archive/jsonl/softwareengineering.stackexchange.com.jsonl.gz": 5, "/home/stackexchange_archive/jsonl/security.stackexchange.com.jsonl.gz": 5, "/home/stackexchange_archive/jsonl/diy.stackexchange.com.jsonl.gz": 5, "/home/stackexchange_archive/jsonl/scifi.stackexchange.com.jsonl.gz": 6, "/home/stackexchange_archive/jsonl/mathematica.stackexchange.com.jsonl.gz": 7, "/home/stackexchange_archive/jsonl/drupal.stackexchange.com.jsonl.gz": 7, "/home/stackexchange_archive/jsonl/blender.stackexchange.com.jsonl.gz": 7, "/home/stackexchange_archive/jsonl/dba.stackexchange.com.jsonl.gz": 7, "/home/stackexchange_archive/jsonl/ell.stackexchange.com.jsonl.gz": 7, "/home/stackexchange_archive/jsonl/meta.stackexchange.com.jsonl.gz": 7, "/home/stackexchange_archive/jsonl/gaming.stackexchange.com.jsonl.gz": 8, "/home/stackexchange_archive/jsonl/sharepoint.stackexchange.com.jsonl.gz": 8, "/home/stackexchange_archive/jsonl/magento.stackexchange.com.jsonl.gz": 9, "/home/stackexchange_archive/jsonl/wordpress.stackexchange.com.jsonl.gz": 9, "/home/stackexchange_archive/jsonl/salesforce.stackexchange.com.jsonl.gz": 9, "/home/stackexchange_archive/jsonl/english.stackexchange.com.jsonl.gz": 10, "/home/stackexchange_archive/jsonl/apple.stackexchange.com.jsonl.gz": 10, "/home/stackexchange_archive/jsonl/mathoverflow.net.jsonl.gz": 10, "/home/stackexchange_archive/jsonl/gis.stackexchange.com.jsonl.gz": 11, "/home/stackexchange_archive/jsonl/electronics.stackexchange.com.jsonl.gz": 12, "/home/stackexchange_archive/jsonl/physics.stackexchange.com.jsonl.gz": 15, "/home/stackexchange_archive/jsonl/stats.stackexchange.com.jsonl.gz": 15, "/home/stackexchange_archive/jsonl/unix.stackexchange.com.jsonl.gz": 16, "/home/stackexchange_archive/jsonl/tex.stackexchange.com.jsonl.gz": 17, "/home/stackexchange_archive/jsonl/serverfault.com.jsonl.gz": 23, "/home/stackexchange_archive/jsonl/askubuntu.com.jsonl.gz": 29, "/home/stackexchange_archive/jsonl/superuser.com.jsonl.gz": 36, "/home/stackexchange_archive/jsonl/small_stackexchanges.jsonl.gz": 37, "/home/stackexchange_archive/jsonl/math.stackexchange.com.jsonl.gz": 83, "/home/stackexchange_archive/jsonl/stackoverflow.com-Posts.jsonl.gz": 83
|
23 |
+
}
|
24 |
+
},
|
25 |
+
"answer2question": {
|
26 |
+
"weight": 30,
|
27 |
+
"files": {"/home/sbert_pretrained_models/datasets/embedding-training-data/yahoo_answers_title_answer.jsonl.gz": 80, "/home/sbert_pretrained_models/datasets/embedding-training-data/amazon-qa.jsonl.gz": 80,
|
28 |
+
"/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/islam.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/anime.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/french.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/civicrm.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/expressionengine.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/history.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/politics.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/craftcms.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/christianity.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/softwarerecs.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/boardgames.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/networkengineering.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/space.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/quant.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/philosophy.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/gardening.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/german.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/bicycles.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/law.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/arduino.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/emacs.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/dsp.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/puzzling.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/movies.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/mechanics.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/aviation.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/biology.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/crypto.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/music.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/datascience.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/japanese.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/bitcoin.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/cooking.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/photo.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/workplace.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/meta.stackoverflow.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/raspberrypi.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/webapps.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/judaism.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/ethereum.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/worldbuilding.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/chemistry.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/graphicdesign.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/ux.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/money.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/cs.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/webmasters.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/academia.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/travel.stackexchange.com.jsonl.gz": 4, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/android.stackexchange.com.jsonl.gz": 4, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/gamedev.stackexchange.com.jsonl.gz": 4, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/rpg.stackexchange.com.jsonl.gz": 4, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/codereview.stackexchange.com.jsonl.gz": 4, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/softwareengineering.stackexchange.com.jsonl.gz": 5, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/security.stackexchange.com.jsonl.gz": 5, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/diy.stackexchange.com.jsonl.gz": 5, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/blender.stackexchange.com.jsonl.gz": 5, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/scifi.stackexchange.com.jsonl.gz": 5, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/mathematica.stackexchange.com.jsonl.gz": 5, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/meta.stackexchange.com.jsonl.gz": 5, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/drupal.stackexchange.com.jsonl.gz": 6, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/dba.stackexchange.com.jsonl.gz": 6, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/ell.stackexchange.com.jsonl.gz": 7, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/magento.stackexchange.com.jsonl.gz": 7, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/sharepoint.stackexchange.com.jsonl.gz": 7, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/gaming.stackexchange.com.jsonl.gz": 7, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/wordpress.stackexchange.com.jsonl.gz": 7, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/mathoverflow.net.jsonl.gz": 8, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/salesforce.stackexchange.com.jsonl.gz": 8, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/apple.stackexchange.com.jsonl.gz": 8, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/gis.stackexchange.com.jsonl.gz": 9, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/english.stackexchange.com.jsonl.gz": 9, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/stats.stackexchange.com.jsonl.gz": 10, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/electronics.stackexchange.com.jsonl.gz": 11, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/physics.stackexchange.com.jsonl.gz": 12, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/unix.stackexchange.com.jsonl.gz": 13, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/tex.stackexchange.com.jsonl.gz": 15, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/serverfault.com.jsonl.gz": 20, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/askubuntu.com.jsonl.gz": 22, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/superuser.com.jsonl.gz": 30, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/small_stackexchanges.jsonl.gz": 38, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/math.stackexchange.com.jsonl.gz": 83, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/stackoverflow.com-Posts.jsonl.gz": 83}
|
29 |
+
},
|
30 |
+
"abstract2title": {
|
31 |
+
"weight": 10,
|
32 |
+
"files": {"/home/sbert_pretrained_models/datasets/embedding-training-data/S2ORC_title_abstract.jsonl.gz": 100}
|
33 |
+
},
|
34 |
+
"review2title": {
|
35 |
+
"weight": 10,
|
36 |
+
"files": {"/home/sbert_pretrained_models/datasets/embedding-training-data/amazon_review_2018.jsonl.gz": 100}
|
37 |
+
},
|
38 |
+
"news2title": {
|
39 |
+
"weight": 5,
|
40 |
+
"files": {"/home/sbert_pretrained_models/datasets/embedding-training-data/agnews.jsonl.gz": 50, "/home/sbert_pretrained_models/datasets/embedding-training-data/ccnews_title_text.jsonl.gz": 50, "/home/sbert_pretrained_models/datasets/embedding-training-data/xsum.jsonl.gz": 25}
|
41 |
+
},
|
42 |
+
"text2query": {
|
43 |
+
"weight": 10,
|
44 |
+
"files": {"/home/sbert_pretrained_models/datasets/embedding-training-data/msmarco-triplets.jsonl.gz": 33, "/home/sbert_pretrained_models/datasets/embedding-training-data/gooaq_pairs.jsonl.gz": 66, "/home/sbert_pretrained_models/datasets/embedding-training-data/NQ-train_pairs.jsonl.gz": 6}
|
45 |
+
},
|
46 |
+
"question2question": {
|
47 |
+
"weight": 10,
|
48 |
+
"files": {"/home/sbert_pretrained_models/datasets/embedding-training-data/WikiAnswers.jsonl.gz": 95, "/home/sbert_pretrained_models/datasets/embedding-training-data/quora_duplicates.jsonl.gz": 5}
|
49 |
+
}
|
50 |
+
|
51 |
+
}
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f44600793a9a002ab1a0d8d956b8c00f99adc0b9587e77485b7724de7a968f61
|
3 |
+
size 990445401
|
special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>", "additional_special_tokens": ["<extra_id_0>", "<extra_id_1>", "<extra_id_2>", "<extra_id_3>", "<extra_id_4>", "<extra_id_5>", "<extra_id_6>", "<extra_id_7>", "<extra_id_8>", "<extra_id_9>", "<extra_id_10>", "<extra_id_11>", "<extra_id_12>", "<extra_id_13>", "<extra_id_14>", "<extra_id_15>", "<extra_id_16>", "<extra_id_17>", "<extra_id_18>", "<extra_id_19>", "<extra_id_20>", "<extra_id_21>", "<extra_id_22>", "<extra_id_23>", "<extra_id_24>", "<extra_id_25>", "<extra_id_26>", "<extra_id_27>", "<extra_id_28>", "<extra_id_29>", "<extra_id_30>", "<extra_id_31>", "<extra_id_32>", "<extra_id_33>", "<extra_id_34>", "<extra_id_35>", "<extra_id_36>", "<extra_id_37>", "<extra_id_38>", "<extra_id_39>", "<extra_id_40>", "<extra_id_41>", "<extra_id_42>", "<extra_id_43>", "<extra_id_44>", "<extra_id_45>", "<extra_id_46>", "<extra_id_47>", "<extra_id_48>", "<extra_id_49>", "<extra_id_50>", "<extra_id_51>", "<extra_id_52>", "<extra_id_53>", "<extra_id_54>", "<extra_id_55>", "<extra_id_56>", "<extra_id_57>", "<extra_id_58>", "<extra_id_59>", "<extra_id_60>", "<extra_id_61>", "<extra_id_62>", "<extra_id_63>", "<extra_id_64>", "<extra_id_65>", "<extra_id_66>", "<extra_id_67>", "<extra_id_68>", "<extra_id_69>", "<extra_id_70>", "<extra_id_71>", "<extra_id_72>", "<extra_id_73>", "<extra_id_74>", "<extra_id_75>", "<extra_id_76>", "<extra_id_77>", "<extra_id_78>", "<extra_id_79>", "<extra_id_80>", "<extra_id_81>", "<extra_id_82>", "<extra_id_83>", "<extra_id_84>", "<extra_id_85>", "<extra_id_86>", "<extra_id_87>", "<extra_id_88>", "<extra_id_89>", "<extra_id_90>", "<extra_id_91>", "<extra_id_92>", "<extra_id_93>", "<extra_id_94>", "<extra_id_95>", "<extra_id_96>", "<extra_id_97>", "<extra_id_98>", "<extra_id_99>"]}
|
spiece.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
|
3 |
+
size 791656
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>", "extra_ids": 100, "additional_special_tokens": ["<extra_id_0>", "<extra_id_1>", "<extra_id_2>", "<extra_id_3>", "<extra_id_4>", "<extra_id_5>", "<extra_id_6>", "<extra_id_7>", "<extra_id_8>", "<extra_id_9>", "<extra_id_10>", "<extra_id_11>", "<extra_id_12>", "<extra_id_13>", "<extra_id_14>", "<extra_id_15>", "<extra_id_16>", "<extra_id_17>", "<extra_id_18>", "<extra_id_19>", "<extra_id_20>", "<extra_id_21>", "<extra_id_22>", "<extra_id_23>", "<extra_id_24>", "<extra_id_25>", "<extra_id_26>", "<extra_id_27>", "<extra_id_28>", "<extra_id_29>", "<extra_id_30>", "<extra_id_31>", "<extra_id_32>", "<extra_id_33>", "<extra_id_34>", "<extra_id_35>", "<extra_id_36>", "<extra_id_37>", "<extra_id_38>", "<extra_id_39>", "<extra_id_40>", "<extra_id_41>", "<extra_id_42>", "<extra_id_43>", "<extra_id_44>", "<extra_id_45>", "<extra_id_46>", "<extra_id_47>", "<extra_id_48>", "<extra_id_49>", "<extra_id_50>", "<extra_id_51>", "<extra_id_52>", "<extra_id_53>", "<extra_id_54>", "<extra_id_55>", "<extra_id_56>", "<extra_id_57>", "<extra_id_58>", "<extra_id_59>", "<extra_id_60>", "<extra_id_61>", "<extra_id_62>", "<extra_id_63>", "<extra_id_64>", "<extra_id_65>", "<extra_id_66>", "<extra_id_67>", "<extra_id_68>", "<extra_id_69>", "<extra_id_70>", "<extra_id_71>", "<extra_id_72>", "<extra_id_73>", "<extra_id_74>", "<extra_id_75>", "<extra_id_76>", "<extra_id_77>", "<extra_id_78>", "<extra_id_79>", "<extra_id_80>", "<extra_id_81>", "<extra_id_82>", "<extra_id_83>", "<extra_id_84>", "<extra_id_85>", "<extra_id_86>", "<extra_id_87>", "<extra_id_88>", "<extra_id_89>", "<extra_id_90>", "<extra_id_91>", "<extra_id_92>", "<extra_id_93>", "<extra_id_94>", "<extra_id_95>", "<extra_id_96>", "<extra_id_97>", "<extra_id_98>", "<extra_id_99>"], "model_max_length": 512, "name_or_path": "google/t5-v1_1-base", "special_tokens_map_file": "/root/.cache/huggingface/transformers/76bf19bfedb85afbe644966ca9ab7b0404d753a41bf601115bced39f825ffa9c.c94798918c92ded6aeef2d2f0e666d2cc4145eca1aa6e1336fde07f2e13e2f46", "sp_model_kwargs": {}, "tokenizer_class": "T5Tokenizer"}
|
train_script.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
from torch.utils.data import Dataset, IterableDataset
|
4 |
+
import gzip
|
5 |
+
import json
|
6 |
+
from transformers import Seq2SeqTrainer, AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments
|
7 |
+
import sys
|
8 |
+
from datetime import datetime
|
9 |
+
import torch
|
10 |
+
import random
|
11 |
+
from shutil import copyfile
|
12 |
+
import os
|
13 |
+
import wandb
|
14 |
+
import re
|
15 |
+
|
16 |
+
|
17 |
+
logging.basicConfig(
|
18 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
19 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
20 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
21 |
+
)
|
22 |
+
|
23 |
+
parser = argparse.ArgumentParser()
|
24 |
+
parser.add_argument("--model_name", default="google/t5-v1_1-base")
|
25 |
+
parser.add_argument("--train_file", required=True)
|
26 |
+
parser.add_argument("--epochs", default=1, type=int)
|
27 |
+
parser.add_argument("--batch_size", default=16, type=int)
|
28 |
+
parser.add_argument("--max_source_length", default=384, type=int)
|
29 |
+
parser.add_argument("--max_target_length", default=64, type=int)
|
30 |
+
parser.add_argument("--name", required=True)
|
31 |
+
parser.add_argument("--train_size", default=100*1000*1000, type=int)
|
32 |
+
parser.add_argument("--eval_size", default=10000, type=int)
|
33 |
+
parser.add_argument("--fp16", default=False, action='store_true')
|
34 |
+
parser.add_argument("--no_prefix", default=False, action='store_true')
|
35 |
+
args = parser.parse_args()
|
36 |
+
|
37 |
+
wandb.init(project="doc2query", name=f"{args.name}-{args.model_name}")
|
38 |
+
|
39 |
+
|
40 |
+
class PairDataset:
|
41 |
+
def __init__(self, filepath):
|
42 |
+
self.filepath = filepath
|
43 |
+
self.examples = []
|
44 |
+
|
45 |
+
def __iter__(self):
|
46 |
+
with gzip.open(self.filepath, 'rt') as fIn:
|
47 |
+
for line in fIn:
|
48 |
+
example = self.get_example(json.loads(line))
|
49 |
+
|
50 |
+
if example is not None:
|
51 |
+
self.examples.append(example)
|
52 |
+
yield example
|
53 |
+
|
54 |
+
while True:
|
55 |
+
random.shuffle(self.examples)
|
56 |
+
for ex in self.examples:
|
57 |
+
yield ex
|
58 |
+
|
59 |
+
|
60 |
+
def get_example(self, raw_example):
|
61 |
+
if isinstance(raw_example, dict):
|
62 |
+
if 'set' in raw_example:
|
63 |
+
example = random.sample(raw_example['set'], 2)
|
64 |
+
elif 'query' in raw_example:
|
65 |
+
example = [raw_example['query'], random.choice(raw_example['pos'])]
|
66 |
+
else:
|
67 |
+
raise ValueError("Unknown format: "+str(raw_example))
|
68 |
+
else:
|
69 |
+
example = [raw_example[0], raw_example[1]]
|
70 |
+
|
71 |
+
return example
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
class RedditTitleDataset(PairDataset):
|
77 |
+
def get_example(self, raw_example):
|
78 |
+
return [self.clean_title(raw_example['title']), raw_example['body']]
|
79 |
+
|
80 |
+
|
81 |
+
def clean_title(self, text):
|
82 |
+
text = text.replace("&", "&").strip()
|
83 |
+
if text.startswith("["):
|
84 |
+
text = re.sub("^\[[a-zA-Z0-9]+\]", "", text).strip()
|
85 |
+
|
86 |
+
if text.endswith("]"):
|
87 |
+
text = re.sub("\[[a-zA-Z0-9\.]+\]$", "", text).strip()
|
88 |
+
|
89 |
+
if text.startswith("/r"):
|
90 |
+
text = re.sub("^/[a-zA-Z0-9/]+[;,: \-]+", "", text).strip()
|
91 |
+
|
92 |
+
return text
|
93 |
+
|
94 |
+
|
95 |
+
class StackExchangeTitleBodyDataset(PairDataset):
|
96 |
+
def get_example(self, raw_example):
|
97 |
+
return raw_example['texts']
|
98 |
+
|
99 |
+
|
100 |
+
class MultiDataset(IterableDataset):
|
101 |
+
def __init__(self, train_config_path, num_samples):
|
102 |
+
self.num_samples = num_samples
|
103 |
+
|
104 |
+
with open(train_config_path) as fIn:
|
105 |
+
train_config = json.load(fIn)
|
106 |
+
|
107 |
+
self.categories = []
|
108 |
+
self.files = {}
|
109 |
+
self.file2dataset = {}
|
110 |
+
self.file2datasetIter = {}
|
111 |
+
|
112 |
+
for prefix in train_config:
|
113 |
+
self.categories.extend([prefix]*train_config[prefix]['weight'])
|
114 |
+
self.files[prefix] = []
|
115 |
+
|
116 |
+
for filename, weight in train_config[prefix]['files'].items():
|
117 |
+
self.files[prefix].extend([filename]*weight)
|
118 |
+
dataset = self.OpenDataset(filename)
|
119 |
+
self.file2dataset[filename] = dataset
|
120 |
+
self.file2datasetIter[filename] = iter(dataset)
|
121 |
+
|
122 |
+
random.shuffle(self.files[prefix])
|
123 |
+
|
124 |
+
random.shuffle(self.categories)
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
def OpenDataset(self, filepath):
|
130 |
+
if 'reddit_title_text' in filepath:
|
131 |
+
dataset = RedditTitleDataset(filepath)
|
132 |
+
elif 'stackexchange_archive/jsonl' in filepath:
|
133 |
+
dataset = StackExchangeTitleBodyDataset(filepath)
|
134 |
+
else:
|
135 |
+
dataset = PairDataset(filepath)
|
136 |
+
return dataset
|
137 |
+
|
138 |
+
|
139 |
+
def __len__(self):
|
140 |
+
return self.num_samples
|
141 |
+
|
142 |
+
def __iter__(self):
|
143 |
+
while True:
|
144 |
+
category = random.choice(self.categories)
|
145 |
+
filepath = random.choice(self.files[category])
|
146 |
+
dataset = self.file2datasetIter[filepath]
|
147 |
+
pair = next(dataset)
|
148 |
+
|
149 |
+
#Add prefix to the input
|
150 |
+
if not args.no_prefix:
|
151 |
+
pair[1] = category+": "+pair[1].strip()
|
152 |
+
yield pair
|
153 |
+
|
154 |
+
def delete_examples_cache(self):
|
155 |
+
for dataset in self.file2dataset.values():
|
156 |
+
dataset.examples = []
|
157 |
+
|
158 |
+
|
159 |
+
|
160 |
+
def main():
|
161 |
+
############ Model
|
162 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
|
163 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
164 |
+
|
165 |
+
save_steps = 5000
|
166 |
+
|
167 |
+
output_dir = 'output/'+args.name+'-'+args.model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
168 |
+
print("Output dir:", output_dir)
|
169 |
+
|
170 |
+
# Write self to path
|
171 |
+
os.makedirs(output_dir, exist_ok=True)
|
172 |
+
|
173 |
+
copyfile(args.train_file, os.path.join(output_dir, 'data_config.json'))
|
174 |
+
train_script_path = os.path.join(output_dir, 'train_script.py')
|
175 |
+
copyfile(__file__, train_script_path)
|
176 |
+
with open(train_script_path, 'a') as fOut:
|
177 |
+
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
|
178 |
+
|
179 |
+
####
|
180 |
+
|
181 |
+
training_args = Seq2SeqTrainingArguments(
|
182 |
+
output_dir=output_dir,
|
183 |
+
fp16=args.fp16,
|
184 |
+
fp16_backend="amp",
|
185 |
+
per_device_train_batch_size=args.batch_size,
|
186 |
+
evaluation_strategy="steps",
|
187 |
+
save_steps=save_steps,
|
188 |
+
logging_steps=100,
|
189 |
+
eval_steps=save_steps, #logging_steps,
|
190 |
+
warmup_steps=1000,
|
191 |
+
save_total_limit=1,
|
192 |
+
num_train_epochs=args.epochs,
|
193 |
+
report_to="wandb",
|
194 |
+
)
|
195 |
+
|
196 |
+
############ Arguments
|
197 |
+
|
198 |
+
############ Load datasets
|
199 |
+
|
200 |
+
|
201 |
+
train_dataset = MultiDataset(args.train_file, args.train_size)
|
202 |
+
train_dataset_iter = iter(train_dataset)
|
203 |
+
eval_dataset = [next(train_dataset_iter) for _ in range(args.eval_size)]
|
204 |
+
train_dataset.delete_examples_cache() #Make sure dev data is no re-used for training
|
205 |
+
|
206 |
+
for i in range(50):
|
207 |
+
print("Target:", eval_dataset[i][0])
|
208 |
+
print("Input:", eval_dataset[i][1])
|
209 |
+
print("\n\n===================\n\n")
|
210 |
+
|
211 |
+
print("Train dataset len:", len(train_dataset))
|
212 |
+
|
213 |
+
|
214 |
+
def data_collator(examples):
|
215 |
+
targets = [row[0] for row in examples]
|
216 |
+
inputs = [row[1] for row in examples]
|
217 |
+
label_pad_token_id = -100
|
218 |
+
|
219 |
+
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=True, truncation=True, return_tensors='pt', pad_to_multiple_of=8 if training_args.fp16 else None)
|
220 |
+
|
221 |
+
# Setup the tokenizer for targets
|
222 |
+
with tokenizer.as_target_tokenizer():
|
223 |
+
labels = tokenizer(targets, max_length=args.max_target_length, padding=True, truncation=True, pad_to_multiple_of=8 if training_args.fp16 else None)
|
224 |
+
|
225 |
+
# replace all tokenizer.pad_token_id in the labels by -100 to ignore padding in the loss.
|
226 |
+
labels["input_ids"] = [
|
227 |
+
[(l if l != tokenizer.pad_token_id else label_pad_token_id) for l in label] for label in labels["input_ids"]
|
228 |
+
]
|
229 |
+
|
230 |
+
|
231 |
+
model_inputs["labels"] = torch.tensor(labels["input_ids"])
|
232 |
+
return model_inputs
|
233 |
+
|
234 |
+
## Define the trainer
|
235 |
+
trainer = Seq2SeqTrainer(
|
236 |
+
model=model,
|
237 |
+
args=training_args,
|
238 |
+
train_dataset=train_dataset,
|
239 |
+
eval_dataset=eval_dataset,
|
240 |
+
tokenizer=tokenizer,
|
241 |
+
data_collator=data_collator
|
242 |
+
)
|
243 |
+
|
244 |
+
### Save the model
|
245 |
+
train_result = trainer.train()
|
246 |
+
trainer.save_model()
|
247 |
+
|
248 |
+
|
249 |
+
if __name__ == "__main__":
|
250 |
+
main()
|
251 |
+
|
252 |
+
# Script was called via:
|
253 |
+
#python train_hf_trainer_prefix.py --train_file train_config.json --name all-datasets-v1-no_prefix --no_prefix
|
trainer_state.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|