Spaces:
Runtime error
Runtime error
import requests | |
import os | |
import fasttext | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
import torch | |
title = "Community Tab Language Detection & Translation" | |
description = """ | |
When comments are created in the community tab, detect the language of the content. | |
Then, if the detected language is different from the user's language, display an option to translate it. | |
""" | |
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") | |
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") | |
device = 0 if torch.cuda.is_available() else -1 | |
print(f"Is CUDA available: {torch.cuda.is_available()}") | |
language_code_map = { | |
"English": "eng_Latn", | |
"French": "fra_Latn", | |
"German": "deu_Latn", | |
"Spanish": "spa_Latn", | |
"Korean": "kor_Hang", | |
"Japanese": "jpn_Jpan" | |
} | |
def identify_language(text): | |
model_file = "lid218e.bin" | |
model_full_path = os.path.join(os.path.dirname(__file__), model_file) | |
model = fasttext.load_model(model_full_path) | |
predictions = model.predict(text, k=1) # e.g., (('__label__eng_Latn',), array([0.81148803])) | |
PREFIX_LENGTH = 9 # To strip away '__label__' from language code | |
language_code = predictions[0][0][PREFIX_LENGTH:] | |
return language_code | |
def translate(text, src_lang, tgt_lang): | |
src_lang_code = language_code_map[src_lang] | |
tgt_lang_code = language_code_map[tgt_lang] | |
translation_pipeline = pipeline( | |
"translation", model=model, tokenizer=tokenizer, src_lang=src_lang_code, tgt_lang=tgt_lang_code, device=device) | |
result = translation_pipeline(text) | |
return result[0]['translation_text'] | |
def query(text, user_lang): | |
detected_lang = identify_language(text) | |
translation = translate( | |
text, src_lang, tgt_lang) if detected_lang != user_lang else "User's content language is the same as the language of the input" | |
return [detected_lang, translation] | |
examples = [ | |
["Hello, world", "English", "French"], | |
["Can I have a cheeseburger?", "English", "German"], | |
["Hasta la vista", "Spanish", "German"], | |
["동경에 휴가를 간다", "Korean", "Japanese"], | |
] | |
gr.Interface( | |
query, | |
[ | |
gr.Textbox(lines=3, label="User Input"), | |
gr.Radio(["English", "Spanish", "Korean", "French", "German", "Japanese"], | |
value="English", label="User's Content Language"), | |
], | |
outputs=[ | |
gr.Textbox(lines=1, label="Detected Language"), | |
gr.Textbox(lines=3, label="Translation") | |
], | |
title=title, | |
description=description, | |
examples=examples | |
).launch() | |