|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("EzekielMW/Eksl_dataset") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("EzekielMW/Eksl_dataset") |
|
|
|
def translate(text, source_language, target_language): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
inputs = tokenizer(text.lower(), return_tensors="pt").to(device) |
|
inputs['input_ids'][0][0] = tokenizer.convert_tokens_to_ids(source_language) |
|
translated_tokens = model.to(device).generate( |
|
**inputs, |
|
forced_bos_token_id=tokenizer.convert_tokens_to_ids(target_language), |
|
max_length=100, |
|
num_beams=5, |
|
) |
|
result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] |
|
|
|
if target_language == 'ksl': |
|
result = result.upper() |
|
|
|
return result |
|
|
|
st.title('Translation App') |
|
|
|
|
|
text = st.text_input('Enter text to translate') |
|
source_language = st.selectbox('Source Language', ['eng', 'swa', 'ksl']) |
|
target_language = st.selectbox('Target Language', ['eng', 'swa', 'ksl']) |
|
|
|
if st.button('Translate'): |
|
if text: |
|
translation = translate(text, source_language, target_language) |
|
st.write(f'Translation: {translation}') |
|
else: |
|
st.write('Please enter text to translate.') |
|
|