Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,46 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## facebook/tart-full-flan-t5-xl
|
2 |
+
|
3 |
+
`facebook/tart-full-flan-t5-xl` is a multi-task cross-encoder model trained via instruction-tuning on approximately 40 retrieval tasks, initialized with [google/flan-t5-xl](https://huggingface.co/google/flan-t5-xl).
|
4 |
+
|
5 |
+
### Installation
|
6 |
+
```
|
7 |
+
git clone https://github.com/facebookresearch/tart
|
8 |
+
pip install -r requirements.txt
|
9 |
+
cd tart/TART
|
10 |
+
```
|
11 |
+
|
12 |
+
TART-full can be loaded through our customized EncT5 model.
|
13 |
+
```python
|
14 |
+
from src.modeling_enc_t5 import EncT5ForSequenceClassification
|
15 |
+
from src.tokenization_enc_t5 import EncT5Tokenizer
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
|
19 |
+
# load TART full and tokenizer
|
20 |
+
model = EncT5ForSequenceClassification.from_pretrained("tart_full_flan_t5_xl")
|
21 |
+
tokenizer = EncT5Tokenizer.from_pretrained("tart_full_flan_t5_xl")
|
22 |
+
model.eval()
|
23 |
+
|
24 |
+
q = "What is the population of Tokyo?"
|
25 |
+
in_answer = "retrieve a passage that answers this question from Wikipedia"
|
26 |
+
|
27 |
+
p_1 = "The population of Japan's capital, Tokyo, dropped by about 48,600 people to just under 14 million at the start of 2022."
|
28 |
+
p_2 = "Tokyo, officially the Tokyo Metropolis (東京都, Tōkyō-to), is the capital and largest city of Japan."
|
29 |
+
|
30 |
+
# 1. TART-full can identify more relevant paragraph.
|
31 |
+
features = tokenizer(['{0} [SEP] {1}'.format(in_answer, q), '{0} [SEP] {1}'.format(in_answer, q)], [p_1, p_2], padding=True, truncation=True, return_tensors="pt")
|
32 |
+
with torch.no_grad():
|
33 |
+
scores = model(**features).logits
|
34 |
+
normalized_scores = [float(score[1]) for score in F.softmax(scores, dim=1)]
|
35 |
+
print([p_1, p_2]np.argmax(normalized_scores)) # "The population of Japan's capital, Tokyo, dropped by about 48,600 people to just under 14 million."
|
36 |
+
|
37 |
+
# 2. TART-full can identify the document that is more relevant AND follows instructions.
|
38 |
+
in_sim = "You need to find duplicated questions in Wiki forum. Could you find a question that is similar to this question"
|
39 |
+
q_1 = "How many people live in Tokyo?"
|
40 |
+
features = tokenizer(['{0} [SEP] {1}'.format(in_sim, q), '{0} [SEP] {1}'.format(in_sim, q)], [p, q_1], padding=True, truncation=True, return_tensors="pt")
|
41 |
+
with torch.no_grad():
|
42 |
+
scores = model(**features).logits
|
43 |
+
normalized_scores = [float(score[1]) for score in F.softmax(scores, dim=1)]
|
44 |
+
|
45 |
+
print([p, q_1]np.argmax(normalized_scores)) # "How many people live in Tokyo?"
|
46 |
+
```
|