kusht55 commited on
Commit
73836a5
1 Parent(s): 18dcfdd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -54
app.py CHANGED
@@ -1,49 +1,44 @@
 
1
  import torch
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
 
 
3
  from IndicTransToolkit import IndicProcessor
4
- import gradio as gr
 
5
 
6
- # Define the model and tokenizer
7
- model_name = "ai4bharat/indictrans2-indic-indic-1B"
8
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
9
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
10
- ip = IndicProcessor(inference=True)
 
 
 
 
 
 
 
11
 
12
- # Define the language codes
13
- LANGUAGES = {
14
- "Assamese (asm_Beng)": "asm_Beng",
15
- "Kashmiri (kas_Arab)": "kas_Arab",
16
- "Punjabi (pan_Guru)": "pan_Guru",
17
- "Bengali (ben_Beng)": "ben_Beng",
18
- "Kashmiri (kas_Deva)": "kas_Deva",
19
- "Sanskrit (san_Deva)": "san_Deva",
20
- "Bodo (brx_Deva)": "brx_Deva",
21
- "Maithili (mai_Deva)": "mai_Deva",
22
- "Santali (sat_Olck)": "sat_Olck",
23
- "Dogri (doi_Deva)": "doi_Deva",
24
- "Malayalam (mal_Mlym)": "mal_Mlym",
25
- "Sindhi (snd_Arab)": "snd_Arab",
26
- "English (eng_Latn)": "eng_Latn",
27
- "Marathi (mar_Deva)": "mar_Deva",
28
- "Sindhi (snd_Deva)": "snd_Deva",
29
- "Konkani (gom_Deva)": "gom_Deva",
30
- "Manipuri (mni_Beng)": "mni_Beng",
31
- "Tamil (tam_Taml)": "tam_Taml",
32
- "Gujarati (guj_Gujr)": "guj_Gujr",
33
- "Manipuri (mni_Mtei)": "mni_Mtei",
34
- "Telugu (tel_Telu)": "tel_Telu",
35
- "Hindi (hin_Deva)": "hin_Deva",
36
- "Nepali (npi_Deva)": "npi_Deva",
37
- "Urdu (urd_Arab)": "urd_Arab",
38
- "Kannada (kan_Knda)": "kan_Knda",
39
- "Odia (ory_Orya)": "ory_Orya",
40
- }
41
-
42
- # Define the translation function
43
- def translate(text, src_lang, tgt_lang):
44
- batch = ip.preprocess_batch([text], src_lang=src_lang, tgt_lang=tgt_lang)
45
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
46
- inputs = tokenizer(batch, truncation=True, padding="longest", return_tensors="pt").to(DEVICE)
 
 
 
 
 
 
 
47
  with torch.no_grad():
48
  generated_tokens = model.generate(
49
  **inputs,
@@ -53,22 +48,42 @@ def translate(text, src_lang, tgt_lang):
53
  num_beams=5,
54
  num_return_sequences=1,
55
  )
 
56
  with tokenizer.as_target_tokenizer():
57
- generated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
58
- return generated_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- # Create a Gradio interface
61
  with gr.Blocks() as demo:
62
- gr.Markdown("### Indic Translations")
63
- input_text = gr.Textbox(label="Input Text", placeholder="Enter text to translate")
64
- src_lang = gr.Dropdown(label="Source Language", choices=list(LANGUAGES.keys()))
65
- tgt_lang = gr.Dropdown(label="Target Language", choices=list(LANGUAGES.keys()))
66
- translate_button = gr.Button("Translate")
67
- translation_output = gr.Textbox(label="Translation", interactive=False)
 
 
 
68
 
69
- @translate_button.click
70
- def on_translate(text, src_lang, tgt_lang):
71
- translation = translate(text, LANGUAGES[src_lang], LANGUAGES[tgt_lang])
72
- translation_output.value = translation
73
 
74
  demo.launch()
 
1
+ import gradio as gr
2
  import torch
3
+ from transformers import (
4
+ AutoModelForSeq2SeqLM,
5
+ AutoTokenizer,
6
+ )
7
  from IndicTransToolkit import IndicProcessor
