Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import pipeline | |
import pandas as pd | |
modelos_opcao =[ | |
"Narrativa/mbart-large-50-finetuned-opus-en-pt-translation", | |
# "unicamp-dl/translation-en-pt-t5" # desempenho inferior ao MBART (porém, mais rápido) | |
] | |
# Carrega o modelo | |
def carregar_modelo_e_tokenizador_mbart(modelo): | |
# https://huggingface.co/Narrativa/mbart-large-50-finetuned-opus-en-pt-translation | |
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration | |
st.write(f'Carregando modelo {modelo}') | |
tokenizer = MBart50TokenizerFast.from_pretrained(modelo) | |
model = MBartForConditionalGeneration.from_pretrained(modelo).to("cuda") | |
tokenizer.src_lang = 'en_XX' | |
return model, tokenizer | |
# TODO:batch? | |
def traduzir_en_pt(text): | |
inputs = tokenizer(text, return_tensors='pt') | |
input_ids = inputs.input_ids.to('cuda') | |
attention_mask = inputs.attention_mask.to('cuda') | |
output = model.generate(input_ids, attention_mask=attention_mask, forced_bos_token_id=tokenizer.lang_code_to_id['pt_XX']) | |
return tokenizer.decode(output[0], skip_special_tokens=True) | |
################### | |
#### interface #### | |
################### | |
# Cabeçalho | |
st.title('Tradutor de datasets (inglês para português)') | |
# Carrega dataset | |
dataset = st.file_uploader("Carrege o dataset (coluna a ser traduzida deve ser nomeada como 'texto')", type=["csv"]) | |
st.write('Carregando dataset...') | |
if dataset is not None: | |
st.write('🎲 Dataset carregado com sucesso!') | |
dataset = pd.read_csv(dataset) | |
st.write(dataset) | |
modelo_selecionado = st.selectbox('Escolha um modelo', modelos_opcao) | |
if st.button("Carregar modelo escolhido"): | |
tokenizer, model = carregar_modelo_e_tokenizador_mbart(modelo_selecionado) | |
st.write(f"🎰 Modelo {modelo_selecionado} carregado com sucesso! 🔥") | |
qtde_linhas_traduzir = st.slider('Quantidade de linhas a serem traduzidas', 1, len(dataset), 50) | |
if st.button(f"Traduzir {qtde_linhas_traduzir} linhas"): | |
for i in range(qtde_linhas_traduzir): | |
st.write(f'🔡 Traduzindo linha {i+1}...') | |
st.write(f'Texto: {dataset.iloc[i]["texto"]}') | |
texto_traduzido= traduzir_en_pt(dataset.iloc[i]["texto"]) | |
st.write(f'Tradução: {texto_traduzido}') | |
# adiciona traducao em nova coluna dataset | |
dataset["traduzido"]= texto_traduzido | |
st.write("Fim 👍") | |