UnarineLeo commited on
Commit
ddd231c
·
verified ·
1 Parent(s): d028f12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -23
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import streamlit as st
2
  from transformers import pipeline
 
3
 
4
  unmasker = pipeline('fill-mask', model='dsfsi/zabantu-nso-120m')
5
 
@@ -21,6 +22,7 @@ def replace_mask(sentence, predicted_word):
21
 
22
  st.title("Fill Mask | Zabantu-nso-120m")
23
  st.write(f"")
 
24
  st.markdown("This is a variant of Zabantu pre-trained on a monolingual dataset of Sepedi(nso) sentences on a transformer network with 120 million traininable parameters.")
25
 
26
  col1, col2 = st.columns(2)
@@ -34,26 +36,37 @@ if 'warnings' not in st.session_state:
34
  with col1:
35
  with st.container(border=True):
36
  st.markdown("Input :clipboard:")
37
- sample_sentence = "bašomedi ba polase ya dinamune ya zebediela citrus ba hlomile magato a <mask> malebana le go se sepetšwe botse ga dilo ka polaseng eo."
38
-
39
- text_input = st.text_area(
40
- "Enter sentences with <mask> token:",
41
- value=st.session_state['text_input']
42
- )
43
-
44
- input_sentences = text_input.split("\n")
45
-
46
- button1, button2, _ = st.columns([2, 2, 4])
47
- with button1:
48
- if st.button("Test Example"):
49
- # st.rerun()
50
- result, warnings = fill_mask(sample_sentence.split("\n"))
51
- # st.session_state['text_input'] = sample_sentence
52
 
53
- with button2:
54
- if st.button("Submit"):
 
55
  result, warnings = fill_mask(input_sentences)
56
  st.session_state['warnings'] = warnings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  if st.session_state['warnings']:
59
  for warning in st.session_state['warnings']:
@@ -61,12 +74,14 @@ with col1:
61
 
62
  st.markdown("Example")
63
  st.code(sample_sentence, wrap_lines=True)
 
 
64
 
65
  with col2:
66
  with st.container(border=True):
67
  st.markdown("Output :bar_chart:")
68
  if 'result' in locals() and result:
69
- if result:
70
  for sentence, predictions in result.items():
71
  for prediction in predictions:
72
  predicted_word = prediction['token_str']
@@ -82,12 +97,34 @@ with col2:
82
  </div>
83
  """, unsafe_allow_html=True)
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  if 'result' in locals():
86
- if result:
87
- for sentence, predictions in result.items():
88
- predicted_word = predictions[0]['token_str']
89
- full_sentence = replace_mask(sentence, predicted_word)
90
- st.write(f"**Sentence:** {full_sentence }")
 
 
91
 
92
  css = """
93
  <style>
 
1
  import streamlit as st
2
  from transformers import pipeline
3
+ from io import StringIO
4
 
5
  unmasker = pipeline('fill-mask', model='dsfsi/zabantu-nso-120m')
6
 
 
22
 
23
  st.title("Fill Mask | Zabantu-nso-120m")
24
  st.write(f"")
25
+
26
  st.markdown("This is a variant of Zabantu pre-trained on a monolingual dataset of Sepedi(nso) sentences on a transformer network with 120 million traininable parameters.")
27
 
28
  col1, col2 = st.columns(2)
 
36
  with col1:
37
  with st.container(border=True):
38
  st.markdown("Input :clipboard:")
39
+
40
+ select_options = ['Choose option', 'Enter text input', 'Upload a file(csv/txt)']
41
+ sample_sentence = "Vhana vhane vha kha ḓi bva u bebwa vha kha khombo ya u <mask> nga Listeriosis."
42
+
43
+ option_selected = st.selectbox(f"Select an input option:", select_options, index=0)
44
+
45
+ if option_selected == 'Enter text input':
46
+ text_input = st.text_area(
47
+ "Enter sentences with <mask> token:",
48
+ value=st.session_state['text_input']
49
+ )
 
 
 
 
50
 
51
+ input_sentences = text_input.split("\n")
52
+
53
+ if st.button("Submit",use_container_width=True):
54
  result, warnings = fill_mask(input_sentences)
55
  st.session_state['warnings'] = warnings
56
+
57
+ if option_selected == 'Upload a file(csv/txt)':
58
+
59
+ uploaded_file = st.file_uploader("Choose a file")
60
+ if uploaded_file is not None:
61
+
62
+ stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
63
+ string_data = stringio.read()
64
+
65
+ input_sentences = string_data.split("\n")
66
+
67
+ if st.button("Submit",use_container_width=True):
68
+ result, warnings = fill_mask(input_sentences)
69
+ st.session_state['warnings'] = warnings
70
 
71
  if st.session_state['warnings']:
72
  for warning in st.session_state['warnings']:
 
74
 
75
  st.markdown("Example")
76
  st.code(sample_sentence, wrap_lines=True)
77
+ if st.button("Test Example",use_container_width=True):
78
+ result, warnings = fill_mask(sample_sentence.split("\n"))
79
 
80
  with col2:
81
  with st.container(border=True):
82
  st.markdown("Output :bar_chart:")
83
  if 'result' in locals() and result:
84
+ if len(result) == 1:
85
  for sentence, predictions in result.items():
86
  for prediction in predictions:
87
  predicted_word = prediction['token_str']
 
97
  </div>
98
  """, unsafe_allow_html=True)
99
 
100
+ else:
101
+ index = 0
102
+ for sentence, predictions in result.items():
103
+ index += 1
104
+ if predictions:
105
+ top_prediction = predictions[0]
106
+ predicted_word = top_prediction['token_str']
107
+ score = top_prediction['score'] * 100
108
+
109
+ st.markdown(f"""
110
+ <div class="bar">
111
+ <div class="bar-fill" style="width: {score}%;"></div>
112
+ </div>
113
+ <div class="container">
114
+ <div style="align-items: left;">{predicted_word} (line {index})</div>
115
+ <div style="align-items: right;">{score:.2f}%</div>
116
+ </div>
117
+ """, unsafe_allow_html=True)
118
+
119
+
120
  if 'result' in locals():
121
+ if result:
122
+ line = 0
123
+ for sentence, predictions in result.items():
124
+ line += 1
125
+ predicted_word = predictions[0]['token_str']
126
+ full_sentence = replace_mask(sentence, predicted_word)
127
+ st.write(f"**Sentence {line}:** {full_sentence }")
128
 
129
  css = """
130
  <style>