Spaces:
Paused
Paused
import json | |
import os | |
import traceback | |
from typing import List, Tuple | |
import gradio as gr | |
import requests | |
from huggingface_hub import HfApi | |
hf_api = HfApi() | |
roots_datasets = { | |
dset.id.split("/")[-1]: dset | |
for dset in hf_api.list_datasets( | |
author="bigscience-data", use_auth_token=os.environ.get("bigscience_data_token") | |
) | |
} | |
def get_docid_html(docid): | |
data_org, dataset, docid = docid.split("/") | |
metadata = roots_datasets[dataset] | |
if metadata.private: | |
docid_html = """ | |
<a title="This dataset is private. See the introductory text for more information" | |
style="color:#AA4A44; font-weight: bold; text-decoration:none" | |
onmouseover="style='color:#AA4A44; font-weight: bold; text-decoration:underline'" | |
onmouseout="style='color:#AA4A44; font-weight: bold; text-decoration:none'" | |
href="https://huggingface.co/datasets/bigscience-data/{dataset}" | |
target="_blank"> | |
π{dataset} | |
</a> | |
<span style="color:#7978FF; ">/{docid}</span>""".format( | |
dataset=dataset, docid=docid | |
) | |
else: | |
docid_html = """ | |
<a title="This dataset is licensed {metadata}" | |
style="color:#7978FF; font-weight: bold; text-decoration:none" | |
onmouseover="style='color:#7978FF; font-weight: bold; text-decoration:underline'" | |
onmouseout="style='color:#7978FF; font-weight: bold; text-decoration:none'" | |
href="https://huggingface.co/datasets/bigscience-data/{dataset}" | |
target="_blank"> | |
{dataset} | |
</a> | |
<span style="color:#7978FF; ">/{docid}</span>""".format( | |
metadata=metadata.tags[0].split(":")[-1], dataset=dataset, docid=docid | |
) | |
return docid_html | |
PII_TAGS = {"KEY", "EMAIL", "USER", "IP_ADDRESS", "ID", "IPv4", "IPv6"} | |
PII_PREFIX = "PI:" | |
def process_pii(text): | |
for tag in PII_TAGS: | |
text = text.replace( | |
PII_PREFIX + tag, | |
"""<b><mark style="background: Fuchsia; color: Lime;">REDACTED {}</mark></b>""".format( | |
tag | |
), | |
) | |
return text | |
def flag(query, language, num_results, issue_description): | |
try: | |
post_data = { | |
"query": query, | |
"k": num_results, | |
"flag": True, | |
"description": issue_description, | |
} | |
if language != "detect_language": | |
post_data["lang"] = language | |
output = requests.post( | |
os.environ.get("address"), | |
headers={"Content-type": "application/json"}, | |
data=json.dumps(post_data), | |
timeout=120, | |
) | |
results = json.loads(output.text) | |
except: | |
print("Error flagging") | |
return "" | |
def format_result(result, highlight_terms, exact_search, datasets_filter=None): | |
# print("result", result) | |
text, url, docid = result | |
if datasets_filter is not None: | |
datasets_filter = set(datasets_filter) | |
dataset = docid.split("/")[1] | |
if not dataset in datasets_filter: | |
return "" | |
if exact_search: | |
query_start = text.find(highlight_terms) | |
query_end = query_start + len(highlight_terms) | |
tokens_html = text[0:query_start] | |
tokens_html += "<b>{}</b>".format(text[query_start:query_end]) | |
tokens_html += text[query_end:] | |
else: | |
tokens = text.split() | |
tokens_html = [] | |
for token in tokens: | |
if token in highlight_terms: | |
tokens_html.append("<b>{}</b>".format(token)) | |
else: | |
tokens_html.append(token) | |
tokens_html = " ".join(tokens_html) | |
tokens_html = process_pii(tokens_html) | |
url_html = ( | |
""" | |
<span style='font-size:12px; font-family: Arial; color:Silver; text-align: left;'> | |
<a style='text-decoration:none; color:Silver;' | |
onmouseover="style='text-decoration:underline; color:Silver;'" | |
onmouseout="style='text-decoration:none; color:Silver;'" | |
href='{url}' | |
target="_blank"> | |
{url} | |
</a> | |
</span><br> | |
""".format( | |
url=url | |
) | |
if url is not None | |
else "" | |
) | |
docid_html = get_docid_html(docid) | |
language = "FIXME" | |
result_html = """{} | |
<span style='font-size:14px; font-family: Arial; color:#7978FF; text-align: left;'>Document ID: {}</span> | |
<button type="button" onclick="alert('Hello world!')">Flag result</button><br> | |
<!-- <span style='font-size:12px; font-family: Arial; color:MediumAquaMarine'>Language: {}</span><br> --> | |
<span style='font-family: Arial;'>{}</span><br> | |
<br> | |
""".format( | |
url_html, docid_html, language, tokens_html | |
) | |
return "<p>" + result_html + "</p>" | |
def format_result_page( | |
results, highlight_terms, num_results, exact_search, datasets_filter=None | |
): | |
results_html = [] | |
for result in results: | |
result_html = format_result( | |
result, highlight_terms, exact_search, datasets_filter | |
) | |
if result_html != "": | |
results_html.append(result_html) | |
return results_html | |
def extract_results_from_payload(query, language, payload, exact_search): | |
results = payload["results"] | |
processed_results = list() | |
datasets = set() | |
highlight_terms = None | |
num_results = None | |
if exact_search: | |
highlight_terms = query | |
num_results = payload["num_results"] | |
else: | |
highlight_terms = payload["highlight_terms"] | |
results = [] | |
for lang, res_for_lang in payload["results"].items(): | |
for result in res_for_lang: | |
results.append(result) | |
for result in results: | |
text = result["text"] | |
url = ( | |
result["meta"]["url"] | |
if "meta" in result | |
and result["meta"] is not None | |
and "url" in result["meta"] | |
else None | |
) | |
docid = result["docid"] | |
_, dataset, _ = docid.split("/") | |
datasets.add(dataset) | |
processed_results.append((text, url, docid)) | |
return processed_results, highlight_terms, num_results, list(datasets) | |
def process_error(error_type): | |
if error_type == "unsupported_lang": | |
detected_lang = payload["err"]["meta"]["detected_lang"] | |
return f""" | |
<p style='font-size:18px; font-family: Arial; color:MediumVioletRed; text-align: center;'> | |
Detected language <b>{detected_lang}</b> is not supported.<br> | |
Please choose a language from the dropdown or type another query. | |
</p><br><hr><br>""" | |
def extract_error_from_payload(payload): | |
if "err" in payload: | |
return payload["err"]["type"] | |
return None | |
def request_payload(query, language, exact_search, num_results=10, received_results=0): | |
post_data = {"query": query, "k": num_results, "received_results": received_results} | |
if language != "detect_language": | |
post_data["lang"] = language | |
address = "http://34.105.160.81:8080" if exact_search else os.environ.get("address") | |
output = requests.post( | |
address, | |
headers={"Content-type": "application/json"}, | |
data=json.dumps(post_data), | |
timeout=60, | |
) | |
payload = json.loads(output.text) | |
return payload | |
title = ( | |
"""<p style="text-align: center; font-size:28px"> πΈ π ROOTS search tool π πΈ </p>""" | |
) | |
description = """ | |
The ROOTS corpus was developed during the [BigScience workshop](https://bigscience.huggingface.co/) for the purpose | |
of training the Multilingual Large Language Model [BLOOM](https://huggingface.co/bigscience/bloom). This tool allows | |
you to search through the ROOTS corpus. We serve a BM25 index for each language or group of languages included in | |
ROOTS. You can read more about the details of the tool design | |
[here](https://huggingface.co/spaces/bigscience-data/scisearch/blob/main/roots_search_tool_specs.pdf). For more | |
information and instructions on how to access the full corpus check [this form](https://forms.gle/qyYswbEL5kA23Wu99).""" | |
if __name__ == "__main__": | |
demo = gr.Blocks( | |
css=".underline-on-hover:hover { text-decoration: underline; } .flagging { font-size:12px; color:Silver; }" | |
) | |
with demo: | |
processed_results_state = gr.State([]) | |
highlight_terms_state = gr.State([]) | |
num_results_state = gr.State(0) | |
exact_search_state = gr.State(False) | |
lang_state = gr.State("") | |
max_page_size_state = gr.State(100) | |
received_results_state = gr.State(0) | |
with gr.Row(): | |
gr.Markdown(value=title) | |
with gr.Row(): | |
gr.Markdown(value=description) | |
with gr.Row(): | |
query = gr.Textbox( | |
lines=1, | |
max_lines=1, | |
placeholder="Put your query in double quotes for exact search.", | |
label="Query", | |
) | |
with gr.Row(): | |
lang = gr.Dropdown( | |
choices=[ | |
"ar", | |
"ca", | |
"code", | |
"en", | |
"es", | |
"eu", | |
"fr", | |
"id", | |
"indic", | |
"nigercongo", | |
"pt", | |
"vi", | |
"zh", | |
"detect_language", | |
], | |
value="en", | |
label="Language", | |
) | |
k = gr.Slider(1, 100, value=10, step=1, label="Max Results") | |
with gr.Row(): | |
submit_btn = gr.Button("Submit") | |
with gr.Row(visible=False) as datasets_filter: | |
available_datasets = gr.Dropdown( | |
type="value", | |
choices=[], | |
value=[], | |
label="Datasets Filter", | |
multiselect=True, | |
) | |
with gr.Row(): | |
header_html = gr.HTML(label="Header", value="hello") | |
results_html = [] | |
for i in range(100): | |
results_html.append(gr.HTML(label="Results")) | |
with gr.Row(visible=False) as pagination: | |
next_page_btn = gr.Button("Next Page") | |
def run_query(query, lang, k, dropdown_input, max_page_size, received_results): | |
query = query.strip() | |
exact_search = False | |
if query.startswith('"') and query.endswith('"') and len(query) >= 2: | |
exact_search = True | |
query = query[1:-1] | |
k = max_page_size | |
else: | |
query = " ".join(query.split()) | |
if query == "" or query is None: | |
return None | |
print("submitting", query, lang, k) | |
payload = request_payload(query, lang, exact_search, k, received_results) | |
err = extract_error_from_payload(payload) | |
if err is not None: | |
return process_error(err) | |
( | |
processed_results, | |
highlight_terms, | |
num_results, | |
ds, | |
) = extract_results_from_payload( | |
query, | |
lang, | |
payload, | |
exact_search, | |
) | |
header_html = "" | |
if lang == "detect_language" and not exact_search: | |
header_html += """<div style='font-family: Arial; color:MediumAquaMarine; text-align: center; line-height: 3em'> | |
Detected language: <b style='color:MediumAquaMarine'>{}</b></div>""".format( | |
"FIX ME!" | |
) | |
if len(processed_results) == 0: | |
header_html += """<div style='font-family: Arial; color:Silver; text-align: left; line-height: 3em'> | |
No results found.</div>""" | |
elif num_results is not None: | |
header_html += """<div style='font-family: Arial; color:MediumAquaMarine; text-align: center; line-height: 3em'> | |
Total number of matches: <b style='color:MediumAquaMarine'>{}</b></div>""".format( | |
num_results | |
) | |
# print("processed_results", processed_results) | |
results_html = format_result_page( | |
processed_results, highlight_terms, num_results, exact_search | |
) | |
return ( | |
processed_results, | |
highlight_terms, | |
num_results, | |
exact_search, | |
results_html, | |
ds, | |
) | |
def submit(query, lang, k, dropdown_input, max_page_size): | |
( | |
processed_results, | |
highlight_terms, | |
num_results, | |
exact_search, | |
results_html, | |
datasets, | |
) = run_query(query, lang, k, dropdown_input, max_page_size, 0) | |
has_more_results = exact_search and (num_results > max_page_size) | |
return ( | |
[ | |
processed_results, | |
highlight_terms, | |
num_results, | |
exact_search, | |
gr.update(visible=True), | |
gr.Dropdown.update(choices=datasets, value=datasets), | |
gr.update(visible=has_more_results), | |
len(processed_results), | |
] | |
+ results_html | |
+ [gr.update(visible=False)] * (100 - len(results_html)) | |
) | |
def next_page( | |
query, | |
lang, | |
k, | |
dropdown_input, | |
max_page_size, | |
received_results, | |
processed_results, | |
): | |
( | |
processed_results, | |
highlight_terms, | |
num_results, | |
exact_search, | |
results_html, | |
datasets, | |
) = run_query( | |
query, lang, k, dropdown_input, max_page_size, received_results | |
) | |
num_processed_results = len(processed_results) | |
has_more_results = exact_search and (num_results > max_page_size) | |
print("num_processed_results", num_processed_results) | |
print("has_more_results", has_more_results) | |
print("received_results", received_results) | |
return ( | |
[ | |
processed_results, | |
highlight_terms, | |
num_results, | |
exact_search, | |
gr.update(visible=True), | |
gr.Dropdown.update(choices=datasets, value=datasets), | |
gr.update( | |
visible=num_processed_results >= max_page_size | |
and has_more_results | |
), | |
received_results + num_processed_results, | |
] | |
+ results_html | |
+ [gr.update(visible=False)] * (100 - len(results_html)) | |
) | |
def filter_datasets( | |
lang, | |
processed_results, | |
highlight_terms, | |
num_results, | |
exact_search, | |
datasets_filter, | |
): | |
results_html = format_result_page( | |
processed_results, | |
highlight_terms, | |
num_results, | |
exact_search, | |
datasets_filter, | |
) | |
return results_html + [gr.update(visible=False)] * (100 - len(results_html)) | |
query.submit( | |
fn=submit, | |
inputs=[query, lang, k, available_datasets, max_page_size_state], | |
outputs=[ | |
processed_results_state, | |
highlight_terms_state, | |
num_results_state, | |
exact_search_state, | |
datasets_filter, | |
available_datasets, | |
pagination, | |
received_results_state, | |
] | |
+ results_html, | |
) | |
submit_btn.click( | |
submit, | |
inputs=[query, lang, k, available_datasets, max_page_size_state], | |
outputs=[ | |
processed_results_state, | |
highlight_terms_state, | |
num_results_state, | |
exact_search_state, | |
datasets_filter, | |
available_datasets, | |
pagination, | |
received_results_state, | |
] | |
+ results_html, | |
) | |
next_page_btn.click( | |
next_page, | |
inputs=[ | |
query, | |
lang, | |
k, | |
available_datasets, | |
max_page_size_state, | |
received_results_state, | |
processed_results_state, | |
], | |
outputs=[ | |
processed_results_state, | |
highlight_terms_state, | |
num_results_state, | |
exact_search_state, | |
datasets_filter, | |
available_datasets, | |
pagination, | |
received_results_state, | |
] | |
+ results_html, | |
) | |
available_datasets.change( | |
filter_datasets, | |
inputs=[ | |
lang, | |
processed_results_state, | |
highlight_terms_state, | |
num_results_state, | |
exact_search_state, | |
available_datasets, | |
], | |
outputs=results_html, | |
) | |
results_html[0].change( | |
filter_datasets, | |
inputs=[ | |
lang, | |
processed_results_state, | |
highlight_terms_state, | |
num_results_state, | |
exact_search_state, | |
available_datasets, | |
], | |
outputs=results_html, | |
) | |
demo.launch(enable_queue=True, debug=True) | |