Spaces:
Running
Running
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Module to compute TREC evaluation scores.""" | |
import datasets | |
import pandas as pd | |
from trectools import TrecEval, TrecQrel, TrecRun | |
import evaluate | |
_CITATION = """\ | |
@inproceedings{palotti2019, | |
author = {Palotti, Joao and Scells, Harrisen and Zuccon, Guido}, | |
title = {TrecTools: an open-source Python library for Information Retrieval practitioners involved in TREC-like campaigns}, | |
series = {SIGIR'19}, | |
year = {2019}, | |
location = {Paris, France}, | |
publisher = {ACM} | |
} | |
""" | |
# TODO: Add description of the module here | |
_DESCRIPTION = """\ | |
The TREC Eval metric combines a number of information retrieval metrics such as \ | |
precision and nDCG. It is used to score rankings of retrieved documents with reference values.""" | |
# TODO: Add description of the arguments of the module here | |
_KWARGS_DESCRIPTION = """ | |
Calculates TREC evaluation scores based on a run and qrel. | |
Args: | |
predictions: list containing a single run. | |
references: list containing a single qrel. | |
Returns: | |
dict: TREC evaluation scores. | |
Examples: | |
>>> trec = evaluate.load("trec_eval") | |
>>> qrel = { | |
... "query": [0], | |
... "q0": ["0"], | |
... "docid": ["doc_1"], | |
... "rel": [2] | |
... } | |
>>> run = { | |
... "query": [0, 0], | |
... "q0": ["q0", "q0"], | |
... "docid": ["doc_2", "doc_1"], | |
... "rank": [0, 1], | |
... "score": [1.5, 1.2], | |
... "system": ["test", "test"] | |
... } | |
>>> results = trec.compute(references=[qrel], predictions=[run]) | |
>>> print(results["P@5"]) | |
0.2 | |
""" | |
class TRECEval(evaluate.Metric): | |
"""Compute TREC evaluation scores.""" | |
def _info(self): | |
return evaluate.MetricInfo( | |
module_type="metric", | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
inputs_description=_KWARGS_DESCRIPTION, | |
features=datasets.Features( | |
{ | |
"predictions": { | |
"query": datasets.Sequence(datasets.Value("int64")), | |
"q0": datasets.Sequence(datasets.Value("string")), | |
"docid": datasets.Sequence(datasets.Value("string")), | |
"rank": datasets.Sequence(datasets.Value("int64")), | |
"score": datasets.Sequence(datasets.Value("float")), | |
"system": datasets.Sequence(datasets.Value("string")), | |
}, | |
"references": { | |
"query": datasets.Sequence(datasets.Value("int64")), | |
"q0": datasets.Sequence(datasets.Value("string")), | |
"docid": datasets.Sequence(datasets.Value("string")), | |
"rel": datasets.Sequence(datasets.Value("int64")), | |
}, | |
} | |
), | |
homepage="https://github.com/joaopalotti/trectools", | |
) | |
def _compute(self, references, predictions): | |
"""Returns the TREC evaluation scores.""" | |
if len(predictions) > 1 or len(references) > 1: | |
raise ValueError( | |
f"You can only pass one prediction and reference per evaluation. You passed {len(predictions)} prediction(s) and {len(references)} reference(s)." | |
) | |
df_run = pd.DataFrame(predictions[0]) | |
df_qrel = pd.DataFrame(references[0]) | |
trec_run = TrecRun() | |
trec_run.filename = "placeholder.file" | |
trec_run.run_data = df_run | |
trec_qrel = TrecQrel() | |
trec_qrel.filename = "placeholder.file" | |
trec_qrel.qrels_data = df_qrel | |
trec_eval = TrecEval(trec_run, trec_qrel) | |
result = {} | |
result["runid"] = trec_eval.run.get_runid() | |
result["num_ret"] = trec_eval.get_retrieved_documents(per_query=False) | |
result["num_rel"] = trec_eval.get_relevant_documents(per_query=False) | |
result["num_rel_ret"] = trec_eval.get_relevant_retrieved_documents(per_query=False) | |
result["num_q"] = len(trec_eval.run.topics()) | |
result["map"] = trec_eval.get_map(depth=10000, per_query=False, trec_eval=True) | |
result["gm_map"] = trec_eval.get_geometric_map(depth=10000, trec_eval=True) | |
result["bpref"] = trec_eval.get_bpref(depth=1000, per_query=False, trec_eval=True) | |
result["Rprec"] = trec_eval.get_rprec(depth=1000, per_query=False, trec_eval=True) | |
result["recip_rank"] = trec_eval.get_reciprocal_rank(depth=1000, per_query=False, trec_eval=True) | |
for v in [5, 10, 15, 20, 30, 100, 200, 500, 1000]: | |
result[f"P@{v}"] = trec_eval.get_precision(depth=v, per_query=False, trec_eval=True) | |
for v in [5, 10, 15, 20, 30, 100, 200, 500, 1000]: | |
result[f"NDCG@{v}"] = trec_eval.get_ndcg(depth=v, per_query=False, trec_eval=True) | |
return result | |