Spaces:
Runtime error
Runtime error
import pandas as pd | |
from datasets import load_metric | |
import os | |
import streamlit as st | |
from transformers import MarianMTModel, MarianTokenizer | |
def downloading_model(): | |
sentence_pair_df = pd.read_json("sentence_pair.json") | |
metric = load_metric("sacrebleu") | |
tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-id-en") | |
original_model = MarianMTModel.from_pretrained( | |
"Helsinki-NLP/opus-mt-id-en") | |
finetuned_model = MarianMTModel.from_pretrained( | |
"wolfrage89/annual_report_translation_id_en") | |
return sentence_pair_df, metric, tokenizer, original_model, finetuned_model | |
def get_translation(model, tokenizer, text): | |
translated_tokens = model.generate( | |
**tokenizer([text], return_tensors='pt', max_length=104, truncation=True))[0] | |
translated_sentence = tokenizer.decode( | |
translated_tokens, skip_special_tokens=True) | |
return translated_sentence | |
def get_bleu_score(translated_sentence, reference_sentence, metric): | |
metric.add(prediction=translated_sentence, reference=[reference_sentence]) | |
return metric.compute()['score'] | |
# initalization | |
if "bahasa_input" not in st.session_state: | |
st.session_state["bahasa_input"] = "" | |
if "ideal_translation" not in st.session_state: | |
st.session_state['ideal_translation'] = "" | |
if "original_translation" not in st.session_state: | |
st.session_state['original_translation'] = "" | |
if "finetuned_translation" not in st.session_state: | |
st.session_state['finetuned_translation'] = "" | |
sentence_pair_df, metric, tokenizer, original_model, finetuned_model = downloading_model() | |
st.sidebar.title("Bahasa to English Translation (Finance Domain)") | |
st.sidebar.markdown("---") | |
random_button = st.sidebar.button( | |
"Random") | |
st.sidebar.write("Randomly generates a bahasa sentence") | |
st.sidebar.markdown("---") | |
translate_button = st.sidebar.button( | |
"Translate", help="translate bahasa to english") | |
st.sidebar.write("Translate!") | |
if random_button: | |
sample_data = sentence_pair_df.sample(1) | |
st.session_state['bahasa_input'] = sample_data['bahasa'].item() | |
st.session_state['ideal_translation'] = sample_data['english'].item() | |
st.session_state['original_translation'] = "" | |
st.session_state['finetuned_translation'] = "" | |
if translate_button: | |
if len(st.session_state['bahasa_input']) > 0: | |
st.session_state['original_translation'] = get_translation( | |
original_model, tokenizer, st.session_state['bahasa_input']) | |
st.session_state['finetuned_translation'] = get_translation( | |
finetuned_model, tokenizer, st.session_state['bahasa_input']) | |
original_bleu_score = get_bleu_score( | |
st.session_state['original_translation'], st.session_state['ideal_translation'], metric) | |
finetuned_bleu_score = get_bleu_score( | |
st.session_state['finetuned_translation'], st.session_state['ideal_translation'], metric) | |
else: | |
st.session_state['original_translation'] = "" | |
st.session_state['finetuned_translation'] = "" | |
st.session_state['ideal_translation'] = "" | |
original_bleu_score = 0 | |
finetuned_bleu_score = 0 | |
with st.container(): | |
col_1, col_2 = st.columns(2) | |
with col_1: | |
st.session_state['bahasa_input'] = st.text_area( | |
"Bahasa (Input text here)", value=st.session_state['bahasa_input'], height=200) | |
st.text_area( | |
"Pretrained model Translation (Helsinki_id_en)", value=st.session_state['original_translation'], height=200) | |
if translate_button: | |
st.write("Bleu score: ", original_bleu_score) | |
with col_2: | |
st.text_area("Ideal translation (Target)", | |
value=st.session_state['ideal_translation'], height=200) | |
st.text_area("Finetuned translation (Finetuned on annual report)", | |
value=st.session_state['finetuned_translation'], height=200) | |
if translate_button: | |
st.write("Bleu Score: ", finetuned_bleu_score) | |