vukosi commited on
Commit
3096ba9
·
1 Parent(s): 8b740b2

Filling Mask

Browse files
Files changed (1) hide show
  1. app.py +52 -11
app.py CHANGED
@@ -54,9 +54,20 @@ st.sidebar.markdown("""
54
  # -------------------- CACHING FUNCTIONS --------------------
55
  @st.cache_resource
56
  def load_mask_filling_model():
57
- tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa")
58
- model = AutoModelForMaskedLM.from_pretrained("dsfsi/PuoBERTa")
59
- return pipeline("fill-mask", model=model, tokenizer=tokenizer, top_k=5)
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  @st.cache_resource
62
  def load_pos_model():
@@ -77,6 +88,23 @@ def load_news_classification_model():
77
  return pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True)
78
 
79
  # -------------------- UTILITY FUNCTIONS --------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def merge_entities(output):
81
  """Merge consecutive entities of the same type"""
82
  merged = []
@@ -166,25 +194,31 @@ tab1, tab2, tab3, tab4 = st.tabs(["🎭 Mask Filling", "🏷️ POS Tagging", "
166
  # -------------------- MASK FILLING TAB --------------------
167
  with tab1:
168
  st.header("Mask Filling")
169
- st.write("Fill in the blanks in Setswana sentences using `[MASK]` token.")
170
 
171
  mask_examples = [
172
- "Ke rata go [MASK] dijo tsa Batswana.",
173
- "Botswana ke naga e e [MASK] mo Afrika Borwa.",
174
- "Bana ba [MASK] sekolo ka Mosupologo.",
175
- "Re tshwanetse go [MASK] tikologo ya rona."
176
  ]
177
 
178
  mask_input = get_input_text("mask", mask_examples)
179
 
180
  if st.button("Fill Masks", key="mask_button") and mask_input.strip():
181
- if "[MASK]" not in mask_input:
182
- st.warning("Please include [MASK] token in your text.")
 
 
 
 
183
  else:
184
  with st.spinner("Filling masks..."):
185
  try:
186
  mask_filler = load_mask_filling_model()
187
- results = mask_filler(mask_input)
 
 
188
 
189
  st.subheader("Predictions")
190
  for i, result in enumerate(results, 1):
@@ -193,6 +227,13 @@ with tab1:
193
 
194
  except Exception as e:
195
  st.error(f"Error: {str(e)}")
 
 
 
 
 
 
 
196
 
197
  # -------------------- POS TAGGING TAB --------------------
198
  with tab2:
 
54
  # -------------------- CACHING FUNCTIONS --------------------
55
  @st.cache_resource
56
  def load_mask_filling_model():
57
+ try:
58
+ tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa")
59
+ model = AutoModelForMaskedLM.from_pretrained("dsfsi/PuoBERTa")
60
+
61
+ # Create pipeline and verify mask token
62
+ pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, top_k=5)
63
+
64
+ # Debug: print mask token for verification
65
+ print(f"Mask token being used: {tokenizer.mask_token}")
66
+
67
+ return pipe
68
+ except Exception as e:
69
+ st.error(f"Failed to load mask filling model: {str(e)}")
70
+ return None
71
 
72
  @st.cache_resource
73
  def load_pos_model():
 
88
  return pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True)
89
 
90
  # -------------------- UTILITY FUNCTIONS --------------------
91
+
92
+ def get_correct_mask_token(text, tokenizer):
93
+ """Get the correct mask token format for the given tokenizer"""
94
+ mask_token = tokenizer.mask_token
95
+
96
+ # Replace common mask token formats with the correct one
97
+ text = text.replace("[MASK]", mask_token)
98
+ text = text.replace("<mask>", mask_token)
99
+ text = text.replace("&lt;mask&gt;", mask_token)
100
+
101
+ return text
102
+
103
+ # Then in your mask filling section, use:
104
+ # corrected_input = get_correct_mask_token(mask_input, mask_filler.tokenizer)
105
+ # results = mask_filler(corrected_input)
106
+
107
+
108
  def merge_entities(output):
109
  """Merge consecutive entities of the same type"""
110
  merged = []
 
194
  # -------------------- MASK FILLING TAB --------------------
195
  with tab1:
196
  st.header("Mask Filling")
197
+ st.write("Fill in the blanks in Setswana sentences using `<mask>` token.")
198
 
199
  mask_examples = [
200
+ "Ke rata go <mask> dijo tsa Batswana.",
201
+ "Botswana ke naga e e <mask> mo Afrika Borwa.",
202
+ "Bana ba <mask> sekolo ka Mosupologo.",
203
+ "Re tshwanetse go <mask> tikologo ya rona."
204
  ]
205
 
206
  mask_input = get_input_text("mask", mask_examples)
207
 
208
  if st.button("Fill Masks", key="mask_button") and mask_input.strip():
209
+ # Check for both mask formats and convert if needed
210
+ if "[MASK]" in mask_input:
211
+ mask_input = mask_input.replace("[MASK]", "<mask>")
212
+ st.info("Converted [MASK] to <mask> format")
213
+ elif "<mask>" not in mask_input:
214
+ st.warning("Please include <mask> token in your text.")
215
  else:
216
  with st.spinner("Filling masks..."):
217
  try:
218
  mask_filler = load_mask_filling_model()
219
+ corrected_input = get_correct_mask_token(mask_input, mask_filler.tokenizer)
220
+ results = mask_filler(corrected_input)
221
+ # results = mask_filler(mask_input)
222
 
223
  st.subheader("Predictions")
224
  for i, result in enumerate(results, 1):
 
227
 
228
  except Exception as e:
229
  st.error(f"Error: {str(e)}")
230
+ # Debug information
231
+ st.info(f"Input text: {mask_input}")
232
+ try:
233
+ mask_filler = load_mask_filling_model()
234
+ st.info(f"Model mask token: {mask_filler.tokenizer.mask_token}")
235
+ except:
236
+ pass
237
 
238
  # -------------------- POS TAGGING TAB --------------------
239
  with tab2: