|
--- |
|
pipeline_tag: text-classification |
|
tags: |
|
- transformers |
|
- information-retrieval |
|
language: pl |
|
license: apache-2.0 |
|
|
|
--- |
|
|
|
<h1 align="center">polish-reranker-base-mse</h1> |
|
|
|
This is a Polish text ranking model trained using the mean squared error (MSE) distillation method on a large dataset of text pairs consisting of 1.4 million queries and 10 million documents. |
|
The training data included the following parts: 1) The Polish MS MARCO training split (800k queries); 2) The ELI5 dataset translated to Polish (over 500k queries); 3) A collection of Polish medical questions and answers (approximately 100k queries). |
|
As a teacher model, we employed [unicamp-dl/mt5-13b-mmarco-100k](https://huggingface.co/unicamp-dl/mt5-13b-mmarco-100k), a large multilingual reranker based on the MT5-XXL architecture. As a student model, we choose [Polish RoBERTa](https://huggingface.co/sdadas/polish-roberta-base-v2). |
|
In the MSE method, the student is trained to directly replicate the outputs returned by the teacher. |
|
|
|
## Usage (Sentence-Transformers) |
|
|
|
You can use the model like this with [sentence-transformers](https://www.SBERT.net): |
|
|
|
```python |
|
from sentence_transformers import CrossEncoder |
|
import torch.nn |
|
|
|
query = "Jak dożyć 100 lat?" |
|
answers = [ |
|
"Trzeba zdrowo się odżywiać i uprawiać sport.", |
|
"Trzeba pić alkohol, imprezować i jeździć szybkimi autami.", |
|
"Gdy trwała kampania politycy zapewniali, że rozprawią się z zakazem niedzielnego handlu." |
|
] |
|
|
|
model = CrossEncoder( |
|
"sdadas/polish-reranker-base-mse", |
|
default_activation_function=torch.nn.Identity(), |
|
max_length=512, |
|
device="cuda" if torch.cuda.is_available() else "cpu" |
|
) |
|
pairs = [[query, answer] for answer in answers] |
|
results = model.predict(pairs) |
|
print(results.tolist()) |
|
``` |
|
|
|
## Usage (Huggingface Transformers) |
|
|
|
The model can also be used with Huggingface Transformers in the following way: |
|
|
|
```python |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import numpy as np |
|
|
|
query = "Jak dożyć 100 lat?" |
|
answers = [ |
|
"Trzeba zdrowo się odżywiać i uprawiać sport.", |
|
"Trzeba pić alkohol, imprezować i jeździć szybkimi autami.", |
|
"Gdy trwała kampania politycy zapewniali, że rozprawią się z zakazem niedzielnego handlu." |
|
] |
|
|
|
model_name = "sdadas/polish-reranker-base-mse" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
texts = [f"{query}</s></s>{answer}" for answer in answers] |
|
tokens = tokenizer(texts, padding="longest", max_length=512, truncation=True, return_tensors="pt") |
|
output = model(**tokens) |
|
results = output.logits.detach().numpy() |
|
results = np.squeeze(results) |
|
print(results.tolist()) |
|
``` |
|
|
|
## Evaluation Results |
|
|
|
The model achieves **NDCG@10** of **57.50** in the Rerankers category of the Polish Information Retrieval Benchmark. See [PIRB Leaderboard](https://huggingface.co/spaces/sdadas/pirb) for detailed results. |