|
import streamlit as st
|
|
import io
|
|
import os
|
|
import yaml
|
|
import pyarrow
|
|
import tokenizers
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
|
|
|
|
st.set_page_config(layout="wide")
|
|
|
|
@st.cache_resource
|
|
def from_library():
|
|
from retro_reader import RetroReader
|
|
from retro_reader import constants as C
|
|
return C, RetroReader
|
|
|
|
C, RetroReader = from_library()
|
|
|
|
my_hash_func = {
|
|
io.TextIOWrapper: lambda _: None,
|
|
pyarrow.lib.Buffer: lambda _: 0,
|
|
tokenizers.Tokenizer: lambda _: None,
|
|
tokenizers.AddedToken: lambda _: None
|
|
}
|
|
|
|
@st.cache_resource(hash_funcs=my_hash_func)
|
|
def load_en_electra_base_model():
|
|
config_file = "configs/inference_en_electra_base.yaml"
|
|
return RetroReader.load(config_file=config_file)
|
|
|
|
@st.cache_resource(hash_funcs=my_hash_func)
|
|
def load_en_electra_large_model():
|
|
config_file = "configs/inference_en_electra_large.yaml"
|
|
return RetroReader.load(config_file=config_file)
|
|
|
|
RETRO_READER_HOST = {
|
|
"google/electra-base-discriminator": load_en_electra_base_model(),
|
|
"google/electra-large-discriminator": load_en_electra_large_model(),
|
|
}
|
|
|
|
def display_top_predictions(nbest_preds, top_k=10):
|
|
|
|
if not isinstance(nbest_preds, list):
|
|
nbest_preds = nbest_preds['id-01']
|
|
|
|
sorted_preds = sorted(nbest_preds, key=lambda x: x['probability'], reverse=True)[:top_k]
|
|
st.markdown("### Top Predictions")
|
|
for i, pred in enumerate(sorted_preds, 1):
|
|
st.markdown(f"**{i}. {pred['text']}** - Probability: {pred['probability']*100:.2f}%")
|
|
|
|
def main():
|
|
|
|
st.sidebar.title("π Welcome to Retro Reader")
|
|
st.sidebar.write("""
|
|
MRC-RetroReader is a machine reading comprehension (MRC) model designed for reading comprehension tasks. The model leverages advanced neural network architectures to provide high accuracy in understanding and responding to textual queries.
|
|
""")
|
|
image_url = "img.jpg"
|
|
st.sidebar.image(image_url, use_column_width=True)
|
|
st.sidebar.title("Contributors")
|
|
st.sidebar.write("""
|
|
- Phan Van Hoang
|
|
- Pham Long Khanh
|
|
""")
|
|
|
|
st.title("Retrospective Reader Demo")
|
|
st.markdown("## Model nameπ¨")
|
|
option = st.selectbox(
|
|
label="Choose the model used in retro reader",
|
|
options=(
|
|
"[1] google/electra-base-discriminator",
|
|
"[2] google/electra-large-discriminator"
|
|
),
|
|
index=1,
|
|
)
|
|
lang_code, model_name = option.split(" ")
|
|
retro_reader = RETRO_READER_HOST[model_name]
|
|
|
|
lang_prefix = "EN"
|
|
height = 200
|
|
return_submodule_outputs = True
|
|
|
|
with st.form(key="my_form"):
|
|
st.markdown("## Type your query β")
|
|
query = st.text_input(
|
|
label="",
|
|
value=getattr(C, f"{lang_prefix}_EXAMPLE_QUERY"),
|
|
max_chars=None,
|
|
help=getattr(C, f"{lang_prefix}_QUERY_HELP_TEXT"),
|
|
)
|
|
st.markdown("## Type your query π¬")
|
|
context = st.text_area(
|
|
label="",
|
|
value=getattr(C, f"{lang_prefix}_EXAMPLE_CONTEXTS"),
|
|
height=height,
|
|
max_chars=None,
|
|
help=getattr(C, f"{lang_prefix}_CONTEXT_HELP_TEXT"),
|
|
)
|
|
submit_button = st.form_submit_button(label="Submit")
|
|
|
|
if submit_button:
|
|
with st.spinner("π Please wait.."):
|
|
outputs = retro_reader(query=query, context=context, return_submodule_outputs=return_submodule_outputs)
|
|
answer, score = outputs[0]["id-01"], outputs[1]
|
|
if not answer:
|
|
answer = "No answer"
|
|
st.markdown("## π Results")
|
|
st.write(answer)
|
|
if return_submodule_outputs:
|
|
score_ext, nbest_preds, score_diff = outputs[2:]
|
|
display_top_predictions(nbest_preds)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|