import streamlit as st from transformers import MT5ForConditionalGeneration, T5Tokenizer import time @st.cache_resource def load_model(): model = MT5ForConditionalGeneration.from_pretrained('iliemihai/mt5-base-romanian-diacritics', cache_dir='cache/') return model @st.cache_resource def load_tokenizer(): tokenizer = T5Tokenizer.from_pretrained('iliemihai/mt5-base-romanian-diacritics', legacy=False, cache_dir='cache/') return tokenizer def initialize_app(): st.set_page_config( page_title="Dia-critic", page_icon="public/favicon.ico", menu_items={ "About": "### Contact\n ✉️florinbobis@gmail.com", }, ) st.title("🖋️Dia-critic") st.caption("Made with :heart: by NEBO Technologies") def generate_text(text): model = load_model() tokenizer = load_tokenizer() inputs = tokenizer(text, max_length=256, truncation=True, return_tensors="pt") outputs = model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]) output = tokenizer.decode(outputs[0], skip_special_tokens=True) return output def main(): initialize_app() input_text = st.text_area("Introduceți textul mai jos") st.write(f'{len(input_text)} caractere.') if st.button("Corectează"): if input_text != "": res = '' with st.spinner('Sarcină în desfășurare...'): # start task res = generate_text(input_text) with st.container(border=True): st.markdown(res) else: st.warning("Câmpul este gol!") if __name__ == "__main__": main()