|
--- |
|
license: llama2 |
|
--- |
|
|
|
|
|
# RepLLaMA-7B-Passage-MRL |
|
|
|
[Fine-Tuning LLaMA for Multi-Stage Text Retrieval](https://arxiv.org/abs/2310.08319). |
|
Xueguang Ma, Liang Wang, Nan Yang, Furu Wei, Jimmy Lin, arXiv 2023 |
|
|
|
This model is fine-tuned from LLaMA-2-7B using LoRA and the embedding size is **flexible**, as Matryoshka Representation Learning is applied during training. The maximum dimensionality of query and passage embedding is 4096. |
|
|
|
## Training Data |
|
The model is fine-tuned on the training split of [MS MARCO Passage Ranking](https://microsoft.github.io/msmarco/Datasets) datasets for 1 epoch. |
|
Please check our paper for details. |
|
|
|
## Usage |
|
|
|
Below is an example to encode a query and a passage, and then compute their similarity using their embedding. |
|
|
|
```python |
|
import torch |
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
# Load the tokenizer and model |
|
tokenizer = AutoTokenizer.from_pretrained('castorini/repllama-v1-mrl-7b-lora-passage') |
|
model = AutoModel.from_pretrained('castorini/repllama-v1-mrl-7b-lora-passage') |
|
dim = 512 |
|
|
|
# Define query and passage inputs |
|
query = "What is llama?" |
|
title = "Llama" |
|
passage = "The llama is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era." |
|
query_input = tokenizer(f'query: {query}</s>', return_tensors='pt') |
|
passage_input = tokenizer(f'passage: {title} {passage}</s>', return_tensors='pt') |
|
|
|
# Run the model forward to compute embeddings and query-passage similarity score |
|
with torch.no_grad(): |
|
# compute query embedding |
|
query_outputs = model(**query_input) |
|
query_embedding = query_outputs.last_hidden_state[0][-1][:dim] |
|
query_embedding = torch.nn.functional.normalize(query_embedding, p=2, dim=0) |
|
|
|
# compute passage embedding |
|
passage_outputs = model(**passage_input) |
|
passage_embeddings = passage_outputs.last_hidden_state[0][-1][:dim] |
|
passage_embeddings = torch.nn.functional.normalize(passage_embeddings, p=2, dim=0) |
|
|
|
# compute similarity score |
|
score = torch.dot(query_embedding, passage_embeddings) |
|
print(score) |
|
|
|
``` |
|
## Batch inference and training |
|
An unofficial replication of the inference and training code can be found [here](https://github.com/texttron/tevatron/tree/main/examples/repllama) |
|
|
|
## Citation |
|
|
|
If you find our paper or models helpful, please consider cite as follows: |
|
|
|
``` |
|
@article{rankllama, |
|
title={Fine-Tuning LLaMA for Multi-Stage Text Retrieval}, |
|
author={Xueguang Ma and Liang Wang and Nan Yang and Furu Wei and Jimmy Lin}, |
|
year={2023}, |
|
journal={arXiv:2310.08319}, |
|
} |
|
``` |