sdadas's picture
Update README.md
7b4904c verified
metadata
pipeline_tag: text-classification
tags:
  - transformers
  - information-retrieval
language: pl
license: apache-2.0

polish-reranker-base-mse

This is a Polish text ranking model trained using the mean squared error (MSE) distillation method on a large dataset of text pairs consisting of 1.4 million queries and 10 million documents. The training data included the following parts: 1) The Polish MS MARCO training split (800k queries); 2) The ELI5 dataset translated to Polish (over 500k queries); 3) A collection of Polish medical questions and answers (approximately 100k queries). As a teacher model, we employed unicamp-dl/mt5-13b-mmarco-100k, a large multilingual reranker based on the MT5-XXL architecture. As a student model, we choose Polish RoBERTa. In the MSE method, the student is trained to directly replicate the outputs returned by the teacher.

Usage (Sentence-Transformers)

You can use the model like this with sentence-transformers:

from sentence_transformers import CrossEncoder
import torch.nn

query = "Jak dożyć 100 lat?"
answers = [
    "Trzeba zdrowo się odżywiać i uprawiać sport.",
    "Trzeba pić alkohol, imprezować i jeździć szybkimi autami.",
    "Gdy trwała kampania politycy zapewniali, że rozprawią się z zakazem niedzielnego handlu."
]

model = CrossEncoder(
    "sdadas/polish-reranker-base-mse",
    default_activation_function=torch.nn.Identity(),
    max_length=512,
    device="cuda" if torch.cuda.is_available() else "cpu"
)
pairs = [[query, answer] for answer in answers]
results = model.predict(pairs)
print(results.tolist())

Usage (Huggingface Transformers)

The model can also be used with Huggingface Transformers in the following way:

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np

query = "Jak dożyć 100 lat?"
answers = [
    "Trzeba zdrowo się odżywiać i uprawiać sport.",
    "Trzeba pić alkohol, imprezować i jeździć szybkimi autami.",
    "Gdy trwała kampania politycy zapewniali, że rozprawią się z zakazem niedzielnego handlu."
]

model_name = "sdadas/polish-reranker-base-mse"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
texts = [f"{query}</s></s>{answer}" for answer in answers]
tokens = tokenizer(texts, padding="longest", max_length=512, truncation=True, return_tensors="pt")
output = model(**tokens)
results = output.logits.detach().numpy()
results = np.squeeze(results)
print(results.tolist())

Evaluation Results

The model achieves NDCG@10 of 57.50 in the Rerankers category of the Polish Information Retrieval Benchmark. See PIRB Leaderboard for detailed results.

Citation

@article{dadas2024assessing,
  title={Assessing generalization capability of text ranking models in Polish}, 
  author={Sławomir Dadas and Małgorzata Grębowiec},
  year={2024},
  eprint={2402.14318},
  archivePrefix={arXiv},
  primaryClass={cs.CL}
}