File size: 5,478 Bytes
4350355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8aa2c5
4350355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import http.client as http_client
import json
import logging
import os
import re
import time
import string
import traceback

import gradio as gr
from typing import Callable, Optional, Tuple, Union, Dict, Any
from pyserini import util
from pyserini.search import LuceneSearcher, FaissSearcher, AutoQueryEncoder
from pyserini.index.lucene import IndexReader


Searcher = Union[FaissSearcher, LuceneSearcher]

def _load_sparse_searcher(language: str, k1: Optional[float]=None, b: Optional[float]=None) -> (Searcher):
    searcher = LuceneSearcher(f'index/')
    searcher.set_language(language)
    if k1 is not None and b is not None:
        searcher.set_bm25(k1, b)
        retriever_name = f'BM25 (k1={k1}, b={b})'
    else:
        retriever_name = 'BM25'

    return searcher


def get_docid_html(docid):
    if "False":
        docid_html = (
            f"<a "
            f'class="underline-on-hover"'
            f'style="color:#AA4A44;"'
            'href="https://huggingface.co/datasets/xsum"'
            'target="_blank"><b>πŸ”’xsum</b></a><span style="color: #7978FF;">/'+f'{docid}</span>'
        )
    else:
        docid_html = (
            f"<a "
            f'class="underline-on-hover"'
            'title="This dataset is licensed apache-2.0"'
            f'style="color:#2D31FA;"'
            'href="https://huggingface.co/datasets/πŸš€"'
            'target="_blank"><b>πŸ”’xsum</b></a><span style="color: #7978FF;">/'+f'{docid}</span>'
        )        
    return docid_html

def fetch_index_stats(index_path: str) -> Dict[str, Any]:
    """
    Fetch index statistics
    index_path : str
        Path to index directory
    Returns
    -------
    Dictionary of index statistics
    Dictionary Keys ==> total_terms, documents, unique_terms
    """
    assert os.path.exists(index_path), f"Index path {index_path} does not exist"
    index_reader = IndexReader(index_path)
    return index_reader.stats()

def process_results(results, highlight_terms=[]):
    if len(results) == 0:
        return """<br><p style='font-family: Arial; color:Silver; text-align: center;'>
                No results retrieved.</p><br><hr>"""

    results_html = ""
    for i in range(len(results)):
        tokens = results["text"][i].split()
        tokens_html = []
        for token in tokens:
            if token in highlight_terms:
                tokens_html.append("<b>{}</b>".format(token))
            else:
                tokens_html.append(token)
        tokens_html = " ".join(tokens_html)
        meta_html = (
            """
                <p class='underline-on-hover' style='font-size:12px; font-family: Arial; color:#585858; text-align: left;'>
            """
        )
        docid_html = get_docid_html(results["docid"][i])
        results_html += """{}
            <p style='font-size:20px; font-family: Arial; color:#7978FF; text-align: left;'>Document ID: {}</p>
            <p style='font-size:14px; font-family: Arial; color:#7978FF; text-align: left;'>Score: {}</p>
            <p style='font-size:12px; font-family: Arial; color:MediumAquaMarine'>Language: {}</p>
            <p style='font-family: Arial;font-size:15px;'>{}</p>
            <br>
        """.format(
            meta_html, docid_html, results["score"][i], results["lang"], tokens_html
        )
    return results_html + "<hr>"

def search(query, language, num_results=10):
    searcher = _load_sparse_searcher(language=language)

    t_0 = time.time()
    search_results = searcher.search(query, k=num_results)
    search_time = time.time() - t_0

    results_dict ={"text": [], "docid": [], "score":[], "lang": language}
    for i, result in enumerate(search_results):
        result = json.loads(result.raw)
        results_dict["text"].append(result["contents"])
        results_dict["docid"].append(result["id"])
        results_dict["score"].append(search_results[i].score)

    return process_results(results_dict)

stats = fetch_index_stats('index/')

description = f"""# <h2 style="text-align: center;"> πŸš€ πŸ”Ž XSum Train Dataset Search πŸ” πŸš€ </h2>
<p style="text-align: center;font-size:15px;">A search space built on the Extreme Summarization (XSUM) Dataset with Spacerini</p>
<p style="text-align: center;font-size:20px;">Dataset Statistics: Total Number of Documents = <b>{stats["documents"]}</b>, Number of Terms = <b>{stats["total_terms"]}</b> </p>"""

demo = gr.Blocks(
    css=".underline-on-hover:hover { text-decoration: underline; } .flagging { font-size:12px; color:Silver; }"
)

with demo:
    with gr.Row():
        gr.Markdown(value=description)
    with gr.Row():
        query = gr.Textbox(lines=1, max_lines=1, placeholder="Type your query here...", label="Query")
    with gr.Row():
        lang = gr.Dropdown(
            choices=[
                "en",
            ],
            value="en",
            label="Language",
        )
    with gr.Row():
            k = gr.Slider(1, 100, value=10, step=1, label="Max Results")
    with gr.Row():
        submit_btn = gr.Button("Submit")
    with gr.Row():
        results = gr.HTML(label="Results")


    def submit(query, lang, k):
        query = query.strip()
        if query is None or query == "":
            return "", ""
        return {
            results: search(query, lang, k),
        }

    query.submit(fn=submit, inputs=[query, lang, k], outputs=[results])
    submit_btn.click(submit, inputs=[query, lang, k], outputs=[results])
demo.launch(enable_queue=True, debug=True)