Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,49 +1,44 @@
|
|
|
|
1 |
import torch
|
2 |
-
from transformers import
|
|
|
|
|
|
|
3 |
from IndicTransToolkit import IndicProcessor
|
4 |
-
import
|
|
|
5 |
|
6 |
-
#
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
#
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
"Maithili (mai_Deva)": "mai_Deva",
|
22 |
-
"Santali (sat_Olck)": "sat_Olck",
|
23 |
-
"Dogri (doi_Deva)": "doi_Deva",
|
24 |
-
"Malayalam (mal_Mlym)": "mal_Mlym",
|
25 |
-
"Sindhi (snd_Arab)": "snd_Arab",
|
26 |
-
"English (eng_Latn)": "eng_Latn",
|
27 |
-
"Marathi (mar_Deva)": "mar_Deva",
|
28 |
-
"Sindhi (snd_Deva)": "snd_Deva",
|
29 |
-
"Konkani (gom_Deva)": "gom_Deva",
|
30 |
-
"Manipuri (mni_Beng)": "mni_Beng",
|
31 |
-
"Tamil (tam_Taml)": "tam_Taml",
|
32 |
-
"Gujarati (guj_Gujr)": "guj_Gujr",
|
33 |
-
"Manipuri (mni_Mtei)": "mni_Mtei",
|
34 |
-
"Telugu (tel_Telu)": "tel_Telu",
|
35 |
-
"Hindi (hin_Deva)": "hin_Deva",
|
36 |
-
"Nepali (npi_Deva)": "npi_Deva",
|
37 |
-
"Urdu (urd_Arab)": "urd_Arab",
|
38 |
-
"Kannada (kan_Knda)": "kan_Knda",
|
39 |
-
"Odia (ory_Orya)": "ory_Orya",
|
40 |
-
}
|
41 |
-
|
42 |
-
# Define the translation function
|
43 |
-
def translate(text, src_lang, tgt_lang):
|
44 |
-
batch = ip.preprocess_batch([text], src_lang=src_lang, tgt_lang=tgt_lang)
|
45 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
46 |
-
inputs = tokenizer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
with torch.no_grad():
|
48 |
generated_tokens = model.generate(
|
49 |
**inputs,
|
@@ -53,22 +48,42 @@ def translate(text, src_lang, tgt_lang):
|
|
53 |
num_beams=5,
|
54 |
num_return_sequences=1,
|
55 |
)
|
|
|
56 |
with tokenizer.as_target_tokenizer():
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
-
#
|
61 |
with gr.Blocks() as demo:
|
62 |
-
gr.Markdown("
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
translation = translate(text, LANGUAGES[src_lang], LANGUAGES[tgt_lang])
|
72 |
-
translation_output.value = translation
|
73 |
|
74 |
demo.launch()
|
|
|
1 |
+
import gradio as gr
|
2 |
import torch
|
3 |
+
from transformers import (
|
4 |
+
AutoModelForSeq2SeqLM,
|
5 |
+
AutoTokenizer,
|
6 |
+
)
|
7 |
from IndicTransToolkit import IndicProcessor
|
8 |
+
import os
|
9 |
+
import subprocess
|
10 |
|
11 |
+
# Function to clone the repository and set up the environment
|
12 |
+
def setup_repo():
|
13 |
+
# Clone the repository
|
14 |
+
repo_url = "https://github.com/AI4Bharat/IndicTrans2"
|
15 |
+
repo_dir = "IndicTrans2"
|
16 |
+
|
17 |
+
if not os.path.exists(repo_dir):
|
18 |
+
subprocess.run(["git", "clone", repo_url])
|
19 |
+
|
20 |
+
# Navigate to the project directory and install dependencies
|
21 |
+
os.chdir(os.path.join(repo_dir, "huggingface_interface"))
|
22 |
+
subprocess.run(["source", "install.sh"], shell=True)
|
23 |
|
24 |
+
# Function to process translation
|
25 |
+
def translate(input_text, src_lang, tgt_lang):
|
26 |
+
setup_repo() # Ensure the repo is set up
|
27 |
+
model_name = "ai4bharat/indictrans2-indic-indic-1B"
|
28 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
29 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
|
30 |
+
ip = IndicProcessor(inference=True)
|
31 |
+
|
32 |
+
batch = ip.preprocess_batch([input_text], src_lang=src_lang, tgt_lang=tgt_lang)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
34 |
+
inputs = tokenizer(
|
35 |
+
batch,
|
36 |
+
truncation=True,
|
37 |
+
padding="longest",
|
38 |
+
return_tensors="pt",
|
39 |
+
return_attention_mask=True,
|
40 |
+
).to(DEVICE)
|
41 |
+
|
42 |
with torch.no_grad():
|
43 |
generated_tokens = model.generate(
|
44 |
**inputs,
|
|
|
48 |
num_beams=5,
|
49 |
num_return_sequences=1,
|
50 |
)
|
51 |
+
|
52 |
with tokenizer.as_target_tokenizer():
|
53 |
+
translation = tokenizer.batch_decode(
|
54 |
+
generated_tokens.detach().cpu().tolist(),
|
55 |
+
skip_special_tokens=True,
|
56 |
+
clean_up_tokenization_spaces=True,
|
57 |
+
)[0]
|
58 |
+
|
59 |
+
return translation
|
60 |
+
|
61 |
+
# List of languages with their code names
|
62 |
+
languages = [
|
63 |
+
("Assamese", "asm_Beng"), ("Kashmiri (Arabic)", "kas_Arab"), ("Punjabi", "pan_Guru"),
|
64 |
+
("Bengali", "ben_Beng"), ("Kashmiri (Devanagari)", "kas_Deva"), ("Sanskrit", "san_Deva"),
|
65 |
+
("Bodo", "brx_Deva"), ("Maithili", "mai_Deva"), ("Santali", "sat_Olck"),
|
66 |
+
("Dogri", "doi_Deva"), ("Malayalam", "mal_Mlym"), ("Sindhi (Arabic)", "snd_Arab"),
|
67 |
+
("English", "eng_Latn"), ("Marathi", "mar_Deva"), ("Sindhi (Devanagari)", "snd_Deva"),
|
68 |
+
("Konkani", "gom_Deva"), ("Manipuri (Bengali)", "mni_Beng"), ("Tamil", "tam_Taml"),
|
69 |
+
("Gujarati", "guj_Gujr"), ("Manipuri (Meitei)", "mni_Mtei"), ("Telugu", "tel_Telu"),
|
70 |
+
("Hindi", "hin_Deva"), ("Nepali", "npi_Deva"), ("Urdu", "urd_Arab"),
|
71 |
+
("Kannada", "kan_Knda"), ("Odia", "ory_Orya")
|
72 |
+
]
|
73 |
|
74 |
+
# Gradio interface
|
75 |
with gr.Blocks() as demo:
|
76 |
+
gr.Markdown("# IndicTrans2 Translation")
|
77 |
+
with gr.Row():
|
78 |
+
with gr.Column():
|
79 |
+
input_text = gr.Textbox(label="Input Text")
|
80 |
+
src_lang = gr.Dropdown(label="Source Language", choices=[lang[0] for lang in languages], type="value")
|
81 |
+
tgt_lang = gr.Dropdown(label="Target Language", choices=[lang[0] for lang in languages], type="value")
|
82 |
+
translate_button = gr.Button("Translate")
|
83 |
+
|
84 |
+
output_text = gr.Textbox(label="Translated Output")
|
85 |
|
86 |
+
# Call translate function when button is clicked
|
87 |
+
translate_button.click(fn=translate, inputs=[input_text, src_lang, tgt_lang], outputs=output_text)
|
|
|
|
|
88 |
|
89 |
demo.launch()
|