NLLB-Translator / app.py
mrm8488's picture
Add device support
45d922e
raw
history blame
1.43 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import torch
from ui import title, description
from langs import LANGS
TASK = "translation"
CKPT = "facebook/nllb-200-distilled-600M"
model = AutoModelForSeq2SeqLM.from_pretrained(CKPT)
tokenizer = AutoTokenizer.from_pretrained(CKPT)
device = 0 if torch.cuda.is_available() else -1
def translate(text, src_lang, tgt_lang, max_length=400):
"""
Translate the text from source lang to target lang
"""
translation_pipeline = pipeline(TASK,
model=model,
tokenizer=tokenizer,
src_lang=src_lang,
tgt_lang=tgt_lang,
max_length=max_length,
device=device)
result = translation_pipeline(text)
return result[0]['translation_text']
gr.Interface(
translate,
[
gr.inputs.Textbox(label="Text"),
gr.inputs.Dropdown(label="Source Language", choices=LANGS),
gr.inputs.Dropdown(label="Target Language", choices=LANGS),
gr.inputs.Slider(label="Max Length", minimum=8,
maximum=512, value=400, step=8)
],
["text"],
# examples=examples,
# article=article,
cache_examples=False,
title=title,
description=description
).launch()