Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import MT5ForConditionalGeneration, T5Tokenizer | |
import time | |
def load_model(): | |
model = MT5ForConditionalGeneration.from_pretrained('iliemihai/mt5-base-romanian-diacritics', cache_dir='cache/') | |
return model | |
def load_tokenizer(): | |
tokenizer = T5Tokenizer.from_pretrained('iliemihai/mt5-base-romanian-diacritics', legacy=False, cache_dir='cache/') | |
return tokenizer | |
def initialize_app(): | |
st.set_page_config( | |
page_title="Dia-critic", | |
page_icon="public/favicon.ico", | |
menu_items={ | |
"About": "### Contact\n ✉️[email protected]", | |
}, | |
) | |
st.title("🖋️Dia-critic") | |
st.caption("Made with :heart: by NEBO Technologies") | |
def generate_text(text): | |
model = load_model() | |
tokenizer = load_tokenizer() | |
inputs = tokenizer(text, max_length=256, truncation=True, return_tensors="pt") | |
outputs = model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]) | |
output = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return output | |
def main(): | |
initialize_app() | |
input_text = st.text_area("Introduceți textul mai jos") | |
st.write(f'{len(input_text)} caractere.') | |
if st.button("Corectează"): | |
if input_text != "": | |
res = '' | |
with st.spinner('Sarcină în desfășurare...'): | |
# start task | |
res = generate_text(input_text) | |
with st.container(border=True): | |
st.markdown(res) | |
else: | |
st.warning("Câmpul este gol!") | |
if __name__ == "__main__": | |
main() |