File size: 4,098 Bytes
550665c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import streamlit as st
import io
import os
import yaml
import pyarrow
import tokenizers
os.environ["TOKENIZERS_PARALLELISM"] = "true"
# Setting page config to wide mode
st.set_page_config(layout="wide")
@st.cache_resource
def from_library():
from retro_reader import RetroReader
from retro_reader import constants as C
return C, RetroReader
C, RetroReader = from_library()
my_hash_func = {
io.TextIOWrapper: lambda _: None,
pyarrow.lib.Buffer: lambda _: 0,
tokenizers.Tokenizer: lambda _: None,
tokenizers.AddedToken: lambda _: None
}
@st.cache_resource(hash_funcs=my_hash_func)
def load_en_electra_base_model():
config_file = "configs/inference_en_electra_base.yaml"
return RetroReader.load(config_file=config_file)
@st.cache_resource(hash_funcs=my_hash_func)
def load_en_electra_large_model():
config_file = "configs/inference_en_electra_large.yaml"
return RetroReader.load(config_file=config_file)
RETRO_READER_HOST = {
"google/electra-base-discriminator": load_en_electra_base_model(),
"google/electra-large-discriminator": load_en_electra_large_model(),
}
def display_top_predictions(nbest_preds, top_k=10):
# Assuming nbest_preds might be a dictionary with a key that contains the list
if not isinstance(nbest_preds, list):
nbest_preds = nbest_preds['id-01'] # Adjust key as per actual structure
sorted_preds = sorted(nbest_preds, key=lambda x: x['probability'], reverse=True)[:top_k]
st.markdown("### Top Predictions")
for i, pred in enumerate(sorted_preds, 1):
st.markdown(f"**{i}. {pred['text']}** - Probability: {pred['probability']*100:.2f}%")
def main():
# Sidebar Introduction
st.sidebar.title("π Welcome to Retro Reader")
st.sidebar.write("""
MRC-RetroReader is a machine reading comprehension (MRC) model designed for reading comprehension tasks. The model leverages advanced neural network architectures to provide high accuracy in understanding and responding to textual queries.
""")
image_url = "img.jpg" # Replace this URL with your actual image URL or local path
st.sidebar.image(image_url, use_column_width=True)
st.sidebar.title("Contributors")
st.sidebar.write("""
- Phan Van Hoang
- Pham Long Khanh
""")
st.title("Retrospective Reader Demo")
st.markdown("## Model nameπ¨")
option = st.selectbox(
label="Choose the model used in retro reader",
options=(
"[1] google/electra-base-discriminator",
"[2] google/electra-large-discriminator"
),
index=1,
)
lang_code, model_name = option.split(" ")
retro_reader = RETRO_READER_HOST[model_name]
lang_prefix = "EN"
height = 200
return_submodule_outputs = True
with st.form(key="my_form"):
st.markdown("## Type your query β")
query = st.text_input(
label="",
value=getattr(C, f"{lang_prefix}_EXAMPLE_QUERY"),
max_chars=None,
help=getattr(C, f"{lang_prefix}_QUERY_HELP_TEXT"),
)
st.markdown("## Type your query π¬")
context = st.text_area(
label="",
value=getattr(C, f"{lang_prefix}_EXAMPLE_CONTEXTS"),
height=height,
max_chars=None,
help=getattr(C, f"{lang_prefix}_CONTEXT_HELP_TEXT"),
)
submit_button = st.form_submit_button(label="Submit")
if submit_button:
with st.spinner("π Please wait.."):
outputs = retro_reader(query=query, context=context, return_submodule_outputs=return_submodule_outputs)
answer, score = outputs[0]["id-01"], outputs[1]
if not answer:
answer = "No answer"
st.markdown("## π Results")
st.write(answer)
if return_submodule_outputs:
score_ext, nbest_preds, score_diff = outputs[2:]
display_top_predictions(nbest_preds)
if __name__ == "__main__":
main()
|