HTK / app.py
faori's picture
Upload folder using huggingface_hub
550665c verified
import streamlit as st
import io
import os
import yaml
import pyarrow
import tokenizers
os.environ["TOKENIZERS_PARALLELISM"] = "true"
# Setting page config to wide mode
st.set_page_config(layout="wide")
@st.cache_resource
def from_library():
from retro_reader import RetroReader
from retro_reader import constants as C
return C, RetroReader
C, RetroReader = from_library()
my_hash_func = {
io.TextIOWrapper: lambda _: None,
pyarrow.lib.Buffer: lambda _: 0,
tokenizers.Tokenizer: lambda _: None,
tokenizers.AddedToken: lambda _: None
}
@st.cache_resource(hash_funcs=my_hash_func)
def load_en_electra_base_model():
config_file = "configs/inference_en_electra_base.yaml"
return RetroReader.load(config_file=config_file)
@st.cache_resource(hash_funcs=my_hash_func)
def load_en_electra_large_model():
config_file = "configs/inference_en_electra_large.yaml"
return RetroReader.load(config_file=config_file)
RETRO_READER_HOST = {
"google/electra-base-discriminator": load_en_electra_base_model(),
"google/electra-large-discriminator": load_en_electra_large_model(),
}
def display_top_predictions(nbest_preds, top_k=10):
# Assuming nbest_preds might be a dictionary with a key that contains the list
if not isinstance(nbest_preds, list):
nbest_preds = nbest_preds['id-01'] # Adjust key as per actual structure
sorted_preds = sorted(nbest_preds, key=lambda x: x['probability'], reverse=True)[:top_k]
st.markdown("### Top Predictions")
for i, pred in enumerate(sorted_preds, 1):
st.markdown(f"**{i}. {pred['text']}** - Probability: {pred['probability']*100:.2f}%")
def main():
# Sidebar Introduction
st.sidebar.title("πŸ“ Welcome to Retro Reader")
st.sidebar.write("""
MRC-RetroReader is a machine reading comprehension (MRC) model designed for reading comprehension tasks. The model leverages advanced neural network architectures to provide high accuracy in understanding and responding to textual queries.
""")
image_url = "img.jpg" # Replace this URL with your actual image URL or local path
st.sidebar.image(image_url, use_column_width=True)
st.sidebar.title("Contributors")
st.sidebar.write("""
- Phan Van Hoang
- Pham Long Khanh
""")
st.title("Retrospective Reader Demo")
st.markdown("## Model name🚨")
option = st.selectbox(
label="Choose the model used in retro reader",
options=(
"[1] google/electra-base-discriminator",
"[2] google/electra-large-discriminator"
),
index=1,
)
lang_code, model_name = option.split(" ")
retro_reader = RETRO_READER_HOST[model_name]
lang_prefix = "EN"
height = 200
return_submodule_outputs = True
with st.form(key="my_form"):
st.markdown("## Type your query ❓")
query = st.text_input(
label="",
value=getattr(C, f"{lang_prefix}_EXAMPLE_QUERY"),
max_chars=None,
help=getattr(C, f"{lang_prefix}_QUERY_HELP_TEXT"),
)
st.markdown("## Type your query πŸ’¬")
context = st.text_area(
label="",
value=getattr(C, f"{lang_prefix}_EXAMPLE_CONTEXTS"),
height=height,
max_chars=None,
help=getattr(C, f"{lang_prefix}_CONTEXT_HELP_TEXT"),
)
submit_button = st.form_submit_button(label="Submit")
if submit_button:
with st.spinner("πŸ•’ Please wait.."):
outputs = retro_reader(query=query, context=context, return_submodule_outputs=return_submodule_outputs)
answer, score = outputs[0]["id-01"], outputs[1]
if not answer:
answer = "No answer"
st.markdown("## πŸ“œ Results")
st.write(answer)
if return_submodule_outputs:
score_ext, nbest_preds, score_diff = outputs[2:]
display_top_predictions(nbest_preds)
if __name__ == "__main__":
main()