File size: 1,377 Bytes
797458a
ac9765d
 
797458a
ac9765d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("EzekielMW/Eksl_dataset")
model = AutoModelForSeq2SeqLM.from_pretrained("EzekielMW/Eksl_dataset")

def translate(text, source_language, target_language):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    inputs = tokenizer(text.lower(), return_tensors="pt").to(device)
    inputs['input_ids'][0][0] = tokenizer.convert_tokens_to_ids(source_language)
    translated_tokens = model.to(device).generate(
        **inputs,
        forced_bos_token_id=tokenizer.convert_tokens_to_ids(target_language),
        max_length=100,
        num_beams=5,
    )
    result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]

    if target_language == 'ksl':
        result = result.upper()

    return result

st.title('Translation App')

# Text input
text = st.text_input('Enter text to translate')
source_language = st.selectbox('Source Language', ['eng', 'swa', 'ksl'])
target_language = st.selectbox('Target Language', ['eng', 'swa', 'ksl'])

if st.button('Translate'):
    if text:
        translation = translate(text, source_language, target_language)
        st.write(f'Translation: {translation}')
    else:
        st.write('Please enter text to translate.')