File size: 1,377 Bytes
797458a ac9765d 797458a ac9765d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("EzekielMW/Eksl_dataset")
model = AutoModelForSeq2SeqLM.from_pretrained("EzekielMW/Eksl_dataset")
def translate(text, source_language, target_language):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = tokenizer(text.lower(), return_tensors="pt").to(device)
inputs['input_ids'][0][0] = tokenizer.convert_tokens_to_ids(source_language)
translated_tokens = model.to(device).generate(
**inputs,
forced_bos_token_id=tokenizer.convert_tokens_to_ids(target_language),
max_length=100,
num_beams=5,
)
result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
if target_language == 'ksl':
result = result.upper()
return result
st.title('Translation App')
# Text input
text = st.text_input('Enter text to translate')
source_language = st.selectbox('Source Language', ['eng', 'swa', 'ksl'])
target_language = st.selectbox('Target Language', ['eng', 'swa', 'ksl'])
if st.button('Translate'):
if text:
translation = translate(text, source_language, target_language)
st.write(f'Translation: {translation}')
else:
st.write('Please enter text to translate.')
|