import streamlit as st import io import os import yaml import pyarrow import tokenizers os.environ["TOKENIZERS_PARALLELISM"] = "true" # Setting page config to wide mode st.set_page_config(layout="wide") @st.cache_resource def from_library(): from retro_reader import RetroReader from retro_reader import constants as C return C, RetroReader C, RetroReader = from_library() my_hash_func = { io.TextIOWrapper: lambda _: None, pyarrow.lib.Buffer: lambda _: 0, tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None } @st.cache_resource(hash_funcs=my_hash_func) def load_en_electra_base_model(): config_file = "configs/inference_en_electra_base.yaml" return RetroReader.load(config_file=config_file) @st.cache_resource(hash_funcs=my_hash_func) def load_en_electra_large_model(): config_file = "configs/inference_en_electra_large.yaml" return RetroReader.load(config_file=config_file) RETRO_READER_HOST = { "google/electra-base-discriminator": load_en_electra_base_model(), "google/electra-large-discriminator": load_en_electra_large_model(), } def display_top_predictions(nbest_preds, top_k=10): # Assuming nbest_preds might be a dictionary with a key that contains the list if not isinstance(nbest_preds, list): nbest_preds = nbest_preds['id-01'] # Adjust key as per actual structure sorted_preds = sorted(nbest_preds, key=lambda x: x['probability'], reverse=True)[:top_k] st.markdown("### Top Predictions") for i, pred in enumerate(sorted_preds, 1): st.markdown(f"**{i}. {pred['text']}** - Probability: {pred['probability']*100:.2f}%") def main(): # Sidebar Introduction st.sidebar.title("πŸ“ Welcome to Retro Reader") st.sidebar.write(""" MRC-RetroReader is a machine reading comprehension (MRC) model designed for reading comprehension tasks. The model leverages advanced neural network architectures to provide high accuracy in understanding and responding to textual queries. """) image_url = "img.jpg" # Replace this URL with your actual image URL or local path st.sidebar.image(image_url, use_column_width=True) st.sidebar.title("Contributors") st.sidebar.write(""" - Phan Van Hoang - Pham Long Khanh """) st.title("Retrospective Reader Demo") st.markdown("## Model name🚨") option = st.selectbox( label="Choose the model used in retro reader", options=( "[1] google/electra-base-discriminator", "[2] google/electra-large-discriminator" ), index=1, ) lang_code, model_name = option.split(" ") retro_reader = RETRO_READER_HOST[model_name] lang_prefix = "EN" height = 200 return_submodule_outputs = True with st.form(key="my_form"): st.markdown("## Type your query ❓") query = st.text_input( label="", value=getattr(C, f"{lang_prefix}_EXAMPLE_QUERY"), max_chars=None, help=getattr(C, f"{lang_prefix}_QUERY_HELP_TEXT"), ) st.markdown("## Type your query πŸ’¬") context = st.text_area( label="", value=getattr(C, f"{lang_prefix}_EXAMPLE_CONTEXTS"), height=height, max_chars=None, help=getattr(C, f"{lang_prefix}_CONTEXT_HELP_TEXT"), ) submit_button = st.form_submit_button(label="Submit") if submit_button: with st.spinner("πŸ•’ Please wait.."): outputs = retro_reader(query=query, context=context, return_submodule_outputs=return_submodule_outputs) answer, score = outputs[0]["id-01"], outputs[1] if not answer: answer = "No answer" st.markdown("## πŸ“œ Results") st.write(answer) if return_submodule_outputs: score_ext, nbest_preds, score_diff = outputs[2:] display_top_predictions(nbest_preds) if __name__ == "__main__": main()