File size: 3,834 Bytes
62977bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import json
import os

from abc import ABC, abstractmethod
from enum import Enum, unique
from typing import List

from pyserini.search import JLuceneSearcherResult


@unique
class OutputFormat(Enum):
    TREC = 'trec'
    MSMARCO = "msmarco"
    KILT = 'kilt'


class OutputWriter(ABC):

    def __init__(self, file_path: str, mode: str = 'w',
                 max_hits: int = 1000, tag: str = None, topics: dict = None,
                 use_max_passage: bool = False, max_passage_delimiter: str = None, max_passage_hits: int = 100):
        self.file_path = file_path
        self.mode = mode
        self.tag = tag
        self.topics = topics
        self.use_max_passage = use_max_passage
        self.max_passage_delimiter = max_passage_delimiter if use_max_passage else None
        self.max_hits = max_passage_hits if use_max_passage else max_hits
        self._file = None

    def __enter__(self):
        dirname = os.path.dirname(self.file_path)
        if dirname:
            os.makedirs(dirname, exist_ok=True)
        self._file = open(self.file_path, self.mode)
        return self

    def __exit__(self, exc_type, exc_value, exc_traceback):
        self._file.close()

    def hits_iterator(self, hits: List[JLuceneSearcherResult]):
        unique_docs = set()
        rank = 1
        for hit in hits:
            if self.use_max_passage and self.max_passage_delimiter:
                docid = hit.docid.split(self.max_passage_delimiter)[0]
            else:
                docid = hit.docid.strip()

            if self.use_max_passage:
                if docid in unique_docs:
                    continue
                unique_docs.add(docid)

            yield docid, rank, hit.score, hit

            rank = rank + 1
            if rank > self.max_hits:
                break

    @abstractmethod
    def write(self, topic: str, hits: List[JLuceneSearcherResult]):
        raise NotImplementedError()


class TrecWriter(OutputWriter):
    def write(self, topic: str, hits: List[JLuceneSearcherResult]):
        for docid, rank, score, _ in self.hits_iterator(hits):
            self._file.write(f'{topic} Q0 {docid} {rank} {score:.6f} {self.tag}\n')


class MsMarcoWriter(OutputWriter):
    def write(self, topic: str, hits: List[JLuceneSearcherResult]):
        for docid, rank, score, _ in self.hits_iterator(hits):
            self._file.write(f'{topic}\t{docid}\t{rank}\n')


class KiltWriter(OutputWriter):
    def write(self, topic: str, hits: List[JLuceneSearcherResult]):
        datapoint = self.topics[topic]
        provenance = []
        for docid, rank, score, _ in self.hits_iterator(hits):
            provenance.append({"wikipedia_id": docid})
        datapoint["output"] = [{"provenance": provenance}]
        json.dump(datapoint, self._file)
        self._file.write('\n')


def get_output_writer(file_path: str, output_format: OutputFormat, *args, **kwargs) -> OutputWriter:
    mapping = {
        OutputFormat.TREC: TrecWriter,
        OutputFormat.MSMARCO: MsMarcoWriter,
        OutputFormat.KILT: KiltWriter,
    }
    return mapping[output_format](file_path, *args, **kwargs)


def tie_breaker(hits):
    return sorted(hits, key=lambda x: (-x.score, x.docid))