jwieting commited on
Commit
ceee6db
·
1 Parent(s): 7b753e5

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +80 -0
  2. config.json +35 -0
  3. modeling_vmsst.py +19 -0
  4. pytorch_model.bin +3 -0
README.md CHANGED
@@ -1,3 +1,83 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ pipeline_tag: sentence-similarity
4
+ tags:
5
+ - cross-lingual
6
+ - multilingual
7
+ - question-answering
8
+ - retrieval
9
+ - sentence-similarity
10
+ - variational
11
  ---
12
+
13
+ # VMSST
14
+
15
+ Published as a long paper at ACL 2023.
16
+
17
+ Contrastive learning has been successfully used for retrieval of semantically aligned sentences, but it often requires large batch sizes and carefully engineered heuristics to work well. In this paper, we instead propose a generative model for learning multilingual text embeddings which can be used to retrieve or score sentence pairs. Our model operates on parallel data in N languages and, through an approximation we introduce, efficiently encourages source separation in this multilingual setting, separating semantic information that is shared between translations from stylistic or language-specific variation. We show careful large-scale comparisons between contrastive and generation-based approaches for learning multilingual text embeddings, a comparison that has not been done to the best of our knowledge despite the popularity of these approaches. We evaluate this method on a suite of tasks including semantic similarity, bitext mining, and cross-lingual question retrieval––the last of which we introduce in this paper. Overall, our Variational Multilingual Source-Separation Transformer (VMSST) model outperforms both a strong contrastive and generative baseline on these tasks.
18
+
19
+ ## Checkpoints
20
+
21
+ T5X (Jax): https://storage.googleapis.com/gresearch/vmsst/vmsst-large-2048-t5x.zip
22
+
23
+ PyTorch: https://storage.googleapis.com/gresearch/vmsst/vmsst-large-2048-pytorch.zip
24
+
25
+ ## Usage
26
+
27
+ ### Installation
28
+
29
+ 1. Clone the following repository from Google Research.
30
+
31
+ ```
32
+ git clone -b master --single-branch https://github.com/google-research/google-research.git
33
+ ```
34
+
35
+ 2. Make sure `google-research` is the current directory:
36
+
37
+ ```
38
+ cd google-research/vmsst
39
+ ```
40
+
41
+ 3. Create and activate a new virtualenv:
42
+
43
+ ```
44
+ python -m venv vmsst
45
+ source vmsst/bin/activate
46
+ ```
47
+
48
+ 4. This repository is tested on Python 3.10+. Install required packages:
49
+
50
+ ```
51
+ pip install -r requirements.txt
52
+ ```
53
+
54
+ ### Test
55
+
56
+ To test that the checkpoint and installation are working as intended, run:
57
+
58
+ bash run.sh
59
+
60
+ The expected cosine similarity scores for the three sentences pairs are:
61
+
62
+ 0.2573888301849365, 0.1563197821378708, and 0.28531330823898315.
63
+
64
+ ### Inference
65
+
66
+ To embed a list of sentences:
67
+
68
+ python score_sentence_pairs.py --sentence_pair_file test_data/test_sentence_pairs.tsv
69
+
70
+ To score a list of sentence pairs:
71
+
72
+ python embed_sentences.py --sentence_file test_data/test_sentences.txt
73
+
74
+ ## Citation
75
+
76
+ If you use our code or models your work please cite:
77
+
78
+ @article{wieting2022beyond,
79
+ title={Beyond Contrastive Learning: A Variational Generative Model for Multilingual Retrieval},
80
+ author={Wieting, John and Clark, Jonathan H and Cohen, William W and Neubig, Graham and Berg-Kirkpatrick, Taylor},
81
+ journal={arXiv preprint arXiv:2212.10726},
82
+ year={2022}
83
+ }
config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "jwieting/vmsst",
3
+ "architectures": [
4
+ "MT5EncoderWithProjection"
5
+ ],
6
+ "auto_map": {
7
+ "AutoModel": "jwieting/vmsst--modeling_vmsst.MT5EncoderWithProjection"
8
+ },
9
+ "d_ff": 2816,
10
+ "d_kv": 64,
11
+ "d_model": 1024,
12
+ "decoder_start_token_id": 0,
13
+ "dense_act_fn": "gelu_new",
14
+ "dropout_rate": 0.1,
15
+ "eos_token_id": 1,
16
+ "feed_forward_proj": "gated-gelu",
17
+ "initializer_factor": 1.0,
18
+ "is_encoder_decoder": true,
19
+ "is_gated_act": true,
20
+ "layer_norm_epsilon": 1e-06,
21
+ "model_type": "mt5",
22
+ "num_decoder_layers": 24,
23
+ "num_heads": 16,
24
+ "num_layers": 24,
25
+ "output_past": true,
26
+ "pad_token_id": 0,
27
+ "relative_attention_max_distance": 128,
28
+ "relative_attention_num_buckets": 32,
29
+ "tie_word_embeddings": false,
30
+ "tokenizer_class": "T5Tokenizer",
31
+ "torch_dtype": "float32",
32
+ "transformers_version": "4.30.2",
33
+ "use_cache": true,
34
+ "vocab_size": 250112
35
+ }
modeling_vmsst.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tqdm
3
+ from torch import nn
4
+ from transformers import MT5EncoderModel, MT5PreTrainedModel
5
+
6
+ class MT5EncoderWithProjection(MT5PreTrainedModel):
7
+ def __init__(self, config):
8
+ super().__init__(config)
9
+ self.config = config
10
+ self.mt5_encoder = MT5EncoderModel(config)
11
+ self.projection = nn.Linear(config.d_model, config.d_model, bias=False)
12
+ self.post_init()
13
+
14
+ def forward(self, **input_args):
15
+ hidden_states = self.mt5_encoder(**input_args).last_hidden_state
16
+ mask = input_args['attention_mask']
17
+ batch_embeddings = torch.sum(hidden_states * mask[:, :, None], dim=1) / torch.sum(mask, dim=1)[:, None]
18
+ batch_embeddings = self.projection(batch_embeddings)
19
+ return batch_embeddings
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed70bb240affcc90a945b5905dc643778806ecf9e3c1ff6542de24fa70056228
3
+ size 2262056637