Wootang01's picture
Update app.py
c649c75
raw
history blame
1.06 kB
import streamlit as st
import torch
import sacremoses
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import FSMTForConditionalGeneration, FSMTTokenizer
st.title("Paraphraser Three -- Back Translation")
user_input = st.text_area("Input sentence.")
def load_en2de():
en2de = pipeline("translation_en_to_de", model="t5-base")
return en2de
def load_de2en():
model_name = "facebook/wmt19-de-en"
tokenizer = FSMTTokenizer.from_pretrained(model_name)
model_de_to_en = FSMTForConditionalGeneration.from_pretrained(model_name)
return tokenizer, model_de_to_en
en2de = load_en2de()
tokenizer_de2en, de2en = load_de2en()
en_to_de_output = en2de(user_input)
translated_text = en_to_de_output[0]['translation_text']
input_ids = tokenizer_de2en.encode(translated_text, return_tensors="pt")
output_ids = de2en.generate(input_ids)[0]
augmented_text = tokenizer_de2en.decode(output_ids, skip_special_tokens=True)
st.write("Paraphrased text using back translation: ", augmented_text)