8
+ import os
9
+ import subprocess
10
 
11
+ # Function to clone the repository and set up the environment
12
+ def setup_repo():
13
+ # Clone the repository
14
+ repo_url = "https://github.com/AI4Bharat/IndicTrans2"
15
+ repo_dir = "IndicTrans2"
16
+
17
+ if not os.path.exists(repo_dir):
18
+ subprocess.run(["git", "clone", repo_url])
19
+
20
+ # Navigate to the project directory and install dependencies
21
+ os.chdir(os.path.join(repo_dir, "huggingface_interface"))
22
+ subprocess.run(["source", "install.sh"], shell=True)
23
 
24
+ # Function to process translation
25
+ def translate(input_text, src_lang, tgt_lang):
26
+ setup_repo() # Ensure the repo is set up
27
+ model_name = "ai4bharat/indictrans2-indic-indic-1B"
28
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
29
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
30
+ ip = IndicProcessor(inference=True)
31
+
32
+ batch = ip.preprocess_batch([input_text], src_lang=src_lang, tgt_lang=tgt_lang)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
34
+ inputs = tokenizer(
35
+ batch,
36
+ truncation=True,
37
+ padding="longest",
38
+ return_tensors="pt",
39
+ return_attention_mask=True,
40
+ ).to(DEVICE)
41
+
42
  with torch.no_grad():
43
  generated_tokens = model.generate(
44
  **inputs,
 
48
  num_beams=5,
49
  num_return_sequences=1,
50
  )
51
+
52
  with tokenizer.as_target_tokenizer():
53
+ translation = tokenizer.batch_decode(
54
+ generated_tokens.detach().cpu().tolist(),
55
+ skip_special_tokens=True,
56
+ clean_up_tokenization_spaces=True,
57
+ )[0]
58
+
59
+ return translation
60
+
61
+ # List of languages with their code names
62
+ languages = [
63
+ ("Assamese", "asm_Beng"), ("Kashmiri (Arabic)", "kas_Arab"), ("Punjabi", "pan_Guru"),
64
+ ("Bengali", "ben_Beng"), ("Kashmiri (Devanagari)", "kas_Deva"), ("Sanskrit", "san_Deva"),
65
+ ("Bodo", "brx_Deva"), ("Maithili", "mai_Deva"), ("Santali", "sat_Olck"),
66
+ ("Dogri", "doi_Deva"), ("Malayalam", "mal_Mlym"), ("Sindhi (Arabic)", "snd_Arab"),
67
+ ("English", "eng_Latn"), ("Marathi", "mar_Deva"), ("Sindhi (Devanagari)", "snd_Deva"),
68
+ ("Konkani", "gom_Deva"), ("Manipuri (Bengali)", "mni_Beng"), ("Tamil", "tam_Taml"),
69
+ ("Gujarati", "guj_Gujr"), ("Manipuri (Meitei)", "mni_Mtei"), ("Telugu", "tel_Telu"),
70
+ ("Hindi", "hin_Deva"), ("Nepali", "npi_Deva"), ("Urdu", "urd_Arab"),
71
+ ("Kannada", "kan_Knda"), ("Odia", "ory_Orya")
72
+ ]
73
 
74
+ # Gradio interface
75
  with gr.Blocks() as demo:
76
+ gr.Markdown("# IndicTrans2 Translation")
77
+ with gr.Row():
78
+ with gr.Column():
79
+ input_text = gr.Textbox(label="Input Text")
80
+ src_lang = gr.Dropdown(label="Source Language", choices=[lang[0] for lang in languages], type="value")
81
+ tgt_lang = gr.Dropdown(label="Target Language", choices=[lang[0] for lang in languages], type="value")
82
+ translate_button = gr.Button("Translate")
83
+
84
+ output_text = gr.Textbox(label="Translated Output")
85
 
86
+ # Call translate function when button is clicked
87
+ translate_button.click(fn=translate, inputs=[input_text, src_lang, tgt_lang], outputs=output_text)
 
 
88
 
89
  demo.launch()