import json
import os
import traceback
from typing import List, Tuple
import gradio as gr
import requests
from huggingface_hub import HfApi
hf_api = HfApi()
roots_datasets = {
dset.id.split("/")[-1]: dset
for dset in hf_api.list_datasets(
author="bigscience-data", use_auth_token=os.environ.get("bigscience_data_token")
)
}
def get_docid_html(docid):
data_org, dataset, docid = docid.split("/")
metadata = roots_datasets[dataset]
if metadata.private:
docid_html = """
🔒{dataset}
/{docid}""".format(
dataset=dataset, docid=docid
)
else:
docid_html = """
{dataset}
/{docid}""".format(
metadata=metadata.tags[0].split(":")[-1], dataset=dataset, docid=docid
)
return docid_html
PII_TAGS = {"KEY", "EMAIL", "USER", "IP_ADDRESS", "ID", "IPv4", "IPv6"}
PII_PREFIX = "PI:"
def process_pii(text):
for tag in PII_TAGS:
text = text.replace(
PII_PREFIX + tag,
"""REDACTED {}""".format(
tag
),
)
return text
def flag(query, language, num_results, issue_description):
try:
post_data = {
"query": query,
"k": num_results,
"flag": True,
"description": issue_description,
}
if language != "detect_language":
post_data["lang"] = language
output = requests.post(
os.environ.get("address"),
headers={"Content-type": "application/json"},
data=json.dumps(post_data),
timeout=120,
)
results = json.loads(output.text)
except:
print("Error flagging")
return ""
def format_result(result, highlight_terms, exact_search, datasets_filter=None):
# print("result", result)
text, url, docid = result
if datasets_filter is not None:
datasets_filter = set(datasets_filter)
dataset = docid.split("/")[1]
if not dataset in datasets_filter:
return ""
if exact_search:
query_start = text.find(highlight_terms)
query_end = query_start + len(highlight_terms)
tokens_html = text[0:query_start]
tokens_html += "{}".format(text[query_start:query_end])
tokens_html += text[query_end:]
else:
tokens = text.split()
tokens_html = []
for token in tokens:
if token in highlight_terms:
tokens_html.append("{}".format(token))
else:
tokens_html.append(token)
tokens_html = " ".join(tokens_html)
tokens_html = process_pii(tokens_html)
url_html = (
"""
{url}
""".format(
url=url
)
if url is not None
else ""
)
docid_html = get_docid_html(docid)
language = "FIXME"
result_html = """{}
Document ID: {}
{}
""".format(
url_html, docid_html, language, tokens_html
)
return "
" + result_html + "
" def format_result_page( results, highlight_terms, num_results, exact_search, datasets_filter=None ): results_html = [] for result in results: result_html = format_result( result, highlight_terms, exact_search, datasets_filter ) if result_html != "": results_html.append(result_html) return results_html def extract_results_from_payload(query, language, payload, exact_search): results = payload["results"] processed_results = list() datasets = set() highlight_terms = None num_results = None if exact_search: highlight_terms = query num_results = payload["num_results"] else: highlight_terms = payload["highlight_terms"] results = [] for lang, res_for_lang in payload["results"].items(): for result in res_for_lang: results.append(result) for result in results: text = result["text"] url = ( result["meta"]["url"] if "meta" in result and result["meta"] is not None and "url" in result["meta"] else None ) docid = result["docid"] _, dataset, _ = docid.split("/") datasets.add(dataset) processed_results.append((text, url, docid)) return processed_results, highlight_terms, num_results, list(datasets) def process_error(error_type): if error_type == "unsupported_lang": detected_lang = payload["err"]["meta"]["detected_lang"] return f"""
Detected language {detected_lang} is not supported.
Please choose a language from the dropdown or type another query.
🌸 🔎 ROOTS search tool 🔍 🌸
""" ) description = """ The ROOTS corpus was developed during the [BigScience workshop](https://bigscience.huggingface.co/) for the purpose of training the Multilingual Large Language Model [BLOOM](https://huggingface.co/bigscience/bloom). This tool allows you to search through the ROOTS corpus. We serve a BM25 index for each language or group of languages included in ROOTS. You can read more about the details of the tool design [here](https://huggingface.co/spaces/bigscience-data/scisearch/blob/main/roots_search_tool_specs.pdf). For more information and instructions on how to access the full corpus check [this form](https://forms.gle/qyYswbEL5kA23Wu99).""" if __name__ == "__main__": demo = gr.Blocks( css=".underline-on-hover:hover { text-decoration: underline; } .flagging { font-size:12px; color:Silver; }" ) with demo: processed_results_state = gr.State([]) highlight_terms_state = gr.State([]) num_results_state = gr.State(0) exact_search_state = gr.State(False) lang_state = gr.State("") max_page_size_state = gr.State(100) received_results_state = gr.State(0) with gr.Row(): gr.Markdown(value=title) with gr.Row(): gr.Markdown(value=description) with gr.Row(): query = gr.Textbox( lines=1, max_lines=1, placeholder="Put your query in double quotes for exact search.", label="Query", ) with gr.Row(): lang = gr.Dropdown( choices=[ "ar", "ca", "code", "en", "es", "eu", "fr", "id", "indic", "nigercongo", "pt", "vi", "zh", "detect_language", ], value="en", label="Language", ) k = gr.Slider(1, 100, value=10, step=1, label="Max Results") with gr.Row(): submit_btn = gr.Button("Submit") with gr.Row(visible=False) as datasets_filter: available_datasets = gr.Dropdown( type="value", choices=[], value=[], label="Datasets Filter", multiselect=True, ) with gr.Row(): header_html = gr.HTML(label="Header", value="hello") results_html = [] for i in range(100): results_html.append(gr.HTML(label="Results")) with gr.Row(visible=False) as pagination: next_page_btn = gr.Button("Next Page") def run_query(query, lang, k, dropdown_input, max_page_size, received_results): query = query.strip() exact_search = False if query.startswith('"') and query.endswith('"') and len(query) >= 2: exact_search = True query = query[1:-1] k = max_page_size else: query = " ".join(query.split()) if query == "" or query is None: return None print("submitting", query, lang, k) payload = request_payload(query, lang, exact_search, k, received_results) err = extract_error_from_payload(payload) if err is not None: return process_error(err) ( processed_results, highlight_terms, num_results, ds, ) = extract_results_from_payload( query, lang, payload, exact_search, ) header_html = "" if lang == "detect_language" and not exact_search: header_html += """