Spaces:
Sleeping
Sleeping
UnarineLeo
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import streamlit as st
|
2 |
from transformers import pipeline
|
|
|
3 |
|
4 |
unmasker = pipeline('fill-mask', model='dsfsi/zabantu-nso-120m')
|
5 |
|
@@ -21,6 +22,7 @@ def replace_mask(sentence, predicted_word):
|
|
21 |
|
22 |
st.title("Fill Mask | Zabantu-nso-120m")
|
23 |
st.write(f"")
|
|
|
24 |
st.markdown("This is a variant of Zabantu pre-trained on a monolingual dataset of Sepedi(nso) sentences on a transformer network with 120 million traininable parameters.")
|
25 |
|
26 |
col1, col2 = st.columns(2)
|
@@ -34,26 +36,37 @@ if 'warnings' not in st.session_state:
|
|
34 |
with col1:
|
35 |
with st.container(border=True):
|
36 |
st.markdown("Input :clipboard:")
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
if st.button("Test Example"):
|
49 |
-
# st.rerun()
|
50 |
-
result, warnings = fill_mask(sample_sentence.split("\n"))
|
51 |
-
# st.session_state['text_input'] = sample_sentence
|
52 |
|
53 |
-
|
54 |
-
|
|
|
55 |
result, warnings = fill_mask(input_sentences)
|
56 |
st.session_state['warnings'] = warnings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
if st.session_state['warnings']:
|
59 |
for warning in st.session_state['warnings']:
|
@@ -61,12 +74,14 @@ with col1:
|
|
61 |
|
62 |
st.markdown("Example")
|
63 |
st.code(sample_sentence, wrap_lines=True)
|
|
|
|
|
64 |
|
65 |
with col2:
|
66 |
with st.container(border=True):
|
67 |
st.markdown("Output :bar_chart:")
|
68 |
if 'result' in locals() and result:
|
69 |
-
if result:
|
70 |
for sentence, predictions in result.items():
|
71 |
for prediction in predictions:
|
72 |
predicted_word = prediction['token_str']
|
@@ -82,12 +97,34 @@ with col2:
|
|
82 |
</div>
|
83 |
""", unsafe_allow_html=True)
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
if 'result' in locals():
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
91 |
|
92 |
css = """
|
93 |
<style>
|
|
|
1 |
import streamlit as st
|
2 |
from transformers import pipeline
|
3 |
+
from io import StringIO
|
4 |
|
5 |
unmasker = pipeline('fill-mask', model='dsfsi/zabantu-nso-120m')
|
6 |
|
|
|
22 |
|
23 |
st.title("Fill Mask | Zabantu-nso-120m")
|
24 |
st.write(f"")
|
25 |
+
|
26 |
st.markdown("This is a variant of Zabantu pre-trained on a monolingual dataset of Sepedi(nso) sentences on a transformer network with 120 million traininable parameters.")
|
27 |
|
28 |
col1, col2 = st.columns(2)
|
|
|
36 |
with col1:
|
37 |
with st.container(border=True):
|
38 |
st.markdown("Input :clipboard:")
|
39 |
+
|
40 |
+
select_options = ['Choose option', 'Enter text input', 'Upload a file(csv/txt)']
|
41 |
+
sample_sentence = "Vhana vhane vha kha ḓi bva u bebwa vha kha khombo ya u <mask> nga Listeriosis."
|
42 |
+
|
43 |
+
option_selected = st.selectbox(f"Select an input option:", select_options, index=0)
|
44 |
+
|
45 |
+
if option_selected == 'Enter text input':
|
46 |
+
text_input = st.text_area(
|
47 |
+
"Enter sentences with <mask> token:",
|
48 |
+
value=st.session_state['text_input']
|
49 |
+
)
|
|
|
|
|
|
|
|
|
50 |
|
51 |
+
input_sentences = text_input.split("\n")
|
52 |
+
|
53 |
+
if st.button("Submit",use_container_width=True):
|
54 |
result, warnings = fill_mask(input_sentences)
|
55 |
st.session_state['warnings'] = warnings
|
56 |
+
|
57 |
+
if option_selected == 'Upload a file(csv/txt)':
|
58 |
+
|
59 |
+
uploaded_file = st.file_uploader("Choose a file")
|
60 |
+
if uploaded_file is not None:
|
61 |
+
|
62 |
+
stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
|
63 |
+
string_data = stringio.read()
|
64 |
+
|
65 |
+
input_sentences = string_data.split("\n")
|
66 |
+
|
67 |
+
if st.button("Submit",use_container_width=True):
|
68 |
+
result, warnings = fill_mask(input_sentences)
|
69 |
+
st.session_state['warnings'] = warnings
|
70 |
|
71 |
if st.session_state['warnings']:
|
72 |
for warning in st.session_state['warnings']:
|
|
|
74 |
|
75 |
st.markdown("Example")
|
76 |
st.code(sample_sentence, wrap_lines=True)
|
77 |
+
if st.button("Test Example",use_container_width=True):
|
78 |
+
result, warnings = fill_mask(sample_sentence.split("\n"))
|
79 |
|
80 |
with col2:
|
81 |
with st.container(border=True):
|
82 |
st.markdown("Output :bar_chart:")
|
83 |
if 'result' in locals() and result:
|
84 |
+
if len(result) == 1:
|
85 |
for sentence, predictions in result.items():
|
86 |
for prediction in predictions:
|
87 |
predicted_word = prediction['token_str']
|
|
|
97 |
</div>
|
98 |
""", unsafe_allow_html=True)
|
99 |
|
100 |
+
else:
|
101 |
+
index = 0
|
102 |
+
for sentence, predictions in result.items():
|
103 |
+
index += 1
|
104 |
+
if predictions:
|
105 |
+
top_prediction = predictions[0]
|
106 |
+
predicted_word = top_prediction['token_str']
|
107 |
+
score = top_prediction['score'] * 100
|
108 |
+
|
109 |
+
st.markdown(f"""
|
110 |
+
<div class="bar">
|
111 |
+
<div class="bar-fill" style="width: {score}%;"></div>
|
112 |
+
</div>
|
113 |
+
<div class="container">
|
114 |
+
<div style="align-items: left;">{predicted_word} (line {index})</div>
|
115 |
+
<div style="align-items: right;">{score:.2f}%</div>
|
116 |
+
</div>
|
117 |
+
""", unsafe_allow_html=True)
|
118 |
+
|
119 |
+
|
120 |
if 'result' in locals():
|
121 |
+
if result:
|
122 |
+
line = 0
|
123 |
+
for sentence, predictions in result.items():
|
124 |
+
line += 1
|
125 |
+
predicted_word = predictions[0]['token_str']
|
126 |
+
full_sentence = replace_mask(sentence, predicted_word)
|
127 |
+
st.write(f"**Sentence {line}:** {full_sentence }")
|
128 |
|
129 |
css = """
|
130 |
<style>
|