Text2Text Generation
Transformers
PyTorch
English
t5
text-generation-inference
Inference Endpoints
nreimers commited on
Commit
28d82c0
1 Parent(s): 2e7228b
README.md ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ datasets:
4
+ - sentence-transformers/reddit-title-body
5
+ - sentence-transformers/embedding-training-data
6
+ widget:
7
+ - text: "Python is an interpreted, high-level and general-purpose programming language. Python's design philosophy emphasizes code readability with its notable use of significant whitespace. Its language constructs and object-oriented approach aim to help programmers write clear, logical code for small and large-scale projects."
8
+
9
+ license: apache-2.0
10
+ ---
11
+
12
+ # doc2query/all-t5-base-v1
13
+
14
+ This is a [doc2query](https://arxiv.org/abs/1904.08375) model based on T5 (also known as [docT5query](https://cs.uwaterloo.ca/~jimmylin/publications/Nogueira_Lin_2019_docTTTTTquery-v2.pdf)).
15
+
16
+ It can be used for:
17
+ - **Document expansion**: You generate for your paragraphs 20-40 queries and index the paragraphs and the generates queries in a standard BM25 index like Elasticsearch, OpenSearch, or Lucene. The generated queries help to close the lexical gap of lexical search, as the generate queries contain synonyms. Further, it re-weights words giving important words a higher weight even if they appear seldomn in a paragraph. In our [BEIR](https://arxiv.org/abs/2104.08663) paper we showed that BM25+docT5query is a powerful search engine. In the [BEIR repository](https://github.com/UKPLab/beir) we have an example how to use docT5query with Pyserini.
18
+ - **Domain Specific Training Data Generation**: It can be used to generate training data to learn an embedding model. On [SBERT.net](https://www.sbert.net/examples/unsupervised_learning/query_generation/README.html) we have an example how to use the model to generate (query, text) pairs for a given collection of unlabeled texts. These pairs can then be used to train powerful dense embedding models.
19
+
20
+ ## Usage
21
+ ```python
22
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
23
+
24
+ model_name = 'doc2query/all-t5-base-v1'
25
+ tokenizer = T5Tokenizer.from_pretrained(model_name)
26
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
27
+
28
+ text = "Python is an interpreted, high-level and general-purpose programming language. Python's design philosophy emphasizes code readability with its notable use of significant whitespace. Its language constructs and object-oriented approach aim to help programmers write clear, logical code for small and large-scale projects."
29
+
30
+
31
+ input_ids = tokenizer.encode(text, max_length=384, truncation=True, return_tensors='pt')
32
+ outputs = model.generate(
33
+ input_ids=input_ids,
34
+ max_length=64,
35
+ do_sample=True,
36
+ top_p=0.95,
37
+ num_return_sequences=5)
38
+
39
+ print("Text:")
40
+ print(text)
41
+
42
+ print("\nGenerated Queries:")
43
+ for i in range(len(outputs)):
44
+ query = tokenizer.decode(outputs[i], skip_special_tokens=True)
45
+ print(f'{i + 1}: {query}')
46
+ ```
47
+
48
+ **Note:** `model.generate()` is non-deterministic. It produces different queries each time you run it.
49
+
50
+ ## Training
51
+ This model fine-tuned [google/t5-v1_1-base](https://huggingface.co/google/t5-v1_1-base) for 570k training steps. For the training script, see the `train_script.py` in this repository.
52
+
53
+ The input-text was truncated to 384 word pieces. Output text was generated up to 64 word pieces.
54
+
55
+ This model was trained on a large collection of datasets. For the exact datasets names and weights see the `data_config.json` in this repository. Most of the datasets are available at [https://huggingface.co/sentence-transformers](https://huggingface.co/sentence-transformers).
56
+
57
+ The datasets include besides others:
58
+ - (title, body) pairs from [Reddit](https://huggingface.co/datasets/sentence-transformers/reddit-title-body)
59
+ - (title, body) pairs and (title, answer) pairs from StackExchange and Yahoo Answers!
60
+ - (title, review) pairs from Amazon reviews
61
+ - (query, paragraph) pairs from MS MARCO, NQ, and GooAQ
62
+ - (question, duplicate_question) from Quora and WikiAnswers
63
+ - (title, abstract) pairs from S2ORC
64
+
65
+ ## Prefix
66
+
67
+ This model was trained **without a prefix**. In contrast to [doc2query/all-with_prefix-t5-base-v1](https://huggingface.co/doc2query/all-with_prefix-t5-base-v1) you cannot specify what type of transformation (answer2question, review2title) etc. you will have. This can lead to a mixture of output values.
68
+
69
+
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "google/t5-v1_1-base",
3
+ "architectures": [
4
+ "T5ForConditionalGeneration"
5
+ ],
6
+ "d_ff": 2048,
7
+ "d_kv": 64,
8
+ "d_model": 768,
9
+ "decoder_start_token_id": 0,
10
+ "dropout_rate": 0.1,
11
+ "eos_token_id": 1,
12
+ "feed_forward_proj": "gated-gelu",
13
+ "gradient_checkpointing": false,
14
+ "initializer_factor": 1.0,
15
+ "is_encoder_decoder": true,
16
+ "layer_norm_epsilon": 1e-06,
17
+ "model_type": "t5",
18
+ "num_decoder_layers": 12,
19
+ "num_heads": 12,
20
+ "num_layers": 12,
21
+ "output_past": true,
22
+ "pad_token_id": 0,
23
+ "relative_attention_num_buckets": 32,
24
+ "tie_word_embeddings": false,
25
+ "torch_dtype": "float32",
26
+ "transformers_version": "4.10.2",
27
+ "use_cache": true,
28
+ "vocab_size": 32128
29
+ }
data_config.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "text2reddit": {
3
+ "weight": 30,
4
+ "files": {"/home/reddit/reddit-title-body/reddit_title_text_2010.jsonl.gz": 1,
5
+ "/home/reddit/reddit-title-body/reddit_title_text_2011.jsonl.gz": 1,
6
+ "/home/reddit/reddit-title-body/reddit_title_text_2012.jsonl.gz": 3,
7
+ "/home/reddit/reddit-title-body/reddit_title_text_2013.jsonl.gz": 5,
8
+ "/home/reddit/reddit-title-body/reddit_title_text_2014.jsonl.gz": 8,
9
+ "/home/reddit/reddit-title-body/reddit_title_text_2015.jsonl.gz": 11,
10
+ "/home/reddit/reddit-title-body/reddit_title_text_2016.jsonl.gz": 12,
11
+ "/home/reddit/reddit-title-body/reddit_title_text_2017.jsonl.gz": 13,
12
+ "/home/reddit/reddit-title-body/reddit_title_text_2018.jsonl.gz": 15,
13
+ "/home/reddit/reddit-title-body/reddit_title_text_2019.jsonl.gz": 19,
14
+ "/home/reddit/reddit-title-body/reddit_title_text_2020.jsonl.gz": 23,
15
+ "/home/reddit/reddit-title-body/reddit_title_text_2021.jsonl.gz": 12
16
+ }
17
+
18
+ },
19
+ "question2title": {
20
+ "weight": 30,
21
+ "files": {"/home/sbert_pretrained_models/datasets/embedding-training-data/yahoo_answers_title_question.jsonl.gz": 40,
22
+ "/home/stackexchange_archive/jsonl/skeptics.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/writers.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/astronomy.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/vi.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/cstheory.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/engineering.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/french.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/economics.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/anime.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/islam.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/expressionengine.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/politics.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/history.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/christianity.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/boardgames.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/jsonl/civicrm.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/craftcms.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/hinduism.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/networkengineering.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/german.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/philosophy.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/gardening.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/space.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/bicycles.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/quant.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/puzzling.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/law.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/arduino.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/aviation.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/softwarerecs.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/movies.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/music.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/emacs.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/dsp.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/japanese.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/mechanics.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/crypto.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/cooking.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/photo.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/workplace.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/jsonl/biology.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/bitcoin.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/worldbuilding.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/datascience.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/ux.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/webapps.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/graphicdesign.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/raspberrypi.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/money.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/judaism.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/ethereum.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/academia.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/chemistry.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/webmasters.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/meta.stackoverflow.com.jsonl.gz": 3, "/home/stackexchange_archive/jsonl/cs.stackexchange.com.jsonl.gz": 4, "/home/stackexchange_archive/jsonl/travel.stackexchange.com.jsonl.gz": 4, "/home/stackexchange_archive/jsonl/rpg.stackexchange.com.jsonl.gz": 4, "/home/stackexchange_archive/jsonl/codereview.stackexchange.com.jsonl.gz": 4, "/home/stackexchange_archive/jsonl/gamedev.stackexchange.com.jsonl.gz": 4, "/home/stackexchange_archive/jsonl/android.stackexchange.com.jsonl.gz": 5, "/home/stackexchange_archive/jsonl/softwareengineering.stackexchange.com.jsonl.gz": 5, "/home/stackexchange_archive/jsonl/security.stackexchange.com.jsonl.gz": 5, "/home/stackexchange_archive/jsonl/diy.stackexchange.com.jsonl.gz": 5, "/home/stackexchange_archive/jsonl/scifi.stackexchange.com.jsonl.gz": 6, "/home/stackexchange_archive/jsonl/mathematica.stackexchange.com.jsonl.gz": 7, "/home/stackexchange_archive/jsonl/drupal.stackexchange.com.jsonl.gz": 7, "/home/stackexchange_archive/jsonl/blender.stackexchange.com.jsonl.gz": 7, "/home/stackexchange_archive/jsonl/dba.stackexchange.com.jsonl.gz": 7, "/home/stackexchange_archive/jsonl/ell.stackexchange.com.jsonl.gz": 7, "/home/stackexchange_archive/jsonl/meta.stackexchange.com.jsonl.gz": 7, "/home/stackexchange_archive/jsonl/gaming.stackexchange.com.jsonl.gz": 8, "/home/stackexchange_archive/jsonl/sharepoint.stackexchange.com.jsonl.gz": 8, "/home/stackexchange_archive/jsonl/magento.stackexchange.com.jsonl.gz": 9, "/home/stackexchange_archive/jsonl/wordpress.stackexchange.com.jsonl.gz": 9, "/home/stackexchange_archive/jsonl/salesforce.stackexchange.com.jsonl.gz": 9, "/home/stackexchange_archive/jsonl/english.stackexchange.com.jsonl.gz": 10, "/home/stackexchange_archive/jsonl/apple.stackexchange.com.jsonl.gz": 10, "/home/stackexchange_archive/jsonl/mathoverflow.net.jsonl.gz": 10, "/home/stackexchange_archive/jsonl/gis.stackexchange.com.jsonl.gz": 11, "/home/stackexchange_archive/jsonl/electronics.stackexchange.com.jsonl.gz": 12, "/home/stackexchange_archive/jsonl/physics.stackexchange.com.jsonl.gz": 15, "/home/stackexchange_archive/jsonl/stats.stackexchange.com.jsonl.gz": 15, "/home/stackexchange_archive/jsonl/unix.stackexchange.com.jsonl.gz": 16, "/home/stackexchange_archive/jsonl/tex.stackexchange.com.jsonl.gz": 17, "/home/stackexchange_archive/jsonl/serverfault.com.jsonl.gz": 23, "/home/stackexchange_archive/jsonl/askubuntu.com.jsonl.gz": 29, "/home/stackexchange_archive/jsonl/superuser.com.jsonl.gz": 36, "/home/stackexchange_archive/jsonl/small_stackexchanges.jsonl.gz": 37, "/home/stackexchange_archive/jsonl/math.stackexchange.com.jsonl.gz": 83, "/home/stackexchange_archive/jsonl/stackoverflow.com-Posts.jsonl.gz": 83
23
+ }
24
+ },
25
+ "answer2question": {
26
+ "weight": 30,
27
+ "files": {"/home/sbert_pretrained_models/datasets/embedding-training-data/yahoo_answers_title_answer.jsonl.gz": 80, "/home/sbert_pretrained_models/datasets/embedding-training-data/amazon-qa.jsonl.gz": 80,
28
+ "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/islam.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/anime.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/french.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/civicrm.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/expressionengine.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/history.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/politics.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/craftcms.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/christianity.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/softwarerecs.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/boardgames.stackexchange.com.jsonl.gz": 1, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/networkengineering.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/space.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/quant.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/philosophy.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/gardening.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/german.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/bicycles.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/law.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/arduino.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/emacs.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/dsp.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/puzzling.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/movies.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/mechanics.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/aviation.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/biology.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/crypto.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/music.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/datascience.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/japanese.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/bitcoin.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/cooking.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/photo.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/workplace.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/meta.stackoverflow.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/raspberrypi.stackexchange.com.jsonl.gz": 2, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/webapps.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/judaism.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/ethereum.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/worldbuilding.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/chemistry.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/graphicdesign.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/ux.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/money.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/cs.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/webmasters.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/academia.stackexchange.com.jsonl.gz": 3, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/travel.stackexchange.com.jsonl.gz": 4, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/android.stackexchange.com.jsonl.gz": 4, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/gamedev.stackexchange.com.jsonl.gz": 4, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/rpg.stackexchange.com.jsonl.gz": 4, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/codereview.stackexchange.com.jsonl.gz": 4, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/softwareengineering.stackexchange.com.jsonl.gz": 5, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/security.stackexchange.com.jsonl.gz": 5, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/diy.stackexchange.com.jsonl.gz": 5, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/blender.stackexchange.com.jsonl.gz": 5, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/scifi.stackexchange.com.jsonl.gz": 5, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/mathematica.stackexchange.com.jsonl.gz": 5, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/meta.stackexchange.com.jsonl.gz": 5, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/drupal.stackexchange.com.jsonl.gz": 6, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/dba.stackexchange.com.jsonl.gz": 6, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/ell.stackexchange.com.jsonl.gz": 7, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/magento.stackexchange.com.jsonl.gz": 7, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/sharepoint.stackexchange.com.jsonl.gz": 7, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/gaming.stackexchange.com.jsonl.gz": 7, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/wordpress.stackexchange.com.jsonl.gz": 7, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/mathoverflow.net.jsonl.gz": 8, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/salesforce.stackexchange.com.jsonl.gz": 8, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/apple.stackexchange.com.jsonl.gz": 8, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/gis.stackexchange.com.jsonl.gz": 9, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/english.stackexchange.com.jsonl.gz": 9, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/stats.stackexchange.com.jsonl.gz": 10, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/electronics.stackexchange.com.jsonl.gz": 11, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/physics.stackexchange.com.jsonl.gz": 12, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/unix.stackexchange.com.jsonl.gz": 13, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/tex.stackexchange.com.jsonl.gz": 15, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/serverfault.com.jsonl.gz": 20, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/askubuntu.com.jsonl.gz": 22, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/superuser.com.jsonl.gz": 30, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/small_stackexchanges.jsonl.gz": 38, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/math.stackexchange.com.jsonl.gz": 83, "/home/stackexchange_archive/stackexchange_extracted/TitleAnswer/stackoverflow.com-Posts.jsonl.gz": 83}
29
+ },
30
+ "abstract2title": {
31
+ "weight": 10,
32
+ "files": {"/home/sbert_pretrained_models/datasets/embedding-training-data/S2ORC_title_abstract.jsonl.gz": 100}
33
+ },
34
+ "review2title": {
35
+ "weight": 10,
36
+ "files": {"/home/sbert_pretrained_models/datasets/embedding-training-data/amazon_review_2018.jsonl.gz": 100}
37
+ },
38
+ "news2title": {
39
+ "weight": 5,
40
+ "files": {"/home/sbert_pretrained_models/datasets/embedding-training-data/agnews.jsonl.gz": 50, "/home/sbert_pretrained_models/datasets/embedding-training-data/ccnews_title_text.jsonl.gz": 50, "/home/sbert_pretrained_models/datasets/embedding-training-data/xsum.jsonl.gz": 25}
41
+ },
42
+ "text2query": {
43
+ "weight": 10,
44
+ "files": {"/home/sbert_pretrained_models/datasets/embedding-training-data/msmarco-triplets.jsonl.gz": 33, "/home/sbert_pretrained_models/datasets/embedding-training-data/gooaq_pairs.jsonl.gz": 66, "/home/sbert_pretrained_models/datasets/embedding-training-data/NQ-train_pairs.jsonl.gz": 6}
45
+ },
46
+ "question2question": {
47
+ "weight": 10,
48
+ "files": {"/home/sbert_pretrained_models/datasets/embedding-training-data/WikiAnswers.jsonl.gz": 95, "/home/sbert_pretrained_models/datasets/embedding-training-data/quora_duplicates.jsonl.gz": 5}
49
+ }
50
+
51
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f44600793a9a002ab1a0d8d956b8c00f99adc0b9587e77485b7724de7a968f61
3
+ size 990445401
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>", "additional_special_tokens": ["<extra_id_0>", "<extra_id_1>", "<extra_id_2>", "<extra_id_3>", "<extra_id_4>", "<extra_id_5>", "<extra_id_6>", "<extra_id_7>", "<extra_id_8>", "<extra_id_9>", "<extra_id_10>", "<extra_id_11>", "<extra_id_12>", "<extra_id_13>", "<extra_id_14>", "<extra_id_15>", "<extra_id_16>", "<extra_id_17>", "<extra_id_18>", "<extra_id_19>", "<extra_id_20>", "<extra_id_21>", "<extra_id_22>", "<extra_id_23>", "<extra_id_24>", "<extra_id_25>", "<extra_id_26>", "<extra_id_27>", "<extra_id_28>", "<extra_id_29>", "<extra_id_30>", "<extra_id_31>", "<extra_id_32>", "<extra_id_33>", "<extra_id_34>", "<extra_id_35>", "<extra_id_36>", "<extra_id_37>", "<extra_id_38>", "<extra_id_39>", "<extra_id_40>", "<extra_id_41>", "<extra_id_42>", "<extra_id_43>", "<extra_id_44>", "<extra_id_45>", "<extra_id_46>", "<extra_id_47>", "<extra_id_48>", "<extra_id_49>", "<extra_id_50>", "<extra_id_51>", "<extra_id_52>", "<extra_id_53>", "<extra_id_54>", "<extra_id_55>", "<extra_id_56>", "<extra_id_57>", "<extra_id_58>", "<extra_id_59>", "<extra_id_60>", "<extra_id_61>", "<extra_id_62>", "<extra_id_63>", "<extra_id_64>", "<extra_id_65>", "<extra_id_66>", "<extra_id_67>", "<extra_id_68>", "<extra_id_69>", "<extra_id_70>", "<extra_id_71>", "<extra_id_72>", "<extra_id_73>", "<extra_id_74>", "<extra_id_75>", "<extra_id_76>", "<extra_id_77>", "<extra_id_78>", "<extra_id_79>", "<extra_id_80>", "<extra_id_81>", "<extra_id_82>", "<extra_id_83>", "<extra_id_84>", "<extra_id_85>", "<extra_id_86>", "<extra_id_87>", "<extra_id_88>", "<extra_id_89>", "<extra_id_90>", "<extra_id_91>", "<extra_id_92>", "<extra_id_93>", "<extra_id_94>", "<extra_id_95>", "<extra_id_96>", "<extra_id_97>", "<extra_id_98>", "<extra_id_99>"]}
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
3
+ size 791656
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>", "extra_ids": 100, "additional_special_tokens": ["<extra_id_0>", "<extra_id_1>", "<extra_id_2>", "<extra_id_3>", "<extra_id_4>", "<extra_id_5>", "<extra_id_6>", "<extra_id_7>", "<extra_id_8>", "<extra_id_9>", "<extra_id_10>", "<extra_id_11>", "<extra_id_12>", "<extra_id_13>", "<extra_id_14>", "<extra_id_15>", "<extra_id_16>", "<extra_id_17>", "<extra_id_18>", "<extra_id_19>", "<extra_id_20>", "<extra_id_21>", "<extra_id_22>", "<extra_id_23>", "<extra_id_24>", "<extra_id_25>", "<extra_id_26>", "<extra_id_27>", "<extra_id_28>", "<extra_id_29>", "<extra_id_30>", "<extra_id_31>", "<extra_id_32>", "<extra_id_33>", "<extra_id_34>", "<extra_id_35>", "<extra_id_36>", "<extra_id_37>", "<extra_id_38>", "<extra_id_39>", "<extra_id_40>", "<extra_id_41>", "<extra_id_42>", "<extra_id_43>", "<extra_id_44>", "<extra_id_45>", "<extra_id_46>", "<extra_id_47>", "<extra_id_48>", "<extra_id_49>", "<extra_id_50>", "<extra_id_51>", "<extra_id_52>", "<extra_id_53>", "<extra_id_54>", "<extra_id_55>", "<extra_id_56>", "<extra_id_57>", "<extra_id_58>", "<extra_id_59>", "<extra_id_60>", "<extra_id_61>", "<extra_id_62>", "<extra_id_63>", "<extra_id_64>", "<extra_id_65>", "<extra_id_66>", "<extra_id_67>", "<extra_id_68>", "<extra_id_69>", "<extra_id_70>", "<extra_id_71>", "<extra_id_72>", "<extra_id_73>", "<extra_id_74>", "<extra_id_75>", "<extra_id_76>", "<extra_id_77>", "<extra_id_78>", "<extra_id_79>", "<extra_id_80>", "<extra_id_81>", "<extra_id_82>", "<extra_id_83>", "<extra_id_84>", "<extra_id_85>", "<extra_id_86>", "<extra_id_87>", "<extra_id_88>", "<extra_id_89>", "<extra_id_90>", "<extra_id_91>", "<extra_id_92>", "<extra_id_93>", "<extra_id_94>", "<extra_id_95>", "<extra_id_96>", "<extra_id_97>", "<extra_id_98>", "<extra_id_99>"], "model_max_length": 512, "name_or_path": "google/t5-v1_1-base", "special_tokens_map_file": "/root/.cache/huggingface/transformers/76bf19bfedb85afbe644966ca9ab7b0404d753a41bf601115bced39f825ffa9c.c94798918c92ded6aeef2d2f0e666d2cc4145eca1aa6e1336fde07f2e13e2f46", "sp_model_kwargs": {}, "tokenizer_class": "T5Tokenizer"}
train_script.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ from torch.utils.data import Dataset, IterableDataset
4
+ import gzip
5
+ import json
6
+ from transformers import Seq2SeqTrainer, AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments
7
+ import sys
8
+ from datetime import datetime
9
+ import torch
10
+ import random
11
+ from shutil import copyfile
12
+ import os
13
+ import wandb
14
+ import re
15
+
16
+
17
+ logging.basicConfig(
18
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
19
+ datefmt="%Y-%m-%d %H:%M:%S",
20
+ handlers=[logging.StreamHandler(sys.stdout)],
21
+ )
22
+
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument("--model_name", default="google/t5-v1_1-base")
25
+ parser.add_argument("--train_file", required=True)
26
+ parser.add_argument("--epochs", default=1, type=int)
27
+ parser.add_argument("--batch_size", default=16, type=int)
28
+ parser.add_argument("--max_source_length", default=384, type=int)
29
+ parser.add_argument("--max_target_length", default=64, type=int)
30
+ parser.add_argument("--name", required=True)
31
+ parser.add_argument("--train_size", default=100*1000*1000, type=int)
32
+ parser.add_argument("--eval_size", default=10000, type=int)
33
+ parser.add_argument("--fp16", default=False, action='store_true')
34
+ parser.add_argument("--no_prefix", default=False, action='store_true')
35
+ args = parser.parse_args()
36
+
37
+ wandb.init(project="doc2query", name=f"{args.name}-{args.model_name}")
38
+
39
+
40
+ class PairDataset:
41
+ def __init__(self, filepath):
42
+ self.filepath = filepath
43
+ self.examples = []
44
+
45
+ def __iter__(self):
46
+ with gzip.open(self.filepath, 'rt') as fIn:
47
+ for line in fIn:
48
+ example = self.get_example(json.loads(line))
49
+
50
+ if example is not None:
51
+ self.examples.append(example)
52
+ yield example
53
+
54
+ while True:
55
+ random.shuffle(self.examples)
56
+ for ex in self.examples:
57
+ yield ex
58
+
59
+
60
+ def get_example(self, raw_example):
61
+ if isinstance(raw_example, dict):
62
+ if 'set' in raw_example:
63
+ example = random.sample(raw_example['set'], 2)
64
+ elif 'query' in raw_example:
65
+ example = [raw_example['query'], random.choice(raw_example['pos'])]
66
+ else:
67
+ raise ValueError("Unknown format: "+str(raw_example))
68
+ else:
69
+ example = [raw_example[0], raw_example[1]]
70
+
71
+ return example
72
+
73
+
74
+
75
+
76
+ class RedditTitleDataset(PairDataset):
77
+ def get_example(self, raw_example):
78
+ return [self.clean_title(raw_example['title']), raw_example['body']]
79
+
80
+
81
+ def clean_title(self, text):
82
+ text = text.replace("&amp;", "&").strip()
83
+ if text.startswith("["):
84
+ text = re.sub("^\[[a-zA-Z0-9]+\]", "", text).strip()
85
+
86
+ if text.endswith("]"):
87
+ text = re.sub("\[[a-zA-Z0-9\.]+\]$", "", text).strip()
88
+
89
+ if text.startswith("/r"):
90
+ text = re.sub("^/[a-zA-Z0-9/]+[;,: \-]+", "", text).strip()
91
+
92
+ return text
93
+
94
+
95
+ class StackExchangeTitleBodyDataset(PairDataset):
96
+ def get_example(self, raw_example):
97
+ return raw_example['texts']
98
+
99
+
100
+ class MultiDataset(IterableDataset):
101
+ def __init__(self, train_config_path, num_samples):
102
+ self.num_samples = num_samples
103
+
104
+ with open(train_config_path) as fIn:
105
+ train_config = json.load(fIn)
106
+
107
+ self.categories = []
108
+ self.files = {}
109
+ self.file2dataset = {}
110
+ self.file2datasetIter = {}
111
+
112
+ for prefix in train_config:
113
+ self.categories.extend([prefix]*train_config[prefix]['weight'])
114
+ self.files[prefix] = []
115
+
116
+ for filename, weight in train_config[prefix]['files'].items():
117
+ self.files[prefix].extend([filename]*weight)
118
+ dataset = self.OpenDataset(filename)
119
+ self.file2dataset[filename] = dataset
120
+ self.file2datasetIter[filename] = iter(dataset)
121
+
122
+ random.shuffle(self.files[prefix])
123
+
124
+ random.shuffle(self.categories)
125
+
126
+
127
+
128
+
129
+ def OpenDataset(self, filepath):
130
+ if 'reddit_title_text' in filepath:
131
+ dataset = RedditTitleDataset(filepath)
132
+ elif 'stackexchange_archive/jsonl' in filepath:
133
+ dataset = StackExchangeTitleBodyDataset(filepath)
134
+ else:
135
+ dataset = PairDataset(filepath)
136
+ return dataset
137
+
138
+
139
+ def __len__(self):
140
+ return self.num_samples
141
+
142
+ def __iter__(self):
143
+ while True:
144
+ category = random.choice(self.categories)
145
+ filepath = random.choice(self.files[category])
146
+ dataset = self.file2datasetIter[filepath]
147
+ pair = next(dataset)
148
+
149
+ #Add prefix to the input
150
+ if not args.no_prefix:
151
+ pair[1] = category+": "+pair[1].strip()
152
+ yield pair
153
+
154
+ def delete_examples_cache(self):
155
+ for dataset in self.file2dataset.values():
156
+ dataset.examples = []
157
+
158
+
159
+
160
+ def main():
161
+ ############ Model
162
+ model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
163
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
164
+
165
+ save_steps = 5000
166
+
167
+ output_dir = 'output/'+args.name+'-'+args.model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
168
+ print("Output dir:", output_dir)
169
+
170
+ # Write self to path
171
+ os.makedirs(output_dir, exist_ok=True)
172
+
173
+ copyfile(args.train_file, os.path.join(output_dir, 'data_config.json'))
174
+ train_script_path = os.path.join(output_dir, 'train_script.py')
175
+ copyfile(__file__, train_script_path)
176
+ with open(train_script_path, 'a') as fOut:
177
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
178
+
179
+ ####
180
+
181
+ training_args = Seq2SeqTrainingArguments(
182
+ output_dir=output_dir,
183
+ fp16=args.fp16,
184
+ fp16_backend="amp",
185
+ per_device_train_batch_size=args.batch_size,
186
+ evaluation_strategy="steps",
187
+ save_steps=save_steps,
188
+ logging_steps=100,
189
+ eval_steps=save_steps, #logging_steps,
190
+ warmup_steps=1000,
191
+ save_total_limit=1,
192
+ num_train_epochs=args.epochs,
193
+ report_to="wandb",
194
+ )
195
+
196
+ ############ Arguments
197
+
198
+ ############ Load datasets
199
+
200
+
201
+ train_dataset = MultiDataset(args.train_file, args.train_size)
202
+ train_dataset_iter = iter(train_dataset)
203
+ eval_dataset = [next(train_dataset_iter) for _ in range(args.eval_size)]
204
+ train_dataset.delete_examples_cache() #Make sure dev data is no re-used for training
205
+
206
+ for i in range(50):
207
+ print("Target:", eval_dataset[i][0])
208
+ print("Input:", eval_dataset[i][1])
209
+ print("\n\n===================\n\n")
210
+
211
+ print("Train dataset len:", len(train_dataset))
212
+
213
+
214
+ def data_collator(examples):
215
+ targets = [row[0] for row in examples]
216
+ inputs = [row[1] for row in examples]
217
+ label_pad_token_id = -100
218
+
219
+ model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=True, truncation=True, return_tensors='pt', pad_to_multiple_of=8 if training_args.fp16 else None)
220
+
221
+ # Setup the tokenizer for targets
222
+ with tokenizer.as_target_tokenizer():
223
+ labels = tokenizer(targets, max_length=args.max_target_length, padding=True, truncation=True, pad_to_multiple_of=8 if training_args.fp16 else None)
224
+
225
+ # replace all tokenizer.pad_token_id in the labels by -100 to ignore padding in the loss.
226
+ labels["input_ids"] = [
227
+ [(l if l != tokenizer.pad_token_id else label_pad_token_id) for l in label] for label in labels["input_ids"]
228
+ ]
229
+
230
+
231
+ model_inputs["labels"] = torch.tensor(labels["input_ids"])
232
+ return model_inputs
233
+
234
+ ## Define the trainer
235
+ trainer = Seq2SeqTrainer(
236
+ model=model,
237
+ args=training_args,
238
+ train_dataset=train_dataset,
239
+ eval_dataset=eval_dataset,
240
+ tokenizer=tokenizer,
241
+ data_collator=data_collator
242
+ )
243
+
244
+ ### Save the model
245
+ train_result = trainer.train()
246
+ trainer.save_model()
247
+
248
+
249
+ if __name__ == "__main__":
250
+ main()
251
+
252
+ # Script was called via:
253
+ #python train_hf_trainer_prefix.py --train_file train_config.json --name all-datasets-v1-no_prefix --no_prefix
trainer_state.json ADDED
The diff for this file is too large to render. See raw diff