Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import pipeline | |
unmasker = pipeline('fill-mask', model='dsfsi/zabantu-nso-120m') | |
st.set_page_config(layout="wide") | |
def fill_mask(sentences): | |
results = {} | |
warnings = [] | |
for sentence in sentences: | |
if "<mask>" in sentence: | |
unmasked = unmasker(sentence) | |
results[sentence] = unmasked | |
else: | |
warnings.append(f"Warning: No <mask> token found in sentence: {sentence}") | |
return results, warnings | |
def replace_mask(sentence, predicted_word): | |
return sentence.replace("<mask>", f"**{predicted_word}**") | |
st.title("Fill Mask | Zabantu-nso-120m") | |
st.write(f"") | |
col1, col2 = st.columns(2) | |
if 'text_input' not in st.session_state: | |
st.session_state['text_input'] = "" | |
if 'warnings' not in st.session_state: | |
st.session_state['warnings'] = [] | |
with col1: | |
with st.container(border=True): | |
st.markdown("Input :clipboard:") | |
sample_sentence = "bašomedi ba polase ya dinamune ya zebediela citrus ba hlomile magato a <mask> malebana le go se sepetšwe botse ga dilo ka polaseng eo." | |
text_input = st.text_area( | |
"Enter sentences with <mask> token:", | |
value=st.session_state['text_input'] | |
) | |
input_sentences = text_input.split("\n") | |
button1, button2, _ = st.columns([2, 2, 4]) | |
with button1: | |
if st.button("Test Example"): | |
# st.rerun() | |
result, warnings = fill_mask(sample_sentence.split("\n")) | |
# st.session_state['text_input'] = sample_sentence | |
with button2: | |
if st.button("Submit"): | |
result, warnings = fill_mask(input_sentences) | |
st.session_state['warnings'] = warnings | |
if st.session_state['warnings']: | |
for warning in st.session_state['warnings']: | |
st.warning(warning) | |
st.markdown("Example") | |
st.code(sample_sentence, wrap_lines=True) | |
with col2: | |
with st.container(border=True): | |
st.markdown("Output :bar_chart:") | |
if 'result' in locals() and result: | |
if result: | |
for sentence, predictions in result.items(): | |
for prediction in predictions: | |
predicted_word = prediction['token_str'] | |
score = prediction['score'] * 100 | |
st.markdown(f""" | |
<div class="bar"> | |
<div class="bar-fill" style="width: {score}%;"></div> | |
</div> | |
<div class="container"> | |
<div style="align-items: left;">{predicted_word}</div> | |
<div style="align-items: center;">{score:.2f}%</div> | |
</div> | |
""", unsafe_allow_html=True) | |
if 'result' in locals(): | |
if result: | |
for sentence, predictions in result.items(): | |
predicted_word = predictions[0]['token_str'] | |
full_sentence = replace_mask(sentence, predicted_word) | |
st.write(f"**Sentence:** {full_sentence }") | |
css = """ | |
<style> | |
footer {display:none !important;} | |
.gr-button-primary { | |
z-index: 14; | |
height: 43px; | |
width: 130px; | |
left: 0px; | |
top: 0px; | |
padding: 0px; | |
cursor: pointer !important; | |
background: none rgb(17, 20, 45) !important; | |
border: none !important; | |
text-align: center !important; | |
font-family: Poppins !important; | |
font-size: 14px !important; | |
font-weight: 500 !important; | |
color: rgb(255, 255, 255) !important; | |
line-height: 1 !important; | |
border-radius: 12px !important; | |
transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important; | |
box-shadow: none !important; | |
} | |
.gr-button-primary:hover{ | |
z-index: 14; | |
height: 43px; | |
width: 130px; | |
left: 0px; | |
top: 0px; | |
padding: 0px; | |
cursor: pointer !important; | |
background: none rgb(66, 133, 244) !important; | |
border: none !important; | |
text-align: center !important; | |
font-family: Poppins !important; | |
font-size: 14px !important; | |
font-weight: 500 !important; | |
color: rgb(255, 255, 255) !important; | |
line-height: 1 !important; | |
border-radius: 12px !important; | |
transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important; | |
box-shadow: rgb(0 0 0 / 23%) 0px 1px 7px 0px !important; | |
} | |
.hover\:bg-orange-50:hover { | |
--tw-bg-opacity: 1 !important; | |
background-color: rgb(229,225,255) !important; | |
} | |
.to-orange-200 { | |
--tw-gradient-to: rgb(37 56 133 / 37%) !important; | |
} | |
.from-orange-400 { | |
--tw-gradient-from: rgb(17, 20, 45) !important; | |
--tw-gradient-to: rgb(255 150 51 / 0); | |
--tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to) !important; | |
} | |
.group-hover\:from-orange-500{ | |
--tw-gradient-from:rgb(17, 20, 45) !important; | |
--tw-gradient-to: rgb(37 56 133 / 37%); | |
--tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to) !important; | |
} | |
.group:hover .group-hover\:text-orange-500{ | |
--tw-text-opacity: 1 !important; | |
color:rgb(37 56 133 / var(--tw-text-opacity)) !important; | |
} | |
.container { | |
display: flex; | |
justify-content: space-between; | |
align-items: center; | |
margin-bottom: 5px; | |
width: 100%; | |
} | |
.bar { | |
# width: 70%; | |
background-color: #e6e6e6; | |
border-radius: 12px; | |
overflow: hidden; | |
margin-right: 10px; | |
height: 5px; | |
} | |
.bar-fill { | |
background-color: #17152e; | |
height: 100%; | |
border-radius: 12px; | |
} | |
</style> | |
""" | |
st.markdown(css, unsafe_allow_html=True) |