File size: 4,098 Bytes
550665c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import streamlit as st
import io
import os
import yaml
import pyarrow
import tokenizers

os.environ["TOKENIZERS_PARALLELISM"] = "true"

# Setting page config to wide mode
st.set_page_config(layout="wide")

@st.cache_resource
def from_library():
    from retro_reader import RetroReader
    from retro_reader import constants as C
    return C, RetroReader

C, RetroReader = from_library()

my_hash_func = {
    io.TextIOWrapper: lambda _: None,
    pyarrow.lib.Buffer: lambda _: 0,
    tokenizers.Tokenizer: lambda _: None,
    tokenizers.AddedToken: lambda _: None
}

@st.cache_resource(hash_funcs=my_hash_func)
def load_en_electra_base_model():
    config_file = "configs/inference_en_electra_base.yaml"
    return RetroReader.load(config_file=config_file)

@st.cache_resource(hash_funcs=my_hash_func)
def load_en_electra_large_model():
    config_file = "configs/inference_en_electra_large.yaml"
    return RetroReader.load(config_file=config_file)

RETRO_READER_HOST = {
    "google/electra-base-discriminator": load_en_electra_base_model(),
    "google/electra-large-discriminator": load_en_electra_large_model(),
}

def display_top_predictions(nbest_preds, top_k=10):
    # Assuming nbest_preds might be a dictionary with a key that contains the list
    if not isinstance(nbest_preds, list):
        nbest_preds = nbest_preds['id-01']  # Adjust key as per actual structure

    sorted_preds = sorted(nbest_preds, key=lambda x: x['probability'], reverse=True)[:top_k]
    st.markdown("### Top Predictions")
    for i, pred in enumerate(sorted_preds, 1):
        st.markdown(f"**{i}. {pred['text']}** - Probability: {pred['probability']*100:.2f}%")

def main():
    # Sidebar Introduction
    st.sidebar.title("πŸ“ Welcome to Retro Reader")
    st.sidebar.write("""

    MRC-RetroReader is a machine reading comprehension (MRC) model designed for reading comprehension tasks. The model leverages advanced neural network architectures to provide high accuracy in understanding and responding to textual queries.

    """)
    image_url = "img.jpg"  # Replace this URL with your actual image URL or local path
    st.sidebar.image(image_url, use_column_width=True)
    st.sidebar.title("Contributors")
    st.sidebar.write("""

    - Phan Van Hoang

    - Pham Long Khanh

    """)

    st.title("Retrospective Reader Demo")
    st.markdown("## Model name🚨")
    option = st.selectbox(
        label="Choose the model used in retro reader",
        options=(
            "[1] google/electra-base-discriminator",
            "[2] google/electra-large-discriminator"
        ),
        index=1,
    )
    lang_code, model_name = option.split(" ")
    retro_reader = RETRO_READER_HOST[model_name]

    lang_prefix = "EN"
    height = 200
    return_submodule_outputs = True

    with st.form(key="my_form"):
        st.markdown("## Type your query ❓")
        query = st.text_input(
            label="",
            value=getattr(C, f"{lang_prefix}_EXAMPLE_QUERY"),
            max_chars=None,
            help=getattr(C, f"{lang_prefix}_QUERY_HELP_TEXT"),
        )
        st.markdown("## Type your query πŸ’¬")
        context = st.text_area(
            label="",
            value=getattr(C, f"{lang_prefix}_EXAMPLE_CONTEXTS"),
            height=height,
            max_chars=None,
            help=getattr(C, f"{lang_prefix}_CONTEXT_HELP_TEXT"),
        )
        submit_button = st.form_submit_button(label="Submit")
        
    if submit_button:
        with st.spinner("πŸ•’ Please wait.."):
            outputs = retro_reader(query=query, context=context, return_submodule_outputs=return_submodule_outputs)
        answer, score = outputs[0]["id-01"], outputs[1]
        if not answer:
            answer = "No answer"
        st.markdown("## πŸ“œ Results")
        st.write(answer)
        if return_submodule_outputs:
            score_ext, nbest_preds, score_diff = outputs[2:]
            display_top_predictions(nbest_preds)

if __name__ == "__main__":
    main()