GEETHANAYAGI commited on
Commit
f9d7028
·
verified ·
1 Parent(s): 5e885fd

Upload 79 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. IndicTrans2/.gitignore +148 -0
  2. IndicTrans2/LICENSE +21 -0
  3. IndicTrans2/README.md +523 -0
  4. IndicTrans2/apply_sentence_piece.sh +48 -0
  5. IndicTrans2/baseline_eval/azure_translate.py +183 -0
  6. IndicTrans2/baseline_eval/google_translate.py +129 -0
  7. IndicTrans2/baseline_eval/m2m100_inference.py +148 -0
  8. IndicTrans2/baseline_eval/mbart_inference.py +159 -0
  9. IndicTrans2/baseline_eval/nllb_moe_cpu_inference.py +157 -0
  10. IndicTrans2/compute_comet_score.sh +84 -0
  11. IndicTrans2/compute_metrics.sh +29 -0
  12. IndicTrans2/compute_metrics_significance.sh +66 -0
  13. IndicTrans2/eval.sh +54 -0
  14. IndicTrans2/eval_rev.sh +55 -0
  15. IndicTrans2/finetune.sh +54 -0
  16. IndicTrans2/huggingface_interface/.gitignore +1 -0
  17. IndicTrans2/huggingface_interface/README.md +119 -0
  18. IndicTrans2/huggingface_interface/colab_inference.ipynb +458 -0
  19. IndicTrans2/huggingface_interface/configuration_indictrans.py +309 -0
  20. IndicTrans2/huggingface_interface/convert_indictrans_checkpoint_to_pytorch.py +107 -0
  21. IndicTrans2/huggingface_interface/example.py +275 -0
  22. IndicTrans2/huggingface_interface/install.sh +49 -0
  23. IndicTrans2/huggingface_interface/modeling_indictrans.py +1801 -0
  24. IndicTrans2/huggingface_interface/train_lora.py +355 -0
  25. IndicTrans2/huggingface_interface/train_lora.sh +35 -0
  26. IndicTrans2/inference/__init__.py +0 -0
  27. IndicTrans2/inference/custom_interactive.py +304 -0
  28. IndicTrans2/inference/download.py +5 -0
  29. IndicTrans2/inference/engine.py +472 -0
  30. IndicTrans2/inference/flores_codes_map_indic.py +83 -0
  31. IndicTrans2/inference/indic_num_map.py +117 -0
  32. IndicTrans2/inference/model_configs/__init__.py +1 -0
  33. IndicTrans2/inference/model_configs/custom_transformer.py +82 -0
  34. IndicTrans2/inference/normalize-punctuation.perl +90 -0
  35. IndicTrans2/inference/normalize_punctuation.py +60 -0
  36. IndicTrans2/inference/normalize_punctuation.sh +33 -0
  37. IndicTrans2/inference/normalize_regex_inference.py +105 -0
  38. IndicTrans2/inference/requirements.txt +11 -0
  39. IndicTrans2/inference/triton_server/Dockerfile +25 -0
  40. IndicTrans2/inference/triton_server/README.md +22 -0
  41. IndicTrans2/inference/triton_server/azure_ml/README.md +56 -0
  42. IndicTrans2/inference/triton_server/azure_ml/deployment.yml +13 -0
  43. IndicTrans2/inference/triton_server/azure_ml/endpoint.yml +3 -0
  44. IndicTrans2/inference/triton_server/azure_ml/environment.yml +14 -0
  45. IndicTrans2/inference/triton_server/azure_ml/model.yml +5 -0
  46. IndicTrans2/inference/triton_server/client.py +55 -0
  47. IndicTrans2/inference/triton_server/dhruva/ulca_model.json +0 -0
  48. IndicTrans2/inference/triton_server/triton_repo/nmt/1/model.py +167 -0
  49. IndicTrans2/inference/triton_server/triton_repo/nmt/config.pbtxt +32 -0
  50. IndicTrans2/inference/utils.map_token_lang.tsv +26 -0
IndicTrans2/.gitignore ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ignore libs and data folder we use
2
+ indic_nlp_library
3
+ indic_nlp_resources
4
+ fairseq
5
+ devtest
6
+ checkpoints
7
+ eval_benchmarks
8
+
9
+ # Byte-compiled / optimized / DLL files
10
+ __pycache__/
11
+ *.py[cod]
12
+ *$py.class
13
+
14
+ # C extensions
15
+ *.so
16
+
17
+ # Distribution / packaging
18
+ .Python
19
+ build/
20
+ develop-eggs/
21
+ dist/
22
+ downloads/
23
+ eggs/
24
+ .eggs/
25
+ lib/
26
+ lib64/
27
+ parts/
28
+ sdist/
29
+ var/
30
+ wheels/
31
+ share/python-wheels/
32
+ *.egg-info/
33
+ .installed.cfg
34
+ *.egg
35
+ MANIFEST
36
+
37
+ # PyInstaller
38
+ # Usually these files are written by a python script from a template
39
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
40
+ *.manifest
41
+ *.spec
42
+
43
+ # Installer logs
44
+ pip-log.txt
45
+ pip-delete-this-directory.txt
46
+
47
+ # Unit test / coverage reports
48
+ htmlcov/
49
+ .tox/
50
+ .nox/
51
+ .coverage
52
+ .coverage.*
53
+ .cache
54
+ nosetests.xml
55
+ coverage.xml
56
+ *.cover
57
+ *.py,cover
58
+ .hypothesis/
59
+ .pytest_cache/
60
+ cover/
61
+
62
+ # Translations
63
+ *.mo
64
+ *.pot
65
+
66
+ # Django stuff:
67
+ *.log
68
+ local_settings.py
69
+ db.sqlite3
70
+ db.sqlite3-journal
71
+
72
+ # Flask stuff:
73
+ instance/
74
+ .webassets-cache
75
+
76
+ # Scrapy stuff:
77
+ .scrapy
78
+
79
+ # Sphinx documentation
80
+ docs/_build/
81
+
82
+ # PyBuilder
83
+ .pybuilder/
84
+ target/
85
+
86
+ # Jupyter Notebook
87
+ .ipynb_checkpoints
88
+
89
+ # IPython
90
+ profile_default/
91
+ ipython_config.py
92
+
93
+ # pyenv
94
+ # For a library or package, you might want to ignore these files since the code is
95
+ # intended to run in multiple environments; otherwise, check them in:
96
+ # .python-version
97
+
98
+ # pipenv
99
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
101
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
102
+ # install all needed dependencies.
103
+ #Pipfile.lock
104
+
105
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
106
+ __pypackages__/
107
+
108
+ # Celery stuff
109
+ celerybeat-schedule
110
+ celerybeat.pid
111
+
112
+ # SageMath parsed files
113
+ *.sage.py
114
+
115
+ # Environments
116
+ .env
117
+ .venv
118
+ env/
119
+ venv/
120
+ ENV/
121
+ env.bak/
122
+ venv.bak/
123
+
124
+ # Spyder project settings
125
+ .spyderproject
126
+ .spyproject
127
+
128
+ # Rope project settings
129
+ .ropeproject
130
+
131
+ # mkdocs documentation
132
+ /site
133
+
134
+ # mypy
135
+ .mypy_cache/
136
+ .dmypy.json
137
+ dmypy.json
138
+
139
+ # Pyre type checker
140
+ .pyre/
141
+
142
+ # pytype static type analyzer
143
+ .pytype/
144
+
145
+ # Cython debug symbols
146
+ cython_debug/
147
+
148
+ .DS_Store
IndicTrans2/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) AI4Bharat.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE
IndicTrans2/README.md ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # IndicTrans2
2
+
3
+ [📜 Paper](https://arxiv.org/abs/2305.16307) | [🌐 Website](https://ai4bharat.iitm.ac.in/indic-trans2) | [▶️ Demo](https://models.ai4bharat.org/#/nmt/v2) | [🤗 HF Interface](https://github.com/AI4Bharat/IndicTrans2/tree/main/huggingface_interface) | [![colab link](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI4Bharat/IndicTrans2/blob/main/huggingface_interface/colab_inference.ipynb)
4
+
5
+ IndicTrans2 is the first open-source transformer-based multilingual NMT model that supports high-quality translations across all the 22 scheduled Indic languages — including multiple scripts for low-resouce languages like Kashmiri, Manipuri and Sindhi. It adopts script unification wherever feasible to leverage transfer learning by lexical sharing between languages. Overall, the model supports five scripts Perso-Arabic (Kashmiri, Sindhi, Urdu), Ol Chiki (Santali), Meitei (Manipuri), Latin (English), and Devanagari (used for all the remaining languages).
6
+
7
+ We open-souce all our training dataset (BPCC), back-translation data (BPCC-BT), final IndicTrans2 models, evaluation benchmarks (IN22, which includes IN22-Gen and IN22-Conv) and training and inference scripts for easier use and adoption within the research community. We hope that this will foster even more research in low-resource Indic languages, leading to further improvements in the quality of low-resource translation through contributions from the research community.
8
+
9
+ This code repository contains instructions for downloading the artifacts associated with IndicTrans2, as well as the code for training/fine-tuning the multilingual NMT models.
10
+
11
+ Here is the list of languages supported by the IndicTrans2 models:
12
+
13
+ <table>
14
+ <tbody>
15
+ <tr>
16
+ <td>Assamese (asm_Beng)</td>
17
+ <td>Kashmiri (Arabic) (kas_Arab)</td>
18
+ <td>Punjabi (pan_Guru)</td>
19
+ </tr>
20
+ <tr>
21
+ <td>Bengali (ben_Beng)</td>
22
+ <td>Kashmiri (Devanagari) (kas_Deva)</td>
23
+ <td>Sanskrit (san_Deva)</td>
24
+ </tr>
25
+ <tr>
26
+ <td>Bodo (brx_Deva)</td>
27
+ <td>Maithili (mai_Deva)</td>
28
+ <td>Santali (sat_Olck)</td>
29
+ </tr>
30
+ <tr>
31
+ <td>Dogri (doi_Deva)</td>
32
+ <td>Malayalam (mal_Mlym)</td>
33
+ <td>Sindhi (Arabic) (snd_Arab)</td>
34
+ </tr>
35
+ <tr>
36
+ <td>English (eng_Latn)</td>
37
+ <td>Marathi (mar_Deva)</td>
38
+ <td>Sindhi (Devanagari) (snd_Deva)</td>
39
+ </tr>
40
+ <tr>
41
+ <td>Konkani (gom_Deva)</td>
42
+ <td>Manipuri (Bengali) (mni_Beng)</td>
43
+ <td>Tamil (tam_Taml)</td>
44
+ </tr>
45
+ <tr>
46
+ <td>Gujarati (guj_Gujr)</td>
47
+ <td>Manipuri (Meitei) (mni_Mtei)</td>
48
+ <td>Telugu (tel_Telu)</td>
49
+ </tr>
50
+ <tr>
51
+ <td>Hindi (hin_Deva)</td>
52
+ <td>Nepali (npi_Deva)</td>
53
+ <td>Urdu (urd_Arab)</td>
54
+ </tr>
55
+ <tr>
56
+ <td>Kannada (kan_Knda)</td>
57
+ <td>Odia (ory_Orya)</td>
58
+ <td></td>
59
+ </tr>
60
+ </tbody>
61
+ </table>
62
+
63
+ ## Updates
64
+
65
+ - 🚨 Dec 30, 2023 - Migrated IndicTrans2 tokenizer for HF compatible IndicTrans2 models to [IndicTransTokenizer](https://github.com/VarunGumma/IndicTransTokenizer) and will be maintained separately there from now onwards. Add LoRA fine-tuning scripts for our IndicTrans2 models in [huggingface_interface](https://github.com/AI4Bharat/IndicTrans2/tree/main/huggingface_interface).
66
+ - 🚨 Dec 1, 2023 - Release of Indic-Indic model and corresponding distilled variants for each base model. Please refer to the [Download section](https://github.com/AI4Bharat/IndicTrans2#multilingual-translation-models) for the checkpoints.
67
+ - 🚨 Sep 9, 2023 - Added HF compatible IndicTrans2 models. Please refer to the [README](https://github.com/AI4Bharat/IndicTrans2/tree/main/huggingface_interface) for detailed example usage.
68
+
69
+ ## Tables of Contents
70
+
71
+ - [Download Models and Other Artifacts](#download-models-and-other-artifacts)
72
+ - [Multilingual Translation Models](#multilingual-translation-models)
73
+ - [Training Data](#training-data)
74
+ - [Evaluation Data](#evaluation-data)
75
+ - [Installation](#installation)
76
+ - [Data](#data)
77
+ - [Training](#training)
78
+ - [Evaluation](#evaluation)
79
+ - [Preparing Data for Training](#preparing-data-for-training)
80
+ - [Using our SPM model and Fairseq dictionary](#using-our-spm-model-and-fairseq-dictionary)
81
+ - [Training your own SPM models and learning Fairseq dictionary](#training-your-own-spm-models-and-learning-fairseq-dictionary)
82
+ - [Training / Fine-tuning](#training--fine-tuning)
83
+ - [Inference](#inference)
84
+ - [Fairseq Inference](#fairseq-inference)
85
+ - [CT2 Inference](#ct2-inference)
86
+ - [Evaluations](#evaluations)
87
+ - [Baseline Evaluation](#baseline-evaluation)
88
+ - [LICENSE](#license)
89
+ - [Citation](#citation)
90
+
91
+ ## Download Models and Other Artifacts
92
+
93
+ ### Multilingual Translation Models
94
+
95
+ | Model | En-Indic | Indic-En | Indic-Indic | Evaluations |
96
+ | ---------------------------- | ----------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
97
+ | Base (used for benchmarking) | [download](https://indictrans2-public.objectstore.e2enetworks.net/it2_preprint_ckpts/en-indic-preprint.zip) | [download](https://indictrans2-public.objectstore.e2enetworks.net/it2_preprint_ckpts/indic-en-preprint.zip) | [download](https://indictrans2-public.objectstore.e2enetworks.net/it2_preprint_ckpts/indic-indic.zip) | [translations](https://indictrans2-public.objectstore.e2enetworks.net/translation_outputs.zip) (as of May 10, 2023), [metrics](https://drive.google.com/drive/folders/1lOOdaU0VdRSBgJEsNav5zC7wwLBis9NI?usp=sharing) |
98
+ | Distilled | [download](https://indictrans2-public.objectstore.e2enetworks.net/it2_distilled_ckpts/en-indic.zip) | [download](https://indictrans2-public.objectstore.e2enetworks.net/it2_distilled_ckpts/indic-en.zip) | [download](https://indictrans2-public.objectstore.e2enetworks.net/it2_distilled_ckpts/indic-indic.zip) |
99
+
100
+ ### Training Data
101
+
102
+ | Data | URL |
103
+ | ---------------------------------------- | ------------------------------------------------------------------------------ |
104
+ | Bharat Parallel Corpus Collection (BPCC) | [download](https://indictrans2-public.objectstore.e2enetworks.net/BPCC.zip) |
105
+ | Back-translation (BPCC-BT) | [download](https://indictrans2-public.objectstore.e2enetworks.net/BT_data.zip) |
106
+
107
+ ### Evaluation Data
108
+
109
+ | Data | URL |
110
+ | ----------------------- | ------------------------------------------------------------------------------------ |
111
+ | IN22 test set | [download](https://indictrans2-public.objectstore.e2enetworks.net/IN22_testset.zip) |
112
+ | FLORES-22 Indic dev set | [download](https://indictrans2-public.objectstore.e2enetworks.net/flores-22_dev.zip) |
113
+
114
+ ## Installation
115
+
116
+ Instructions to setup and install everything before running the code.
117
+
118
+ ```bash
119
+ # Clone the github repository and navigate to the project directory.
120
+ git clone https://github.com/AI4Bharat/IndicTrans2
121
+ cd IndicTrans2
122
+
123
+ # Install all the dependencies and requirements associated with the project.
124
+ source install.sh
125
+ ```
126
+
127
+ Note: We recommend creating a virtual environment with python>=3.7.
128
+
129
+ ### Additional notes about Installation
130
+ The ``prepare_data_joint_finetuning.sh`` and ``prepare_data_joint_training.sh`` scripts expect that the sentencepiece commandline utility and GNU parallel are installed.
131
+ 1. To install the sentencepiece command line utility, please follow the instructions [here](https://github.com/google/sentencepiece?tab=readme-ov-file#build-and-install-sentencepiece-command-line-tools-from-c-source).
132
+ 2. Please check if GNU parallel is installed, if not please install the same or alternatively in case of installation issues, remove ``parallel --pipe --keep-order`` from the respective training / finetuning script as well as ``apply_sentence_piece.sh``.
133
+
134
+
135
+ ## Data
136
+
137
+ ### Training
138
+
139
+ Bharat Parallel Corpus Collection (BPCC) is a comprehensive and publicly available parallel corpus that includes both existing and new data for all 22 scheduled Indic languages. It is comprised of two parts: BPCC-Mined and BPCC-Human, totaling approximately 230 million bitext pairs. BPCC-Mined contains about 228 million pairs, with nearly 126 million pairs newly added as a part of this work. On the other hand, BPCC-Human consists of 2.2 million gold standard English-Indic pairs, with an additional 644K bitext pairs from English Wikipedia sentences (forming the BPCC-H-Wiki subset) and 139K sentences covering everyday use cases (forming the BPCC-H-Daily subset). It is worth highlighting that BPCC provides the first available datasets for 7 languages and significantly increases the available data for all languages covered.
140
+
141
+ You can find the contribution from different sources in the following table:
142
+
143
+ <table>
144
+ <tbody>
145
+ <tr>
146
+ <td rowspan="4">BPCC-Mined</th>
147
+ <td rowspan="2">Existing</th>
148
+ <td>Samanantar</th>
149
+ <td>19.4M</th>
150
+ </tr>
151
+ <tr>
152
+ <td>NLLB</th>
153
+ <td>85M</th>
154
+ </tr>
155
+ <tr>
156
+ <td rowspan="2">Newly Added</th>
157
+ <td>Samanantar++</th>
158
+ <td>121.6M</th>
159
+ </tr>
160
+ <tr>
161
+ <td>Comparable</th>
162
+ <td>4.3M</th>
163
+ </tr>
164
+ <tr>
165
+ <td rowspan="5">BPCC-Human</td>
166
+ <td rowspan="3">Existing</td>
167
+ <td>NLLB</td>
168
+ <td>18.5K</td>
169
+ </tr>
170
+ <tr>
171
+ <td>ILCI</td>
172
+ <td>1.3M</td>
173
+ </tr>
174
+ <tr>
175
+ <td>Massive</td>
176
+ <td>115K</td>
177
+ </tr>
178
+ <tr>
179
+ <td rowspan="2">Newly Added</td>
180
+ <td>Wiki</td>
181
+ <td>644K</td>
182
+ </tr>
183
+ <tr>
184
+ <td>Daily</td>
185
+ <td>139K</td>
186
+ </tr>
187
+ </tbody>
188
+ </table>
189
+
190
+ Additionally, we provide augmented back-translation data generated by our intermediate IndicTrans2 models for training purposes. Please refer our paper for more details on the selection of sample proportions and sources.
191
+
192
+ <table>
193
+ <tbody>
194
+ <tr>
195
+ <td>English BT data (English Original)</td>
196
+ <td>401.9M</td>
197
+ </tr>
198
+ <tr>
199
+ <td>Indic BT data (Indic Original)</td>
200
+ <td>400.9M</td>
201
+ </tr>
202
+ </tbody>
203
+ </table>
204
+
205
+ <br>
206
+
207
+ ### Evaluation
208
+
209
+ IN22 test set is a newly created comprehensive benchmark for evaluating machine translation performance in multi-domain, n-way parallel contexts across 22 Indic languages. It has been created from three distinct subsets, namely IN22-Wiki, IN22-Web and IN22-Conv. The Wikipedia and Web sources subsets offer diverse content spanning news, entertainment, culture, legal, and India-centric topics. IN22-Wiki and IN22-Web have been combined and considered for evaluation purposes and released as IN22-Gen. Meanwhile, IN22-Conv the conversation domain subset is designed to assess translation quality in typical day-to-day conversational-style applications.
210
+
211
+ <table>
212
+ <tbody>
213
+ <tr>
214
+ <td>IN22-Gen (IN22-Wiki + IN22-Web)</td>
215
+ <td>1024 sentences</td>
216
+ <td>🤗 <a href="https://huggingface.co/datasets/ai4bharat/IN22-Gen">ai4bharat/IN22-Gen</td>
217
+ </tr>
218
+ <tr>
219
+ <td>IN22-Conv</td>
220
+ <td>1503 sentences</td>
221
+ <td>🤗 <a href="https://huggingface.co/datasets/ai4bharat/IN22-Conv">ai4bharat/IN22-Conv</td>
222
+ </tr>
223
+ </tbody>
224
+ </table>
225
+
226
+ You can download the data artifacts released as a part of this work from the [following section](#download-models-and-other-artifacts).
227
+
228
+ ## Preparing Data for Training
229
+
230
+ BPCC data is organized under different subsets as described above, where each subset contains language pair subdirectories with the sentences pairs. We also provide LaBSE and LASER for the mined subsets of BPCC. In order to replicate our training setup, you will need to combine the data for corresponding language pairs from different subsets and remove overlapping bitext pairs if any.
231
+
232
+ Here is the expected directory structure of the data:
233
+
234
+ ```bash
235
+ BPCC
236
+ ├── eng_Latn-asm_Beng
237
+ │ ├── train.eng_Latn
238
+ │ └── train.asm_Beng
239
+ ├── eng_Latn-ben_Beng
240
+ └── ...
241
+ ```
242
+
243
+ While we provide deduplicated subsets with the current available benchmarks, we highly recommend performing deduplication using the combined monolingual side of all the benchmarks. You can use the following command for deduplication once you combine the monolingual side of all the benchmarks in the directory.
244
+
245
+ ```python3
246
+ python3 scripts/dedup_benchmark.py <in_data_dir> <out_data_dir> <benchmark_dir>
247
+ ```
248
+
249
+ - `<in_data_dir>`: path to the directory containing train data for each language pair in the format `{src_lang}-{tgt_lang}`
250
+ - `<out_data_dir>`: path to the directory where the deduplicated train data will be written for each language pair in the format `{src_lang}-{tgt_lang}`
251
+ - `<benchmark_dir>`: path to the directory containing the language-wise monolingual side of dev/test set, with monolingual files named as `test.{lang}`
252
+
253
+ ### Using our SPM model and Fairseq dictionary
254
+
255
+ Once you complete the deduplication of the training data with the available benchmarks, you can preprocess and binarize the data for training models. Please download our trained SPM model and learned Fairseq dictionary using the following links for your experiments.
256
+
257
+ | | En-Indic | Indic-En | Indic-Indic |
258
+ | ------------------ | -------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------- |
259
+ | SPM model | [download](https://indictrans2-public.objectstore.e2enetworks.net/en-indic-spm.zip) | [download](https://indictrans2-public.objectstore.e2enetworks.net/indic-en-spm.zip) | [download](https://indictrans2-public.objectstore.e2enetworks.net/indic-indic-spm.zip) |
260
+ | Fairseq dictionary | [download](https://indictrans2-public.objectstore.e2enetworks.net/en-indic-fairseq-dict.zip) | [download](https://indictrans2-public.objectstore.e2enetworks.net/indic-en-fairseq-dict.zip) | [download](https://indictrans2-public.objectstore.e2enetworks.net/indic-indic-fairseq-dict.zip) |
261
+
262
+ To prepare the data for training En-Indic model, please do the following:
263
+
264
+ 1. Download the SPM model in the experiment directory and rename it as `vocab`.
265
+ 2. Download the Fairseq dictionary in the experiment directory and rename it as `final_bin`.
266
+
267
+ Here is the expected directory for training En-Indic model:
268
+
269
+ ```bash
270
+ en-indic-exp
271
+ ├── train
272
+ │ ├── eng_Latn-asm_Beng
273
+ │ │ ├── train.eng_Latn
274
+ │ │ └── train.asm_Beng
275
+ │ ├── eng_Latn-ben_Beng
276
+ │ └── ...
277
+ ├── devtest
278
+ │ └── all
279
+ │ ├── eng_Latn-asm_Beng
280
+ │ │ ├── dev.eng_Latn
281
+ │ │ └── dev.asm_Beng
282
+ │ ├── eng_Latn-ben_Beng
283
+ │ └── ...
284
+ ├── vocab
285
+ │ ├── model.SRC
286
+ │ ├── model.TGT
287
+ │ ├── vocab.SRC
288
+ │ └── vocab.TGT
289
+ └── final_bin
290
+ ├── dict.SRC.txt
291
+ └── dict.TGT.txt
292
+ ```
293
+
294
+ To prepare data for training the Indic-En model, you should reverse the language pair directories within the train and devtest directories. Additionally, make sure to download the corresponding SPM model and Fairseq dictionary and put them in the experiment directory, similar to the procedure mentioned above for En-Indic model training.
295
+
296
+ You can binarize the data for model training using the following:
297
+
298
+ ```bash
299
+ bash prepare_data_joint_finetuning.sh <exp_dir>
300
+ ```
301
+
302
+ - `<exp_dir>`: path to the directory containing the raw data for binarization
303
+
304
+ You will need to follow the same steps for data preparation in case of fine-tuning models.
305
+
306
+ ### Training your own SPM models and learning Fairseq dictionary
307
+
308
+ If you want to train your own SPM model and learn Fairseq dictionary, then please do the following:
309
+
310
+ 1. Collect a balanced amount of English and Indic monolingual data (we use around 3 million sentences per language-script combination). If some languages have limited data available, increase their representation to achieve a fair distribution of tokens across languages.
311
+ 2. Perform script unification for Indic languages wherever possible using `scripts/preprocess_translate.py` and concatenate all Indic data into a single file.
312
+ 3. Train two SPM models, one for English and other for Indic side using the following:
313
+
314
+ ```bash
315
+ spm_train --input=train.indic --model_prefix=<model_name> --vocab_size=<vocab_size> --character_coverage=1.0 --model_type=BPE
316
+ ```
317
+
318
+ 4. Copy the trained SPM models in the experiment directory mentioned earlier and learn the Fairseq dictionary using the following:
319
+
320
+ ```bash
321
+ bash prepare_data_joint_training.sh <exp_dir>
322
+ ```
323
+
324
+ 5. You will need to use the same Fairseq dictionary for any subsequent fine-tuning experiments and refer to the steps described above ([link](#using-our-spm-model-and-fairseq-dictionary)).
325
+
326
+ ## Training / Fine-tuning
327
+
328
+ After binarizing the data, you can use train.sh to train the models. We provide the default hyperparameters used in this work. You can modify the hyperparameters as per your requirement if needed. If you want to train the model on a customized architecture, then please define the architecture in `model_configs/custom_transformer.py`. You can start the model training with the following command:
329
+
330
+ ```bash
331
+ bash train.sh <exp_dir> <model_arch>
332
+ ```
333
+
334
+ - `<exp_dir>`: path to the directory containing the binarized data
335
+ - `<model_arch>`: custom transformer architecture used for model training
336
+
337
+ For fine-tuning, the initial steps remain the same. However, the `finetune.sh` script includes an additional argument, `pretrained_ckpt`, which specifies the model checkpoint to be loaded for further fine-tuning. You can perform fine-tuning using the following command:
338
+
339
+ ```bash
340
+ bash finetune.sh <exp_dir> <model_arch> <pretrained_ckpt>
341
+ ```
342
+
343
+ - `<exp_dir>`: path to the directory containing the binarized data
344
+ - `<model_arch>`: custom transformer architecture used for model training
345
+ - `transformer_18_18` - For IT2 Base models
346
+ - `transformer_base18L` - For IT2 Distilled models
347
+ - `<pretrained_ckpt>`: path to the fairseq model checkpoint to be loaded for further fine-tuning
348
+
349
+ You can download the model artifacts released as a part of this work from the [following section](#download-models-and-other-artifacts).
350
+
351
+ The pretrained checkpoints have 3 directories, a fairseq model directory and 2 CT-ported model directories. Please note that the CT2 models are provided only for efficient inference. For fine-tuning purposes you should use the `fairseq_model`. Post that you can use the [fairseq-ct2-converter](https://opennmt.net/CTranslate2/guides/fairseq.html) to port your fine-tuned checkpoints to CT2 for faster inference.
352
+
353
+ ## Inference
354
+
355
+ ### Fairseq Inference
356
+
357
+ In order to run inference on our pretrained models using bash interface, please use the following:
358
+
359
+ ```bash
360
+ bash joint_translate.sh <infname> <outfname> <src_lang> <tgt_lang> <ckpt_dir>
361
+ ```
362
+
363
+ - `infname`: path to the input file containing sentences
364
+ - `outfname`: path to the output file where the translations should be stored
365
+ - `src_lang`: source language
366
+ - `tgt_lang`: target language
367
+ - `ckpt_dir`: path to the fairseq model checkpoint directory
368
+
369
+ If you want to run the inference using python interface then please execute the following block of code from the root directory:
370
+
371
+ ```python3
372
+ from inference.engine import Model
373
+
374
+ model = Model(ckpt_dir, model_type="fairseq")
375
+
376
+ sents = [sent1, sent2,...]
377
+
378
+ # for a batch of sentences
379
+ model.batch_translate(sents, src_lang, tgt_lang)
380
+
381
+ # for a paragraph
382
+ model.translate_paragraph(text, src_lang, tgt_lang)
383
+ ```
384
+
385
+ ### CT2 Inference
386
+
387
+ In order to run inference on CT2-ported model using python inference then please execute the following block of code from the root directory:
388
+
389
+ ```python3
390
+ from inference.engine import Model
391
+
392
+ model = Model(ckpt_dir, model_type="ctranslate2")
393
+
394
+ sents = [sent1, sent2,...]
395
+
396
+ # for a batch of sentences
397
+ model.batch_translate(sents, src_lang, tgt_lang)
398
+
399
+ # for a paragraph
400
+ model.translate_paragraph(text, src_lang, tgt_lang)
401
+ ```
402
+
403
+ ## Evaluations
404
+
405
+ We consider the chrF++ as our primary metric. Additionally, we also report the BLEU and Comet scores.
406
+ We also perform statistical significance tests for each metric to ascertain whether the differences are statistically significant.
407
+
408
+ In order to run our evaluation scripts, you will need to organize the evaluation test sets into the following directory structure:
409
+
410
+ ```bash
411
+ eval_benchmarks
412
+ ├── flores
413
+ │ └── eng_Latn-asm_Beng
414
+ │ ├── test.eng_Latn
415
+ │ └── test.asm_Beng
416
+ ├── in22-gen
417
+ ├── in22-conv
418
+ ├── ntrex
419
+ └── ...
420
+ ```
421
+
422
+ To compute the BLEU and chrF++ scores for prediction file, you can use the following command:
423
+
424
+ ```bash
425
+ bash compute_metrics.sh <pred_fname> <ref_fname> <tgt_lang>
426
+ ```
427
+
428
+ - `pred_fname`: path to the model translations
429
+ - `ref_fname`: path to the reference translations
430
+ - `tgt_lang`: target language
431
+
432
+ In order to automate the inference over the individual test sets for En-Indic, you can use the following command:
433
+
434
+ ```bash
435
+ bash eval.sh <devtest_data_dir> <ckpt_dir> <system>
436
+ ```
437
+
438
+ - `<devtest_data_dir>`: path to the evaluation set with language pair subdirectories (for example, flores directory in the above tree structure)
439
+ - `<ckpt_dir>`: path to the fairseq model checkpoint directory
440
+ - `<system>`: system name suffix to store the predictions in the format `test.{lang}.pred.{system}`
441
+
442
+ In case of Indic-En evaluation, please use the following command:
443
+
444
+ ```bash
445
+ bash eval_rev.sh <devtest_data_dir> <ckpt_dir> <system>
446
+ ```
447
+
448
+ - `<devtest_data_dir>`: path to the evaluation set with language pair subdirectories (for example, flores directory in the above tree structure)
449
+ - `<ckpt_dir>`: path to the fairseq model checkpoint directory
450
+ - `<system>`: system name suffix to store the predictions in the format `test.{lang}.pred.{system}`
451
+
452
+ **_Note: You don’t need to reverse the test set directions for each language pair._**
453
+
454
+ In case of Indic-Indic evaluation, please use the following command:
455
+
456
+ ```bash
457
+ bash pivot_eval.sh <devtest_data_dir> <pivot_lang> <src2pivot_ckpt_dir> <pivot2tgt_ckpt_dir> <system>
458
+ ```
459
+
460
+ - `<devtest_data_dir>`: path to the evaluation set with language pair subdirectories (for example, flores directory in the above tree structure)
461
+ - `<pivot_lang>`: pivot language (default should be `eng_Latn`)
462
+ - `<src2pivot_ckpt_dir>`: path to the fairseq Indic-En model checkpoint directory
463
+ - `<pivot2tgt_ckpt_dir>`: path to the fairseq En-Indic model checkpoint directory
464
+ - `<system>`: system name suffix to store the predictions in the format test.{lang}.pred.{system}
465
+
466
+ In order to perform significance testing for BLEU and chrF++ metrics after you have the predictions for different systems, you can use the following command:
467
+
468
+ ```bash
469
+ bash compute_comet_metrics_significance.sh <devtest_data_dir>
470
+ ```
471
+
472
+ - `<devtest_data_dir>`: path to the evaluation set with language pair subdirectories (for example, flores directory in the above tree structure)
473
+
474
+ Similarly, to compute the COMET scores and perform significance testing on predictions of different systems, you can use the following command.
475
+
476
+ ```bash
477
+ bash compute_comet_score.sh <devtest_data_dir>
478
+ ```
479
+
480
+ - `<devtest_data_dir>`: path to the evaluation set with language pair subdirectories (for example, flores directory in the above tree structure)
481
+
482
+ Please note that as we compute significance tests with the same script and automate everything, it is best to have all the predictions for all the systems in place to avoid repeating anything.
483
+ Also, we define the systems in the script itself, if you want to try out other systems, make sure to edit it there itself.
484
+
485
+ ### Baseline Evaluation
486
+
487
+ To generate the translation results for baseline models such as M2M-100, MBART, Azure, Google, and NLLB MoE, you can check the scripts provided in the "baseline_eval" directory of this repository. For NLLB distilled, you can either modify NLLB_MoE eval or use this [repository](https://github.com/pluiez/NLLB-inference). Similarly, for IndicTrans inference, please refer to this [repository](https://github.com/ai4bharat/IndicTrans).
488
+
489
+ You can download the translation outputs released as a part of this work from the [following section](#download-models-and-other-artifacts).
490
+
491
+ ## LICENSE
492
+
493
+ The following table lists the licenses associated with the different artifacts released as a part of this work:
494
+
495
+ | Artifact | LICENSE |
496
+ | ----------------------------------------------------- | --------------------------------------------------------------------- |
497
+ | Existing Mined Corpora (NLLB & Samanantar) | [CC0](https://creativecommons.org/share-your-work/public-domain/cc0/) |
498
+ | Existing Seed Corpora (NLLB-Seed, ILCI, MASSIVE) | [CC0](https://creativecommons.org/share-your-work/public-domain/cc0/) |
499
+ | Newly Added Mined Corpora (Samanantar++ & Comparable) | [CC0](https://creativecommons.org/share-your-work/public-domain/cc0/) |
500
+ | Newly Added Seed Corpora (BPCC-H-Wiki & BPCC-H-Daily) | [CC-BY-4.0](https://creativecommons.org/licenses/by/4.0/) |
501
+ | Newly Created IN-22 test set (IN22-Gen & IN22-Conv) | [CC-BY-4.0](https://creativecommons.org/licenses/by/4.0/) |
502
+ | Back-translation data (BPCC-BT) | [CC0](https://creativecommons.org/share-your-work/public-domain/cc0/) |
503
+ | Model checkpoints | [MIT](https://github.com/ai4bharat/IndicTrans2/blob/main/LICENSE) |
504
+
505
+ The mined corpora collection (BPCC-Mined), existing seed corpora (NLLB-Seed, ILCI, MASSIVE), Backtranslation data (BPCC-BT), are released under the following licensing scheme:
506
+
507
+ - We do not own any of the text from which this data has been extracted.
508
+ - We license the actual packaging of this data under the Creative Commons [CC0 license (“no rights reserved”)](https://creativecommons.org/share-your-work/public-domain/cc0/).
509
+ - To the extent possible under law, [AI4Bharat](https://ai4bharat.iitm.ac.in/) has waived all copyright and related or neighboring rights to BPCC-Mined, existing seed corpora (NLLB-Seed, ILCI, MASSIVE) and BPCC-BT.
510
+
511
+ ## Citation
512
+
513
+ ```bibtex
514
+ @article{gala2023indictrans,
515
+ title={IndicTrans2: Towards High-Quality and Accessible Machine Translation Models for all 22 Scheduled Indian Languages},
516
+ author={Jay Gala and Pranjal A Chitale and A K Raghavan and Varun Gumma and Sumanth Doddapaneni and Aswanth Kumar M and Janki Atul Nawale and Anupama Sujatha and Ratish Puduppully and Vivek Raghavan and Pratyush Kumar and Mitesh M Khapra and Raj Dabre and Anoop Kunchukuttan},
517
+ journal={Transactions on Machine Learning Research},
518
+ issn={2835-8856},
519
+ year={2023},
520
+ url={https://openreview.net/forum?id=vfT4YuzAYA},
521
+ note={}
522
+ }
523
+ ```
IndicTrans2/apply_sentence_piece.sh ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # This script tokenizes the preprocessed train and dev set using the trained spm models.
4
+
5
+
6
+ echo `date`
7
+ exp_dir=$1 # path to the experiment directory
8
+ data_dir=$2 # path to the data directory where all lang pairs are concatenated
9
+ bpe_dir=$3 # path to the tokenized data directory
10
+ src_lang=$4 # source language
11
+ tgt_lang=$5 # target language
12
+ split=$6 # name of the split
13
+ parallel_installed=${7:-false} # If GNU Parallel is installed or not
14
+
15
+ in_split_dir=$data_dir/$split
16
+ out_split_dir=$bpe_dir/$split
17
+
18
+ echo "Apply Sentence Piece tokenization to SRC corpus"
19
+ # for very large datasets, it is recommended to use gnu-parallel to speed up applying bpe
20
+
21
+ if $parallel_installed; then
22
+ parallel --pipe --keep-order \
23
+ spm_encode --model=$exp_dir/vocab/model.SRC \
24
+ --output_format=piece \
25
+ < $in_split_dir.$src_lang \
26
+ > $out_split_dir.$src_lang
27
+ else
28
+ spm_encode --model=$exp_dir/vocab/model.SRC \
29
+ --output_format=piece \
30
+ < $in_split_dir.$src_lang \
31
+ > $out_split_dir.$src_lang
32
+ fi
33
+
34
+ echo "Apply Sentence Piece tokenization to TGT corpus"
35
+ # for very large datasets, it is recommended to use gnu-parallel to speed up applying bpe
36
+
37
+ if $parallel_installed; then
38
+ parallel --pipe --keep-order \
39
+ spm_encode --model=$exp_dir/vocab/model.TGT \
40
+ --output_format=piece \
41
+ < $in_split_dir.$tgt_lang \
42
+ > $out_split_dir.$tgt_lang
43
+ else
44
+ spm_encode --model=$exp_dir/vocab/model.TGT \
45
+ --output_format=piece \
46
+ < $in_split_dir.$tgt_lang \
47
+ > $out_split_dir.$tgt_lang
48
+ fi
IndicTrans2/baseline_eval/azure_translate.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import glob
4
+ import requests
5
+ from urllib.parse import urlencode
6
+ from dotenv import dotenv_values
7
+ import traceback
8
+ import time
9
+
10
+ flores_to_iso = {
11
+ "asm_Beng": "as",
12
+ "ben_Beng": "bn",
13
+ "brx_Deva": "brx",
14
+ "doi_Deva": "doi",
15
+ "eng_Latn": "en",
16
+ "gom_Deva": "gom",
17
+ "guj_Gujr": "gu",
18
+ "hin_Deva": "hi",
19
+ "kan_Knda": "kn",
20
+ "kas_Arab": "ks",
21
+ "kas_Deva": "ks_Deva",
22
+ "mai_Deva": "mai",
23
+ "mal_Mlym": "ml",
24
+ "mar_Deva": "mr",
25
+ "mni_Beng": "mni_Beng",
26
+ "mni_Mtei": "mni",
27
+ "npi_Deva": "ne",
28
+ "ory_Orya": "or",
29
+ "pan_Guru": "pa",
30
+ "san_Deva": "sa",
31
+ "sat_Olck": "sat",
32
+ "snd_Arab": "sd",
33
+ "snd_Deva": "sd_Deva",
34
+ "tam_Taml": "ta",
35
+ "tel_Telu": "te",
36
+ "urd_Arab": "ur",
37
+ }
38
+
39
+
40
+ class AzureTranslator:
41
+ def __init__(
42
+ self,
43
+ subscription_key: str,
44
+ region: str,
45
+ endpoint: str = "https://api.cognitive.microsofttranslator.com",
46
+ ) -> None:
47
+ self.http_headers = {
48
+ "Ocp-Apim-Subscription-Key": subscription_key,
49
+ "Ocp-Apim-Subscription-Region": region,
50
+ }
51
+ self.translate_endpoint = endpoint + "/translate?api-version=3.0&"
52
+ self.languages_endpoint = endpoint + "/languages?api-version=3.0"
53
+
54
+ self.supported_languages = self.get_supported_languages()
55
+
56
+ def get_supported_languages(self) -> dict:
57
+ return requests.get(self.languages_endpoint).json()["translation"]
58
+
59
+ def batch_translate(self, texts: list, src_lang: str, tgt_lang: str) -> list:
60
+ if not texts:
61
+ return texts
62
+
63
+ src_lang = flores_to_iso[src_lang]
64
+ tgt_lang = flores_to_iso[tgt_lang]
65
+
66
+ if src_lang not in self.supported_languages:
67
+ raise NotImplementedError(
68
+ f"Source language code: `{src_lang}` not supported!"
69
+ )
70
+
71
+ if tgt_lang not in self.supported_languages:
72
+ raise NotImplementedError(
73
+ f"Target language code: `{tgt_lang}` not supported!"
74
+ )
75
+
76
+ body = [{"text": text} for text in texts]
77
+ query_string = urlencode(
78
+ {
79
+ "from": src_lang,
80
+ "to": tgt_lang,
81
+ }
82
+ )
83
+
84
+ try:
85
+ response = requests.post(
86
+ self.translate_endpoint + query_string,
87
+ headers=self.http_headers,
88
+ json=body,
89
+ )
90
+ except:
91
+ traceback.print_exc()
92
+ return None
93
+
94
+ try:
95
+ response = response.json()
96
+ except:
97
+ traceback.print_exc()
98
+ print("Response:", response.text)
99
+ return None
100
+
101
+ return [payload["translations"][0]["text"] for payload in response]
102
+
103
+ def text_translate(self, text: str, src_lang: str, tgt_lang: str) -> str:
104
+ return self.batch_translate([text], src_lang, tgt_lang)[0]
105
+
106
+
107
+ if __name__ == "__main__":
108
+ root_dir = sys.argv[1]
109
+
110
+ # Expects a .env file containing the API credentials.
111
+ config = dotenv_values(os.path.join(os.path.dirname(__file__), ".env"))
112
+
113
+ t = AzureTranslator(
114
+ config["AZURE_TRANSLATOR_TEXT_SUBSCRIPTION_KEY"],
115
+ config["AZURE_TRANSLATOR_TEXT_REGION"],
116
+ config["AZURE_TRANSLATOR_TEXT_ENDPOINT"],
117
+ )
118
+
119
+ pairs = sorted(glob.glob(os.path.join(root_dir, "*")))
120
+
121
+ for i, pair in enumerate(pairs):
122
+ basename = os.path.basename(pair)
123
+
124
+ print(pair)
125
+
126
+ src_lang, tgt_lang = basename.split("-")
127
+
128
+ print(f"{src_lang} - {tgt_lang}")
129
+
130
+ # source to target translations
131
+ src_infname = os.path.join(pair, f"test.{src_lang}")
132
+ tgt_outfname = os.path.join(pair, f"test.{tgt_lang}.pred.azure")
133
+ if not os.path.exists(src_infname):
134
+ continue
135
+
136
+ src_sents = [
137
+ sent.replace("\n", "").strip()
138
+ for sent in open(src_infname, "r").read().split("\n")
139
+ if sent
140
+ ]
141
+
142
+ if not os.path.exists(tgt_outfname):
143
+ try:
144
+ translations = []
145
+ for i in range(0, len(src_sents), 128):
146
+ start, end = i, int(min(i + 128, len(src_sents)))
147
+ translations.extend(
148
+ t.batch_translate(src_sents[start:end], src_lang, tgt_lang)
149
+ )
150
+ with open(tgt_outfname, "w") as f:
151
+ f.write("\n".join(translations))
152
+
153
+ time.sleep(10)
154
+ except Exception as e:
155
+ print(e)
156
+ continue
157
+
158
+ # target to source translations
159
+ tgt_infname = os.path.join(pair, f"test.{tgt_lang}")
160
+ src_outfname = os.path.join(pair, f"test.{src_lang}.pred.azure")
161
+ if not os.path.exists(tgt_infname):
162
+ continue
163
+
164
+ tgt_sents = [
165
+ sent.replace("\n", "").strip()
166
+ for sent in open(tgt_infname, "r").read().split("\n")
167
+ if sent
168
+ ]
169
+
170
+ if not os.path.exists(src_outfname):
171
+ try:
172
+ translations = []
173
+ for i in range(0, len(tgt_sents), 128):
174
+ start, end = i, int(min(i + 128, len(tgt_sents)))
175
+ translations.extend(
176
+ t.batch_translate(tgt_sents[start:end], tgt_lang, src_lang)
177
+ )
178
+ with open(src_outfname, "w") as f:
179
+ f.write("\n".join(translations))
180
+ except Exception as e:
181
+ continue
182
+
183
+ time.sleep(10)
IndicTrans2/baseline_eval/google_translate.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import glob
4
+ from tqdm import tqdm
5
+ from google.cloud import translate
6
+
7
+ # Expects a json file containing the API credentials.
8
+ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.join(
9
+ os.path.dirname(__file__), r"api_key.json"
10
+ )
11
+
12
+ flores_to_iso = {
13
+ "asm_Beng": "as",
14
+ "ben_Beng": "bn",
15
+ "doi_Deva": "doi",
16
+ "eng_Latn": "en",
17
+ "gom_Deva": "gom",
18
+ "guj_Gujr": "gu",
19
+ "hin_Deva": "hi",
20
+ "kan_Knda": "kn",
21
+ "mai_Deva": "mai",
22
+ "mal_Mlym": "ml",
23
+ "mar_Deva": "mr",
24
+ "mni_Mtei": "mni_Mtei",
25
+ "npi_Deva": "ne",
26
+ "ory_Orya": "or",
27
+ "pan_Guru": "pa",
28
+ "san_Deva": "sa",
29
+ "sat_Olck": "sat",
30
+ "snd_Arab": "sd",
31
+ "tam_Taml": "ta",
32
+ "tel_Telu": "te",
33
+ "urd_Arab": "ur",
34
+ }
35
+
36
+
37
+ # Copy the project id from the json file containing API credentials
38
+ def translate_text(text, src_lang, tgt_lang, project_id="project_id"):
39
+
40
+ src_lang = flores_to_iso[src_lang]
41
+ tgt_lang = flores_to_iso[tgt_lang]
42
+
43
+ if src_lang == "mni_Mtei":
44
+ src_lang = "mni-Mtei"
45
+
46
+ if tgt_lang == "mni_Mtei":
47
+ tgt_lang = "mni-Mtei"
48
+
49
+ client = translate.TranslationServiceClient()
50
+
51
+ location = "global"
52
+
53
+ parent = f"projects/{project_id}/locations/{location}"
54
+
55
+ response = client.translate_text(
56
+ request={
57
+ "parent": parent,
58
+ "contents": [text],
59
+ "mime_type": "text/plain", # mime types: text/plain, text/html
60
+ "source_language_code": src_lang,
61
+ "target_language_code": tgt_lang,
62
+ }
63
+ )
64
+
65
+ translated_text = ""
66
+ for translation in response.translations:
67
+ translated_text += translation.translated_text
68
+
69
+ return translated_text
70
+
71
+
72
+ if __name__ == "__main__":
73
+ root_dir = sys.argv[1]
74
+
75
+ pairs = sorted(glob.glob(os.path.join(root_dir, "*")))
76
+
77
+ for pair in pairs:
78
+
79
+ print(pair)
80
+
81
+ basename = os.path.basename(pair)
82
+
83
+ src_lang, tgt_lang = basename.split("-")
84
+ if src_lang not in flores_to_iso.keys() or tgt_lang not in flores_to_iso.keys():
85
+ continue
86
+
87
+ if src_lang == "eng_Latn":
88
+ lang = tgt_lang
89
+ else:
90
+ lang = src_lang
91
+
92
+ lang = flores_to_iso[lang]
93
+
94
+ if lang not in "as bn doi gom gu hi kn mai ml mni_Mtei mr ne or pa sa sd ta te ur":
95
+ continue
96
+
97
+ print(f"{src_lang} - {tgt_lang}")
98
+
99
+ # source to target translations
100
+
101
+ src_infname = os.path.join(pair, f"test.{src_lang}")
102
+ tgt_outfname = os.path.join(pair, f"test.{tgt_lang}.pred.google")
103
+ if os.path.exists(src_infname) and not os.path.exists(tgt_outfname):
104
+ src_sents = [
105
+ sent.replace("\n", "").strip()
106
+ for sent in open(src_infname, "r").read().split("\n")
107
+ if sent
108
+ ]
109
+ translations = [
110
+ translate_text(text, src_lang, tgt_lang).strip() for text in tqdm(src_sents)
111
+ ]
112
+ with open(tgt_outfname, "w") as f:
113
+ f.write("\n".join(translations))
114
+
115
+ # # target to source translations
116
+ tgt_infname = os.path.join(pair, f"test.{tgt_lang}")
117
+ src_outfname = os.path.join(pair, f"test.{src_lang}.pred.google")
118
+ if os.path.exists(tgt_infname) and not os.path.exists(src_outfname):
119
+ tgt_sents = [
120
+ sent.replace("\n", "").strip()
121
+ for sent in open(tgt_infname, "r").read().split("\n")
122
+ if sent
123
+ ]
124
+ translations = [
125
+ translate_text(text, tgt_lang, src_lang).strip() for text in tqdm(tgt_sents)
126
+ ]
127
+
128
+ with open(src_outfname, "w") as f:
129
+ f.write("\n".join(translations))
IndicTrans2/baseline_eval/m2m100_inference.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import sys
4
+ from tqdm import tqdm
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
+
8
+
9
+ # dictionary mapping flores codes to M2M-100 supported codes
10
+ langs_supported = {
11
+ "eng_Latn": "en",
12
+ "ben_Beng": "bn",
13
+ "guj_Gujr": "gu",
14
+ "hin_Deva": "hi",
15
+ "kan_Knda": "kn",
16
+ "mal_Mlym": "ml",
17
+ "mar_Deva": "mr",
18
+ "npi_Deva": "ne",
19
+ "ory_Orya": "or",
20
+ "pan_Guru": "pa",
21
+ "snd_Arab": "sd",
22
+ "tam_Taml": "ta",
23
+ "urd_Arab": "ur",
24
+ }
25
+
26
+
27
+ def predict(batch, tokenizer, model, bos_token_id):
28
+ encoded_batch = tokenizer(batch, padding=True, return_tensors="pt").to(model.device)
29
+ generated_tokens = model.generate(
30
+ **encoded_batch,
31
+ num_beams=5,
32
+ max_length=256,
33
+ min_length=0,
34
+ forced_bos_token_id=bos_token_id,
35
+ )
36
+ hypothesis = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
37
+ return hypothesis
38
+
39
+
40
+ def main(devtest_data_dir, batch_size):
41
+ # load the pre-trained M2M-100 tokenizer and model
42
+ model_name = "facebook/m2m100-12B-last-ckpt"
43
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
44
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
45
+ model.eval()
46
+
47
+ # iterate over a list of language pairs from `devtest_data_dir`
48
+ for pair in sorted(os.listdir(devtest_data_dir)):
49
+ if "-" not in pair:
50
+ continue
51
+
52
+ src_lang, tgt_lang = pair.split("-")
53
+
54
+ # check if the source and target languages are supported
55
+ if (
56
+ src_lang not in langs_supported.keys()
57
+ or tgt_lang not in langs_supported.keys()
58
+ ):
59
+ print(f"Skipping {src_lang}-{tgt_lang} ...")
60
+ continue
61
+
62
+ # -------------------------------------------------------------------
63
+ # source to target evaluation
64
+ # -------------------------------------------------------------------
65
+ print(f"Evaluating {src_lang}-{tgt_lang} ...")
66
+
67
+ infname = os.path.join(devtest_data_dir, pair, f"test.{src_lang}")
68
+ outfname = os.path.join(devtest_data_dir, pair, f"test.{tgt_lang}.pred.m2m100")
69
+
70
+ with open(infname, "r") as f:
71
+ src_sents = f.read().split("\n")
72
+
73
+ add_new_line = False
74
+ if src_sents[-1] == "":
75
+ add_new_line = True
76
+ src_sents = src_sents[:-1]
77
+
78
+ # set the source language for tokenization
79
+ tokenizer.src_lang = langs_supported[src_lang]
80
+
81
+ # process sentences in batches and generate predictions
82
+ hypothesis = []
83
+ for i in tqdm(range(0, len(src_sents), batch_size)):
84
+ start, end = i, int(min(len(src_sents), i + batch_size))
85
+ batch = src_sents[start:end]
86
+ bos_token_id = tokenizer.lang_code_to_id[langs_supported[tgt_lang]]
87
+ hypothesis += predict(batch, tokenizer, model, bos_token_id)
88
+
89
+ assert len(hypothesis) == len(src_sents)
90
+
91
+ hypothesis = [
92
+ re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip()
93
+ for x in hypothesis
94
+ ]
95
+ if add_new_line:
96
+ hypothesis = hypothesis
97
+
98
+ with open(outfname, "w") as f:
99
+ f.write("\n".join(hypothesis))
100
+
101
+ # -------------------------------------------------------------------
102
+ # target to source evaluation
103
+ # -------------------------------------------------------------------
104
+ infname = os.path.join(devtest_data_dir, pair, f"test.{tgt_lang}")
105
+ outfname = os.path.join(devtest_data_dir, pair, f"test.{src_lang}.pred.m2m100")
106
+
107
+ with open(infname, "r") as f:
108
+ src_sents = f.read().split("\n")
109
+
110
+ add_new_line = False
111
+ if src_sents[-1] == "":
112
+ add_new_line = True
113
+ src_sents = src_sents[:-1]
114
+
115
+ # set the source language for tokenization
116
+ tokenizer.src_lang = langs_supported[tgt_lang]
117
+
118
+ # process sentences in batches and generate predictions
119
+ hypothesis = []
120
+ for i in tqdm(range(0, len(src_sents), batch_size)):
121
+ start, end = i, int(min(len(src_sents), i + batch_size))
122
+ batch = src_sents[start:end]
123
+ bos_token_id = tokenizer.lang_code_to_id[langs_supported[src_lang]]
124
+ hypothesis += predict(batch, tokenizer, model, bos_token_id)
125
+
126
+ assert len(hypothesis) == len(src_sents)
127
+
128
+ hypothesis = [
129
+ re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip()
130
+ for x in hypothesis
131
+ ]
132
+ if add_new_line:
133
+ hypothesis = hypothesis
134
+
135
+ with open(outfname, "w") as f:
136
+ f.write("\n".join(hypothesis))
137
+
138
+
139
+ if __name__ == "__main__":
140
+ # expects En-X subdirectories pairs within the devtest data directory
141
+ devtest_data_dir = sys.argv[1]
142
+ batch_size = int(sys.argv[2])
143
+
144
+ if not torch.cuda.is_available():
145
+ print("No GPU available")
146
+ sys.exit(1)
147
+
148
+ main(devtest_data_dir, batch_size)
IndicTrans2/baseline_eval/mbart_inference.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import sys
4
+ from tqdm import tqdm
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
+
8
+
9
+ # dictionary mapping flores codes to mBART supported codes
10
+ langs_supported = {
11
+ "eng_Latn": "en_XX",
12
+ "guj_Gujr": "gu_IN",
13
+ "hin_Deva": "hi_IN",
14
+ "npi_Deva": "ne_NP",
15
+ "ben_Beng": "bn_IN",
16
+ "mal_Mlym": "ml_IN",
17
+ "mar_Deva": "mr_IN",
18
+ "tam_Taml": "ta_IN",
19
+ "tel_Telu": "te_IN",
20
+ "urd_Arab": "ur_PK",
21
+ }
22
+
23
+
24
+ def predict(batch, tokenizer, model, bos_token_id):
25
+ encoded_batch = tokenizer(batch, padding=True, return_tensors="pt").to(model.device)
26
+ generated_tokens = model.generate(
27
+ **encoded_batch,
28
+ num_beams=5,
29
+ max_length=256,
30
+ min_length=0,
31
+ forced_bos_token_id=bos_token_id,
32
+ )
33
+ hypothesis = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
34
+ return hypothesis
35
+
36
+
37
+ def main(devtest_data_dir, batch_size):
38
+ # load the pre-trained mBART tokenizers and models for English-XX and XX-English translation
39
+ enxx_model_name = "facebook/mbart-large-50-one-to-many-mmt"
40
+ xxen_model_name = "facebook/mbart-large-50-many-to-one-mmt"
41
+ tokenizers = {
42
+ "enxx": AutoTokenizer.from_pretrained(enxx_model_name),
43
+ "xxen": AutoTokenizer.from_pretrained(xxen_model_name),
44
+ }
45
+ models = {
46
+ "enxx": AutoModelForSeq2SeqLM.from_pretrained(enxx_model_name).cuda(),
47
+ "xxen": AutoModelForSeq2SeqLM.from_pretrained(xxen_model_name).cuda(),
48
+ }
49
+
50
+ # set the models to evaluation mode
51
+ for model_name in models:
52
+ models[model_name].eval()
53
+
54
+ # iterate over a list of language pairs from `devtest_data_dir`
55
+ for pair in sorted(os.listdir(devtest_data_dir)):
56
+ if "-" not in pair:
57
+ continue
58
+
59
+ src_lang, tgt_lang = pair.split("-")
60
+
61
+ # check if the source and target languages are supported
62
+ if (
63
+ src_lang not in langs_supported.keys()
64
+ or tgt_lang not in langs_supported.keys()
65
+ ):
66
+ print(f"Skipping {src_lang}-{tgt_lang} ...")
67
+ continue
68
+
69
+ # -------------------------------------------------------------------
70
+ # source to target evaluation
71
+ # -------------------------------------------------------------------
72
+ print(f"Evaluating {src_lang}-{tgt_lang} ...")
73
+
74
+ infname = os.path.join(devtest_data_dir, pair, f"test.{src_lang}")
75
+ outfname = os.path.join(devtest_data_dir, pair, f"test.{tgt_lang}.pred.mbart50")
76
+
77
+ with open(infname, "r") as f:
78
+ src_sents = f.read().split("\n")
79
+
80
+ add_new_line = False
81
+ if src_sents[-1] == "":
82
+ add_new_line = True
83
+ src_sents = src_sents[:-1]
84
+
85
+ # set the source language for tokenization
86
+ tokenizers["enxx"].src_lang = langs_supported[src_lang]
87
+
88
+ # process sentences in batches and generate predictions
89
+ hypothesis = []
90
+ for i in tqdm(range(0, len(src_sents), batch_size)):
91
+ start, end = i, int(min(len(src_sents), i + batch_size))
92
+ batch = src_sents[start:end]
93
+ bos_token_id = tokenizers["enxx"].lang_code_to_id[langs_supported[tgt_lang]]
94
+ hypothesis += predict(
95
+ batch, tokenizers["enxx"], models["enxx"], bos_token_id
96
+ )
97
+
98
+ assert len(hypothesis) == len(src_sents)
99
+
100
+ hypothesis = [
101
+ re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip()
102
+ for x in hypothesis
103
+ ]
104
+ if add_new_line:
105
+ hypothesis = hypothesis
106
+
107
+ with open(outfname, "w") as f:
108
+ f.write("\n".join(hypothesis))
109
+
110
+ # -------------------------------------------------------------------
111
+ # target to source evaluation
112
+ # -------------------------------------------------------------------
113
+ infname = os.path.join(devtest_data_dir, pair, f"test.{tgt_lang}")
114
+ outfname = os.path.join(devtest_data_dir, pair, f"test.{src_lang}.pred.mbart50")
115
+
116
+ with open(infname, "r") as f:
117
+ src_sents = f.read().split("\n")
118
+
119
+ add_new_line = False
120
+ if src_sents[-1] == "":
121
+ add_new_line = True
122
+ src_sents = src_sents[:-1]
123
+
124
+ # set the source language for tokenization
125
+ tokenizers["xxen"].src_lang = langs_supported[tgt_lang]
126
+
127
+ # process sentences in batches and generate predictions
128
+ hypothesis = []
129
+ for i in tqdm(range(0, len(src_sents), batch_size)):
130
+ start, end = i, int(min(len(src_sents), i + batch_size))
131
+ batch = src_sents[start:end]
132
+ bos_token_id = tokenizers["xxen"].lang_code_to_id[langs_supported[src_lang]]
133
+ hypothesis += predict(
134
+ batch, tokenizers["xxen"], models["xxen"], bos_token_id
135
+ )
136
+
137
+ assert len(hypothesis) == len(src_sents)
138
+
139
+ hypothesis = [
140
+ re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip()
141
+ for x in hypothesis
142
+ ]
143
+ if add_new_line:
144
+ hypothesis = hypothesis
145
+
146
+ with open(outfname, "w") as f:
147
+ f.write("\n".join(hypothesis))
148
+
149
+
150
+ if __name__ == "__main__":
151
+ # expects En-X subdirectories pairs within the devtest data directory
152
+ devtest_data_dir = sys.argv[1]
153
+ batch_size = int(sys.argv[2])
154
+
155
+ if not torch.cuda.is_available():
156
+ print("No GPU available")
157
+ sys.exit(1)
158
+
159
+ main(devtest_data_dir, batch_size)
IndicTrans2/baseline_eval/nllb_moe_cpu_inference.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import sys
4
+ from tqdm import tqdm
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
+
8
+ langs_supported = [
9
+ "asm_Beng",
10
+ "ben_Beng",
11
+ "guj_Gujr",
12
+ "eng_Latn",
13
+ "hin_Deva",
14
+ "kas_Deva",
15
+ "kas_Arab",
16
+ "kan_Knda",
17
+ "mal_Mlym",
18
+ "mai_Deva",
19
+ "mar_Deva",
20
+ "mni_Beng",
21
+ "npi_Deva",
22
+ "ory_Orya",
23
+ "pan_Guru",
24
+ "san_Deva",
25
+ "snd_Arab",
26
+ "sat_Olck",
27
+ "tam_Taml",
28
+ "tel_Telu",
29
+ "urd_Arab",
30
+ ]
31
+
32
+
33
+ def predict(batch, tokenizer, model, bos_token_id):
34
+ encoded_batch = tokenizer(batch, padding=True, return_tensors="pt").to(model.device)
35
+ generated_tokens = model.generate(
36
+ **encoded_batch,
37
+ num_beams=5,
38
+ max_length=256,
39
+ min_length=0,
40
+ forced_bos_token_id=bos_token_id,
41
+ )
42
+ hypothesis = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
43
+ return hypothesis
44
+
45
+
46
+ def main(devtest_data_dir, batch_size):
47
+ # load the pre-trained NLLB tokenizer and model
48
+ model_name = "facebook/nllb-moe-54b"
49
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
50
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
51
+ model.eval()
52
+
53
+ # iterate over a list of language pairs from `devtest_data_dir`
54
+ for pair in sorted(os.listdir(devtest_data_dir)):
55
+ if "-" not in pair:
56
+ continue
57
+
58
+ src_lang, tgt_lang = pair.split("-")
59
+
60
+ # check if the source and target languages are supported
61
+ if (
62
+ src_lang not in langs_supported.keys()
63
+ or tgt_lang not in langs_supported.keys()
64
+ ):
65
+ print(f"Skipping {src_lang}-{tgt_lang} ...")
66
+ continue
67
+
68
+ # -------------------------------------------------------------------
69
+ # source to target evaluation
70
+ # -------------------------------------------------------------------
71
+ print(f"Evaluating {src_lang}-{tgt_lang} ...")
72
+
73
+ infname = os.path.join(devtest_data_dir, pair, f"test.{src_lang}")
74
+ outfname = os.path.join(
75
+ devtest_data_dir, pair, f"test.{tgt_lang}.pred.nllb_moe"
76
+ )
77
+
78
+ with open(infname, "r") as f:
79
+ src_sents = f.read().split("\n")
80
+
81
+ add_new_line = False
82
+ if src_sents[-1] == "":
83
+ add_new_line = True
84
+ src_sents = src_sents[:-1]
85
+
86
+ # set the source language for tokenization
87
+ tokenizer.src_lang = src_lang
88
+
89
+ # process sentences in batches and generate predictions
90
+ hypothesis = []
91
+ for i in tqdm(range(0, len(src_sents), batch_size)):
92
+ start, end = i, int(min(len(src_sents), i + batch_size))
93
+ batch = src_sents[start:end]
94
+ if tgt_lang == "sat_Olck":
95
+ bos_token_id = tokenizer.lang_code_to_id["sat_Beng"]
96
+ else:
97
+ bos_token_id = tokenizer.lang_code_to_id[tgt_lang]
98
+ hypothesis += predict(batch, tokenizer, model, bos_token_id)
99
+
100
+ assert len(hypothesis) == len(src_sents)
101
+
102
+ hypothesis = [
103
+ re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip()
104
+ for x in hypothesis
105
+ ]
106
+ if add_new_line:
107
+ hypothesis = hypothesis
108
+
109
+ with open(outfname, "w") as f:
110
+ f.write("\n".join(hypothesis))
111
+
112
+ # -------------------------------------------------------------------
113
+ # target to source evaluation
114
+ # -------------------------------------------------------------------
115
+ infname = os.path.join(devtest_data_dir, pair, f"test.{tgt_lang}")
116
+ outfname = os.path.join(
117
+ devtest_data_dir, pair, f"test.{src_lang}.pred.nllb_moe"
118
+ )
119
+
120
+ with open(infname, "r") as f:
121
+ src_sents = f.read().split("\n")
122
+
123
+ add_new_line = False
124
+ if src_sents[-1] == "":
125
+ add_new_line = True
126
+ src_sents = src_sents[:-1]
127
+
128
+ # set the source language for tokenization
129
+ tokenizer.src_lang = "sat_Beng" if tgt_lang == "sat_Olck" else tgt_lang
130
+
131
+ # process sentences in batches and generate predictions
132
+ hypothesis = []
133
+ for i in tqdm(range(0, len(src_sents), batch_size)):
134
+ start, end = i, int(min(len(src_sents), i + batch_size))
135
+ batch = src_sents[start:end]
136
+ bos_token_id = tokenizer.lang_code_to_id[langs_supported[src_lang]]
137
+ hypothesis += predict(batch, tokenizer, model, bos_token_id)
138
+
139
+ assert len(hypothesis) == len(src_sents)
140
+
141
+ hypothesis = [
142
+ re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip()
143
+ for x in hypothesis
144
+ ]
145
+ if add_new_line:
146
+ hypothesis = hypothesis
147
+
148
+ with open(outfname, "w") as f:
149
+ f.write("\n".join(hypothesis))
150
+
151
+
152
+ if __name__ == "__main__":
153
+ # expects En-X subdirectories pairs within the devtest data directory
154
+ devtest_data_dir = sys.argv[1]
155
+ batch_size = int(sys.argv[2])
156
+
157
+ main(devtest_data_dir, batch_size)
IndicTrans2/compute_comet_score.sh ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # This script computes COMET metrics and also performs significance testing on the evaluation set
4
+ # where each subdirectory contains En-X pair
5
+
6
+
7
+ echo `date`
8
+ devtest_data_dir=$1 # path to the evaluation directory
9
+ model_name=${2-"Unbabel/wmt22-comet-da"} # name of the model checkpoint
10
+
11
+ # predefined list of languages supported by COMET
12
+ langs=(asm_Beng ben_Beng guj_Gujr hin_Deva kan_Knda mal_Mlym mar_Deva ory_Orya pan_Guru tam_Taml tel_Telu urd_Arab)
13
+
14
+ # we predefine a set of systems which we consider for evaluation
15
+ # feel free to change the below line in case you want to add or remove any system
16
+ system=(google azure nllb mbart50 m2m100 it1 it2)
17
+
18
+
19
+ # iterate over the list of predefined languages
20
+ for lang in "${langs[@]}"; do
21
+
22
+ mkdir -p "$devtest_data_dir/eng_Latn-$lang/comet"
23
+
24
+ # --------------------------------------------------------------
25
+ # COMET score computation
26
+ # --------------------------------------------------------------
27
+
28
+ # iterate over the list of predefined systems
29
+ for sys in "${system[@]}"; do
30
+
31
+ echo "${sys}"
32
+
33
+ # en - indic direction
34
+ if [ -f "$devtest_data_dir/eng_Latn-$lang/test.$lang.pred.$sys" ]; then
35
+ echo "eng_Latn-${lang}"
36
+
37
+ src_fname=$devtest_data_dir/eng_Latn-$lang/test.eng_Latn
38
+ pred_fname=$devtest_data_dir/eng_Latn-$lang/test.$lang.pred.$sys
39
+ ref_fname=$devtest_data_dir/eng_Latn-$lang/test.$lang
40
+ out_fname=$devtest_data_dir/eng_Latn-$lang/comet/eng_Latn_${lang}_${sys}_comet.txt
41
+
42
+ # Compute COMET scores using the `comet-score`
43
+ comet-score -s $src_fname -t $pred_fname -r $ref_fname --gpus 1 --model $model_name --quiet --only_system > $out_fname
44
+ fi
45
+
46
+ # indic - en direction
47
+ if [ -f "$devtest_data_dir/eng_Latn-$lang/test.eng_Latn.pred.$sys" ]; then
48
+ echo "${lang}-eng_Latn"
49
+
50
+ src_fname=$devtest_data_dir/eng_Latn-$lang/test.$lang
51
+ pred_fname=$devtest_data_dir/eng_Latn-$lang/test.eng_Latn.pred.$sys
52
+ ref_fname=$devtest_data_dir/eng_Latn-$lang/test.eng_Latn
53
+ out_fname=$devtest_data_dir/eng_Latn-$lang/comet/${lang}_eng_Latn_${sys}_comet.txt
54
+
55
+ # Compute COMET scores using the `comet-score`
56
+ comet-score -s $src_fname -t $pred_fname -r $ref_fname --gpus 1 --model $model_name --quiet --only_system > $out_fname
57
+ fi
58
+
59
+ done
60
+
61
+ # --------------------------------------------------------------
62
+ # COMET significance testing
63
+ # --------------------------------------------------------------
64
+
65
+ # en - indic direction
66
+ src_fname=$devtest_data_dir/eng_Latn-$lang/test.eng_Latn
67
+ pred_fname=$devtest_data_dir/eng_Latn-$lang/test.$lang.pred.*
68
+ ref_fname=$devtest_data_dir/eng_Latn-$lang/test.$lang
69
+ out_fname=$devtest_data_dir/eng_Latn-$lang/comet/eng_Latn_${lang}_comet_stat.txt
70
+
71
+ # Compute COMET significance scores using the `comet-compare`
72
+ comet-compare -s $src_fname -t $pred_fname -r $ref_fname > $out_fname
73
+
74
+
75
+ # indic-en direction
76
+ src_fname=$devtest_data_dir/eng_Latn-$lang/test.$lang
77
+ pred_fname=$devtest_data_dir/eng_Latn-$lang/test.eng_Latn.pred.*
78
+ ref_fname=$devtest_data_dir/eng_Latn-$lang/test.eng_Latn
79
+ out_fname=$devtest_data_dir/eng_Latn-$lang/comet/${lang}_eng_Latn_comet_stat.txt
80
+
81
+ # Compute COMET significance scores using the `comet-compare`
82
+ comet-compare -s $src_fname -t $pred_fname -r $ref_fname > $out_fname
83
+
84
+ done
IndicTrans2/compute_metrics.sh ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # This script compute the evaluation metrics such as BLEU, chrF, chrF++ using the
4
+ # detokenized predictions of the translation systems using sacrebleu (version 2.3.1).
5
+ # If the target language is:
6
+ # English: directly use Moses tokenizer that is internally supported (`mteval-v13a`)
7
+ # Indic: use IndicNLP tokenizers and skip tokenization step in sacrebleu.
8
+
9
+
10
+ echo `date`
11
+ pred_fname=$1 # path to the predction file
12
+ ref_fname=$2 # path to the reference file
13
+ tgt_lang=$3 # target language
14
+
15
+
16
+ if [ $tgt_lang == 'eng_Latn' ]; then
17
+ # directly tokenize the prediction and reference files using sacrebleu and compute the metric
18
+ sacrebleu $ref_fname < $pred_fname -m bleu chrf
19
+ sacrebleu $ref_fname < $pred_fname -m chrf --chrf-word-order 2
20
+ else
21
+
22
+ # indicnlp tokenize prediction and reference files before evaluation
23
+ input_size=`python scripts/preprocess_translate.py $ref_fname $ref_fname.tok $tgt_lang false false`
24
+ input_size=`python scripts/preprocess_translate.py $pred_fname $pred_fname.tok $tgt_lang false false`
25
+
26
+ # since we are tokenizing with indicnlp separately, we are setting tokenize to none here
27
+ sacrebleu --tokenize none $ref_fname.tok < $pred_fname.tok -m bleu chrf
28
+ sacrebleu --tokenize none $ref_fname.tok < $pred_fname.tok -m chrf --chrf-word-order 2
29
+ fi
IndicTrans2/compute_metrics_significance.sh ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # This script performs significance testing for metrics such as BLEU, chrF++ using sacrebleu on the evaluation set
4
+ # where each subdirectory contains En-X pair
5
+
6
+
7
+ echo `date`
8
+ devtest_data_dir=$1 # path to the evaluation directory
9
+
10
+ # we predefine a set of systems which we consider for evaluation
11
+ # feel free to change the below line in case you want to add or remove any system
12
+ system=(google azure nllb mbart50 m2m100 it1 it2)
13
+
14
+
15
+ # get a list of language pairs in the `devtest_data_dir`
16
+ pairs=$(ls -d $devtest_data_dir/eng_Latn-* | sort)
17
+
18
+
19
+ # iterate over each language pair
20
+ for pair in ${pairs[@]}; do
21
+ # extract the source and target languages from the pair name
22
+ pair=$(basename $pair)
23
+ src_lang=$(echo "$pair" | cut -d "-" -f 1)
24
+ tgt_lang=$(echo "$pair" | cut -d "-" -f 2)
25
+
26
+ if [[ $src_lang == "eng_Latn" ]]; then
27
+
28
+ # ----------------------------------------------------------------------
29
+ # en - indic direction
30
+ # ----------------------------------------------------------------------
31
+ echo "${src_lang} - ${tgt_lang}"
32
+
33
+ # find all the prediction files for different systems and tokenize it using IndicNLP
34
+ pred_fnames=$devtest_data_dir/$pair/test.${tgt_lang}.pred.*
35
+ ref_fname=$devtest_data_dir/$pair/test.${tgt_lang}
36
+
37
+ for pred_fname in $(find . -type f -name $pred_fnames); do
38
+ input_size=`python scripts/preprocess_translate.py $pred_fname $pred_fname.tok $tgt_lang false false`
39
+ done
40
+
41
+ input_size=`python scripts/preprocess_translate.py $ref_fname $ref_fname.tok $tgt_lang false false`
42
+
43
+ ref_fname=$devtest_data_dir/$pair/test.${tgt_lang}.tok
44
+ it2_fname=$devtest_data_dir/$pair/test.${tgt_lang}.pred.it2.tok
45
+ sys_fnames=$devtest_data_dir/$pair/test.${tgt_lang}.pred.*.tok
46
+ bleu_out_fname=$devtest_data_dir/$pair/${src_lang}_${tgt_lang}_bleu_significance.txt
47
+ chrF_out_fname=$devtest_data_dir/$pair/${src_lang}_${tgt_lang}_chrF++_significance.txt
48
+
49
+ sacrebleu --tokenize none $ref_fname -i $it2_fname $sys_fnames --paired-bs -m bleu --format text > $bleu_out_fname
50
+ sacrebleu --tokenize none $it2_fname $sys_fnames --paired-bs -m chrf --chrf-word-order 2 --format text > $chrF_out_fname
51
+
52
+ # ----------------------------------------------------------------------
53
+ # indic - en direction
54
+ # ----------------------------------------------------------------------
55
+ echo "${tgt_lang} - ${src_lang}"
56
+
57
+ ref_fname=$devtest_data_dir/$pair/test.${src_lang}
58
+ it2_fname=$devtest_data_dir/$pair/test.${src_lang}.pred.it2
59
+ sys_fnames=$devtest_data_dir/$pair/test.${src_lang}.pred.*
60
+ bleu_out_fname=$devtest_data_dir/$pair/${tgt_lang}_${src_lang}_bleu_significance.txt
61
+ chrF_out_fname=$devtest_data_dir/$pair/${tgt_lang}_${src_lang}_chrF++_significance.txt
62
+
63
+ sacrebleu --tokenize none $ref_fname -i $it2_fname $sys_fnames --paired-bs -m bleu --format text > $bleu_out_fname
64
+ sacrebleu --tokenize none $it2_fname $sys_fnames --paired-bs -m chrf --chrf-word-order 2 --format text > $chrF_out_fname
65
+
66
+ fi
IndicTrans2/eval.sh ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # This script evaluates the performance of a machine translation system
4
+ # on a evaluation set in forward direction. For example, if the evaluation set
5
+ # consists of language pairs, such as En-X, where En represents the English language
6
+ # and X represents the target Indic language then this script accesses the translation
7
+ # system from the English language (En) to the target Indic language (X) direction.
8
+
9
+
10
+ echo `date`
11
+ devtest_data_dir=$1 # path to the evaluation directory
12
+ ckpt_dir=$2 # path to the checkpoint directory
13
+ system=${3:-"it2"} # name of the machine translation system
14
+
15
+
16
+ # get a list of language pairs in the `devtest_data_dir`
17
+ pairs=$(ls -d $devtest_data_dir/* | sort)
18
+
19
+
20
+ # iterate over each language pair
21
+ for pair in ${pairs[@]}; do
22
+ # extract the source and target languages from the pair name
23
+ pair=$(basename $pair)
24
+ src_lang=$(echo "$pair" | cut -d "-" -f 1)
25
+ tgt_lang=$(echo "$pair" | cut -d "-" -f 2)
26
+
27
+ src_fname=$devtest_data_dir/$src_lang-$tgt_lang/test.$src_lang
28
+ tgt_fname=$devtest_data_dir/$src_lang-$tgt_lang/test.$tgt_lang
29
+
30
+ # check if the source and target files exists
31
+ if [ -f "$src_fname" ] && [ -f "$tgt_fname" ]; then
32
+ echo "Evaluating $src_lang-$tgt_lang ..."
33
+ else
34
+ echo "Skipping $src_lang-$tgt_lang ..."
35
+ continue
36
+ fi
37
+
38
+ # generate translations if the system name contains "it2"
39
+ if [[ $system == *"it2"* ]]; then
40
+ echo "Generating Translations"
41
+ bash joint_translate.sh $src_fname $tgt_fname.pred.$system $src_lang $tgt_lang $ckpt_dir
42
+ fi
43
+
44
+ # compute automatic string-based metrics if the prediction exists for the system
45
+ if [[ -f "${tgt_fname}.pred.${system}" ]]; then
46
+ echo "Computing Metrics"
47
+ bash compute_metrics.sh $tgt_fname.pred.$system $tgt_fname $tgt_lang > $devtest_data_dir/$src_lang-$tgt_lang/${src_lang}_${tgt_lang}_${system}_scores.txt
48
+ fi
49
+
50
+ # remove the intermediate files
51
+ rm -rf $tgt_fname.pred.$system.*
52
+ rm -rf $devtest_data_dir/$src_lang-$tgt_lang/*.tok
53
+
54
+ done
IndicTrans2/eval_rev.sh ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # This script evaluates the performance of a machine translation system
4
+ # on a evaluation set in forward direction. For example, if the evaluation set
5
+ # consists of language pairs, such as En-X, where En represents the English language
6
+ # and X represents the target Indic language then this script accesses the translation
7
+ # system from the target Indic language (X) to the English language (En) direction.
8
+
9
+
10
+ echo `date`
11
+ devtest_data_dir=$1 # path to the evaluation directory
12
+ ckpt_dir=$2 # path to the checkpoint directory
13
+ system=${3:-"it2"} # name of the machine translation system
14
+
15
+
16
+ # get a list of language pairs in the `devtest_data_dir`
17
+ pairs=$(ls -d $devtest_data_dir/* | sort)
18
+
19
+
20
+ # iterate over each language pair
21
+ for pair in ${pairs[@]}; do
22
+ # extract the source and target languages from the pair name
23
+ pair=$(basename $pair)
24
+ src_lang=$(echo "$pair" | cut -d "-" -f 1)
25
+ tgt_lang=$(echo "$pair" | cut -d "-" -f 2)
26
+
27
+ src_fname=$devtest_data_dir/$src_lang-$tgt_lang/test.$tgt_lang
28
+ tgt_fname=$devtest_data_dir/$src_lang-$tgt_lang/test.$src_lang
29
+
30
+ # check if the source and target files exists
31
+ # in this case, we flip the actual target file as source and vice-versa
32
+ if [ -f "$src_fname" ] && [ -f "$tgt_fname" ]; then
33
+ echo "Evaluating $src_lang-$tgt_lang ..."
34
+ else
35
+ echo "Skipping $src_lang-$tgt_lang ..."
36
+ continue
37
+ fi
38
+
39
+ # generate translations if the system name contains "it2"
40
+ if [[ $system == *"it2"* ]]; then
41
+ echo "Generating Translations"
42
+ bash joint_translate.sh $src_fname $tgt_fname.pred.$system $tgt_lang $src_lang $ckpt_dir
43
+ fi
44
+
45
+ # compute automatic string-based metrics if the prediction exists for the system
46
+ if [[ -f "${tgt_fname}.pred.${system}" ]]; then
47
+ echo "Computing Metrics"
48
+ bash compute_metrics.sh $tgt_fname.pred.$system $tgt_fname $src_lang > $devtest_data_dir/$src_lang-$tgt_lang/${tgt_lang}_${src_lang}_${system}_scores.txt
49
+ fi
50
+
51
+ # remove the intermediate files
52
+ rm -rf $tgt_fname.pred.$system.*
53
+ rm -rf $devtest_data_dir/$src_lang-$tgt_lang/*.tok
54
+
55
+ done
IndicTrans2/finetune.sh ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # This script finetunes the pretrained translation model on the binarized data using fairseq.
4
+
5
+
6
+ echo `date`
7
+ exp_dir=$1 # path of the experiment directory
8
+ model_arch=${2:-"transformer_18_18"} # model architecture (defaults to `transformer_18_18`)
9
+ pretrained_ckpt=$3 # path to the pretrained checkpoint `.pt` file
10
+
11
+
12
+ fairseq-train $exp_dir/final_bin \
13
+ --max-source-positions=256 \
14
+ --max-target-positions=256 \
15
+ --source-lang=SRC \
16
+ --target-lang=TGT \
17
+ --max-update=1000000 \
18
+ --save-interval-updates=1000 \
19
+ --arch=$model_arch \
20
+ --activation-fn gelu \
21
+ --criterion=label_smoothed_cross_entropy \
22
+ --label-smoothing=0.1 \
23
+ --optimizer adam \
24
+ --adam-betas "(0.9, 0.98)" \
25
+ --lr-scheduler=inverse_sqrt \
26
+ --clip-norm 1.0 \
27
+ --warmup-init-lr 1e-07 \
28
+ --lr 3e-5 \
29
+ --warmup-updates 2000 \
30
+ --dropout 0.2 \
31
+ --save-dir $exp_dir/model \
32
+ --keep-last-epochs 5 \
33
+ --keep-interval-updates 3 \
34
+ --patience 10 \
35
+ --skip-invalid-size-inputs-valid-test \
36
+ --fp16 \
37
+ --user-dir model_configs \
38
+ --update-freq=4 \
39
+ --distributed-world-size 8 \
40
+ --num-workers 24 \
41
+ --max-tokens 1024 \
42
+ --eval-bleu \
43
+ --eval-bleu-args "{\"beam\": 1, \"lenpen\": 1.0, \"max_len_a\": 1.2, \"max_len_b\": 10}" \
44
+ --eval-bleu-detok moses \
45
+ --eval-bleu-remove-bpe sentencepiece \
46
+ --eval-bleu-print-samples \
47
+ --best-checkpoint-metric bleu \
48
+ --maximize-best-checkpoint-metric \
49
+ --restore-file $pretrained_ckpt \
50
+ --reset-lr-scheduler \
51
+ --reset-meters \
52
+ --reset-dataloader \
53
+ --reset-optimizer \
54
+ --task translation
IndicTrans2/huggingface_interface/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ IndicTransTokenizer
IndicTrans2/huggingface_interface/README.md ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # IndicTrans2 HF Compatible Models
2
+
3
+ [![colab link](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI4Bharat/IndicTrans2/blob/main/huggingface_interface/colab_inference.ipynb)
4
+
5
+ In this section, we provide details on how to use our [IndicTrans2](https://github.com/AI4Bharat/IndicTrans2) models which were originally trained with the [fairseq](https://github.com/facebookresearch/fairseq) to [HuggingFace transformers](https://huggingface.co/docs/transformers/index) for inference purpose. Our scripts for HuggingFace compatible models are adapted from [M2M100 repository](https://github.com/huggingface/transformers/tree/main/src/transformers/models/m2m_100).
6
+
7
+ > Note: We have migrated IndicTrans2 tokenizer for HF compatible IndicTrans2 models to [IndicTransToolkit](https://github.com/VarunGumma/IndicTransToolkit) and will be maintained separately there from now onwards. This is automatically installed when you call `install.sh` script in `huggingface_interface`.
8
+
9
+ ### Setup
10
+
11
+ To get started, follow these steps to set up the environment:
12
+
13
+ ```
14
+ # Clone the github repository and navigate to the project directory.
15
+ git clone https://github.com/AI4Bharat/IndicTrans2
16
+ cd IndicTrans2/huggingface_interface
17
+
18
+ # Install all the dependencies and requirements associated with the project for running HF compatible models.
19
+ source install.sh
20
+ ```
21
+
22
+ > Note: The `install.sh` script in this directory is specifically for running HF compatible models for inference.
23
+
24
+ ### Converting
25
+
26
+ In order to convert the fairseq checkpoint to a PyTorch checkpoint that is compatible with HuggingFace Transformers, use the following command:
27
+
28
+ ```bash
29
+ python3 convert_indictrans_checkpoint_to_pytorch.py --fairseq_path <fairseq_checkpoint_best.pt> --pytorch_dump_folder_path <hf_output_dir>
30
+ ```
31
+
32
+ - `<fairseq_checkpoint_best.pt>`: path to the fairseq `checkpoint_best.pt` that needs to be converted to HF compatible models
33
+ - `<hf_output_dir>`: path to the output directory where the HF compatible models will be saved
34
+
35
+ ### Models
36
+
37
+ | Model | 🤗 HuggingFace Checkpoints |
38
+ | -------------------------------- | ----------------------------------------------------------------------------------------------------------------- |
39
+ | En-Indic | [ai4bharat/indictrans2-en-indic-1B](https://huggingface.co/ai4bharat/indictrans2-en-indic-1B) |
40
+ | Indic-En | [ai4bharat/indictrans2-indic-en-1B](https://huggingface.co/ai4bharat/indictrans2-indic-en-1B) |
41
+ | Distilled En-Indic | [ai4bharat/indictrans2-en-indic-dist-200M](https://huggingface.co/ai4bharat/indictrans2-en-indic-dist-200M) |
42
+ | Distilled Indic-En | [ai4bharat/indictrans2-indic-en-dist-200M](https://huggingface.co/ai4bharat/indictrans2-indic-en-dist-200M) |
43
+ | Indic-Indic (Stitched) | [ai4bharat/indictrans2-indic-indic-1B](https://huggingface.co/ai4bharat/indictrans2-indic-indic-1B) |
44
+ | Distilled Indic-Indic (Stitched) | [ai4bharat/indictrans2-indic-indic-dist-320M](https://huggingface.co/ai4bharat/indictrans2-indic-indic-dist-320M) |
45
+
46
+ ### Inference
47
+
48
+ With the conversion complete, you can now perform inference using the HuggingFace Transformers.
49
+
50
+ You can start with the provided `example.py` script and customize it for your specific translation use case:
51
+
52
+ ```bash
53
+ python3 example.py
54
+ ```
55
+
56
+ Feel free to modify the `example.py` script to suit your translation needs.
57
+
58
+ ### Fine-tuning with LoRA
59
+
60
+ Before starting with fine-tuning IndicTrans2 models, you will need to restructure the training data in the following format.
61
+
62
+ ```
63
+ en-indic-exp
64
+ ├── train
65
+ │ ├── eng_Latn-asm_Beng
66
+ │ │ ├── train.eng_Latn
67
+ │ │ └── train.asm_Beng
68
+ │ ├── eng_Latn-ben_Beng
69
+ │ │ └── ...
70
+ │ └── {src_lang}-{tgt_lang}
71
+ │ ├── train.{src_lang}
72
+ │ └── train.{tgt_lang}
73
+ └── dev
74
+ ├── eng_Latn-asm_Beng
75
+ │ ├── dev.eng_Latn
76
+ │ └── dev.asm_Beng
77
+ ├── eng_Latn-ben_Beng
78
+ │ └── ...
79
+ └── {src_lang}-{tgt_lang}
80
+ ├── dev.{src_lang}
81
+ └── dev.{tgt_lang}
82
+ ```
83
+
84
+ Once you have data ready in above specified format, use the following command.
85
+
86
+ ```bash
87
+ bash train_lora.sh <data_dir> <model_name> <output_dir> <direction> <src_lang_list> <tgt_lang_list>
88
+ ```
89
+
90
+ We recommend you to refer to `train_lora.sh` for defaults arguments for fine-tuning. Please note that the specified hyperparameters may not be optimal and might require tuning for your use case.
91
+
92
+ ### Inference with LoRA
93
+
94
+ You can load the LoRA adapters with the base model for inference by modifying the model initialization in `example.py` script.
95
+
96
+ ```python
97
+ from transformers import AutoModelForSeq2SeqLM
98
+ from peft import PeftConfig, PeftModel
99
+
100
+ base_ckpt_dir = "ai4bharat/indictrans2-en-indic-1B" # you will need to change as per your use case
101
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(base_ckpt_dir, trust_remote_code=True)
102
+ lora_model = PeftModel.from_pretrained(base_model, lora_ckpt_dir)
103
+ ```
104
+
105
+ > Note: Please feel free to open issues on the GitHub repo in case of any queries/issues.
106
+
107
+ ### Citation
108
+
109
+ ```bibtex
110
+ @article{gala2023indictrans,
111
+ title={IndicTrans2: Towards High-Quality and Accessible Machine Translation Models for all 22 Scheduled Indian Languages},
112
+ author={Jay Gala and Pranjal A Chitale and A K Raghavan and Varun Gumma and Sumanth Doddapaneni and Aswanth Kumar M and Janki Atul Nawale and Anupama Sujatha and Ratish Puduppully and Vivek Raghavan and Pratyush Kumar and Mitesh M Khapra and Raj Dabre and Anoop Kunchukuttan},
113
+ journal={Transactions on Machine Learning Research},
114
+ issn={2835-8856},
115
+ year={2023},
116
+ url={https://openreview.net/forum?id=vfT4YuzAYA},
117
+ note={}
118
+ }
119
+ ```
IndicTrans2/huggingface_interface/colab_inference.ipynb ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "8Aa-nRCzPVdF"
7
+ },
8
+ "source": [
9
+ "# IndicTrans2 HF Inference\n",
10
+ "\n",
11
+ "We provide an example notebook on how to use our IndicTrans2 models which were originally trained with the fairseq to HuggingFace transformers for inference purpose.\n"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "markdown",
16
+ "metadata": {
17
+ "id": "Cfsv02IeP2It"
18
+ },
19
+ "source": [
20
+ "## Setup\n",
21
+ "\n",
22
+ "Please run the cells below to install the necessary dependencies.\n"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "metadata": {
29
+ "id": "qKcYlUZYGLrt"
30
+ },
31
+ "outputs": [],
32
+ "source": [
33
+ "%%capture\n",
34
+ "!git clone https://github.com/AI4Bharat/IndicTrans2.git"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "metadata": {
41
+ "id": "U3vs7FkIGSxK"
42
+ },
43
+ "outputs": [],
44
+ "source": [
45
+ "%%capture\n",
46
+ "%cd /content/IndicTrans2/huggingface_interface"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": null,
52
+ "metadata": {
53
+ "id": "ddkRAXQ2Git0"
54
+ },
55
+ "outputs": [],
56
+ "source": [
57
+ "%%capture\n",
58
+ "!python3 -m pip install nltk sacremoses pandas regex mock transformers>=4.33.2 mosestokenizer\n",
59
+ "!python3 -c \"import nltk; nltk.download('punkt')\"\n",
60
+ "!python3 -m pip install bitsandbytes scipy accelerate datasets\n",
61
+ "!python3 -m pip install sentencepiece\n",
62
+ "\n",
63
+ "!git clone https://github.com/VarunGumma/IndicTransToolkit.git\n",
64
+ "%cd IndicTransToolkit\n",
65
+ "!python3 -m pip install --editable ./\n",
66
+ "%cd .."
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "markdown",
71
+ "metadata": {
72
+ "id": "hjN7ub1tO33H"
73
+ },
74
+ "source": [
75
+ "**IMPORTANT : Restart your run-time first and then run the cells below.**"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "markdown",
80
+ "metadata": {
81
+ "id": "_SLBIw6rQB-0"
82
+ },
83
+ "source": [
84
+ "## Inference\n"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": null,
90
+ "metadata": {
91
+ "id": "fYczM2U6G1Zv"
92
+ },
93
+ "outputs": [],
94
+ "source": [
95
+ "import torch\n",
96
+ "from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig, AutoTokenizer\n",
97
+ "from IndicTransToolkit import IndicProcessor\n",
98
+ "\n",
99
+ "BATCH_SIZE = 4\n",
100
+ "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
101
+ "quantization = None"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": null,
107
+ "metadata": {
108
+ "id": "xj1WCNjuHG-d"
109
+ },
110
+ "outputs": [],
111
+ "source": [
112
+ "def initialize_model_and_tokenizer(ckpt_dir, quantization):\n",
113
+ " if quantization == \"4-bit\":\n",
114
+ " qconfig = BitsAndBytesConfig(\n",
115
+ " load_in_4bit=True,\n",
116
+ " bnb_4bit_use_double_quant=True,\n",
117
+ " bnb_4bit_compute_dtype=torch.bfloat16,\n",
118
+ " )\n",
119
+ " elif quantization == \"8-bit\":\n",
120
+ " qconfig = BitsAndBytesConfig(\n",
121
+ " load_in_8bit=True,\n",
122
+ " bnb_8bit_use_double_quant=True,\n",
123
+ " bnb_8bit_compute_dtype=torch.bfloat16,\n",
124
+ " )\n",
125
+ " else:\n",
126
+ " qconfig = None\n",
127
+ "\n",
128
+ " tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, trust_remote_code=True)\n",
129
+ " model = AutoModelForSeq2SeqLM.from_pretrained(\n",
130
+ " ckpt_dir,\n",
131
+ " trust_remote_code=True,\n",
132
+ " low_cpu_mem_usage=True,\n",
133
+ " quantization_config=qconfig,\n",
134
+ " )\n",
135
+ "\n",
136
+ " if qconfig == None:\n",
137
+ " model = model.to(DEVICE)\n",
138
+ " if DEVICE == \"cuda\":\n",
139
+ " model.half()\n",
140
+ "\n",
141
+ " model.eval()\n",
142
+ "\n",
143
+ " return tokenizer, model\n",
144
+ "\n",
145
+ "\n",
146
+ "def batch_translate(input_sentences, src_lang, tgt_lang, model, tokenizer, ip):\n",
147
+ " translations = []\n",
148
+ " for i in range(0, len(input_sentences), BATCH_SIZE):\n",
149
+ " batch = input_sentences[i : i + BATCH_SIZE]\n",
150
+ "\n",
151
+ " # Preprocess the batch and extract entity mappings\n",
152
+ " batch = ip.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang)\n",
153
+ "\n",
154
+ " # Tokenize the batch and generate input encodings\n",
155
+ " inputs = tokenizer(\n",
156
+ " batch,\n",
157
+ " truncation=True,\n",
158
+ " padding=\"longest\",\n",
159
+ " return_tensors=\"pt\",\n",
160
+ " return_attention_mask=True,\n",
161
+ " ).to(DEVICE)\n",
162
+ "\n",
163
+ " # Generate translations using the model\n",
164
+ " with torch.no_grad():\n",
165
+ " generated_tokens = model.generate(\n",
166
+ " **inputs,\n",
167
+ " use_cache=True,\n",
168
+ " min_length=0,\n",
169
+ " max_length=256,\n",
170
+ " num_beams=5,\n",
171
+ " num_return_sequences=1,\n",
172
+ " )\n",
173
+ "\n",
174
+ " # Decode the generated tokens into text\n",
175
+ "\n",
176
+ " with tokenizer.as_target_tokenizer():\n",
177
+ " generated_tokens = tokenizer.batch_decode(\n",
178
+ " generated_tokens.detach().cpu().tolist(),\n",
179
+ " skip_special_tokens=True,\n",
180
+ " clean_up_tokenization_spaces=True,\n",
181
+ " )\n",
182
+ "\n",
183
+ " # Postprocess the translations, including entity replacement\n",
184
+ " translations += ip.postprocess_batch(generated_tokens, lang=tgt_lang)\n",
185
+ "\n",
186
+ " del inputs\n",
187
+ " torch.cuda.empty_cache()\n",
188
+ "\n",
189
+ " return translations"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "markdown",
194
+ "metadata": {
195
+ "id": "erNCuZTEMt49"
196
+ },
197
+ "source": [
198
+ "### English to Indic Example\n"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": null,
204
+ "metadata": {
205
+ "colab": {
206
+ "base_uri": "https://localhost:8080/"
207
+ },
208
+ "id": "6OG3Bw-sHnf3",
209
+ "outputId": "a204f50e-9456-4fb1-900a-e60680b97b99"
210
+ },
211
+ "outputs": [
212
+ {
213
+ "name": "stdout",
214
+ "output_type": "stream",
215
+ "text": [
216
+ "\n",
217
+ "eng_Latn - hin_Deva\n",
218
+ "eng_Latn: When I was young, I used to go to the park every day.\n",
219
+ "hin_Deva: जब मैं छोटा था, मैं हर दिन पार्क जाता था।\n",
220
+ "eng_Latn: He has many old books, which he inherited from his ancestors.\n",
221
+ "hin_Deva: उनके पास कई पुरानी किताबें हैं, जो उन्हें अपने पूर्वजों से विरासत में मिली हैं।\n",
222
+ "eng_Latn: I can't figure out how to solve my problem.\n",
223
+ "hin_Deva: मुझे समझ नहीं आ रहा है कि मैं अपनी समस्या का समाधान कैसे करूं।\n",
224
+ "eng_Latn: She is very hardworking and intelligent, which is why she got all the good marks.\n",
225
+ "hin_Deva: वह बहुत मेहनती और बुद्धिमान है, यही कारण है कि उसे सभी अच्छे अंक मिले।\n",
226
+ "eng_Latn: We watched a new movie last week, which was very inspiring.\n",
227
+ "hin_Deva: हमने पिछले हफ्ते एक नई फिल्म देखी, जो बहुत प्रेरणादायक थी।\n",
228
+ "eng_Latn: If you had met me at that time, we would have gone out to eat.\n",
229
+ "hin_Deva: अगर आप उस समय मुझसे मिलते तो हम बाहर खाना खाने जाते।\n",
230
+ "eng_Latn: She went to the market with her sister to buy a new sari.\n",
231
+ "hin_Deva: वह अपनी बहन के साथ नई साड़ी खरीदने के लिए बाजार गई थी।\n",
232
+ "eng_Latn: Raj told me that he is going to his grandmother's house next month.\n",
233
+ "hin_Deva: राज ने मुझे बताया कि वह अगले महीने अपनी दादी के घर जा रहा है।\n",
234
+ "eng_Latn: All the kids were having fun at the party and were eating lots of sweets.\n",
235
+ "hin_Deva: पार्टी में सभी बच्चे खूब मस्ती कर रहे थे और खूब मिठाइयां खा रहे थे।\n",
236
+ "eng_Latn: My friend has invited me to his birthday party, and I will give him a gift.\n",
237
+ "hin_Deva: मेरे दोस्त ने मुझे अपने जन्मदिन की पार्टी में आमंत्रित किया है, और मैं उसे एक उपहार दूंगा।\n"
238
+ ]
239
+ }
240
+ ],
241
+ "source": [
242
+ "en_indic_ckpt_dir = \"ai4bharat/indictrans2-en-indic-1B\" # ai4bharat/indictrans2-en-indic-dist-200M\n",
243
+ "en_indic_tokenizer, en_indic_model = initialize_model_and_tokenizer(en_indic_ckpt_dir, quantization)\n",
244
+ "\n",
245
+ "ip = IndicProcessor(inference=True)\n",
246
+ "\n",
247
+ "en_sents = [\n",
248
+ " \"When I was young, I used to go to the park every day.\",\n",
249
+ " \"He has many old books, which he inherited from his ancestors.\",\n",
250
+ " \"I can't figure out how to solve my problem.\",\n",
251
+ " \"She is very hardworking and intelligent, which is why she got all the good marks.\",\n",
252
+ " \"We watched a new movie last week, which was very inspiring.\",\n",
253
+ " \"If you had met me at that time, we would have gone out to eat.\",\n",
254
+ " \"She went to the market with her sister to buy a new sari.\",\n",
255
+ " \"Raj told me that he is going to his grandmother's house next month.\",\n",
256
+ " \"All the kids were having fun at the party and were eating lots of sweets.\",\n",
257
+ " \"My friend has invited me to his birthday party, and I will give him a gift.\",\n",
258
+ "]\n",
259
+ "\n",
260
+ "src_lang, tgt_lang = \"eng_Latn\", \"hin_Deva\"\n",
261
+ "hi_translations = batch_translate(en_sents, src_lang, tgt_lang, en_indic_model, en_indic_tokenizer, ip)\n",
262
+ "\n",
263
+ "print(f\"\\n{src_lang} - {tgt_lang}\")\n",
264
+ "for input_sentence, translation in zip(en_sents, hi_translations):\n",
265
+ " print(f\"{src_lang}: {input_sentence}\")\n",
266
+ " print(f\"{tgt_lang}: {translation}\")\n",
267
+ "\n",
268
+ "# flush the models to free the GPU memory\n",
269
+ "del en_indic_tokenizer, en_indic_model"
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "markdown",
274
+ "metadata": {
275
+ "id": "OM_1pbPtMpV9"
276
+ },
277
+ "source": [
278
+ "### Indic to English Example"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": null,
284
+ "metadata": {
285
+ "colab": {
286
+ "base_uri": "https://localhost:8080/"
287
+ },
288
+ "id": "PLCEWJKvGG9I",
289
+ "outputId": "ab9d8726-67c7-490b-ecb3-208df1c0f741"
290
+ },
291
+ "outputs": [
292
+ {
293
+ "name": "stdout",
294
+ "output_type": "stream",
295
+ "text": [
296
+ "\n",
297
+ "hin_Deva - eng_Latn\n",
298
+ "hin_Deva: जब मैं छोटा था, मैं हर रोज़ पार्क जाता था।\n",
299
+ "eng_Latn: When I was young, I used to go to the park every day.\n",
300
+ "hin_Deva: उसके पास बहुत सारी पुरानी किताबें हैं, जिन्हें उसने अपने दादा-परदादा से विरासत में पाया।\n",
301
+ "eng_Latn: She has a lot of old books, which she inherited from her grandparents.\n",
302
+ "hin_Deva: मुझे समझ में नहीं आ रहा कि मैं अपनी समस्या का समाधान कैसे ढूंढूं।\n",
303
+ "eng_Latn: I don't know how to find a solution to my problem.\n",
304
+ "hin_Deva: वह बहुत मेहनती और समझदार है, इसलिए उसे सभी अच्छे मार्क्स मिले।\n",
305
+ "eng_Latn: He is very hardworking and understanding, so he got all the good marks.\n",
306
+ "hin_Deva: हमने पिछले सप्ताह एक नई फिल्म देखी जो कि बहुत प्रेरणादायक थी।\n",
307
+ "eng_Latn: We saw a new movie last week that was very inspiring.\n",
308
+ "hin_Deva: अगर तुम मुझे उस समय पास मिलते, तो हम बाहर खाना खाने चलते।\n",
309
+ "eng_Latn: If you'd given me a pass at that time, we'd have gone out to eat.\n",
310
+ "hin_Deva: वह अपनी दीदी के साथ बाजार गयी थी ताकि वह नई साड़ी खरीद सके।\n",
311
+ "eng_Latn: She had gone to the market with her sister so that she could buy a new sari.\n",
312
+ "hin_Deva: राज ने मुझसे कहा कि वह अगले महीने अपनी नानी के घर जा रहा है।\n",
313
+ "eng_Latn: Raj told me that he was going to his grandmother's house next month.\n",
314
+ "hin_Deva: सभी बच्चे पार्टी में मज़ा कर रहे थे और खूब सारी मिठाइयाँ खा रहे थे।\n",
315
+ "eng_Latn: All the children were having fun at the party and eating a lot of sweets.\n",
316
+ "hin_Deva: मेरे मित्र ने मुझे उसके जन्मदिन की पार्टी में बुलाया है, और मैं उसे एक तोहफा दूंगा।\n",
317
+ "eng_Latn: My friend has invited me to her birthday party, and I'll give her a present.\n"
318
+ ]
319
+ }
320
+ ],
321
+ "source": [
322
+ "indic_en_ckpt_dir = \"ai4bharat/indictrans2-indic-en-1B\" # ai4bharat/indictrans2-indic-en-dist-200M\n",
323
+ "indic_en_tokenizer, indic_en_model = initialize_model_and_tokenizer(indic_en_ckpt_dir, quantization)\n",
324
+ "\n",
325
+ "ip = IndicProcessor(inference=True)\n",
326
+ "\n",
327
+ "hi_sents = [\n",
328
+ " \"जब मैं छोटा था, मैं हर रोज़ पार्क जाता था।\",\n",
329
+ " \"उसके पास बहुत सारी पुरानी किताबें हैं, जिन्हें उसने अपने दादा-परदादा से विरासत में पाया।\",\n",
330
+ " \"मुझे समझ में नहीं आ रहा कि मैं अपनी समस्या का समाधान कैसे ढूंढूं।\",\n",
331
+ " \"वह बहुत मेहनती और समझदार है, इसलिए उसे सभी अच्छे मार्क्स मिले।\",\n",
332
+ " \"हमने पिछले सप्ताह एक नई फिल्म देखी जो कि बहुत प्रेरणादायक थी।\",\n",
333
+ " \"अगर तुम मुझे उस समय पास मिलते, तो हम बाहर खाना खाने चलते।\",\n",
334
+ " \"वह अपनी दीदी के साथ बाजार गयी थी ताकि वह नई साड़ी खरीद सके।\",\n",
335
+ " \"राज ने मुझसे कहा कि वह अगले महीने अपनी नानी के घर जा रहा है।\",\n",
336
+ " \"सभी बच्चे पार्टी में मज़ा कर रहे थे और खूब सारी मिठाइयाँ खा रहे थे।\",\n",
337
+ " \"मेरे मित्र ने मुझे उसके जन्मदिन की पार्टी में बुलाया है, और मैं उसे एक तोहफा दूंगा।\",\n",
338
+ "]\n",
339
+ "src_lang, tgt_lang = \"hin_Deva\", \"eng_Latn\"\n",
340
+ "en_translations = batch_translate(hi_sents, src_lang, tgt_lang, indic_en_model, indic_en_tokenizer, ip)\n",
341
+ "\n",
342
+ "\n",
343
+ "print(f\"\\n{src_lang} - {tgt_lang}\")\n",
344
+ "for input_sentence, translation in zip(hi_sents, en_translations):\n",
345
+ " print(f\"{src_lang}: {input_sentence}\")\n",
346
+ " print(f\"{tgt_lang}: {translation}\")\n",
347
+ "\n",
348
+ "# flush the models to free the GPU memory\n",
349
+ "del indic_en_tokenizer, indic_en_model"
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "markdown",
354
+ "metadata": {
355
+ "id": "7VCAkyKBGtnV"
356
+ },
357
+ "source": [
358
+ "### Indic to Indic Example\n"
359
+ ]
360
+ },
361
+ {
362
+ "cell_type": "code",
363
+ "execution_count": null,
364
+ "metadata": {
365
+ "colab": {
366
+ "base_uri": "https://localhost:8080/"
367
+ },
368
+ "id": "_7TxTTCoKjti",
369
+ "outputId": "df1a750b-0f32-478d-cfc9-e445f669f3ee"
370
+ },
371
+ "outputs": [
372
+ {
373
+ "name": "stdout",
374
+ "output_type": "stream",
375
+ "text": [
376
+ "\n",
377
+ "hin_Deva - mar_Deva\n",
378
+ "hin_Deva: जब मैं छोटा था, मैं हर रोज़ पार्क जाता था।\n",
379
+ "mar_Deva: मी लहान होतो तेव्हा मी दररोज उद्यानाला जायचे.\n",
380
+ "hin_Deva: उसके पास बहुत सारी पुरानी किताबें हैं, जिन्हें उसने अपने दादा-परदादा से विरासत में पाया।\n",
381
+ "mar_Deva: तिच्याकडे बरीच जुनी पुस्तके आहेत, जी तिला तिच्या आजोबांकडून वारशाने मिळाली आहेत.\n",
382
+ "hin_Deva: मुझे समझ में नहीं आ रहा कि मैं अपनी समस्या का समाधान कैसे ढूंढूं।\n",
383
+ "mar_Deva: माझ्या समस्येवर तोडगा कसा काढायचा हे मला समजत नाही.\n",
384
+ "hin_Deva: वह बहुत मेहनती और समझदार है, इसलिए उसे सभी अच्छे मार्क्स मिले।\n",
385
+ "mar_Deva: तो खूप मेहनती आणि बुद्धिमान आहे, त्यामुळे त्याला सर्व चांगले गुण मिळाले.\n",
386
+ "hin_Deva: हमने पिछले सप्ताह एक नई फिल्म देखी जो कि बहुत प्रेरणादायक थी।\n",
387
+ "mar_Deva: आम्ही गेल्या आठवड्यात एक नवीन चित्रपट पाहिला जो खूप प्रेरणादायी होता.\n",
388
+ "hin_Deva: अगर तुम मुझे उस समय पास मिलते, तो हम बाहर खाना खाने चलते।\n",
389
+ "mar_Deva: जर तुम्हाला त्या वेळी मला पास मिळाला तर आम्ही बाहेर जेवायला जाऊ.\n",
390
+ "hin_Deva: वह अपनी दीदी के साथ बाजार गयी थी ताकि वह नई साड़ी खरीद सके।\n",
391
+ "mar_Deva: ती तिच्या बहिणीसोबत बाजारात गेली होती जेणेकरून ती नवीन साडी खरेदी करू शकेल.\n",
392
+ "hin_Deva: राज ने मुझसे कहा कि वह अगले महीने अपनी नानी के घर जा रहा है।\n",
393
+ "mar_Deva: राजने मला सांगितले की तो पुढच्या महिन्यात त्याच्या आजीच्या घरी जात आहे.\n",
394
+ "hin_Deva: सभी बच्चे पार्टी में मज़ा कर रहे थे और खूब सारी मिठाइयाँ खा रहे थे।\n",
395
+ "mar_Deva: सर्व मुले पार्टीचा आनंद घेत होती आणि भरपूर मिठाई खात होती.\n",
396
+ "hin_Deva: मेरे मित्र ने मुझे उसके जन्मदिन की पार्टी में बुलाया है, और मैं उसे एक तोहफा दूंगा।\n",
397
+ "mar_Deva: माझ्या मित्राने मला त्याच्या वाढदिवसाच्या मेजवानीसाठी आमंत्रित केले आहे आणि मी त्याला भेटवस्तू देईन.\n"
398
+ ]
399
+ }
400
+ ],
401
+ "source": [
402
+ "indic_indic_ckpt_dir = \"ai4bharat/indictrans2-indic-indic-1B\" # ai4bharat/indictrans2-indic-indic-dist-320M\n",
403
+ "indic_indic_tokenizer, indic_indic_model = initialize_model_and_tokenizer(indic_indic_ckpt_dir, quantization)\n",
404
+ "\n",
405
+ "ip = IndicProcessor(inference=True)\n",
406
+ "\n",
407
+ "hi_sents = [\n",
408
+ " \"जब मैं छोटा था, मैं हर रोज़ पार्क जाता था।\",\n",
409
+ " \"उसके पास बहुत सारी पुरानी किताबें हैं, जिन्हें उसने अपने दादा-परदादा से विरासत में पाया।\",\n",
410
+ " \"मुझे समझ में नहीं आ रहा कि मैं अपनी समस्या का समाधान कैसे ढूंढूं।\",\n",
411
+ " \"वह बहुत मेहनती और समझदार है, इसलिए उसे सभी अच्छे मार्क्स मिले।\",\n",
412
+ " \"हमने पिछले सप्ताह एक नई फिल्म देखी जो कि बहुत प्रेरणादायक थी।\",\n",
413
+ " \"अगर तुम मुझे उस समय पास मिलते, तो हम बाहर खाना खाने चलते।\",\n",
414
+ " \"वह अपनी दीदी के साथ बाजार गयी थी ताकि वह नई साड़ी खरीद सके।\",\n",
415
+ " \"राज ने मुझसे कहा कि वह अगले महीने अपनी नानी के घर जा रहा है।\",\n",
416
+ " \"सभी बच्चे पार्टी में मज़ा कर रहे थे और खूब सारी मिठाइयाँ खा रहे थे।\",\n",
417
+ " \"मेरे मित्र ने मुझे उसके जन्मदिन की पार्टी में बुलाया है, और मैं उसे एक तोहफा दूंगा।\",\n",
418
+ "]\n",
419
+ "src_lang, tgt_lang = \"hin_Deva\", \"mar_Deva\"\n",
420
+ "mr_translations = batch_translate(hi_sents, src_lang, tgt_lang, indic_indic_model, indic_indic_tokenizer, ip)\n",
421
+ "\n",
422
+ "print(f\"\\n{src_lang} - {tgt_lang}\")\n",
423
+ "for input_sentence, translation in zip(hi_sents, mr_translations):\n",
424
+ " print(f\"{src_lang}: {input_sentence}\")\n",
425
+ " print(f\"{tgt_lang}: {translation}\")\n",
426
+ "\n",
427
+ "# flush the models to free the GPU memory\n",
428
+ "del indic_indic_tokenizer, indic_indic_model"
429
+ ]
430
+ },
431
+ {
432
+ "cell_type": "code",
433
+ "execution_count": null,
434
+ "metadata": {
435
+ "id": "uyxXpt--Ma6n"
436
+ },
437
+ "outputs": [],
438
+ "source": []
439
+ }
440
+ ],
441
+ "metadata": {
442
+ "accelerator": "GPU",
443
+ "colab": {
444
+ "gpuType": "T4",
445
+ "provenance": [],
446
+ "toc_visible": true
447
+ },
448
+ "kernelspec": {
449
+ "display_name": "Python 3",
450
+ "name": "python3"
451
+ },
452
+ "language_info": {
453
+ "name": "python"
454
+ }
455
+ },
456
+ "nbformat": 4,
457
+ "nbformat_minor": 0
458
+ }
IndicTrans2/huggingface_interface/configuration_indictrans.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The IndicTrans2 Authors and AI4Bharat team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch IndicTrans config."""
16
+
17
+
18
+ from collections import OrderedDict
19
+ from typing import Any, Mapping, Optional
20
+
21
+ from transformers import PreTrainedTokenizer
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.onnx import OnnxConfig, OnnxSeq2SeqConfigWithPast
24
+ from transformers.onnx.utils import compute_effective_axis_dimension
25
+ from transformers.utils import TensorType, is_torch_available
26
+
27
+
28
+ # Copied from transformers.models.m2m_100.configuration_m2m_100.M2M100Config->IndicTrans
29
+ class IndicTransConfig(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`IT2Model`]. It is used to instantiate an
32
+ IT2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
33
+ with the defaults will yield a similar configuration to that of the IT2
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 50265):
41
+ Vocabulary size of the IT2 model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`IT2Model`] or
43
+ d_model (`int`, *optional*, defaults to 1024):
44
+ Dimensionality of the layers and the pooler layer.
45
+ encoder_layers (`int`, *optional*, defaults to 12):
46
+ Number of encoder layers.
47
+ decoder_layers (`int`, *optional*, defaults to 12):
48
+ Number of decoder layers.
49
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
50
+ Number of attention heads for each attention layer in the Transformer encoder.
51
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
52
+ Number of attention heads for each attention layer in the Transformer decoder.
53
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
54
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
55
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
56
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
57
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
58
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
59
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
60
+ dropout (`float`, *optional*, defaults to 0.1):
61
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
62
+ attention_dropout (`float`, *optional*, defaults to 0.0):
63
+ The dropout ratio for the attention probabilities.
64
+ activation_dropout (`float`, *optional*, defaults to 0.0):
65
+ The dropout ratio for activations inside the fully connected layer.
66
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
67
+ The dropout ratio for classifier.
68
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
69
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
70
+ just in case (e.g., 512 or 1024 or 2048).
71
+ init_std (`float`, *optional*, defaults to 0.02):
72
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
73
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
74
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
75
+ for more details.
76
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
77
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
78
+ for more details.
79
+ use_cache (`bool`, *optional*, defaults to `True`):
80
+ Whether or not the model should return the last key/values attentions (not used by all models).
81
+ ```"""
82
+ model_type = "IndicTrans"
83
+ keys_to_ignore_at_inference = ["past_key_values"]
84
+ attribute_map = {
85
+ "num_attention_heads": "encoder_attention_heads",
86
+ "hidden_size": "d_model",
87
+ }
88
+
89
+ def __init__(
90
+ self,
91
+ encoder_vocab_size=None,
92
+ decoder_vocab_size=None,
93
+ encoder_embed_dim=512,
94
+ decoder_embed_dim=512,
95
+ max_source_positions=210,
96
+ max_target_positions=210,
97
+ encoder_layers=6,
98
+ encoder_ffn_dim=2048,
99
+ encoder_attention_heads=8,
100
+ decoder_layers=6,
101
+ decoder_ffn_dim=2048,
102
+ decoder_attention_heads=8,
103
+ encoder_layerdrop=0.00,
104
+ decoder_layerdrop=0.00,
105
+ use_cache=True,
106
+ is_encoder_decoder=True,
107
+ activation_function="relu",
108
+ encoder_normalize_before=False,
109
+ decoder_normalize_before=False,
110
+ layernorm_embedding=False,
111
+ share_decoder_input_output_embed=False,
112
+ dropout=0.1,
113
+ attention_dropout=0.0,
114
+ activation_dropout=0.0,
115
+ init_std=0.02,
116
+ scale_embedding=True,
117
+ decoder_start_token_id=2,
118
+ pad_token_id=1,
119
+ bos_token_id=0,
120
+ eos_token_id=2,
121
+ attn_implementation="eager",
122
+ **kwargs,
123
+ ):
124
+ self.encoder_vocab_size = encoder_vocab_size
125
+ self.decoder_vocab_size = decoder_vocab_size
126
+ self.encoder_normalize_before = encoder_normalize_before
127
+ self.decoder_normalize_before = decoder_normalize_before
128
+ self.layernorm_embedding = layernorm_embedding
129
+ self.max_source_positions = max_source_positions
130
+ self.max_target_positions = max_target_positions
131
+ self.encoder_embed_dim = encoder_embed_dim
132
+ self.decoder_embed_dim = decoder_embed_dim
133
+ self.encoder_ffn_dim = encoder_ffn_dim
134
+ self.encoder_layers = encoder_layers
135
+ self.encoder_attention_heads = encoder_attention_heads
136
+ self.decoder_ffn_dim = decoder_ffn_dim
137
+ self.decoder_layers = decoder_layers
138
+ self.decoder_attention_heads = decoder_attention_heads
139
+ self.dropout = dropout
140
+ self.attention_dropout = attention_dropout
141
+ self.activation_dropout = activation_dropout
142
+ self.activation_function = activation_function
143
+ self.init_std = init_std
144
+ self.encoder_layerdrop = encoder_layerdrop
145
+ self.decoder_layerdrop = decoder_layerdrop
146
+ self.use_cache = use_cache
147
+ self.num_hidden_layers = encoder_layers
148
+ self.scale_embedding = scale_embedding
149
+ self.share_decoder_input_output_embed = share_decoder_input_output_embed
150
+ self.attn_implementation = attn_implementation
151
+
152
+ super().__init__(
153
+ pad_token_id=pad_token_id,
154
+ bos_token_id=bos_token_id,
155
+ eos_token_id=eos_token_id,
156
+ is_encoder_decoder=is_encoder_decoder,
157
+ decoder_start_token_id=decoder_start_token_id,
158
+ **kwargs,
159
+ )
160
+
161
+
162
+ class IndicTransOnnxConfig(OnnxSeq2SeqConfigWithPast):
163
+ @property
164
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
165
+ common_inputs = OrderedDict(
166
+ [
167
+ ("input_ids", {0: "batch", 1: "encoder_sequence"}),
168
+ ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
169
+ ]
170
+ )
171
+
172
+ if self.use_past:
173
+ common_inputs["decoder_input_ids"] = {0: "batch"}
174
+ common_inputs["decoder_attention_mask"] = {
175
+ 0: "batch",
176
+ 1: "past_decoder_sequence + sequence",
177
+ }
178
+ else:
179
+ common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
180
+ common_inputs["decoder_attention_mask"] = {
181
+ 0: "batch",
182
+ 1: "decoder_sequence",
183
+ }
184
+
185
+ if self.use_past:
186
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
187
+ return common_inputs
188
+
189
+ # Copied from BartOnnxConfig._generate_dummy_inputs_for_sequence_classification_and_question_answering
190
+ # A better name would be _generate_dummy_inputs_for_encoder_and_decoder because sequence classification and question
191
+ # answering are not supported for IT2, but this name is preserved to be able to check that the copy matches what
192
+ # was done for BART so that it can be updated if need be.
193
+ def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
194
+ self,
195
+ tokenizer: PreTrainedTokenizer,
196
+ batch_size: int = -1,
197
+ seq_length: int = -1,
198
+ is_pair: bool = False,
199
+ framework: Optional[TensorType] = None,
200
+ ) -> Mapping[str, Any]:
201
+ # Copied from OnnxConfig.generate_dummy_inputs
202
+ # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
203
+ # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
204
+ batch_size = compute_effective_axis_dimension(
205
+ batch_size,
206
+ fixed_dimension=OnnxConfig.default_fixed_batch,
207
+ num_token_to_add=0,
208
+ )
209
+
210
+ # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
211
+ token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
212
+ seq_length = compute_effective_axis_dimension(
213
+ seq_length,
214
+ fixed_dimension=OnnxConfig.default_fixed_sequence,
215
+ num_token_to_add=token_to_add,
216
+ )
217
+
218
+ # Generate dummy inputs according to compute batch and sequence
219
+ dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
220
+ common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
221
+ return common_inputs
222
+
223
+ # Copied from transformers.models.bart.configuration_bart.BartOnnxConfig._generate_dummy_inputs_for_default_and_seq2seq_lm
224
+ def _generate_dummy_inputs_for_default_and_seq2seq_lm(
225
+ self,
226
+ tokenizer: PreTrainedTokenizer,
227
+ batch_size: int = -1,
228
+ seq_length: int = -1,
229
+ is_pair: bool = False,
230
+ framework: Optional[TensorType] = None,
231
+ ) -> Mapping[str, Any]:
232
+ encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
233
+ tokenizer, batch_size, seq_length, is_pair, framework
234
+ )
235
+
236
+ # Generate decoder inputs
237
+ decoder_seq_length = seq_length if not self.use_past else 1
238
+ decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
239
+ tokenizer, batch_size, decoder_seq_length, is_pair, framework
240
+ )
241
+ decoder_inputs = {
242
+ f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()
243
+ }
244
+ common_inputs = dict(**encoder_inputs, **decoder_inputs)
245
+
246
+ if self.use_past:
247
+ if not is_torch_available():
248
+ raise ValueError(
249
+ "Cannot generate dummy past_keys inputs without PyTorch installed."
250
+ )
251
+ else:
252
+ import torch
253
+ batch, encoder_seq_length = common_inputs["input_ids"].shape
254
+ decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
255
+ (
256
+ num_encoder_attention_heads,
257
+ num_decoder_attention_heads,
258
+ ) = self.num_attention_heads
259
+ encoder_shape = (
260
+ batch,
261
+ num_encoder_attention_heads,
262
+ encoder_seq_length,
263
+ self._config.hidden_size // num_encoder_attention_heads,
264
+ )
265
+ decoder_past_length = decoder_seq_length + 3
266
+ decoder_shape = (
267
+ batch,
268
+ num_decoder_attention_heads,
269
+ decoder_past_length,
270
+ self._config.hidden_size // num_decoder_attention_heads,
271
+ )
272
+
273
+ common_inputs["decoder_attention_mask"] = torch.cat(
274
+ [
275
+ common_inputs["decoder_attention_mask"],
276
+ torch.ones(batch, decoder_past_length),
277
+ ],
278
+ dim=1,
279
+ )
280
+
281
+ common_inputs["past_key_values"] = []
282
+ # If the number of encoder and decoder layers are present in the model configuration, both are considered
283
+ num_encoder_layers, num_decoder_layers = self.num_layers
284
+ min_num_layers = min(num_encoder_layers, num_decoder_layers)
285
+ max_num_layers = (
286
+ max(num_encoder_layers, num_decoder_layers) - min_num_layers
287
+ )
288
+ remaining_side_name = (
289
+ "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
290
+ )
291
+
292
+ for _ in range(min_num_layers):
293
+ common_inputs["past_key_values"].append(
294
+ (
295
+ torch.zeros(decoder_shape),
296
+ torch.zeros(decoder_shape),
297
+ torch.zeros(encoder_shape),
298
+ torch.zeros(encoder_shape),
299
+ )
300
+ )
301
+ # TODO: test this.
302
+ shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
303
+ for _ in range(min_num_layers, max_num_layers):
304
+ common_inputs["past_key_values"].append(
305
+ (torch.zeros(shape), torch.zeros(shape))
306
+ )
307
+ return common_inputs
308
+
309
+ generate_dummy_inputs = _generate_dummy_inputs_for_default_and_seq2seq_lm
IndicTrans2/huggingface_interface/convert_indictrans_checkpoint_to_pytorch.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from configuration_indictrans import IndicTransConfig
21
+ from modeling_indictrans import IndicTransForConditionalGeneration
22
+
23
+
24
+ def remove_ignore_keys_(state_dict):
25
+ ignore_keys = [
26
+ "encoder.version",
27
+ "decoder.version",
28
+ "model.encoder.version",
29
+ "model.decoder.version",
30
+ "_float_tensor",
31
+ "encoder.embed_positions._float_tensor",
32
+ "decoder.embed_positions._float_tensor",
33
+ ]
34
+ for k in ignore_keys:
35
+ state_dict.pop(k, None)
36
+
37
+
38
+ def make_linear_from_emb(emb):
39
+ vocab_size, emb_size = emb.shape
40
+ lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
41
+ lin_layer.weight.data = emb.data
42
+ return lin_layer
43
+
44
+
45
+ def convert_fairseq_IT2_checkpoint_from_disk(checkpoint_path):
46
+ model = torch.load(checkpoint_path, map_location="cpu")
47
+ args = model["args"] or model["cfg"]["model"]
48
+ state_dict = model["model"]
49
+ remove_ignore_keys_(state_dict)
50
+ encoder_vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0]
51
+ decoder_vocab_size = state_dict["decoder.embed_tokens.weight"].shape[0]
52
+
53
+ config = IndicTransConfig(
54
+ encoder_vocab_size=encoder_vocab_size,
55
+ decoder_vocab_size=decoder_vocab_size,
56
+ max_source_positions=args.max_source_positions,
57
+ max_target_positions=args.max_target_positions,
58
+ encoder_layers=args.encoder_layers,
59
+ decoder_layers=args.decoder_layers,
60
+ layernorm_embedding=args.layernorm_embedding,
61
+ encoder_normalize_before=args.encoder_normalize_before,
62
+ decoder_normalize_before=args.decoder_normalize_before,
63
+ encoder_attention_heads=args.encoder_attention_heads,
64
+ decoder_attention_heads=args.decoder_attention_heads,
65
+ encoder_ffn_dim=args.encoder_ffn_embed_dim,
66
+ decoder_ffn_dim=args.decoder_ffn_embed_dim,
67
+ encoder_embed_dim=args.encoder_embed_dim,
68
+ decoder_embed_dim=args.decoder_embed_dim,
69
+ encoder_layerdrop=args.encoder_layerdrop,
70
+ decoder_layerdrop=args.decoder_layerdrop,
71
+ dropout=args.dropout,
72
+ attention_dropout=args.attention_dropout,
73
+ activation_dropout=args.activation_dropout,
74
+ activation_function=args.activation_fn,
75
+ share_decoder_input_output_embed=args.share_decoder_input_output_embed,
76
+ scale_embedding=not args.no_scale_embedding,
77
+ )
78
+
79
+ model = IndicTransForConditionalGeneration(config)
80
+ model.model.load_state_dict(state_dict, strict=False)
81
+ if not args.share_decoder_input_output_embed:
82
+ model.lm_head = make_linear_from_emb(
83
+ state_dict["decoder.output_projection.weight"]
84
+ )
85
+ print(model)
86
+ return model
87
+
88
+
89
+ if __name__ == "__main__":
90
+ parser = argparse.ArgumentParser()
91
+ # Required parameters
92
+ parser.add_argument(
93
+ "--fairseq_path",
94
+ default="indic-en/model/checkpoint_best.pt",
95
+ type=str,
96
+ help="path to a model.pt on local filesystem.",
97
+ )
98
+ parser.add_argument(
99
+ "--pytorch_dump_folder_path",
100
+ default="indic-en/hf_model",
101
+ type=str,
102
+ help="Path to the output PyTorch model.",
103
+ )
104
+
105
+ args = parser.parse_args()
106
+ model = convert_fairseq_IT2_checkpoint_from_disk(args.fairseq_path)
107
+ model.save_pretrained(args.pytorch_dump_folder_path)
IndicTrans2/huggingface_interface/example.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig
4
+ from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10
5
+ from IndicTransToolkit import IndicProcessor
6
+ from mosestokenizer import MosesSentenceSplitter
7
+ from nltk import sent_tokenize
8
+ from indicnlp.tokenize.sentence_tokenize import sentence_split, DELIM_PAT_NO_DANDA
9
+
10
+
11
+ en_indic_ckpt_dir = "ai4bharat/indictrans2-en-indic-1B" # ai4bharat/indictrans2-en-indic-dist-200M
12
+ indic_en_ckpt_dir = "ai4bharat/indictrans2-indic-en-1B" # ai4bharat/indictrans2-indic-en-dist-200M
13
+ indic_indic_ckpt_dir = (
14
+ "ai4bharat/indictrans2-indic-indic-dist-320M" # ai4bharat/indictrans2-indic-indic-dist-320M
15
+ )
16
+ BATCH_SIZE = 4
17
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ if len(sys.argv) > 1:
20
+ quantization = sys.argv[1]
21
+ attn_implementation = sys.argv[2]
22
+ else:
23
+ quantization = ""
24
+ attn_implementation = "eager"
25
+
26
+
27
+ # FLORES language code mapping to 2 letter ISO language code for compatibility
28
+ # with Indic NLP Library (https://github.com/anoopkunchukuttan/indic_nlp_library)
29
+ flores_codes = {
30
+ "asm_Beng": "as",
31
+ "awa_Deva": "hi",
32
+ "ben_Beng": "bn",
33
+ "bho_Deva": "hi",
34
+ "brx_Deva": "hi",
35
+ "doi_Deva": "hi",
36
+ "eng_Latn": "en",
37
+ "gom_Deva": "kK",
38
+ "guj_Gujr": "gu",
39
+ "hin_Deva": "hi",
40
+ "hne_Deva": "hi",
41
+ "kan_Knda": "kn",
42
+ "kas_Arab": "ur",
43
+ "kas_Deva": "hi",
44
+ "kha_Latn": "en",
45
+ "lus_Latn": "en",
46
+ "mag_Deva": "hi",
47
+ "mai_Deva": "hi",
48
+ "mal_Mlym": "ml",
49
+ "mar_Deva": "mr",
50
+ "mni_Beng": "bn",
51
+ "mni_Mtei": "hi",
52
+ "npi_Deva": "ne",
53
+ "ory_Orya": "or",
54
+ "pan_Guru": "pa",
55
+ "san_Deva": "hi",
56
+ "sat_Olck": "or",
57
+ "snd_Arab": "ur",
58
+ "snd_Deva": "hi",
59
+ "tam_Taml": "ta",
60
+ "tel_Telu": "te",
61
+ "urd_Arab": "ur",
62
+ }
63
+
64
+
65
+ def split_sentences(input_text, lang):
66
+ if lang == "eng_Latn":
67
+ input_sentences = sent_tokenize(input_text)
68
+ with MosesSentenceSplitter(flores_codes[lang]) as splitter:
69
+ sents_moses = splitter([input_text])
70
+ sents_nltk = sent_tokenize(input_text)
71
+ if len(sents_nltk) < len(sents_moses):
72
+ input_sentences = sents_nltk
73
+ else:
74
+ input_sentences = sents_moses
75
+ input_sentences = [sent.replace("\xad", "") for sent in input_sentences]
76
+ else:
77
+ input_sentences = sentence_split(
78
+ input_text, lang=flores_codes[lang], delim_pat=DELIM_PAT_NO_DANDA
79
+ )
80
+ return input_sentences
81
+
82
+
83
+ def initialize_model_and_tokenizer(ckpt_dir, quantization, attn_implementation):
84
+ if quantization == "4-bit":
85
+ qconfig = BitsAndBytesConfig(
86
+ load_in_4bit=True,
87
+ bnb_4bit_use_double_quant=True,
88
+ bnb_4bit_compute_dtype=torch.bfloat16,
89
+ )
90
+ elif quantization == "8-bit":
91
+ qconfig = BitsAndBytesConfig(
92
+ load_in_8bit=True,
93
+ bnb_8bit_use_double_quant=True,
94
+ bnb_8bit_compute_dtype=torch.bfloat16,
95
+ )
96
+ else:
97
+ qconfig = None
98
+
99
+ if attn_implementation == "flash_attention_2":
100
+ if is_flash_attn_2_available() and is_flash_attn_greater_or_equal_2_10():
101
+ attn_implementation = "flash_attention_2"
102
+ else:
103
+ attn_implementation = "eager"
104
+
105
+ tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, trust_remote_code=True)
106
+ model = AutoModelForSeq2SeqLM.from_pretrained(
107
+ ckpt_dir,
108
+ trust_remote_code=True,
109
+ attn_implementation=attn_implementation,
110
+ low_cpu_mem_usage=True,
111
+ quantization_config=qconfig,
112
+ )
113
+
114
+ if qconfig == None:
115
+ model = model.to(DEVICE)
116
+ model.half()
117
+
118
+ model.eval()
119
+
120
+ return tokenizer, model
121
+
122
+
123
+ def batch_translate(input_sentences, src_lang, tgt_lang, model, tokenizer, ip):
124
+ translations = []
125
+ for i in range(0, len(input_sentences), BATCH_SIZE):
126
+ batch = input_sentences[i : i + BATCH_SIZE]
127
+
128
+ # Preprocess the batch and extract entity mappings
129
+ batch = ip.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang)
130
+
131
+ # Tokenize the batch and generate input encodings
132
+ inputs = tokenizer(
133
+ batch,
134
+ truncation=True,
135
+ padding="longest",
136
+ return_tensors="pt",
137
+ return_attention_mask=True,
138
+ ).to(DEVICE)
139
+
140
+ # Generate translations using the model
141
+ with torch.no_grad():
142
+ generated_tokens = model.generate(
143
+ **inputs,
144
+ use_cache=True,
145
+ min_length=0,
146
+ max_length=256,
147
+ num_beams=5,
148
+ num_return_sequences=1,
149
+ )
150
+
151
+ # Decode the generated tokens into text
152
+ with tokenizer.as_target_tokenizer():
153
+ generated_tokens = tokenizer.batch_decode(
154
+ generated_tokens.detach().cpu().tolist(),
155
+ skip_special_tokens=True,
156
+ clean_up_tokenization_spaces=True,
157
+ )
158
+
159
+ # Postprocess the translations, including entity replacement
160
+ translations += ip.postprocess_batch(generated_tokens, lang=tgt_lang)
161
+
162
+ del inputs
163
+ torch.cuda.empty_cache()
164
+
165
+ return translations
166
+
167
+
168
+ def translate_paragraph(input_text, src_lang, tgt_lang, model, tokenizer, ip):
169
+ input_sentences = split_sentences(input_text, src_lang)
170
+ translated_text = batch_translate(input_sentences, src_lang, tgt_lang, model, tokenizer, ip)
171
+ return " ".join(translated_text)
172
+
173
+
174
+ ip = IndicProcessor(inference=True)
175
+
176
+ en_indic_tokenizer, en_indic_model = initialize_model_and_tokenizer(
177
+ en_indic_ckpt_dir, quantization, attn_implementation
178
+ )
179
+
180
+ indic_en_tokenizer, indic_en_model = initialize_model_and_tokenizer(
181
+ indic_en_ckpt_dir, quantization, attn_implementation
182
+ )
183
+
184
+ indic_indic_tokenizer, indic_indic_model = initialize_model_and_tokenizer(
185
+ indic_indic_ckpt_dir, quantization, attn_implementation
186
+ )
187
+
188
+ # ---------------------------------------------------------------------------
189
+ # Hindi to English
190
+ # ---------------------------------------------------------------------------
191
+ hi_sents = [
192
+ "जब मैं छोटा था, मैं हर रोज़ पार्क जाता था।",
193
+ "उसके पास बहुत सारी पुरानी किताबें हैं, जिन्हें उसने अपने दादा-परदादा से विरासत में पाया।",
194
+ "मुझे समझ में नहीं आ रहा कि मैं अपनी समस्या का समाधान कैसे ढूंढूं।",
195
+ "वह बहुत मेहनती और समझदार है, इसलिए उसे सभी अच्छे मार्क्स मिले।",
196
+ "हमने पिछले सप्ताह एक नई फिल्म देखी जो कि बहुत प्रेरणादायक थी।",
197
+ "अगर तुम मुझे उस समय पास मिलते, तो हम बाहर खाना खाने चलते।",
198
+ "वह अपनी दीदी के साथ बाजार गयी थी ताकि वह नई साड़ी खरीद सके।",
199
+ "राज ने मुझसे कहा कि वह अगले महीने अपनी नानी के घर जा रहा है।",
200
+ "सभी बच्चे पार्टी में मज़ा कर रहे थे और खूब सारी मिठाइयाँ खा रहे थे।",
201
+ "मेरे मित्र ने मुझे उसके जन्मदिन की पार्टी में बुलाया है, और मैं उसे एक तोहफा दूंगा।",
202
+ ]
203
+ src_lang, tgt_lang = "hin_Deva", "eng_Latn"
204
+ en_translations = batch_translate(
205
+ hi_sents, src_lang, tgt_lang, indic_en_model, indic_en_tokenizer, ip
206
+ )
207
+
208
+ print(f"\n{src_lang} - {tgt_lang}")
209
+ for input_sentence, translation in zip(hi_sents, en_translations):
210
+ print(f"{src_lang}: {input_sentence}")
211
+ print(f"{tgt_lang}: {translation}")
212
+
213
+
214
+ # ---------------------------------------------------------------------------
215
+ # English to Hindi
216
+ # ---------------------------------------------------------------------------
217
+ en_sents = [
218
+ "When I was young, I used to go to the park every day.",
219
+ "He has many old books, which he inherited from his ancestors.",
220
+ "I can't figure out how to solve my problem.",
221
+ "She is very hardworking and intelligent, which is why she got all the good marks.",
222
+ "We watched a new movie last week, which was very inspiring.",
223
+ "If you had met me at that time, we would have gone out to eat.",
224
+ "She went to the market with her sister to buy a new sari.",
225
+ "Raj told me that he is going to his grandmother's house next month.",
226
+ "All the kids were having fun at the party and were eating lots of sweets.",
227
+ "My friend has invited me to his birthday party, and I will give him a gift.",
228
+ ]
229
+ src_lang, tgt_lang = "eng_Latn", "hin_Deva"
230
+ hi_translations = batch_translate(
231
+ en_sents, src_lang, tgt_lang, en_indic_model, en_indic_tokenizer, ip
232
+ )
233
+
234
+ print(f"\n{src_lang} - {tgt_lang}")
235
+ for input_sentence, translation in zip(en_sents, hi_translations):
236
+ print(f"{src_lang}: {input_sentence}")
237
+ print(f"{tgt_lang}: {translation}")
238
+
239
+
240
+ # ---------------------------------------------------------------------------
241
+ # Hindi to Marathi
242
+ # ---------------------------------------------------------------------------
243
+ hi_sents = [
244
+ "��ब मैं छोटा था, मैं हर रोज़ पार्क जाता था।",
245
+ "उसके पास बहुत सारी पुरानी किताबें हैं, जिन्हें उसने अपने दादा-परदादा से विरासत में पाया।",
246
+ "मुझे समझ में नहीं आ रहा कि मैं अपनी समस्या का समाधान कैसे ढूंढूं।",
247
+ "वह बहुत मेहनती और समझदार है, इसलिए उसे सभी अच्छे मार्क्स मिले।",
248
+ "हमने पिछले सप्ताह एक नई फिल्म देखी जो कि बहुत प्रेरणादायक थी।",
249
+ "अगर तुम मुझे उस समय पास मिलते, तो हम बाहर खाना खाने चलते।",
250
+ "वह अपनी दीदी के साथ बाजार गयी थी ताकि वह नई साड़ी खरीद सके।",
251
+ "राज ने मुझसे कहा कि वह अगले महीने अपनी नानी के घर जा रहा है।",
252
+ "सभी बच्चे पार्टी में मज़ा कर रहे थे और खूब सारी मिठाइयाँ खा रहे थे।",
253
+ "मेरे मित्र ने मुझे उसके जन्मदिन की पार्टी में बुलाया है, और मैं उसे एक तोहफा दूंगा।",
254
+ ]
255
+ src_lang, tgt_lang = "hin_Deva", "mar_Deva"
256
+ mr_translations = batch_translate(
257
+ hi_sents, src_lang, tgt_lang, indic_indic_model, indic_indic_tokenizer, ip
258
+ )
259
+
260
+ print(f"\n{src_lang} - {tgt_lang}")
261
+ for input_sentence, translation in zip(hi_sents, mr_translations):
262
+ print(f"{src_lang}: {input_sentence}")
263
+ print(f"{tgt_lang}: {translation}")
264
+
265
+
266
+ # ---------------------------------------------------------------------------
267
+ # Paragraph translation
268
+ # ---------------------------------------------------------------------------
269
+ src_lang, tgt_lang = "hin_Deva", "eng_Latn"
270
+ hi_text = "यहाँ एक पाराग्राफ है जो हिंदी में लिखा गया है। हिंदी एक सुंदर भाषा है और भारत की राष्ट्रीय भाषा है। इसका विकास विभिन्न कालों में हुआ है और यह विशेषतः भारतीय उपमहाद्वीप में बोली जाती है। हिंदी भाषा का साहित्य, संस्कृति और इतिहास भी बहुत गर्वनीय है।"
271
+ en_translated_text = translate_paragraph(
272
+ hi_text, src_lang, tgt_lang, indic_en_model, indic_en_tokenizer, ip
273
+ )
274
+ print(f"{src_lang}: {hi_text}")
275
+ print(f"{tgt_lang}: {en_translated_text}")
IndicTrans2/huggingface_interface/install.sh ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #/bin/bash
2
+
3
+ root_dir=$(pwd)
4
+ echo "Setting up the environment in the $root_dir"
5
+
6
+ # --------------------------------------------------------------
7
+ # create and activate the virtual environment
8
+ # --------------------------------------------------------------
9
+ echo "Creating a virtual environment with python3"
10
+ conda create -n itv2_hf python=3.9 -y
11
+ conda activate itv2_hf
12
+
13
+ echo "Installing all the dependencies"
14
+ conda install pip
15
+ python3 -m pip install --upgrade pip
16
+
17
+
18
+ # --------------------------------------------------------------
19
+ # PyTorch Installation
20
+ # --------------------------------------------------------------
21
+ python3 -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu118
22
+
23
+
24
+ # --------------------------------------------------------------
25
+ # Install additional utility packages
26
+ # --------------------------------------------------------------
27
+ python3 -m pip install nltk sacremoses pandas regex mock transformers>=4.33.2 mosestokenizer
28
+ python3 -c "import nltk; nltk.download('punkt')"
29
+ python3 -m pip install bitsandbytes scipy accelerate datasets flash-attn>=2.1
30
+
31
+
32
+ # --------------------------------------------------------------
33
+ # Sentencepiece for tokenization
34
+ # --------------------------------------------------------------
35
+ # build the cpp binaries from the source repo in order to use the command line utility
36
+ # source repo: https://github.com/google/sentencepiece
37
+ python3 -m pip install sentencepiece
38
+
39
+
40
+ # -----------------------------------------------------------------
41
+ # Install IndicTrans2 tokenizer and its dependencies
42
+ # -----------------------------------------------------------------
43
+ git clone https://github.com/VarunGumma/IndicTransToolkit
44
+ cd IndicTransToolkit
45
+ python3 -m pip install --editable ./
46
+ cd $root_dir
47
+
48
+
49
+ echo "Setup completed!"
IndicTrans2/huggingface_interface/modeling_indictrans.py ADDED
@@ -0,0 +1,1801 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The IndicTrans2 Authors and AI4Bharat team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch IndicTrans model."""
16
+
17
+
18
+ import math
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from torch.nn import functional as F
24
+
25
+ from transformers.activations import ACT2FN
26
+
27
+ from transformers.modeling_attn_mask_utils import (
28
+ _prepare_4d_attention_mask,
29
+ _prepare_4d_attention_mask_for_sdpa,
30
+ _prepare_4d_causal_attention_mask,
31
+ _prepare_4d_causal_attention_mask_for_sdpa,
32
+ )
33
+
34
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
35
+ from transformers.modeling_outputs import (
36
+ BaseModelOutput,
37
+ BaseModelOutputWithPastAndCrossAttentions,
38
+ Seq2SeqLMOutput,
39
+ Seq2SeqModelOutput
40
+ )
41
+
42
+ from transformers.utils import (
43
+ logging,
44
+ is_flash_attn_2_available,
45
+ is_flash_attn_greater_or_equal_2_10,
46
+ )
47
+
48
+ from transformers.modeling_utils import PreTrainedModel
49
+
50
+ from .configuration_indictrans import IndicTransConfig
51
+
52
+
53
+ logger = logging.get_logger(__name__)
54
+
55
+ INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
56
+
57
+ try:
58
+ if is_flash_attn_2_available():
59
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
60
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
61
+ except:
62
+ pass
63
+
64
+
65
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
66
+ def _get_unpad_data(attention_mask):
67
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
68
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
69
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
70
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
71
+ return (
72
+ indices,
73
+ cu_seqlens,
74
+ max_seqlen_in_batch,
75
+ )
76
+
77
+
78
+ # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
79
+ def shift_tokens_right(
80
+ input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int
81
+ ):
82
+ """
83
+ Shift input ids one token to the right.
84
+ """
85
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
86
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
87
+ shifted_input_ids[:, 0] = decoder_start_token_id
88
+
89
+ if pad_token_id is None:
90
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
91
+ # replace possible -100 values in labels by `pad_token_id`
92
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
93
+
94
+ return shifted_input_ids
95
+
96
+
97
+ def create_position_ids_from_input_ids(
98
+ input_ids, padding_idx, past_key_values_length=0
99
+ ):
100
+ """
101
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
102
+ are ignored. This is modified from fairseq's `utils.make_positions`.
103
+ """
104
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
105
+ mask = input_ids.ne(padding_idx).int()
106
+ incremental_indices = (
107
+ torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length
108
+ ) * mask
109
+ return incremental_indices.long() + padding_idx
110
+
111
+
112
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding->IndicTrans
113
+ class IndicTransSinusoidalPositionalEmbedding(nn.Module):
114
+ """This module produces sinusoidal positional embeddings of any length."""
115
+
116
+ def __init__(
117
+ self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None
118
+ ):
119
+ super().__init__()
120
+ self.offset = 2
121
+ self.embedding_dim = embedding_dim
122
+ self.padding_idx = padding_idx
123
+ self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
124
+
125
+ def make_weights(
126
+ self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
127
+ ):
128
+ emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
129
+ if hasattr(self, "weights"):
130
+ # in forward put the weights on the correct dtype and device of the param
131
+ emb_weights = emb_weights.to(
132
+ dtype=self.weights.dtype, device=self.weights.device
133
+ )
134
+
135
+ self.register_buffer("weights", emb_weights, persistent=False)
136
+
137
+ @staticmethod
138
+ def get_embedding(
139
+ num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
140
+ ):
141
+ """
142
+ Build sinusoidal embeddings.
143
+
144
+ This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
145
+ "Attention Is All You Need".
146
+ """
147
+ half_dim = embedding_dim // 2
148
+ emb = math.log(10000) / (half_dim - 1)
149
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
150
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
151
+ 1
152
+ ) * emb.unsqueeze(0)
153
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
154
+ num_embeddings, -1
155
+ )
156
+ if embedding_dim % 2 == 1:
157
+ # zero pad
158
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
159
+ if padding_idx is not None:
160
+ emb[padding_idx, :] = 0
161
+
162
+ return emb.to(torch.get_default_dtype())
163
+
164
+ @torch.no_grad()
165
+ def forward(
166
+ self,
167
+ input_ids: torch.Tensor = None,
168
+ inputs_embeds: torch.Tensor = None,
169
+ past_key_values_length: int = 0,
170
+ ):
171
+ if input_ids is not None:
172
+ bsz, seq_len = input_ids.size()
173
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
174
+ position_ids = create_position_ids_from_input_ids(
175
+ input_ids, self.padding_idx, past_key_values_length
176
+ ).to(input_ids.device)
177
+ else:
178
+ bsz, seq_len = inputs_embeds.size()[:-1]
179
+ position_ids = self.create_position_ids_from_inputs_embeds(
180
+ inputs_embeds, past_key_values_length
181
+ )
182
+
183
+ # expand embeddings if needed
184
+ max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
185
+ if max_pos > self.weights.size(0):
186
+ self.make_weights(
187
+ max_pos + self.offset, self.embedding_dim, self.padding_idx
188
+ )
189
+
190
+ return (
191
+ self.weights.index_select(0, position_ids.view(-1))
192
+ .view(bsz, seq_len, self.weights.shape[-1])
193
+ .detach()
194
+ )
195
+
196
+ def create_position_ids_from_inputs_embeds(
197
+ self, inputs_embeds, past_key_values_length
198
+ ):
199
+ """
200
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
201
+
202
+ Args:
203
+ inputs_embeds: torch.Tensor
204
+
205
+ Returns: torch.Tensor
206
+ """
207
+ input_shape = inputs_embeds.size()[:-1]
208
+ sequence_length = input_shape[1]
209
+
210
+ position_ids = torch.arange(
211
+ self.padding_idx + 1,
212
+ sequence_length + self.padding_idx + 1,
213
+ dtype=torch.long,
214
+ device=inputs_embeds.device,
215
+ )
216
+ return (
217
+ position_ids.unsqueeze(0).expand(input_shape).contiguous()
218
+ + past_key_values_length
219
+ )
220
+
221
+
222
+ # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->IndicTrans
223
+ class IndicTransAttention(nn.Module):
224
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
225
+
226
+ def __init__(
227
+ self,
228
+ embed_dim: int,
229
+ num_heads: int,
230
+ dropout: float = 0.0,
231
+ is_decoder: bool = False,
232
+ bias: bool = True,
233
+ is_causal: bool = False,
234
+ config: Optional[IndicTransConfig] = None,
235
+ ):
236
+ super().__init__()
237
+ self.embed_dim = embed_dim
238
+ self.num_heads = num_heads
239
+ self.dropout = dropout
240
+ self.head_dim = embed_dim // num_heads
241
+ self.config = config
242
+
243
+ if (self.head_dim * num_heads) != self.embed_dim:
244
+ raise ValueError(
245
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
246
+ f" and `num_heads`: {num_heads})."
247
+ )
248
+ self.scaling = self.head_dim**-0.5
249
+ self.is_decoder = is_decoder
250
+ self.is_causal = is_causal
251
+
252
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
253
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
254
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
255
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
256
+
257
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
258
+ return (
259
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
260
+ .transpose(1, 2)
261
+ .contiguous()
262
+ )
263
+
264
+ def forward(
265
+ self,
266
+ hidden_states: torch.Tensor,
267
+ key_value_states: Optional[torch.Tensor] = None,
268
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
269
+ attention_mask: Optional[torch.Tensor] = None,
270
+ layer_head_mask: Optional[torch.Tensor] = None,
271
+ output_attentions: bool = False,
272
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
273
+ """Input shape: Batch x Time x Channel"""
274
+
275
+ # if key_value_states are provided this layer is used as a cross-attention layer
276
+ # for the decoder
277
+ is_cross_attention = key_value_states is not None
278
+
279
+ bsz, tgt_len, _ = hidden_states.size()
280
+
281
+ # get query proj
282
+ query_states = self.q_proj(hidden_states) * self.scaling
283
+ # get key, value proj
284
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
285
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
286
+ # the provided `key_value_states` to support prefix tuning
287
+ if (
288
+ is_cross_attention
289
+ and past_key_value is not None
290
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
291
+ ):
292
+ # reuse k,v, cross_attentions
293
+ key_states = past_key_value[0]
294
+ value_states = past_key_value[1]
295
+ elif is_cross_attention:
296
+ # cross_attentions
297
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
298
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
299
+ elif past_key_value is not None:
300
+ # reuse k, v, self_attention
301
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
302
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
303
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
304
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
305
+ else:
306
+ # self_attention
307
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
308
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
309
+
310
+ if self.is_decoder:
311
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
312
+ # Further calls to cross_attention layer can then reuse all cross-attention
313
+ # key/value_states (first "if" case)
314
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
315
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
316
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
317
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
318
+ past_key_value = (key_states, value_states)
319
+
320
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
321
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
322
+ key_states = key_states.reshape(*proj_shape)
323
+ value_states = value_states.reshape(*proj_shape)
324
+
325
+ src_len = key_states.size(1)
326
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
327
+
328
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
329
+ raise ValueError(
330
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
331
+ f" {attn_weights.size()}"
332
+ )
333
+
334
+ if attention_mask is not None:
335
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
336
+ raise ValueError(
337
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
338
+ )
339
+ attn_weights = (
340
+ attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
341
+ + attention_mask
342
+ )
343
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
344
+
345
+ attn_weights = F.softmax(attn_weights, dim=-1)
346
+
347
+ if layer_head_mask is not None:
348
+ if layer_head_mask.size() != (self.num_heads,):
349
+ raise ValueError(
350
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
351
+ f" {layer_head_mask.size()}"
352
+ )
353
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
354
+ bsz, self.num_heads, tgt_len, src_len
355
+ )
356
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
357
+
358
+ if output_attentions:
359
+ # this operation is a bit awkward, but it's required to
360
+ # make sure that attn_weights keeps its gradient.
361
+ # In order to do so, attn_weights have to be reshaped
362
+ # twice and have to be reused in the following
363
+ attn_weights_reshaped = attn_weights.view(
364
+ bsz, self.num_heads, tgt_len, src_len
365
+ )
366
+ attn_weights = attn_weights_reshaped.view(
367
+ bsz * self.num_heads, tgt_len, src_len
368
+ )
369
+ else:
370
+ attn_weights_reshaped = None
371
+
372
+ attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
373
+
374
+ attn_output = torch.bmm(attn_probs, value_states)
375
+
376
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
377
+ raise ValueError(
378
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
379
+ f" {attn_output.size()}"
380
+ )
381
+
382
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
383
+ attn_output = attn_output.transpose(1, 2)
384
+
385
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
386
+ # partitioned across GPUs when using tensor-parallelism.
387
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
388
+
389
+ attn_output = self.out_proj(attn_output)
390
+
391
+ return attn_output, attn_weights_reshaped, past_key_value
392
+
393
+
394
+ class IndicTransFlashAttention2(IndicTransAttention):
395
+ """
396
+ IndicTrans flash attention module. This module inherits from `IndicTransAttention` as the weights of the module stays
397
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
398
+ flash attention and deal with padding tokens in case the input contains any of them.
399
+ """
400
+
401
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
402
+ def __init__(self, *args, **kwargs):
403
+ super().__init__(*args, **kwargs)
404
+
405
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
406
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
407
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
408
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
409
+
410
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
411
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
412
+
413
+ def forward(
414
+ self,
415
+ hidden_states: torch.Tensor,
416
+ key_value_states: Optional[torch.Tensor] = None,
417
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
418
+ attention_mask: Optional[torch.Tensor] = None,
419
+ layer_head_mask: Optional[torch.Tensor] = None,
420
+ output_attentions: bool = False,
421
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
422
+ # IndicTransFlashAttention2 attention does not support output_attentions
423
+ if output_attentions:
424
+ raise ValueError("IndicTransFlashAttention2 attention does not support output_attentions")
425
+
426
+ # if key_value_states are provided this layer is used as a cross-attention layer
427
+ # for the decoder
428
+ is_cross_attention = key_value_states is not None
429
+
430
+ bsz, q_len, _ = hidden_states.size()
431
+
432
+ # get query proj
433
+ query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
434
+ # get key, value proj
435
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
436
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
437
+ # the provided `key_value_states` to support prefix tuning
438
+ if (
439
+ is_cross_attention
440
+ and past_key_value is not None
441
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
442
+ ):
443
+ # reuse k,v, cross_attentions
444
+ key_states = past_key_value[0].transpose(1, 2)
445
+ value_states = past_key_value[1].transpose(1, 2)
446
+ elif is_cross_attention:
447
+ # cross_attentions
448
+ key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
449
+ value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
450
+ elif past_key_value is not None:
451
+ # reuse k, v, self_attention
452
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
453
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
454
+ key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
455
+ value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
456
+ else:
457
+ # self_attention
458
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
459
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
460
+
461
+ if self.is_decoder:
462
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
463
+ # Further calls to cross_attention layer can then reuse all cross-attention
464
+ # key/value_states (first "if" case)
465
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
466
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
467
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
468
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
469
+ past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
470
+
471
+ kv_seq_len = key_states.shape[-2]
472
+ if past_key_value is not None:
473
+ kv_seq_len += past_key_value[0].shape[-2]
474
+
475
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
476
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
477
+ # cast them back in the correct dtype just to be sure everything works as expected.
478
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
479
+ # in fp32. (LlamaRMSNorm handles it correctly)
480
+
481
+ input_dtype = query_states.dtype
482
+ if input_dtype == torch.float32:
483
+ if torch.is_autocast_enabled():
484
+ target_dtype = torch.get_autocast_gpu_dtype()
485
+ # Handle the case where the model is quantized
486
+ elif hasattr(self.config, "_pre_quantization_dtype"):
487
+ target_dtype = self.config._pre_quantization_dtype
488
+ else:
489
+ target_dtype = self.q_proj.weight.dtype
490
+
491
+ logger.warning_once(
492
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
493
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
494
+ f" {target_dtype}."
495
+ )
496
+
497
+ query_states = query_states.to(target_dtype)
498
+ key_states = key_states.to(target_dtype)
499
+ value_states = value_states.to(target_dtype)
500
+
501
+ attn_output = self._flash_attention_forward(
502
+ query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
503
+ )
504
+
505
+ attn_output = attn_output.reshape(bsz, q_len, -1)
506
+ attn_output = self.out_proj(attn_output)
507
+
508
+ if not output_attentions:
509
+ attn_weights = None
510
+
511
+ return attn_output, attn_weights, past_key_value
512
+
513
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
514
+ def _flash_attention_forward(
515
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
516
+ ):
517
+ """
518
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
519
+ first unpad the input, then computes the attention scores and pad the final attention scores.
520
+
521
+ Args:
522
+ query_states (`torch.Tensor`):
523
+ Input query states to be passed to Flash Attention API
524
+ key_states (`torch.Tensor`):
525
+ Input key states to be passed to Flash Attention API
526
+ value_states (`torch.Tensor`):
527
+ Input value states to be passed to Flash Attention API
528
+ attention_mask (`torch.Tensor`):
529
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
530
+ position of padding tokens and 1 for the position of non-padding tokens.
531
+ dropout (`float`):
532
+ Attention dropout
533
+ softmax_scale (`float`, *optional*):
534
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
535
+ """
536
+ if not self._flash_attn_uses_top_left_mask:
537
+ causal = self.is_causal
538
+ else:
539
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
540
+ causal = self.is_causal and query_length != 1
541
+
542
+ # Contains at least one padding token in the sequence
543
+ if attention_mask is not None:
544
+ batch_size = query_states.shape[0]
545
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
546
+ query_states, key_states, value_states, attention_mask, query_length
547
+ )
548
+
549
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
550
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
551
+
552
+ attn_output_unpad = flash_attn_varlen_func(
553
+ query_states,
554
+ key_states,
555
+ value_states,
556
+ cu_seqlens_q=cu_seqlens_q,
557
+ cu_seqlens_k=cu_seqlens_k,
558
+ max_seqlen_q=max_seqlen_in_batch_q,
559
+ max_seqlen_k=max_seqlen_in_batch_k,
560
+ dropout_p=dropout,
561
+ softmax_scale=softmax_scale,
562
+ causal=causal,
563
+ )
564
+
565
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
566
+ else:
567
+ attn_output = flash_attn_func(
568
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
569
+ )
570
+
571
+ return attn_output
572
+
573
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
574
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
575
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
576
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
577
+
578
+ key_layer = index_first_axis(
579
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
580
+ )
581
+ value_layer = index_first_axis(
582
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
583
+ )
584
+ if query_length == kv_seq_len:
585
+ query_layer = index_first_axis(
586
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
587
+ )
588
+ cu_seqlens_q = cu_seqlens_k
589
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
590
+ indices_q = indices_k
591
+ elif query_length == 1:
592
+ max_seqlen_in_batch_q = 1
593
+ cu_seqlens_q = torch.arange(
594
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
595
+ ) # There is a memcpy here, that is very bad.
596
+ indices_q = cu_seqlens_q[:-1]
597
+ query_layer = query_layer.squeeze(1)
598
+ else:
599
+ # The -q_len: slice assumes left padding.
600
+ attention_mask = attention_mask[:, -query_length:]
601
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
602
+
603
+ return (
604
+ query_layer,
605
+ key_layer,
606
+ value_layer,
607
+ indices_q,
608
+ (cu_seqlens_q, cu_seqlens_k),
609
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
610
+ )
611
+
612
+
613
+ class IndicTransSdpaAttention(IndicTransAttention):
614
+ def forward(
615
+ self,
616
+ hidden_states: torch.Tensor,
617
+ key_value_states: Optional[torch.Tensor] = None,
618
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
619
+ attention_mask: Optional[torch.Tensor] = None,
620
+ layer_head_mask: Optional[torch.Tensor] = None,
621
+ output_attentions: bool = False,
622
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
623
+ """Input shape: Batch x Time x Channel"""
624
+ if output_attentions or layer_head_mask is not None:
625
+ # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
626
+ logger.warning_once(
627
+ "IndicTransModel is using IndicTransSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
628
+ ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
629
+ )
630
+ return super().forward(
631
+ hidden_states,
632
+ key_value_states=key_value_states,
633
+ past_key_value=past_key_value,
634
+ attention_mask=attention_mask,
635
+ layer_head_mask=layer_head_mask,
636
+ output_attentions=output_attentions,
637
+ )
638
+
639
+ # if key_value_states are provided this layer is used as a cross-attention layer
640
+ # for the decoder
641
+ is_cross_attention = key_value_states is not None
642
+
643
+ bsz, tgt_len, _ = hidden_states.size()
644
+
645
+ # get query proj
646
+ query_states = self.q_proj(hidden_states)
647
+ # get key, value proj
648
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
649
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
650
+ # the provided `key_value_states` to support prefix tuning
651
+ if (
652
+ is_cross_attention
653
+ and past_key_value is not None
654
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
655
+ ):
656
+ # reuse k,v, cross_attentions
657
+ key_states = past_key_value[0]
658
+ value_states = past_key_value[1]
659
+ elif is_cross_attention:
660
+ # cross_attentions
661
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
662
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
663
+ elif past_key_value is not None:
664
+ # reuse k, v, self_attention
665
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
666
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
667
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
668
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
669
+ else:
670
+ # self_attention
671
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
672
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
673
+
674
+ if self.is_decoder:
675
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
676
+ # Further calls to cross_attention layer can then reuse all cross-attention
677
+ # key/value_states (first "if" case)
678
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
679
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
680
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
681
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
682
+ past_key_value = (key_states, value_states)
683
+
684
+ query_states = self._shape(query_states, tgt_len, bsz)
685
+
686
+ # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
687
+ # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
688
+ attn_output = F.scaled_dot_product_attention(
689
+ query_states,
690
+ key_states,
691
+ value_states,
692
+ attn_mask=attention_mask,
693
+ dropout_p=self.dropout if self.training else 0.0,
694
+ # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
695
+ is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
696
+ )
697
+
698
+ if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
699
+ raise ValueError(
700
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
701
+ f" {attn_output.size()}"
702
+ )
703
+
704
+ attn_output = attn_output.transpose(1, 2)
705
+
706
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
707
+ # partitioned across GPUs when using tensor-parallelism.
708
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
709
+
710
+ attn_output = self.out_proj(attn_output)
711
+
712
+ return attn_output, None, past_key_value
713
+
714
+
715
+ INDICTRANS_ATTENTION_CLASSES = {
716
+ "eager": IndicTransAttention,
717
+ "sdpa": IndicTransSdpaAttention,
718
+ "flash_attention_2": IndicTransFlashAttention2,
719
+ }
720
+
721
+ # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->IndicTrans
722
+ class IndicTransEncoderLayer(nn.Module):
723
+ def __init__(self, config: IndicTransConfig):
724
+ super().__init__()
725
+ self.embed_dim = config.encoder_embed_dim
726
+ self.self_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
727
+ embed_dim=self.embed_dim,
728
+ num_heads=config.encoder_attention_heads,
729
+ dropout=config.attention_dropout,
730
+ config=config,
731
+ )
732
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
733
+ self.dropout = config.dropout
734
+ self.activation_fn = ACT2FN[config.activation_function]
735
+ self.activation_dropout = config.activation_dropout
736
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
737
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
738
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
739
+ self.normalize_before = config.encoder_normalize_before
740
+
741
+ def forward(
742
+ self,
743
+ hidden_states: torch.Tensor,
744
+ attention_mask: torch.Tensor,
745
+ layer_head_mask: torch.Tensor,
746
+ output_attentions: bool = False,
747
+ ) -> torch.Tensor:
748
+ """
749
+ Args:
750
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
751
+ attention_mask (`torch.FloatTensor`): attention mask of size
752
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
753
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
754
+ `(encoder_attention_heads,)`.
755
+ output_attentions (`bool`, *optional*):
756
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
757
+ returned tensors for more detail.
758
+ """
759
+ residual = hidden_states
760
+ if self.normalize_before:
761
+ hidden_states = self.self_attn_layer_norm(hidden_states)
762
+ hidden_states, attn_weights, _ = self.self_attn(
763
+ hidden_states=hidden_states,
764
+ attention_mask=attention_mask,
765
+ layer_head_mask=layer_head_mask,
766
+ output_attentions=output_attentions,
767
+ )
768
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
769
+ hidden_states = residual + hidden_states
770
+ if not self.normalize_before:
771
+ hidden_states = self.self_attn_layer_norm(hidden_states)
772
+
773
+ residual = hidden_states
774
+ if self.normalize_before:
775
+ hidden_states = self.final_layer_norm(hidden_states)
776
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
777
+ hidden_states = F.dropout(
778
+ hidden_states, p=self.activation_dropout, training=self.training
779
+ )
780
+ hidden_states = self.fc2(hidden_states)
781
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
782
+ hidden_states = residual + hidden_states
783
+ if not self.normalize_before:
784
+ hidden_states = self.final_layer_norm(hidden_states)
785
+
786
+ if hidden_states.dtype == torch.float16 and (
787
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
788
+ ):
789
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
790
+ hidden_states = torch.clamp(
791
+ hidden_states, min=-clamp_value, max=clamp_value
792
+ )
793
+
794
+ outputs = (hidden_states,)
795
+
796
+ if output_attentions:
797
+ outputs += (attn_weights,)
798
+
799
+ return outputs
800
+
801
+
802
+ # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->IndicTrans
803
+ class IndicTransDecoderLayer(nn.Module):
804
+ def __init__(self, config: IndicTransConfig):
805
+ super().__init__()
806
+ self.embed_dim = config.decoder_embed_dim
807
+
808
+ self.self_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
809
+ embed_dim=self.embed_dim,
810
+ num_heads=config.decoder_attention_heads,
811
+ dropout=config.attention_dropout,
812
+ is_decoder=True,
813
+ is_causal=True,
814
+ config=config,
815
+ )
816
+ self.dropout = config.dropout
817
+ self.activation_fn = ACT2FN[config.activation_function]
818
+ self.activation_dropout = config.activation_dropout
819
+
820
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
821
+ self.encoder_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
822
+ self.embed_dim,
823
+ config.decoder_attention_heads,
824
+ dropout=config.attention_dropout,
825
+ is_decoder=True,
826
+ config=config,
827
+ )
828
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
829
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
830
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
831
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
832
+ self.normalize_before = config.decoder_normalize_before
833
+
834
+ def forward(
835
+ self,
836
+ hidden_states: torch.Tensor,
837
+ attention_mask: Optional[torch.Tensor] = None,
838
+ encoder_hidden_states: Optional[torch.Tensor] = None,
839
+ encoder_attention_mask: Optional[torch.Tensor] = None,
840
+ layer_head_mask: Optional[torch.Tensor] = None,
841
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
842
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
843
+ output_attentions: Optional[bool] = False,
844
+ use_cache: Optional[bool] = True,
845
+ ) -> torch.Tensor:
846
+ """
847
+ Args:
848
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
849
+ attention_mask (`torch.FloatTensor`): attention mask of size
850
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
851
+ encoder_hidden_states (`torch.FloatTensor`):
852
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
853
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
854
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
855
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
856
+ `(encoder_attention_heads,)`.
857
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
858
+ size `(decoder_attention_heads,)`.
859
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
860
+ output_attentions (`bool`, *optional*):
861
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
862
+ returned tensors for more detail.
863
+ """
864
+ residual = hidden_states
865
+ if self.normalize_before:
866
+ hidden_states = self.self_attn_layer_norm(hidden_states)
867
+
868
+ # Self Attention
869
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
870
+ self_attn_past_key_value = (
871
+ past_key_value[:2] if past_key_value is not None else None
872
+ )
873
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
874
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
875
+ hidden_states=hidden_states,
876
+ past_key_value=self_attn_past_key_value,
877
+ attention_mask=attention_mask,
878
+ layer_head_mask=layer_head_mask,
879
+ output_attentions=output_attentions,
880
+ )
881
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
882
+ hidden_states = residual + hidden_states
883
+ if not self.normalize_before:
884
+ hidden_states = self.self_attn_layer_norm(hidden_states)
885
+
886
+ # Cross-Attention Block
887
+ cross_attn_present_key_value = None
888
+ cross_attn_weights = None
889
+ if encoder_hidden_states is not None:
890
+ residual = hidden_states
891
+ if self.normalize_before:
892
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
893
+
894
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
895
+ cross_attn_past_key_value = (
896
+ past_key_value[-2:] if past_key_value is not None else None
897
+ )
898
+ (
899
+ hidden_states,
900
+ cross_attn_weights,
901
+ cross_attn_present_key_value,
902
+ ) = self.encoder_attn(
903
+ hidden_states=hidden_states,
904
+ key_value_states=encoder_hidden_states,
905
+ attention_mask=encoder_attention_mask,
906
+ layer_head_mask=cross_attn_layer_head_mask,
907
+ past_key_value=cross_attn_past_key_value,
908
+ output_attentions=output_attentions,
909
+ )
910
+ hidden_states = F.dropout(
911
+ hidden_states, p=self.dropout, training=self.training
912
+ )
913
+ hidden_states = residual + hidden_states
914
+ if not self.normalize_before:
915
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
916
+
917
+ # add cross-attn to positions 3,4 of present_key_value tuple
918
+ present_key_value = present_key_value + cross_attn_present_key_value
919
+
920
+ # Fully Connected
921
+ residual = hidden_states
922
+ if self.normalize_before:
923
+ hidden_states = self.final_layer_norm(hidden_states)
924
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
925
+ hidden_states = F.dropout(
926
+ hidden_states, p=self.activation_dropout, training=self.training
927
+ )
928
+ hidden_states = self.fc2(hidden_states)
929
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
930
+ hidden_states = residual + hidden_states
931
+ if not self.normalize_before:
932
+ hidden_states = self.final_layer_norm(hidden_states)
933
+
934
+ outputs = (hidden_states,)
935
+
936
+ if output_attentions:
937
+ outputs += (self_attn_weights, cross_attn_weights)
938
+
939
+ if use_cache:
940
+ outputs += (present_key_value,)
941
+
942
+ return outputs
943
+
944
+
945
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100PretrainedModel->IndicTrans
946
+ class IndicTransPreTrainedModel(PreTrainedModel):
947
+ config_class = IndicTransConfig
948
+ base_model_prefix = "model"
949
+ supports_gradient_checkpointing = True
950
+ _no_split_modules = ["IndicTransAttention"]
951
+
952
+ def _init_weights(self, module):
953
+ std = self.config.init_std
954
+ if isinstance(module, nn.Linear):
955
+ module.weight.data.normal_(mean=0.0, std=std)
956
+ if module.bias is not None:
957
+ module.bias.data.zero_()
958
+ elif isinstance(module, nn.Embedding):
959
+ module.weight.data.normal_(mean=0.0, std=std)
960
+ if module.padding_idx is not None:
961
+ module.weight.data[module.padding_idx].zero_()
962
+
963
+ def _set_gradient_checkpointing(self, module, value=False):
964
+ if isinstance(module, (IndicTransDecoder, IndicTransEncoder)):
965
+ module.gradient_checkpointing = value
966
+
967
+
968
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100EncoderLayer->IndicTrans
969
+ class IndicTransEncoder(IndicTransPreTrainedModel):
970
+ """
971
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
972
+ [`IndicTransEncoderLayer`].
973
+
974
+ Args:
975
+ config: IndicTransConfig
976
+ embed_tokens (nn.Embedding): output embedding
977
+ """
978
+
979
+ def __init__(
980
+ self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None
981
+ ):
982
+ super().__init__(config)
983
+
984
+ self.dropout = config.dropout
985
+ self.layerdrop = config.encoder_layerdrop
986
+
987
+ embed_dim = config.encoder_embed_dim
988
+ self.padding_idx = config.pad_token_id
989
+ self.max_source_positions = config.max_source_positions
990
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
991
+
992
+ self.embed_tokens = nn.Embedding(
993
+ config.encoder_vocab_size, embed_dim, self.padding_idx
994
+ )
995
+
996
+ if embed_tokens is not None:
997
+ self.embed_tokens.weight = embed_tokens.weight
998
+
999
+ self.embed_positions = IndicTransSinusoidalPositionalEmbedding(
1000
+ config.max_source_positions,
1001
+ embed_dim,
1002
+ self.padding_idx,
1003
+ )
1004
+ self.layers = nn.ModuleList(
1005
+ [IndicTransEncoderLayer(config) for _ in range(config.encoder_layers)]
1006
+ )
1007
+ self.layer_norm = (
1008
+ nn.LayerNorm(embed_dim) if config.encoder_normalize_before else None
1009
+ )
1010
+ self.layernorm_embedding = (
1011
+ nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
1012
+ )
1013
+
1014
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1015
+ self._use_sdpa = config._attn_implementation == "sdpa"
1016
+
1017
+ self.gradient_checkpointing = False
1018
+ # Initialize weights and apply final processing
1019
+ self.post_init()
1020
+
1021
+ def forward(
1022
+ self,
1023
+ input_ids: Optional[torch.Tensor] = None,
1024
+ attention_mask: Optional[torch.Tensor] = None,
1025
+ head_mask: Optional[torch.Tensor] = None,
1026
+ inputs_embeds: Optional[torch.Tensor] = None,
1027
+ output_attentions: Optional[bool] = None,
1028
+ output_hidden_states: Optional[bool] = None,
1029
+ return_dict: Optional[bool] = None,
1030
+ ):
1031
+ r"""
1032
+ Args:
1033
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1034
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
1035
+ provide it.
1036
+
1037
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1038
+ [`PreTrainedTokenizer.__call__`] for details.
1039
+
1040
+ [What are input IDs?](../glossary#input-ids)
1041
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1042
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1043
+
1044
+ - 1 for tokens that are **not masked**,
1045
+ - 0 for tokens that are **masked**.
1046
+
1047
+ [What are attention masks?](../glossary#attention-mask)
1048
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
1049
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
1050
+
1051
+ - 1 indicates the head is **not masked**,
1052
+ - 0 indicates the head is **masked**.
1053
+
1054
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1055
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
1056
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
1057
+ than the model's internal embedding lookup matrix.
1058
+ output_attentions (`bool`, *optional*):
1059
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1060
+ returned tensors for more detail.
1061
+ output_hidden_states (`bool`, *optional*):
1062
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1063
+ for more detail.
1064
+ return_dict (`bool`, *optional*):
1065
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1066
+ """
1067
+ output_attentions = (
1068
+ output_attentions
1069
+ if output_attentions is not None
1070
+ else self.config.output_attentions
1071
+ )
1072
+ output_hidden_states = (
1073
+ output_hidden_states
1074
+ if output_hidden_states is not None
1075
+ else self.config.output_hidden_states
1076
+ )
1077
+ return_dict = (
1078
+ return_dict if return_dict is not None else self.config.use_return_dict
1079
+ )
1080
+
1081
+ # retrieve input_ids and inputs_embeds
1082
+ if input_ids is not None and inputs_embeds is not None:
1083
+ raise ValueError(
1084
+ "You cannot specify both input_ids and inputs_embeds at the same time"
1085
+ )
1086
+ elif input_ids is not None:
1087
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1088
+ input_shape = input_ids.size()
1089
+ input_ids = input_ids.view(-1, input_shape[-1])
1090
+ elif inputs_embeds is not None:
1091
+ input_shape = inputs_embeds.size()[:-1]
1092
+ else:
1093
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1094
+
1095
+ if inputs_embeds is None:
1096
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1097
+
1098
+ embed_pos = self.embed_positions(input_ids, inputs_embeds)
1099
+ embed_pos = embed_pos.to(inputs_embeds.device)
1100
+
1101
+ hidden_states = inputs_embeds + embed_pos
1102
+ if self.layernorm_embedding is not None:
1103
+ hidden_states = self.layernorm_embedding(hidden_states)
1104
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
1105
+
1106
+ if attention_mask is not None:
1107
+ if self._use_flash_attention_2:
1108
+ attention_mask = attention_mask if 0 in attention_mask else None
1109
+ elif self._use_sdpa and head_mask is None and not output_attentions:
1110
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
1111
+ # the manual implementation that requires a 4D causal mask in all cases.
1112
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1113
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
1114
+ else:
1115
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1116
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
1117
+
1118
+
1119
+ encoder_states = () if output_hidden_states else None
1120
+ all_attentions = () if output_attentions else None
1121
+
1122
+ # check if head_mask has a correct number of layers specified if desired
1123
+ if head_mask is not None:
1124
+ if head_mask.size()[0] != len(self.layers):
1125
+ raise ValueError(
1126
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
1127
+ f" {head_mask.size()[0]}."
1128
+ )
1129
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
1130
+
1131
+ for idx, encoder_layer in enumerate(self.layers):
1132
+ if output_hidden_states:
1133
+ encoder_states = encoder_states + (hidden_states,)
1134
+
1135
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1136
+ dropout_probability = torch.rand([])
1137
+
1138
+ skip_the_layer = (
1139
+ True
1140
+ if self.training and (dropout_probability < self.layerdrop)
1141
+ else False
1142
+ )
1143
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
1144
+ # under deepspeed zero3 all gpus must run in sync
1145
+
1146
+ if self.gradient_checkpointing and self.training:
1147
+ # create gradient checkpointing function
1148
+ def create_custom_forward(module):
1149
+ def custom_forward(*inputs):
1150
+ return module(*inputs, output_attentions)
1151
+
1152
+ return custom_forward
1153
+
1154
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1155
+ create_custom_forward(encoder_layer),
1156
+ hidden_states,
1157
+ attention_mask,
1158
+ (head_mask[idx] if head_mask is not None else None),
1159
+ )
1160
+ else:
1161
+ layer_outputs = encoder_layer(
1162
+ hidden_states,
1163
+ attention_mask,
1164
+ layer_head_mask=(
1165
+ head_mask[idx] if head_mask is not None else None
1166
+ ),
1167
+ output_attentions=output_attentions,
1168
+ )
1169
+
1170
+ hidden_states = layer_outputs[0]
1171
+
1172
+ if skip_the_layer:
1173
+ layer_outputs = (None, None)
1174
+
1175
+ if output_attentions:
1176
+ all_attentions = all_attentions + (layer_outputs[1],)
1177
+
1178
+ if self.layer_norm is not None:
1179
+ hidden_states = self.layer_norm(hidden_states)
1180
+
1181
+ if output_hidden_states:
1182
+ encoder_states = encoder_states + (hidden_states,)
1183
+
1184
+ if not return_dict:
1185
+ return tuple(
1186
+ v
1187
+ for v in [hidden_states, encoder_states, all_attentions]
1188
+ if v is not None
1189
+ )
1190
+ return BaseModelOutput(
1191
+ last_hidden_state=hidden_states,
1192
+ hidden_states=encoder_states,
1193
+ attentions=all_attentions,
1194
+ )
1195
+
1196
+
1197
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100DecoderLayer->IndicTrans
1198
+ class IndicTransDecoder(IndicTransPreTrainedModel):
1199
+ """
1200
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`IndicTransDecoderLayer`]
1201
+
1202
+ Args:
1203
+ config: IndicTransConfig
1204
+ embed_tokens (nn.Embedding): output embedding
1205
+ """
1206
+
1207
+ def __init__(
1208
+ self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None
1209
+ ):
1210
+ super().__init__(config)
1211
+ self.dropout = config.dropout
1212
+ self.layerdrop = config.decoder_layerdrop
1213
+
1214
+ embed_dim = config.encoder_embed_dim
1215
+ self.padding_idx = config.pad_token_id
1216
+ self.max_target_positions = config.max_target_positions
1217
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
1218
+
1219
+ self.embed_tokens = nn.Embedding(
1220
+ config.decoder_vocab_size, embed_dim, self.padding_idx
1221
+ )
1222
+
1223
+ if embed_tokens is not None:
1224
+ self.embed_tokens.weight = embed_tokens.weight
1225
+
1226
+ self.embed_positions = IndicTransSinusoidalPositionalEmbedding(
1227
+ config.max_target_positions,
1228
+ embed_dim,
1229
+ self.padding_idx,
1230
+ )
1231
+ self.layers = nn.ModuleList(
1232
+ [IndicTransDecoderLayer(config) for _ in range(config.decoder_layers)]
1233
+ )
1234
+ self.layer_norm = (
1235
+ nn.LayerNorm(embed_dim) if config.decoder_normalize_before else None
1236
+ )
1237
+ self.layernorm_embedding = (
1238
+ nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
1239
+ )
1240
+
1241
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1242
+ self._use_sdpa = config._attn_implementation == "sdpa"
1243
+
1244
+ self.gradient_checkpointing = False
1245
+ # Initialize weights and apply final processing
1246
+ self.post_init()
1247
+
1248
+ def forward(
1249
+ self,
1250
+ input_ids: Optional[torch.Tensor] = None,
1251
+ attention_mask: Optional[torch.Tensor] = None,
1252
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1253
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1254
+ head_mask: Optional[torch.Tensor] = None,
1255
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1256
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1257
+ inputs_embeds: Optional[torch.Tensor] = None,
1258
+ use_cache: Optional[bool] = None,
1259
+ output_attentions: Optional[bool] = None,
1260
+ output_hidden_states: Optional[bool] = None,
1261
+ return_dict: Optional[bool] = None,
1262
+ ):
1263
+ r"""
1264
+ Args:
1265
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1266
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
1267
+ provide it.
1268
+
1269
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1270
+ [`PreTrainedTokenizer.__call__`] for details.
1271
+
1272
+ [What are input IDs?](../glossary#input-ids)
1273
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1274
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1275
+
1276
+ - 1 for tokens that are **not masked**,
1277
+ - 0 for tokens that are **masked**.
1278
+
1279
+ [What are attention masks?](../glossary#attention-mask)
1280
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
1281
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
1282
+ of the decoder.
1283
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
1284
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
1285
+ selected in `[0, 1]`:
1286
+
1287
+ - 1 for tokens that are **not masked**,
1288
+ - 0 for tokens that are **masked**.
1289
+
1290
+ [What are attention masks?](../glossary#attention-mask)
1291
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1292
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
1293
+
1294
+ - 1 indicates the head is **not masked**,
1295
+ - 0 indicates the head is **masked**.
1296
+
1297
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1298
+ Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
1299
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
1300
+
1301
+ - 1 indicates the head is **not masked**,
1302
+ - 0 indicates the head is **masked**.
1303
+
1304
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1305
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1306
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
1307
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
1308
+
1309
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
1310
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1311
+
1312
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
1313
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
1314
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
1315
+ shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
1316
+ `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
1317
+ control over how to convert `input_ids` indices into associated vectors than the model's internal
1318
+ embedding lookup matrix.
1319
+ output_attentions (`bool`, *optional*):
1320
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1321
+ returned tensors for more detail.
1322
+ output_hidden_states (`bool`, *optional*):
1323
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1324
+ for more detail.
1325
+ return_dict (`bool`, *optional*):
1326
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1327
+ """
1328
+ output_attentions = (
1329
+ output_attentions
1330
+ if output_attentions is not None
1331
+ else self.config.output_attentions
1332
+ )
1333
+ output_hidden_states = (
1334
+ output_hidden_states
1335
+ if output_hidden_states is not None
1336
+ else self.config.output_hidden_states
1337
+ )
1338
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1339
+ return_dict = (
1340
+ return_dict if return_dict is not None else self.config.use_return_dict
1341
+ )
1342
+
1343
+ # retrieve input_ids and inputs_embeds
1344
+ if input_ids is not None and inputs_embeds is not None:
1345
+ raise ValueError(
1346
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
1347
+ )
1348
+ elif input_ids is not None:
1349
+ input_shape = input_ids.size()
1350
+ input_ids = input_ids.view(-1, input_shape[-1])
1351
+ elif inputs_embeds is not None:
1352
+ input_shape = inputs_embeds.size()[:-1]
1353
+ else:
1354
+ raise ValueError(
1355
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
1356
+ )
1357
+
1358
+ # past_key_values_length
1359
+ past_key_values_length = (
1360
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
1361
+ )
1362
+
1363
+ if inputs_embeds is None:
1364
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1365
+
1366
+
1367
+ if self._use_flash_attention_2:
1368
+ # 2d mask is passed through the layers
1369
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1370
+ elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
1371
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
1372
+ # the manual implementation that requires a 4D causal mask in all cases.
1373
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1374
+ attention_mask,
1375
+ input_shape,
1376
+ inputs_embeds,
1377
+ past_key_values_length,
1378
+ )
1379
+ else:
1380
+ # 4d mask is passed through the layers
1381
+ attention_mask = _prepare_4d_causal_attention_mask(
1382
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
1383
+ )
1384
+
1385
+ # expand encoder attention mask
1386
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
1387
+ if self._use_flash_attention_2:
1388
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
1389
+ elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
1390
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
1391
+ # the manual implementation that requires a 4D causal mask in all cases.
1392
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1393
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
1394
+ encoder_attention_mask,
1395
+ inputs_embeds.dtype,
1396
+ tgt_len=input_shape[-1],
1397
+ )
1398
+ else:
1399
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1400
+ encoder_attention_mask = _prepare_4d_attention_mask(
1401
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1402
+ )
1403
+
1404
+ # embed positions
1405
+ positions = self.embed_positions(
1406
+ input_ids, inputs_embeds, past_key_values_length
1407
+ )
1408
+ positions = positions.to(inputs_embeds.device)
1409
+
1410
+ hidden_states = inputs_embeds + positions
1411
+ if self.layernorm_embedding is not None:
1412
+ hidden_states = self.layernorm_embedding(hidden_states)
1413
+
1414
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
1415
+
1416
+ if self.gradient_checkpointing and self.training:
1417
+ if use_cache:
1418
+ logger.warning_once(
1419
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting"
1420
+ " `use_cache=False`..."
1421
+ )
1422
+ use_cache = False
1423
+
1424
+ # decoder layers
1425
+ all_hidden_states = () if output_hidden_states else None
1426
+ all_self_attns = () if output_attentions else None
1427
+ all_cross_attentions = () if output_attentions else None
1428
+ next_decoder_cache = () if use_cache else None
1429
+
1430
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
1431
+ for attn_mask, mask_name in zip(
1432
+ [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
1433
+ ):
1434
+ if attn_mask is not None:
1435
+ if attn_mask.size()[0] != len(self.layers):
1436
+ raise ValueError(
1437
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
1438
+ f" {head_mask.size()[0]}."
1439
+ )
1440
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
1441
+
1442
+ for idx, decoder_layer in enumerate(self.layers):
1443
+ if output_hidden_states:
1444
+ all_hidden_states += (hidden_states,)
1445
+
1446
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1447
+ dropout_probability = torch.rand([])
1448
+
1449
+ skip_the_layer = (
1450
+ True
1451
+ if self.training and (dropout_probability < self.layerdrop)
1452
+ else False
1453
+ )
1454
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
1455
+ # under deepspeed zero3 all gpus must run in sync
1456
+
1457
+ past_key_value = (
1458
+ past_key_values[idx] if past_key_values is not None else None
1459
+ )
1460
+
1461
+ if self.gradient_checkpointing and self.training:
1462
+
1463
+ def create_custom_forward(module):
1464
+ def custom_forward(*inputs):
1465
+ # None for past_key_value
1466
+ return module(*inputs, output_attentions, use_cache)
1467
+
1468
+ return custom_forward
1469
+
1470
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1471
+ create_custom_forward(decoder_layer),
1472
+ hidden_states,
1473
+ attention_mask,
1474
+ encoder_hidden_states,
1475
+ encoder_attention_mask,
1476
+ head_mask[idx] if head_mask is not None else None,
1477
+ cross_attn_head_mask[idx]
1478
+ if cross_attn_head_mask is not None
1479
+ else None,
1480
+ None,
1481
+ )
1482
+ else:
1483
+ layer_outputs = decoder_layer(
1484
+ hidden_states,
1485
+ attention_mask=attention_mask,
1486
+ encoder_hidden_states=encoder_hidden_states,
1487
+ encoder_attention_mask=encoder_attention_mask,
1488
+ layer_head_mask=(
1489
+ head_mask[idx] if head_mask is not None else None
1490
+ ),
1491
+ cross_attn_layer_head_mask=(
1492
+ cross_attn_head_mask[idx]
1493
+ if cross_attn_head_mask is not None
1494
+ else None
1495
+ ),
1496
+ past_key_value=past_key_value,
1497
+ output_attentions=output_attentions,
1498
+ use_cache=use_cache,
1499
+ )
1500
+
1501
+ hidden_states = layer_outputs[0]
1502
+
1503
+ if skip_the_layer:
1504
+ continue
1505
+
1506
+ if use_cache:
1507
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
1508
+
1509
+ if output_attentions:
1510
+ all_self_attns += (layer_outputs[1],)
1511
+ all_cross_attentions += (layer_outputs[2],)
1512
+
1513
+ if self.layer_norm is not None:
1514
+ hidden_states = self.layer_norm(hidden_states)
1515
+
1516
+ # add hidden states from the last decoder layer
1517
+ if output_hidden_states:
1518
+ all_hidden_states += (hidden_states,)
1519
+
1520
+ next_cache = next_decoder_cache if use_cache else None
1521
+ if not return_dict:
1522
+ return tuple(
1523
+ v
1524
+ for v in [
1525
+ hidden_states,
1526
+ next_cache,
1527
+ all_hidden_states,
1528
+ all_self_attns,
1529
+ all_cross_attentions,
1530
+ ]
1531
+ if v is not None
1532
+ )
1533
+ return BaseModelOutputWithPastAndCrossAttentions(
1534
+ last_hidden_state=hidden_states,
1535
+ past_key_values=next_cache,
1536
+ hidden_states=all_hidden_states,
1537
+ attentions=all_self_attns,
1538
+ cross_attentions=all_cross_attentions,
1539
+ )
1540
+
1541
+
1542
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100Model->IndicTrans
1543
+ class IndicTransModel(IndicTransPreTrainedModel):
1544
+ _tied_weights_keys = None
1545
+
1546
+ def __init__(self, config: IndicTransConfig):
1547
+ super().__init__(config)
1548
+
1549
+ self.encoder = IndicTransEncoder(config)
1550
+ self.decoder = IndicTransDecoder(config)
1551
+
1552
+ # Initialize weights and apply final processing
1553
+ self.post_init()
1554
+
1555
+ def get_encoder(self):
1556
+ return self.encoder
1557
+
1558
+ def get_decoder(self):
1559
+ return self.decoder
1560
+
1561
+ def forward(
1562
+ self,
1563
+ input_ids: Optional[torch.LongTensor] = None,
1564
+ attention_mask: Optional[torch.Tensor] = None,
1565
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1566
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1567
+ head_mask: Optional[torch.Tensor] = None,
1568
+ decoder_head_mask: Optional[torch.Tensor] = None,
1569
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1570
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1571
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1572
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1573
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1574
+ use_cache: Optional[bool] = None,
1575
+ output_attentions: Optional[bool] = None,
1576
+ output_hidden_states: Optional[bool] = None,
1577
+ return_dict: Optional[bool] = None,
1578
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
1579
+ output_attentions = (
1580
+ output_attentions
1581
+ if output_attentions is not None
1582
+ else self.config.output_attentions
1583
+ )
1584
+ output_hidden_states = (
1585
+ output_hidden_states
1586
+ if output_hidden_states is not None
1587
+ else self.config.output_hidden_states
1588
+ )
1589
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1590
+ return_dict = (
1591
+ return_dict if return_dict is not None else self.config.use_return_dict
1592
+ )
1593
+
1594
+ if encoder_outputs is None:
1595
+ encoder_outputs = self.encoder(
1596
+ input_ids=input_ids,
1597
+ attention_mask=attention_mask,
1598
+ head_mask=head_mask,
1599
+ inputs_embeds=inputs_embeds,
1600
+ output_attentions=output_attentions,
1601
+ output_hidden_states=output_hidden_states,
1602
+ return_dict=return_dict,
1603
+ )
1604
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1605
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1606
+ encoder_outputs = BaseModelOutput(
1607
+ last_hidden_state=encoder_outputs[0],
1608
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1609
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1610
+ )
1611
+
1612
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
1613
+ decoder_outputs = self.decoder(
1614
+ input_ids=decoder_input_ids,
1615
+ attention_mask=decoder_attention_mask,
1616
+ encoder_hidden_states=encoder_outputs[0],
1617
+ encoder_attention_mask=attention_mask,
1618
+ head_mask=decoder_head_mask,
1619
+ cross_attn_head_mask=cross_attn_head_mask,
1620
+ past_key_values=past_key_values,
1621
+ inputs_embeds=decoder_inputs_embeds,
1622
+ use_cache=use_cache,
1623
+ output_attentions=output_attentions,
1624
+ output_hidden_states=output_hidden_states,
1625
+ return_dict=return_dict,
1626
+ )
1627
+
1628
+ if not return_dict:
1629
+ return decoder_outputs + encoder_outputs
1630
+
1631
+ return Seq2SeqModelOutput(
1632
+ last_hidden_state=decoder_outputs.last_hidden_state,
1633
+ past_key_values=decoder_outputs.past_key_values,
1634
+ decoder_hidden_states=decoder_outputs.hidden_states,
1635
+ decoder_attentions=decoder_outputs.attentions,
1636
+ cross_attentions=decoder_outputs.cross_attentions,
1637
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1638
+ encoder_hidden_states=encoder_outputs.hidden_states,
1639
+ encoder_attentions=encoder_outputs.attentions,
1640
+ )
1641
+
1642
+
1643
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->IndicTrans
1644
+ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1645
+ base_model_prefix = "model"
1646
+ _tied_weights_keys = None
1647
+ _label_smoothing = 0.0
1648
+
1649
+ def __init__(self, config: IndicTransConfig):
1650
+ super().__init__(config)
1651
+ self.model = IndicTransModel(config)
1652
+ self.lm_head = nn.Linear(
1653
+ config.decoder_embed_dim, config.decoder_vocab_size, bias=False
1654
+ )
1655
+
1656
+ if config.share_decoder_input_output_embed:
1657
+ self.lm_head.weight = self.model.decoder.embed_tokens.weight
1658
+
1659
+ self.post_init()
1660
+
1661
+ def tie_weights(self):
1662
+ pass
1663
+
1664
+ def get_encoder(self):
1665
+ return self.model.get_encoder()
1666
+
1667
+ def get_decoder(self):
1668
+ return self.model.get_decoder()
1669
+
1670
+ def get_output_embeddings(self):
1671
+ return self.lm_head
1672
+
1673
+ def set_output_embeddings(self, new_embeddings):
1674
+ self.lm_head = new_embeddings
1675
+
1676
+ def set_label_smoothing(self, label_smoothing):
1677
+ self._label_smoothing = label_smoothing
1678
+
1679
+ def forward(
1680
+ self,
1681
+ input_ids: Optional[torch.LongTensor] = None,
1682
+ attention_mask: Optional[torch.Tensor] = None,
1683
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1684
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1685
+ head_mask: Optional[torch.Tensor] = None,
1686
+ decoder_head_mask: Optional[torch.Tensor] = None,
1687
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1688
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1689
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1690
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1691
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1692
+ labels: Optional[torch.LongTensor] = None,
1693
+ use_cache: Optional[bool] = None,
1694
+ output_attentions: Optional[bool] = None,
1695
+ output_hidden_states: Optional[bool] = None,
1696
+ return_dict: Optional[bool] = None,
1697
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
1698
+ r"""
1699
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1700
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1701
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1702
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1703
+
1704
+ Returns:
1705
+ """
1706
+ return_dict = (
1707
+ return_dict if return_dict is not None else self.config.use_return_dict
1708
+ )
1709
+
1710
+ if labels is not None:
1711
+ if decoder_input_ids is None:
1712
+ decoder_input_ids = shift_tokens_right(
1713
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
1714
+ )
1715
+
1716
+ outputs = self.model(
1717
+ input_ids,
1718
+ attention_mask=attention_mask,
1719
+ decoder_input_ids=decoder_input_ids,
1720
+ encoder_outputs=encoder_outputs,
1721
+ decoder_attention_mask=decoder_attention_mask,
1722
+ head_mask=head_mask,
1723
+ decoder_head_mask=decoder_head_mask,
1724
+ cross_attn_head_mask=cross_attn_head_mask,
1725
+ past_key_values=past_key_values,
1726
+ inputs_embeds=inputs_embeds,
1727
+ decoder_inputs_embeds=decoder_inputs_embeds,
1728
+ use_cache=use_cache,
1729
+ output_attentions=output_attentions,
1730
+ output_hidden_states=output_hidden_states,
1731
+ return_dict=return_dict,
1732
+ )
1733
+ lm_logits = self.lm_head(outputs[0])
1734
+
1735
+ masked_lm_loss = None
1736
+ if labels is not None:
1737
+ # move labels to the correct device to enable PP
1738
+ labels = labels.to(lm_logits.device)
1739
+ masked_lm_loss = F.cross_entropy(
1740
+ input=lm_logits.view(-1, self.config.decoder_vocab_size),
1741
+ target=labels.view(-1),
1742
+ ignore_index=-100,
1743
+ label_smoothing=self._label_smoothing,
1744
+ )
1745
+
1746
+ if not return_dict:
1747
+ output = (lm_logits,) + outputs[1:]
1748
+ return (
1749
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1750
+ )
1751
+
1752
+ return Seq2SeqLMOutput(
1753
+ loss=masked_lm_loss,
1754
+ logits=lm_logits,
1755
+ past_key_values=outputs.past_key_values,
1756
+ decoder_hidden_states=outputs.decoder_hidden_states,
1757
+ decoder_attentions=outputs.decoder_attentions,
1758
+ cross_attentions=outputs.cross_attentions,
1759
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1760
+ encoder_hidden_states=outputs.encoder_hidden_states,
1761
+ encoder_attentions=outputs.encoder_attentions,
1762
+ )
1763
+
1764
+ def prepare_inputs_for_generation(
1765
+ self,
1766
+ decoder_input_ids,
1767
+ past_key_values=None,
1768
+ attention_mask=None,
1769
+ head_mask=None,
1770
+ decoder_head_mask=None,
1771
+ cross_attn_head_mask=None,
1772
+ use_cache=None,
1773
+ encoder_outputs=None,
1774
+ **kwargs,
1775
+ ):
1776
+ # cut decoder_input_ids if past is used
1777
+ if past_key_values is not None:
1778
+ decoder_input_ids = decoder_input_ids[:, -1:]
1779
+
1780
+ return {
1781
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
1782
+ "encoder_outputs": encoder_outputs,
1783
+ "past_key_values": past_key_values,
1784
+ "decoder_input_ids": decoder_input_ids,
1785
+ "attention_mask": attention_mask,
1786
+ "head_mask": head_mask,
1787
+ "decoder_head_mask": decoder_head_mask,
1788
+ "cross_attn_head_mask": cross_attn_head_mask,
1789
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
1790
+ }
1791
+
1792
+ @staticmethod
1793
+ def _reorder_cache(past_key_values, beam_idx):
1794
+ reordered_past = ()
1795
+ for layer_past in past_key_values:
1796
+ reordered_past += (
1797
+ tuple(
1798
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1799
+ ),
1800
+ )
1801
+ return reordered_past
IndicTrans2/huggingface_interface/train_lora.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import pandas as pd
4
+ from datasets import Dataset
5
+ from sacrebleu.metrics import BLEU, CHRF
6
+ from peft import LoraConfig, get_peft_model
7
+ from IndicTransToolkit import IndicProcessor, IndicDataCollator
8
+
9
+ from transformers import (
10
+ Seq2SeqTrainer,
11
+ Seq2SeqTrainingArguments,
12
+ AutoModelForSeq2SeqLM,
13
+ AutoTokenizer,
14
+ EarlyStoppingCallback,
15
+ )
16
+
17
+ bleu_metric = BLEU()
18
+ chrf_metric = CHRF()
19
+
20
+
21
+ def get_arg_parse():
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument(
24
+ "--model",
25
+ type=str,
26
+ )
27
+ parser.add_argument(
28
+ "--src_lang_list",
29
+ type=str,
30
+ help="comma separated list of source languages",
31
+ )
32
+ parser.add_argument(
33
+ "--tgt_lang_list",
34
+ type=str,
35
+ help="comma separated list of target languages",
36
+ )
37
+ parser.add_argument("--data_dir", type=str)
38
+ parser.add_argument("--output_dir", type=str)
39
+ parser.add_argument("--save_steps", type=int, default=1000)
40
+ parser.add_argument("--eval_steps", type=int, default=1000)
41
+ parser.add_argument("--batch_size", type=int, default=32)
42
+ parser.add_argument("--num_train_epochs", type=int, default=100)
43
+ parser.add_argument("--max_steps", type=int, default=1000000)
44
+ parser.add_argument("--grad_accum_steps", type=int, default=4)
45
+ parser.add_argument("--warmup_steps", type=int, default=4000)
46
+ parser.add_argument("--warmup_ratio", type=int, default=0.0)
47
+ parser.add_argument("--max_grad_norm", type=float, default=1.0)
48
+ parser.add_argument("--learning_rate", type=float, default=5e-4)
49
+ parser.add_argument("--weight_decay", type=float, default=0.0)
50
+ parser.add_argument("--adam_beta1", type=float, default=0.9)
51
+ parser.add_argument("--adam_beta2", type=float, default=0.98)
52
+ parser.add_argument("--dropout", type=float, default=0.0)
53
+ parser.add_argument("--print_samples", action="store_true")
54
+ parser.add_argument(
55
+ "--optimizer",
56
+ type=str,
57
+ default="adamw_torch",
58
+ choices=[
59
+ "adam_hf",
60
+ "adamw_torch",
61
+ "adamw_torch_fused",
62
+ "adamw_apex_fused",
63
+ "adafactor",
64
+ ],
65
+ )
66
+ parser.add_argument(
67
+ "--lr_scheduler",
68
+ type=str,
69
+ default="inverse_sqrt",
70
+ choices=[
71
+ "inverse_sqrt",
72
+ "linear",
73
+ "polynomial",
74
+ "cosine",
75
+ "constant",
76
+ "constant_with_warmup",
77
+ ],
78
+ )
79
+ parser.add_argument("--label_smoothing", type=float, default=0.0)
80
+ parser.add_argument("--num_workers", type=int, default=8)
81
+ parser.add_argument("--metric_for_best_model", type=str, default="eval_loss")
82
+ parser.add_argument("--greater_is_better", action="store_true")
83
+ parser.add_argument("--lora_target_modules", type=str, default="q_proj,k_proj")
84
+ parser.add_argument("--lora_dropout", type=float, default=0.1)
85
+ parser.add_argument("--lora_r", type=int, default=16)
86
+ parser.add_argument("--lora_alpha", type=int, default=32)
87
+ parser.add_argument(
88
+ "--report_to",
89
+ type=str,
90
+ default="none",
91
+ choices=["wandb", "tensorboard", "azure_ml", "none"],
92
+ )
93
+ parser.add_argument("--patience", type=int, default=5),
94
+ parser.add_argument("--threshold", type=float, default=1e-3)
95
+ return parser
96
+
97
+
98
+ def load_and_process_translation_dataset(
99
+ data_dir,
100
+ split="train",
101
+ tokenizer=None,
102
+ processor=None,
103
+ src_lang_list=None,
104
+ tgt_lang_list=None,
105
+ num_proc=8,
106
+ seed=42
107
+ ):
108
+ complete_dataset = {
109
+ "sentence_SRC": [],
110
+ "sentence_TGT": [],
111
+ }
112
+
113
+ for src_lang in src_lang_list:
114
+ for tgt_lang in tgt_lang_list:
115
+ if src_lang == tgt_lang:
116
+ continue
117
+ src_path = os.path.join(
118
+ data_dir, split, f"{src_lang}-{tgt_lang}", f"{split}.{src_lang}"
119
+ )
120
+ tgt_path = os.path.join(
121
+ data_dir, split, f"{src_lang}-{tgt_lang}", f"{split}.{tgt_lang}"
122
+ )
123
+ if not os.path.exists(src_path) or not os.path.exists(tgt_path):
124
+ raise FileNotFoundError(
125
+ f"Source ({split}.{src_lang}) or Target ({split}.{tgt_lang}) file not found in {data_dir}"
126
+ )
127
+ with open(src_path, encoding="utf-8") as src_file, open(
128
+ tgt_path, encoding="utf-8"
129
+ ) as tgt_file:
130
+ src_lines = src_file.readlines()
131
+ tgt_lines = tgt_file.readlines()
132
+
133
+ # Ensure both files have the same number of lines
134
+ assert len(src_lines) == len(
135
+ tgt_lines
136
+ ), f"Source and Target files have different number of lines for {split}.{src_lang} and {split}.{tgt_lang}"
137
+
138
+ complete_dataset["sentence_SRC"] += processor.preprocess_batch(
139
+ src_lines, src_lang=src_lang, tgt_lang=tgt_lang, is_target=False
140
+ )
141
+
142
+ complete_dataset["sentence_TGT"] += processor.preprocess_batch(
143
+ tgt_lines, src_lang=tgt_lang, tgt_lang=src_lang, is_target=True
144
+ )
145
+
146
+ complete_dataset = Dataset.from_dict(complete_dataset).shuffle(seed=seed)
147
+
148
+ return complete_dataset.map(
149
+ lambda example: preprocess_fn(
150
+ example,
151
+ tokenizer=tokenizer
152
+ ),
153
+ batched=True,
154
+ num_proc=num_proc,
155
+ )
156
+
157
+
158
+ def compute_metrics_factory(
159
+ tokenizer, metric_dict=None, print_samples=False, n_samples=10
160
+ ):
161
+ def compute_metrics(eval_preds):
162
+ preds, labels = eval_preds
163
+
164
+ labels[labels == -100] = tokenizer.pad_token_id
165
+ preds[preds == -100] = tokenizer.pad_token_id
166
+
167
+ with tokenizer.as_target_tokenizer():
168
+ preds = [
169
+ x.strip()
170
+ for x in tokenizer.batch_decode(
171
+ preds, skip_special_tokens=True, clean_up_tokenization_spaces=True
172
+ )
173
+ ]
174
+ labels = [
175
+ x.strip()
176
+ for x in tokenizer.batch_decode(
177
+ labels, skip_special_tokens=True, clean_up_tokenization_spaces=True
178
+ )
179
+ ]
180
+
181
+ assert len(preds) == len(
182
+ labels
183
+ ), "Predictions and Labels have different lengths"
184
+
185
+ df = pd.DataFrame({"Predictions": preds, "References": labels}).sample(
186
+ n=n_samples
187
+ )
188
+
189
+ if print_samples:
190
+ for pred, label in zip(df["Predictions"].values, df["References"].values):
191
+ print(f" | > Prediction: {pred}")
192
+ print(f" | > Reference: {label}\n")
193
+
194
+ return {
195
+ metric_name: metric.corpus_score(preds, [labels]).score
196
+ for (metric_name, metric) in metric_dict.items()
197
+ }
198
+
199
+ return compute_metrics
200
+
201
+
202
+ def preprocess_fn(example, tokenizer, **kwargs):
203
+ model_inputs = tokenizer(
204
+ example["sentence_SRC"], truncation=True, padding=False, max_length=256
205
+ )
206
+
207
+ with tokenizer.as_target_tokenizer():
208
+ labels = tokenizer(
209
+ example["sentence_TGT"], truncation=True, padding=False, max_length=256
210
+ )
211
+
212
+ model_inputs["labels"] = labels["input_ids"]
213
+ return model_inputs
214
+
215
+
216
+ def main(args):
217
+ print(f" | > Loading {args.model} and tokenizer ...")
218
+ model = AutoModelForSeq2SeqLM.from_pretrained(
219
+ args.model,
220
+ trust_remote_code=True,
221
+ attn_implementation="eager",
222
+ dropout=args.dropout
223
+ )
224
+
225
+ tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
226
+ processor = IndicProcessor(inference=False) # pre-process before tokenization
227
+
228
+ data_collator = IndicDataCollator(
229
+ tokenizer=tokenizer,
230
+ model=model,
231
+ padding="longest", # saves padding tokens
232
+ pad_to_multiple_of=8, # better to have it as 8 when using fp16
233
+ label_pad_token_id=-100
234
+ )
235
+
236
+ if args.data_dir is not None:
237
+ train_dataset = load_and_process_translation_dataset(
238
+ args.data_dir,
239
+ split="train",
240
+ tokenizer=tokenizer,
241
+ processor=processor,
242
+ src_lang_list=args.src_lang_list.split(","),
243
+ tgt_lang_list=args.tgt_lang_list.split(","),
244
+ )
245
+ print(f" | > Loaded train dataset from {args.data_dir}. Size: {len(train_dataset)} ...")
246
+
247
+ eval_dataset = load_and_process_translation_dataset(
248
+ args.data_dir,
249
+ split="dev",
250
+ tokenizer=tokenizer,
251
+ processor=processor,
252
+ src_lang_list=args.src_lang_list.split(","),
253
+ tgt_lang_list=args.tgt_lang_list.split(","),
254
+ )
255
+ print(f" | > Loaded eval dataset from {args.data_dir}. Size: {len(eval_dataset)} ...")
256
+ else:
257
+ raise ValueError(" | > Data directory not provided")
258
+
259
+ lora_config = LoraConfig(
260
+ r=args.lora_r,
261
+ bias="none",
262
+ inference_mode=False,
263
+ task_type="SEQ_2_SEQ_LM",
264
+ lora_alpha=args.lora_alpha,
265
+ lora_dropout=args.lora_dropout,
266
+ target_modules=args.lora_target_modules.split(","),
267
+ )
268
+
269
+ model.set_label_smoothing(args.label_smoothing)
270
+
271
+ model = get_peft_model(model, lora_config)
272
+ model.print_trainable_parameters()
273
+
274
+ print(f" | > Loading metrics factory with BLEU and chrF ...")
275
+ seq2seq_compute_metrics = compute_metrics_factory(
276
+ tokenizer=tokenizer,
277
+ print_samples=args.print_samples,
278
+ metric_dict={"BLEU": bleu_metric, "chrF": chrf_metric},
279
+ )
280
+
281
+ training_args = Seq2SeqTrainingArguments(
282
+ output_dir=args.output_dir,
283
+ do_train=True,
284
+ do_eval=True,
285
+ fp16=True, # use fp16 for faster training
286
+ logging_strategy="steps",
287
+ evaluation_strategy="steps",
288
+ save_strategy="steps",
289
+ logging_steps=100,
290
+ save_total_limit=1,
291
+ predict_with_generate=True,
292
+ load_best_model_at_end=True,
293
+ max_steps=args.max_steps, # max_steps overrides num_train_epochs
294
+ per_device_train_batch_size=args.batch_size,
295
+ per_device_eval_batch_size=args.batch_size,
296
+ gradient_accumulation_steps=args.grad_accum_steps,
297
+ eval_accumulation_steps=args.grad_accum_steps,
298
+ weight_decay=args.weight_decay,
299
+ adam_beta1=args.adam_beta1,
300
+ adam_beta2=args.adam_beta2,
301
+ max_grad_norm=args.max_grad_norm,
302
+ optim=args.optimizer,
303
+ lr_scheduler_type=args.lr_scheduler,
304
+ warmup_ratio=args.warmup_ratio,
305
+ warmup_steps=args.warmup_steps,
306
+ learning_rate=args.learning_rate,
307
+ num_train_epochs=args.num_train_epochs,
308
+ save_steps=args.save_steps,
309
+ eval_steps=args.eval_steps,
310
+ dataloader_num_workers=args.num_workers,
311
+ metric_for_best_model=args.metric_for_best_model,
312
+ greater_is_better=args.greater_is_better,
313
+ report_to=args.report_to,
314
+ generation_max_length=256,
315
+ generation_num_beams=5,
316
+ sortish_sampler=True,
317
+ group_by_length=True,
318
+ include_tokens_per_second=True,
319
+ include_num_input_tokens_seen=True,
320
+ dataloader_prefetch_factor=2,
321
+ )
322
+
323
+ # Create Trainer instance
324
+ trainer = Seq2SeqTrainer(
325
+ model=model,
326
+ args=training_args,
327
+ data_collator=data_collator,
328
+ train_dataset=train_dataset,
329
+ eval_dataset=eval_dataset,
330
+ compute_metrics=seq2seq_compute_metrics,
331
+ callbacks=[
332
+ EarlyStoppingCallback(
333
+ early_stopping_patience=args.patience,
334
+ early_stopping_threshold=args.threshold,
335
+ )
336
+ ],
337
+ )
338
+
339
+ print(f" | > Starting training ...")
340
+
341
+ try:
342
+ trainer.train()
343
+ except KeyboardInterrupt:
344
+ print(f" | > Training interrupted ...")
345
+
346
+ # this will only save the LoRA adapter weights
347
+ model.save_pretrained(args.output_dir)
348
+
349
+
350
+
351
+ if __name__ == "__main__":
352
+ parser = get_arg_parse()
353
+ args = parser.parse_args()
354
+
355
+ main(args)
IndicTrans2/huggingface_interface/train_lora.sh ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export CUDA_VISIBLE_DEVICES=0
2
+
3
+ data_dir=${1:-"en-indic-exp"}
4
+ model_name=${2:-"ai4bharat/indictrans2-en-indic-dist-200M"}
5
+ output_dir=${3:-"output"}
6
+ src_lang_list=${4:-"eng_Latn"}
7
+ tgt_lang_list=${5:-"asm_Beng,ben_Beng,guj_Gujr,hin_Deva,kan_Knda,mal_Mlym,mar_Deva,npi_Deva,ory_Orya,pan_Guru,tam_Taml,tel_Telu,urd_Arab"}
8
+
9
+ python3 train_lora.py \
10
+ --data_dir $data_dir \
11
+ --model_name $model_name \
12
+ --output_dir $output_dir \
13
+ --src_lang_list $src_lang_list \
14
+ --tgt_lang_list $tgt_lang_list \
15
+ --save_steps 1000 \
16
+ --max_steps 1000000 \
17
+ --batch_size 32 \
18
+ --grad_accum_steps 4 \
19
+ --warmup_steps 4000 \
20
+ --max_grad_norm 1.0 \
21
+ --learning_rate 2e-4 \
22
+ --adam_beta1 0.9 \
23
+ --adam_beta2 0.98 \
24
+ --optimizer adamw_torch \
25
+ --lr_scheduler inverse_sqrt \
26
+ --num_workers 16 \
27
+ --metric_for_best_model eval_BLEU \
28
+ --greater_is_better \
29
+ --patience 10 \
30
+ --weight_decay 0.01 \
31
+ --lora_target_modules "q_proj,k_proj" \
32
+ --lora_dropout 0.1 \
33
+ --lora_r 16 \
34
+ --lora_alpha 32 \
35
+ --print_samples
IndicTrans2/inference/__init__.py ADDED
File without changes
IndicTrans2/inference/custom_interactive.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python wrapper for fairseq-interactive command line tool
2
+
3
+ #!/usr/bin/env python3 -u
4
+ # Copyright (c) Facebook, Inc. and its affiliates.
5
+ #
6
+ # This source code is licensed under the MIT license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+ """
9
+ Translate raw text with a trained model. Batches data on-the-fly.
10
+ """
11
+
12
+ import os
13
+ import ast
14
+ from collections import namedtuple
15
+
16
+ import torch
17
+ from fairseq import checkpoint_utils, options, tasks, utils
18
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
19
+ from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
20
+ from fairseq_cli.generate import get_symbols_to_strip_from_output
21
+
22
+ import codecs
23
+
24
+ PWD = os.path.dirname(__file__)
25
+ Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints")
26
+ Translation = namedtuple("Translation", "src_str hypos pos_scores alignments")
27
+
28
+
29
+ def make_batches(
30
+ lines, cfg, task, max_positions, encode_fn, constrainted_decoding=False
31
+ ):
32
+ def encode_fn_target(x):
33
+ return encode_fn(x)
34
+
35
+ if constrainted_decoding:
36
+ # Strip (tab-delimited) contraints, if present, from input lines,
37
+ # store them in batch_constraints
38
+ batch_constraints = [list() for _ in lines]
39
+ for i, line in enumerate(lines):
40
+ if "\t" in line:
41
+ lines[i], *batch_constraints[i] = line.split("\t")
42
+
43
+ # Convert each List[str] to List[Tensor]
44
+ for i, constraint_list in enumerate(batch_constraints):
45
+ batch_constraints[i] = [
46
+ task.target_dictionary.encode_line(
47
+ encode_fn_target(constraint),
48
+ append_eos=False,
49
+ add_if_not_exist=False,
50
+ )
51
+ for constraint in constraint_list
52
+ ]
53
+
54
+ if constrainted_decoding:
55
+ constraints_tensor = pack_constraints(batch_constraints)
56
+ else:
57
+ constraints_tensor = None
58
+
59
+ tokens, lengths = task.get_interactive_tokens_and_lengths(lines, encode_fn)
60
+
61
+ itr = task.get_batch_iterator(
62
+ dataset=task.build_dataset_for_inference(
63
+ tokens, lengths, constraints=constraints_tensor
64
+ ),
65
+ max_tokens=cfg.dataset.max_tokens,
66
+ max_sentences=cfg.dataset.batch_size,
67
+ max_positions=max_positions,
68
+ ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
69
+ ).next_epoch_itr(shuffle=False)
70
+ for batch in itr:
71
+ ids = batch["id"]
72
+ src_tokens = batch["net_input"]["src_tokens"]
73
+ src_lengths = batch["net_input"]["src_lengths"]
74
+ constraints = batch.get("constraints", None)
75
+
76
+ yield Batch(
77
+ ids=ids,
78
+ src_tokens=src_tokens,
79
+ src_lengths=src_lengths,
80
+ constraints=constraints,
81
+ )
82
+
83
+
84
+ class Translator:
85
+ """
86
+ Wrapper class to handle the interaction with fairseq model class for translation
87
+ """
88
+
89
+ def __init__(
90
+ self, data_dir, checkpoint_path, batch_size=25, constrained_decoding=False
91
+ ):
92
+
93
+ self.constrained_decoding = constrained_decoding
94
+ self.parser = options.get_generation_parser(interactive=True)
95
+ # buffer_size is currently not used but we just initialize it to batch
96
+ # size + 1 to avoid any assertion errors.
97
+ if self.constrained_decoding:
98
+ self.parser.set_defaults(
99
+ path=checkpoint_path,
100
+ num_workers=-1,
101
+ constraints="ordered",
102
+ batch_size=batch_size,
103
+ buffer_size=batch_size + 1,
104
+ )
105
+ else:
106
+ self.parser.set_defaults(
107
+ path=checkpoint_path,
108
+ remove_bpe="subword_nmt",
109
+ num_workers=-1,
110
+ batch_size=batch_size,
111
+ buffer_size=batch_size + 1,
112
+ )
113
+ args = options.parse_args_and_arch(self.parser, input_args=[data_dir])
114
+ # we are explictly setting src_lang and tgt_lang here
115
+ # generally the data_dir we pass contains {split}-{src_lang}-{tgt_lang}.*.idx files from
116
+ # which fairseq infers the src and tgt langs(if these are not passed). In deployment we dont
117
+ # use any idx files and only store the SRC and TGT dictionaries.
118
+ args.source_lang = "SRC"
119
+ args.target_lang = "TGT"
120
+ # since we are truncating sentences to max_seq_len in engine, we can set it to False here
121
+ args.skip_invalid_size_inputs_valid_test = False
122
+
123
+ # we have custom architechtures in this folder and we will let fairseq
124
+ # import this
125
+ args.user_dir = os.path.join(PWD, "model_configs")
126
+ self.cfg = convert_namespace_to_omegaconf(args)
127
+
128
+ utils.import_user_module(self.cfg.common)
129
+
130
+ if self.cfg.interactive.buffer_size < 1:
131
+ self.cfg.interactive.buffer_size = 1
132
+ if self.cfg.dataset.max_tokens is None and self.cfg.dataset.batch_size is None:
133
+ self.cfg.dataset.batch_size = 1
134
+
135
+ assert (
136
+ not self.cfg.generation.sampling
137
+ or self.cfg.generation.nbest == self.cfg.generation.beam
138
+ ), "--sampling requires --nbest to be equal to --beam"
139
+ assert (
140
+ not self.cfg.dataset.batch_size
141
+ or self.cfg.dataset.batch_size <= self.cfg.interactive.buffer_size
142
+ ), "--batch-size cannot be larger than --buffer-size"
143
+
144
+ # Fix seed for stochastic decoding
145
+ # if self.cfg.common.seed is not None and not self.cfg.generation.no_seed_provided:
146
+ # np.random.seed(self.cfg.common.seed)
147
+ # utils.set_torch_seed(self.cfg.common.seed)
148
+
149
+ # if not self.constrained_decoding:
150
+ # self.use_cuda = torch.cuda.is_available() and not self.cfg.common.cpu
151
+ # else:
152
+ # self.use_cuda = False
153
+
154
+ self.use_cuda = torch.cuda.is_available() and not self.cfg.common.cpu
155
+
156
+ # Setup task, e.g., translation
157
+ self.task = tasks.setup_task(self.cfg.task)
158
+
159
+ # Load ensemble
160
+ overrides = ast.literal_eval(self.cfg.common_eval.model_overrides)
161
+ self.models, self._model_args = checkpoint_utils.load_model_ensemble(
162
+ utils.split_paths(self.cfg.common_eval.path),
163
+ arg_overrides=overrides,
164
+ task=self.task,
165
+ suffix=self.cfg.checkpoint.checkpoint_suffix,
166
+ strict=(self.cfg.checkpoint.checkpoint_shard_count == 1),
167
+ num_shards=self.cfg.checkpoint.checkpoint_shard_count,
168
+ )
169
+
170
+ # Set dictionaries
171
+ self.src_dict = self.task.source_dictionary
172
+ self.tgt_dict = self.task.target_dictionary
173
+
174
+ # Optimize ensemble for generation
175
+ for model in self.models:
176
+ if model is None:
177
+ continue
178
+ if self.cfg.common.fp16:
179
+ model.half()
180
+ if (
181
+ self.use_cuda
182
+ and not self.cfg.distributed_training.pipeline_model_parallel
183
+ ):
184
+ model.cuda()
185
+ model.prepare_for_inference_(self.cfg)
186
+
187
+ # Initialize generator
188
+ self.generator = self.task.build_generator(self.models, self.cfg.generation)
189
+
190
+ self.tokenizer = None
191
+ self.bpe = None
192
+ # # Handle tokenization and BPE
193
+ # self.tokenizer = self.task.build_tokenizer(self.cfg.tokenizer)
194
+ # self.bpe = self.task.build_bpe(self.cfg.bpe)
195
+
196
+ # Load alignment dictionary for unknown word replacement
197
+ # (None if no unknown word replacement, empty if no path to align dictionary)
198
+ self.align_dict = utils.load_align_dict(self.cfg.generation.replace_unk)
199
+
200
+ self.max_positions = utils.resolve_max_positions(
201
+ self.task.max_positions(), *[model.max_positions() for model in self.models]
202
+ )
203
+
204
+ def encode_fn(self, x):
205
+ if self.tokenizer is not None:
206
+ x = self.tokenizer.encode(x)
207
+ if self.bpe is not None:
208
+ x = self.bpe.encode(x)
209
+ return x
210
+
211
+ def decode_fn(self, x):
212
+ if self.bpe is not None:
213
+ x = self.bpe.decode(x)
214
+ if self.tokenizer is not None:
215
+ x = self.tokenizer.decode(x)
216
+ return x
217
+
218
+ def translate(self, inputs, constraints=None):
219
+ if self.constrained_decoding and constraints is None:
220
+ raise ValueError("Constraints cant be None in constrained decoding mode")
221
+ if not self.constrained_decoding and constraints is not None:
222
+ raise ValueError("Cannot pass constraints during normal translation")
223
+ if constraints:
224
+ constrained_decoding = True
225
+ modified_inputs = []
226
+ for _input, constraint in zip(inputs, constraints):
227
+ modified_inputs.append(_input + f"\t{constraint}")
228
+ inputs = modified_inputs
229
+ else:
230
+ constrained_decoding = False
231
+
232
+ start_id = 0
233
+ results = []
234
+ final_translations = []
235
+ for batch in make_batches(
236
+ inputs,
237
+ self.cfg,
238
+ self.task,
239
+ self.max_positions,
240
+ self.encode_fn,
241
+ constrained_decoding,
242
+ ):
243
+ bsz = batch.src_tokens.size(0)
244
+ src_tokens = batch.src_tokens
245
+ src_lengths = batch.src_lengths
246
+ constraints = batch.constraints
247
+ if self.use_cuda:
248
+ src_tokens = src_tokens.cuda()
249
+ src_lengths = src_lengths.cuda()
250
+ if constraints is not None:
251
+ constraints = constraints.cuda()
252
+
253
+ sample = {
254
+ "net_input": {
255
+ "src_tokens": src_tokens,
256
+ "src_lengths": src_lengths,
257
+ },
258
+ }
259
+
260
+ translations = self.task.inference_step(
261
+ self.generator, self.models, sample, constraints=constraints
262
+ )
263
+
264
+ list_constraints = [[] for _ in range(bsz)]
265
+ if constrained_decoding:
266
+ list_constraints = [unpack_constraints(c) for c in constraints]
267
+ for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
268
+ src_tokens_i = utils.strip_pad(src_tokens[i], self.tgt_dict.pad())
269
+ constraints = list_constraints[i]
270
+ results.append(
271
+ (
272
+ start_id + id,
273
+ src_tokens_i,
274
+ hypos,
275
+ {
276
+ "constraints": constraints,
277
+ },
278
+ )
279
+ )
280
+
281
+ # sort output to match input order
282
+ for id_, src_tokens, hypos, _ in sorted(results, key=lambda x: x[0]):
283
+ src_str = ""
284
+ if self.src_dict is not None:
285
+ src_str = self.src_dict.string(
286
+ src_tokens, self.cfg.common_eval.post_process
287
+ )
288
+
289
+ # Process top predictions
290
+ for hypo in hypos[: min(len(hypos), self.cfg.generation.nbest)]:
291
+ hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
292
+ hypo_tokens=hypo["tokens"].int().cpu(),
293
+ src_str=src_str,
294
+ alignment=hypo["alignment"],
295
+ align_dict=self.align_dict,
296
+ tgt_dict=self.tgt_dict,
297
+
298
+ extra_symbols_to_ignore=get_symbols_to_strip_from_output(
299
+ self.generator
300
+ ),
301
+ )
302
+ detok_hypo_str = self.decode_fn(hypo_str)
303
+ final_translations.append(detok_hypo_str)
304
+ return final_translations
IndicTrans2/inference/download.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import urduhack
2
+ urduhack.download()
3
+
4
+ import nltk
5
+ nltk.download('punkt')
IndicTrans2/inference/engine.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import uuid
4
+ from typing import List, Tuple, Union, Dict
5
+
6
+ import regex as re
7
+ import sentencepiece as spm
8
+ from indicnlp.normalize import indic_normalize
9
+ from indicnlp.tokenize import indic_detokenize, indic_tokenize
10
+ from indicnlp.tokenize.sentence_tokenize import DELIM_PAT_NO_DANDA, sentence_split
11
+ from indicnlp.transliterate import unicode_transliterate
12
+ from mosestokenizer import MosesSentenceSplitter
13
+ from nltk.tokenize import sent_tokenize
14
+ from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer
15
+ from tqdm import tqdm
16
+
17
+ from .flores_codes_map_indic import flores_codes, iso_to_flores
18
+ from .normalize_punctuation import punc_norm
19
+ from .normalize_regex_inference import EMAIL_PATTERN, normalize
20
+
21
+
22
+ def split_sentences(paragraph: str, lang: str) -> List[str]:
23
+ """
24
+ Splits the input text paragraph into sentences. It uses `moses` for English and
25
+ `indic-nlp` for Indic languages.
26
+
27
+ Args:
28
+ paragraph (str): input text paragraph.
29
+ lang (str): flores language code.
30
+
31
+ Returns:
32
+ List[str] -> list of sentences.
33
+ """
34
+ if lang == "eng_Latn":
35
+ with MosesSentenceSplitter(flores_codes[lang]) as splitter:
36
+ sents_moses = splitter([paragraph])
37
+ sents_nltk = sent_tokenize(paragraph)
38
+ if len(sents_nltk) < len(sents_moses):
39
+ sents = sents_nltk
40
+ else:
41
+ sents = sents_moses
42
+ return [sent.replace("\xad", "") for sent in sents]
43
+ else:
44
+ return sentence_split(paragraph, lang=flores_codes[lang], delim_pat=DELIM_PAT_NO_DANDA)
45
+
46
+
47
+ def add_token(sent: str, src_lang: str, tgt_lang: str, delimiter: str = " ") -> str:
48
+ """
49
+ Add special tokens indicating source and target language to the start of the input sentence.
50
+ The resulting string will have the format: "`{src_lang} {tgt_lang} {input_sentence}`".
51
+
52
+ Args:
53
+ sent (str): input sentence to be translated.
54
+ src_lang (str): flores lang code of the input sentence.
55
+ tgt_lang (str): flores lang code in which the input sentence will be translated.
56
+ delimiter (str): separator to add between language tags and input sentence (default: " ").
57
+
58
+ Returns:
59
+ str: input sentence with the special tokens added to the start.
60
+ """
61
+ return src_lang + delimiter + tgt_lang + delimiter + sent
62
+
63
+
64
+ def apply_lang_tags(sents: List[str], src_lang: str, tgt_lang: str) -> List[str]:
65
+ """
66
+ Add special tokens indicating source and target language to the start of the each input sentence.
67
+ Each resulting input sentence will have the format: "`{src_lang} {tgt_lang} {input_sentence}`".
68
+
69
+ Args:
70
+ sent (str): input sentence to be translated.
71
+ src_lang (str): flores lang code of the input sentence.
72
+ tgt_lang (str): flores lang code in which the input sentence will be translated.
73
+
74
+ Returns:
75
+ List[str]: list of input sentences with the special tokens added to the start.
76
+ """
77
+ tagged_sents = []
78
+ for sent in sents:
79
+ tagged_sent = add_token(sent.strip(), src_lang, tgt_lang)
80
+ tagged_sents.append(tagged_sent)
81
+ return tagged_sents
82
+
83
+
84
+ def truncate_long_sentences(
85
+ sents: List[str], placeholder_entity_map_sents: List[Dict]
86
+ ) -> Tuple[List[str], List[Dict]]:
87
+ """
88
+ Truncates the sentences that exceed the maximum sequence length.
89
+ The maximum sequence for the IndicTrans2 model is limited to 256 tokens.
90
+
91
+ Args:
92
+ sents (List[str]): list of input sentences to truncate.
93
+
94
+ Returns:
95
+ Tuple[List[str], List[Dict]]: tuple containing the list of sentences with truncation applied and the updated placeholder entity maps.
96
+ """
97
+ MAX_SEQ_LEN = 256
98
+ new_sents = []
99
+ placeholders = []
100
+
101
+ for j, sent in enumerate(sents):
102
+ words = sent.split()
103
+ num_words = len(words)
104
+ if num_words > MAX_SEQ_LEN:
105
+ sents = []
106
+ i = 0
107
+ while i <= len(words):
108
+ sents.append(" ".join(words[i : i + MAX_SEQ_LEN]))
109
+ i += MAX_SEQ_LEN
110
+ placeholders.extend([placeholder_entity_map_sents[j]] * (len(sents)))
111
+ new_sents.extend(sents)
112
+ else:
113
+ placeholders.append(placeholder_entity_map_sents[j])
114
+ new_sents.append(sent)
115
+ return new_sents, placeholders
116
+
117
+
118
+ class Model:
119
+ """
120
+ Model class to run the IndicTransv2 models using python interface.
121
+ """
122
+
123
+ def __init__(
124
+ self,
125
+ ckpt_dir: str,
126
+ device: str = "cuda",
127
+ input_lang_code_format: str = "flores",
128
+ model_type: str = "ctranslate2",
129
+ ):
130
+ """
131
+ Initialize the model class.
132
+
133
+ Args:
134
+ ckpt_dir (str): path of the model checkpoint directory.
135
+ device (str, optional): where to load the model (defaults: cuda).
136
+ """
137
+ self.ckpt_dir = ckpt_dir
138
+ self.en_tok = MosesTokenizer(lang="en")
139
+ self.en_normalizer = MosesPunctNormalizer()
140
+ self.en_detok = MosesDetokenizer(lang="en")
141
+ self.xliterator = unicode_transliterate.UnicodeIndicTransliterator()
142
+
143
+ print("Initializing sentencepiece model for SRC and TGT")
144
+ self.sp_src = spm.SentencePieceProcessor(
145
+ model_file=os.path.join(ckpt_dir, "vocab", "model.SRC")
146
+ )
147
+ self.sp_tgt = spm.SentencePieceProcessor(
148
+ model_file=os.path.join(ckpt_dir, "vocab", "model.TGT")
149
+ )
150
+
151
+ self.input_lang_code_format = input_lang_code_format
152
+
153
+ print("Initializing model for translation")
154
+ # initialize the model
155
+ if model_type == "ctranslate2":
156
+ import ctranslate2
157
+
158
+ self.translator = ctranslate2.Translator(
159
+ self.ckpt_dir, device=device
160
+ ) # , compute_type="auto")
161
+ self.translate_lines = self.ctranslate2_translate_lines
162
+ elif model_type == "fairseq":
163
+ from .custom_interactive import Translator
164
+
165
+ self.translator = Translator(
166
+ data_dir=os.path.join(self.ckpt_dir, "final_bin"),
167
+ checkpoint_path=os.path.join(self.ckpt_dir, "model", "checkpoint_best.pt"),
168
+ batch_size=100,
169
+ )
170
+ self.translate_lines = self.fairseq_translate_lines
171
+ else:
172
+ raise NotImplementedError(f"Unknown model_type: {model_type}")
173
+
174
+ def ctranslate2_translate_lines(self, lines: List[str]) -> List[str]:
175
+ tokenized_sents = [x.strip().split(" ") for x in lines]
176
+ translations = self.translator.translate_batch(
177
+ tokenized_sents,
178
+ max_batch_size=9216,
179
+ batch_type="tokens",
180
+ max_input_length=160,
181
+ max_decoding_length=256,
182
+ beam_size=5,
183
+ )
184
+ translations = [" ".join(x.hypotheses[0]) for x in translations]
185
+ return translations
186
+
187
+ def fairseq_translate_lines(self, lines: List[str]) -> List[str]:
188
+ return self.translator.translate(lines)
189
+
190
+ def paragraphs_batch_translate__multilingual(self, batch_payloads: List[tuple]) -> List[str]:
191
+ """
192
+ Translates a batch of input paragraphs (including pre/post processing)
193
+ from any language to any language.
194
+
195
+ Args:
196
+ batch_payloads (List[tuple]): batch of long input-texts to be translated, each in format: (paragraph, src_lang, tgt_lang)
197
+
198
+ Returns:
199
+ List[str]: batch of paragraph-translations in the respective languages.
200
+ """
201
+ paragraph_id_to_sentence_range = []
202
+ global__sents = []
203
+ global__preprocessed_sents = []
204
+ global__preprocessed_sents_placeholder_entity_map = []
205
+
206
+ for i in range(len(batch_payloads)):
207
+ paragraph, src_lang, tgt_lang = batch_payloads[i]
208
+ if self.input_lang_code_format == "iso":
209
+ src_lang, tgt_lang = iso_to_flores[src_lang], iso_to_flores[tgt_lang]
210
+
211
+ batch = split_sentences(paragraph, src_lang)
212
+ global__sents.extend(batch)
213
+
214
+ preprocessed_sents, placeholder_entity_map_sents = self.preprocess_batch(
215
+ batch, src_lang, tgt_lang
216
+ )
217
+
218
+ global_sentence_start_index = len(global__preprocessed_sents)
219
+ global__preprocessed_sents.extend(preprocessed_sents)
220
+ global__preprocessed_sents_placeholder_entity_map.extend(placeholder_entity_map_sents)
221
+ paragraph_id_to_sentence_range.append(
222
+ (global_sentence_start_index, len(global__preprocessed_sents))
223
+ )
224
+
225
+ translations = self.translate_lines(global__preprocessed_sents)
226
+
227
+ translated_paragraphs = []
228
+ for paragraph_id, sentence_range in enumerate(paragraph_id_to_sentence_range):
229
+ tgt_lang = batch_payloads[paragraph_id][2]
230
+ if self.input_lang_code_format == "iso":
231
+ tgt_lang = iso_to_flores[tgt_lang]
232
+
233
+ postprocessed_sents = self.postprocess(
234
+ translations[sentence_range[0] : sentence_range[1]],
235
+ global__preprocessed_sents_placeholder_entity_map[
236
+ sentence_range[0] : sentence_range[1]
237
+ ],
238
+ tgt_lang,
239
+ )
240
+ translated_paragraph = " ".join(postprocessed_sents)
241
+ translated_paragraphs.append(translated_paragraph)
242
+
243
+ return translated_paragraphs
244
+
245
+ # translate a batch of sentences from src_lang to tgt_lang
246
+ def batch_translate(self, batch: List[str], src_lang: str, tgt_lang: str) -> List[str]:
247
+ """
248
+ Translates a batch of input sentences (including pre/post processing)
249
+ from source language to target language.
250
+
251
+ Args:
252
+ batch (List[str]): batch of input sentences to be translated.
253
+ src_lang (str): flores source language code.
254
+ tgt_lang (str): flores target language code.
255
+
256
+ Returns:
257
+ List[str]: batch of translated-sentences generated by the model.
258
+ """
259
+
260
+ assert isinstance(batch, list)
261
+
262
+ if self.input_lang_code_format == "iso":
263
+ src_lang, tgt_lang = iso_to_flores[src_lang], iso_to_flores[tgt_lang]
264
+
265
+ preprocessed_sents, placeholder_entity_map_sents = self.preprocess_batch(
266
+ batch, src_lang, tgt_lang
267
+ )
268
+ translations = self.translate_lines(preprocessed_sents)
269
+ return self.postprocess(translations, placeholder_entity_map_sents, tgt_lang)
270
+
271
+ # translate a paragraph from src_lang to tgt_lang
272
+ def translate_paragraph(self, paragraph: str, src_lang: str, tgt_lang: str) -> str:
273
+ """
274
+ Translates an input text paragraph (including pre/post processing)
275
+ from source language to target language.
276
+
277
+ Args:
278
+ paragraph (str): input text paragraph to be translated.
279
+ src_lang (str): flores source language code.
280
+ tgt_lang (str): flores target language code.
281
+
282
+ Returns:
283
+ str: paragraph translation generated by the model.
284
+ """
285
+
286
+ assert isinstance(paragraph, str)
287
+
288
+ if self.input_lang_code_format == "iso":
289
+ flores_src_lang = iso_to_flores[src_lang]
290
+ else:
291
+ flores_src_lang = src_lang
292
+
293
+ sents = split_sentences(paragraph, flores_src_lang)
294
+ postprocessed_sents = self.batch_translate(sents, src_lang, tgt_lang)
295
+ translated_paragraph = " ".join(postprocessed_sents)
296
+
297
+ return translated_paragraph
298
+
299
+ def preprocess_batch(self, batch: List[str], src_lang: str, tgt_lang: str) -> List[str]:
300
+ """
301
+ Preprocess an array of sentences by normalizing, tokenization, and possibly transliterating it. It also tokenizes the
302
+ normalized text sequences using sentence piece tokenizer and also adds language tags.
303
+
304
+ Args:
305
+ batch (List[str]): input list of sentences to preprocess.
306
+ src_lang (str): flores language code of the input text sentences.
307
+ tgt_lang (str): flores language code of the output text sentences.
308
+
309
+ Returns:
310
+ Tuple[List[str], List[Dict]]: a tuple of list of preprocessed input text sentences and also a corresponding list of dictionary
311
+ mapping placeholders to their original values.
312
+ """
313
+ preprocessed_sents, placeholder_entity_map_sents = self.preprocess(batch, lang=src_lang)
314
+ tokenized_sents = self.apply_spm(preprocessed_sents)
315
+ tokenized_sents, placeholder_entity_map_sents = truncate_long_sentences(
316
+ tokenized_sents, placeholder_entity_map_sents
317
+ )
318
+ tagged_sents = apply_lang_tags(tokenized_sents, src_lang, tgt_lang)
319
+ return tagged_sents, placeholder_entity_map_sents
320
+
321
+ def apply_spm(self, sents: List[str]) -> List[str]:
322
+ """
323
+ Applies sentence piece encoding to the batch of input sentences.
324
+
325
+ Args:
326
+ sents (List[str]): batch of the input sentences.
327
+
328
+ Returns:
329
+ List[str]: batch of encoded sentences with sentence piece model
330
+ """
331
+ return [" ".join(self.sp_src.encode(sent, out_type=str)) for sent in sents]
332
+
333
+ def preprocess_sent(
334
+ self,
335
+ sent: str,
336
+ normalizer: Union[MosesPunctNormalizer, indic_normalize.IndicNormalizerFactory],
337
+ lang: str,
338
+ ) -> Tuple[str, Dict]:
339
+ """
340
+ Preprocess an input text sentence by normalizing, tokenization, and possibly transliterating it.
341
+
342
+ Args:
343
+ sent (str): input text sentence to preprocess.
344
+ normalizer (Union[MosesPunctNormalizer, indic_normalize.IndicNormalizerFactory]): an object that performs normalization on the text.
345
+ lang (str): flores language code of the input text sentence.
346
+
347
+ Returns:
348
+ Tuple[str, Dict]: A tuple containing the preprocessed input text sentence and a corresponding dictionary
349
+ mapping placeholders to their original values.
350
+ """
351
+ iso_lang = flores_codes[lang]
352
+ sent = punc_norm(sent, iso_lang)
353
+ sent, placeholder_entity_map = normalize(sent)
354
+
355
+ transliterate = True
356
+ if lang.split("_")[1] in ["Arab", "Aran", "Olck", "Mtei", "Latn"]:
357
+ transliterate = False
358
+
359
+ if iso_lang == "en":
360
+ processed_sent = " ".join(
361
+ self.en_tok.tokenize(self.en_normalizer.normalize(sent.strip()), escape=False)
362
+ )
363
+ elif transliterate:
364
+ # transliterates from the any specific language to devanagari
365
+ # which is why we specify lang2_code as "hi".
366
+ processed_sent = self.xliterator.transliterate(
367
+ " ".join(
368
+ indic_tokenize.trivial_tokenize(normalizer.normalize(sent.strip()), iso_lang)
369
+ ),
370
+ iso_lang,
371
+ "hi",
372
+ ).replace(" ् ", "्")
373
+ else:
374
+ # we only need to transliterate for joint training
375
+ processed_sent = " ".join(
376
+ indic_tokenize.trivial_tokenize(normalizer.normalize(sent.strip()), iso_lang)
377
+ )
378
+
379
+ return processed_sent, placeholder_entity_map
380
+
381
+ def preprocess(self, sents: List[str], lang: str):
382
+ """
383
+ Preprocess an array of sentences by normalizing, tokenization, and possibly transliterating it.
384
+
385
+ Args:
386
+ batch (List[str]): input list of sentences to preprocess.
387
+ lang (str): flores language code of the input text sentences.
388
+
389
+ Returns:
390
+ Tuple[List[str], List[Dict]]: a tuple of list of preprocessed input text sentences and also a corresponding list of dictionary
391
+ mapping placeholders to their original values.
392
+ """
393
+ processed_sents, placeholder_entity_map_sents = [], []
394
+
395
+ if lang == "eng_Latn":
396
+ normalizer = None
397
+ else:
398
+ normfactory = indic_normalize.IndicNormalizerFactory()
399
+ normalizer = normfactory.get_normalizer(flores_codes[lang])
400
+
401
+ for sent in sents:
402
+ sent, placeholder_entity_map = self.preprocess_sent(sent, normalizer, lang)
403
+ processed_sents.append(sent)
404
+ placeholder_entity_map_sents.append(placeholder_entity_map)
405
+
406
+ return processed_sents, placeholder_entity_map_sents
407
+
408
+ def postprocess(
409
+ self,
410
+ sents: List[str],
411
+ placeholder_entity_map: List[Dict],
412
+ lang: str,
413
+ common_lang: str = "hin_Deva",
414
+ ) -> List[str]:
415
+ """
416
+ Postprocesses a batch of input sentences after the translation generations.
417
+
418
+ Args:
419
+ sents (List[str]): batch of translated sentences to postprocess.
420
+ placeholder_entity_map (List[Dict]): dictionary mapping placeholders to the original entity values.
421
+ lang (str): flores language code of the input sentences.
422
+ common_lang (str, optional): flores language code of the transliterated language (defaults: hin_Deva).
423
+
424
+ Returns:
425
+ List[str]: postprocessed batch of input sentences.
426
+ """
427
+
428
+ lang_code, script_code = lang.split("_")
429
+ # SPM decode
430
+ for i in range(len(sents)):
431
+ # sent_tokens = sents[i].split(" ")
432
+ # sents[i] = self.sp_tgt.decode(sent_tokens)
433
+
434
+ sents[i] = sents[i].replace(" ", "").replace("▁", " ").strip()
435
+
436
+ # Fixes for Perso-Arabic scripts
437
+ # TODO: Move these normalizations inside indic-nlp-library
438
+ if script_code in {"Arab", "Aran"}:
439
+ # UrduHack adds space before punctuations. Since the model was trained without fixing this issue, let's fix it now
440
+ sents[i] = sents[i].replace(" ؟", "؟").replace(" ۔", "۔").replace(" ،", "،")
441
+ # Kashmiri bugfix for palatalization: https://github.com/AI4Bharat/IndicTrans2/issues/11
442
+ sents[i] = sents[i].replace("ٮ۪", "ؠ")
443
+
444
+ assert len(sents) == len(placeholder_entity_map)
445
+
446
+ for i in range(0, len(sents)):
447
+ for key in placeholder_entity_map[i].keys():
448
+ sents[i] = sents[i].replace(key, placeholder_entity_map[i][key])
449
+
450
+ # Detokenize and transliterate to native scripts if applicable
451
+ postprocessed_sents = []
452
+
453
+ if lang == "eng_Latn":
454
+ for sent in sents:
455
+ postprocessed_sents.append(self.en_detok.detokenize(sent.split(" ")))
456
+ else:
457
+ for sent in sents:
458
+ outstr = indic_detokenize.trivial_detokenize(
459
+ self.xliterator.transliterate(
460
+ sent, flores_codes[common_lang], flores_codes[lang]
461
+ ),
462
+ flores_codes[lang],
463
+ )
464
+
465
+ # Oriya bug: indic-nlp-library produces ଯ଼ instead of ୟ when converting from Devanagari to Odia
466
+ # TODO: Find out what's the issue with unicode transliterator for Oriya and fix it
467
+ if lang_code == "ory":
468
+ outstr = outstr.replace("ଯ଼", 'ୟ')
469
+
470
+ postprocessed_sents.append(outstr)
471
+
472
+ return postprocessed_sents
IndicTrans2/inference/flores_codes_map_indic.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FLORES language code mapping to 2 letter ISO language code for compatibility
3
+ with Indic NLP Library (https://github.com/anoopkunchukuttan/indic_nlp_library)
4
+ """
5
+ flores_codes = {
6
+ "asm_Beng": "as",
7
+ "awa_Deva": "hi",
8
+ "ben_Beng": "bn",
9
+ "bho_Deva": "hi",
10
+ "brx_Deva": "hi",
11
+ "doi_Deva": "hi",
12
+ "eng_Latn": "en",
13
+ "gom_Deva": "kK",
14
+ "guj_Gujr": "gu",
15
+ "hin_Deva": "hi",
16
+ "hne_Deva": "hi",
17
+ "kan_Knda": "kn",
18
+ "kas_Arab": "ur",
19
+ "kas_Deva": "hi",
20
+ "kha_Latn": "en",
21
+ "lus_Latn": "en",
22
+ "mag_Deva": "hi",
23
+ "mai_Deva": "hi",
24
+ "mal_Mlym": "ml",
25
+ "mar_Deva": "mr",
26
+ "mni_Beng": "bn",
27
+ "mni_Mtei": "hi",
28
+ "npi_Deva": "ne",
29
+ "ory_Orya": "or",
30
+ "pan_Guru": "pa",
31
+ "san_Deva": "hi",
32
+ "sat_Olck": "or",
33
+ "snd_Arab": "ur",
34
+ "snd_Deva": "hi",
35
+ "tam_Taml": "ta",
36
+ "tel_Telu": "te",
37
+ "urd_Arab": "ur",
38
+ }
39
+
40
+
41
+ flores_to_iso = {
42
+ "asm_Beng": "as",
43
+ "awa_Deva": "awa",
44
+ "ben_Beng": "bn",
45
+ "bho_Deva": "bho",
46
+ "brx_Deva": "brx",
47
+ "doi_Deva": "doi",
48
+ "eng_Latn": "en",
49
+ "gom_Deva": "gom",
50
+ "guj_Gujr": "gu",
51
+ "hin_Deva": "hi",
52
+ "hne_Deva": "hne",
53
+ "kan_Knda": "kn",
54
+ "kas_Arab": "ksa",
55
+ "kas_Deva": "ksd",
56
+ "kha_Latn": "kha",
57
+ "lus_Latn": "lus",
58
+ "mag_Deva": "mag",
59
+ "mai_Deva": "mai",
60
+ "mal_Mlym": "ml",
61
+ "mar_Deva": "mr",
62
+ "mni_Beng": "mnib",
63
+ "mni_Mtei": "mnim",
64
+ "npi_Deva": "ne",
65
+ "ory_Orya": "or",
66
+ "pan_Guru": "pa",
67
+ "san_Deva": "sa",
68
+ "sat_Olck": "sat",
69
+ "snd_Arab": "sda",
70
+ "snd_Deva": "sdd",
71
+ "tam_Taml": "ta",
72
+ "tel_Telu": "te",
73
+ "urd_Arab": "ur",
74
+ }
75
+
76
+ iso_to_flores = {iso_code: flores_code for flores_code, iso_code in flores_to_iso.items()}
77
+ # Patch for digraphic langs.
78
+ iso_to_flores["ks"] = "kas_Arab"
79
+ iso_to_flores["ks_Deva"] = "kas_Deva"
80
+ iso_to_flores["mni"] = "mni_Mtei"
81
+ iso_to_flores["mni_Beng"] = "mni_Beng"
82
+ iso_to_flores["sd"] = "snd_Arab"
83
+ iso_to_flores["sd_Deva"] = "snd_Deva"
IndicTrans2/inference/indic_num_map.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A dictionary mapping intended to normalize the numerals in Indic languages from
3
+ native script to Roman script. This is done to ensure that the figures / numbers
4
+ mentioned in native script are perfectly preserved during translation.
5
+ """
6
+ INDIC_NUM_MAP = {
7
+ "\u09e6": "0",
8
+ "0": "0",
9
+ "\u0ae6": "0",
10
+ "\u0ce6": "0",
11
+ "\u0966": "0",
12
+ "\u0660": "0",
13
+ "\uabf0": "0",
14
+ "\u0b66": "0",
15
+ "\u0a66": "0",
16
+ "\u1c50": "0",
17
+ "\u06f0": "0",
18
+ "\u09e7": "1",
19
+ "1": "1",
20
+ "\u0ae7": "1",
21
+ "\u0967": "1",
22
+ "\u0ce7": "1",
23
+ "\u06f1": "1",
24
+ "\uabf1": "1",
25
+ "\u0b67": "1",
26
+ "\u0a67": "1",
27
+ "\u1c51": "1",
28
+ "\u0c67": "1",
29
+ "\u09e8": "2",
30
+ "2": "2",
31
+ "\u0ae8": "2",
32
+ "\u0968": "2",
33
+ "\u0ce8": "2",
34
+ "\u06f2": "2",
35
+ "\uabf2": "2",
36
+ "\u0b68": "2",
37
+ "\u0a68": "2",
38
+ "\u1c52": "2",
39
+ "\u0c68": "2",
40
+ "\u09e9": "3",
41
+ "3": "3",
42
+ "\u0ae9": "3",
43
+ "\u0969": "3",
44
+ "\u0ce9": "3",
45
+ "\u06f3": "3",
46
+ "\uabf3": "3",
47
+ "\u0b69": "3",
48
+ "\u0a69": "3",
49
+ "\u1c53": "3",
50
+ "\u0c69": "3",
51
+ "\u09ea": "4",
52
+ "4": "4",
53
+ "\u0aea": "4",
54
+ "\u096a": "4",
55
+ "\u0cea": "4",
56
+ "\u06f4": "4",
57
+ "\uabf4": "4",
58
+ "\u0b6a": "4",
59
+ "\u0a6a": "4",
60
+ "\u1c54": "4",
61
+ "\u0c6a": "4",
62
+ "\u09eb": "5",
63
+ "5": "5",
64
+ "\u0aeb": "5",
65
+ "\u096b": "5",
66
+ "\u0ceb": "5",
67
+ "\u06f5": "5",
68
+ "\uabf5": "5",
69
+ "\u0b6b": "5",
70
+ "\u0a6b": "5",
71
+ "\u1c55": "5",
72
+ "\u0c6b": "5",
73
+ "\u09ec": "6",
74
+ "6": "6",
75
+ "\u0aec": "6",
76
+ "\u096c": "6",
77
+ "\u0cec": "6",
78
+ "\u06f6": "6",
79
+ "\uabf6": "6",
80
+ "\u0b6c": "6",
81
+ "\u0a6c": "6",
82
+ "\u1c56": "6",
83
+ "\u0c6c": "6",
84
+ "\u09ed": "7",
85
+ "7": "7",
86
+ "\u0aed": "7",
87
+ "\u096d": "7",
88
+ "\u0ced": "7",
89
+ "\u06f7": "7",
90
+ "\uabf7": "7",
91
+ "\u0b6d": "7",
92
+ "\u0a6d": "7",
93
+ "\u1c57": "7",
94
+ "\u0c6d": "7",
95
+ "\u09ee": "8",
96
+ "8": "8",
97
+ "\u0aee": "8",
98
+ "\u096e": "8",
99
+ "\u0cee": "8",
100
+ "\u06f8": "8",
101
+ "\uabf8": "8",
102
+ "\u0b6e": "8",
103
+ "\u0a6e": "8",
104
+ "\u1c58": "8",
105
+ "\u0c6e": "8",
106
+ "\u09ef": "9",
107
+ "9": "9",
108
+ "\u0aef": "9",
109
+ "\u096f": "9",
110
+ "\u0cef": "9",
111
+ "\u06f9": "9",
112
+ "\uabf9": "9",
113
+ "\u0b6f": "9",
114
+ "\u0a6f": "9",
115
+ "\u1c59": "9",
116
+ "\u0c6f": "9",
117
+ }
IndicTrans2/inference/model_configs/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import custom_transformer
IndicTrans2/inference/model_configs/custom_transformer.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fairseq.models import register_model_architecture
2
+ from fairseq.models.transformer import base_architecture
3
+
4
+
5
+ @register_model_architecture("transformer", "transformer_2x")
6
+ def transformer_big(args):
7
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
8
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
9
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
10
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
11
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
12
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
13
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
14
+ base_architecture(args)
15
+
16
+
17
+ @register_model_architecture("transformer", "transformer_4x")
18
+ def transformer_huge(args):
19
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1536)
20
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
21
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
22
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
23
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1536)
24
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
25
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
26
+ base_architecture(args)
27
+
28
+
29
+ @register_model_architecture("transformer", "transformer_9x")
30
+ def transformer_xlarge(args):
31
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 2048)
32
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 8192)
33
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
34
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
35
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048)
36
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8192)
37
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
38
+ base_architecture(args)
39
+
40
+
41
+ @register_model_architecture("transformer", "transformer_12e12d_9xeq")
42
+ def transformer_vxlarge(args):
43
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1536)
44
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
45
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
46
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
47
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1536)
48
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
49
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
50
+ args.encoder_layers = getattr(args, "encoder_layers", 12)
51
+ args.decoder_layers = getattr(args, "decoder_layers", 12)
52
+ base_architecture(args)
53
+
54
+
55
+ @register_model_architecture("transformer", "transformer_18_18")
56
+ def transformer_deep(args):
57
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
58
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 8 * 1024)
59
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
60
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
61
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
62
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
63
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8 * 1024)
64
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
65
+ args.encoder_layers = getattr(args, "encoder_layers", 18)
66
+ args.decoder_layers = getattr(args, "decoder_layers", 18)
67
+ base_architecture(args)
68
+
69
+
70
+ @register_model_architecture("transformer", "transformer_24_24")
71
+ def transformer_xdeep(args):
72
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
73
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 8 * 1024)
74
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
75
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
76
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
77
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
78
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8 * 1024)
79
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
80
+ args.encoder_layers = getattr(args, "encoder_layers", 24)
81
+ args.decoder_layers = getattr(args, "decoder_layers", 24)
82
+ base_architecture(args)
IndicTrans2/inference/normalize-punctuation.perl ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env perl
2
+ #
3
+ # This file is part of moses. Its use is licensed under the GNU Lesser General
4
+ # Public License version 2.1 or, at your option, any later version.
5
+
6
+ use warnings;
7
+ use strict;
8
+
9
+ my $language = "en";
10
+ my $PENN = 0;
11
+
12
+ while (@ARGV) {
13
+ $_ = shift;
14
+ /^-b$/ && ($| = 1, next); # not buffered (flush each line)
15
+ /^-l$/ && ($language = shift, next);
16
+ /^[^\-]/ && ($language = $_, next);
17
+ /^-penn$/ && ($PENN = 1, next);
18
+ }
19
+
20
+ while(<STDIN>) {
21
+ s/\r//g;
22
+ # remove extra spaces
23
+ s/\(/ \(/g;
24
+ s/\)/\) /g; s/ +/ /g;
25
+ s/\) ([\.\!\:\?\;\,])/\)$1/g;
26
+ s/\( /\(/g;
27
+ s/ \)/\)/g;
28
+ s/(\d) \%/$1\%/g;
29
+ s/ :/:/g;
30
+ s/ ;/;/g;
31
+ # normalize unicode punctuation
32
+ if ($PENN == 0) {
33
+ s/\`/\'/g;
34
+ s/\'\'/ \" /g;
35
+ }
36
+
37
+ s/„/\"/g;
38
+ s/“/\"/g;
39
+ s/”/\"/g;
40
+ s/–/-/g;
41
+ s/—/ - /g; s/ +/ /g;
42
+ s/´/\'/g;
43
+ s/([a-z])‘([a-z])/$1\'$2/gi;
44
+ s/([a-z])’([a-z])/$1\'$2/gi;
45
+ s/‘/\'/g;
46
+ s/‚/\'/g;
47
+ s/’/\"/g;
48
+ s/''/\"/g;
49
+ s/´´/\"/g;
50
+ s/…/.../g;
51
+ # French quotes
52
+ s/ « / \"/g;
53
+ s/« /\"/g;
54
+ s/«/\"/g;
55
+ s/ » /\" /g;
56
+ s/ »/\"/g;
57
+ s/»/\"/g;
58
+ # handle pseudo-spaces
59
+ s/ \%/\%/g;
60
+ s/nº /nº /g;
61
+ s/ :/:/g;
62
+ s/ ºC/ ºC/g;
63
+ s/ cm/ cm/g;
64
+ s/ \?/\?/g;
65
+ s/ \!/\!/g;
66
+ s/ ;/;/g;
67
+ s/, /, /g; s/ +/ /g;
68
+
69
+ # English "quotation," followed by comma, style
70
+ if ($language eq "en") {
71
+ s/\"([,\.]+)/$1\"/g;
72
+ }
73
+ # Czech is confused
74
+ elsif ($language eq "cs" || $language eq "cz") {
75
+ }
76
+ # German/Spanish/French "quotation", followed by comma, style
77
+ else {
78
+ s/,\"/\",/g;
79
+ s/(\.+)\"(\s*[^<])/\"$1$2/g; # don't fix period at end of sentence
80
+ }
81
+
82
+
83
+ if ($language eq "de" || $language eq "es" || $language eq "cz" || $language eq "cs" || $language eq "fr") {
84
+ s/(\d) (\d)/$1,$2/g;
85
+ }
86
+ else {
87
+ s/(\d) (\d)/$1.$2/g;
88
+ }
89
+ print $_;
90
+ }
IndicTrans2/inference/normalize_punctuation.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # IMPORTANT NOTE: DO NOT DIRECTLY EDIT THIS FILE
2
+ # This file was manually ported from `normalize-punctuation.perl`
3
+ # TODO: Only supports English, add others
4
+
5
+ import regex as re
6
+ multispace_regex = re.compile("[ ]{2,}")
7
+ multidots_regex = re.compile(r"\.{2,}")
8
+ end_bracket_space_punc_regex = re.compile(r"\) ([\.!:?;,])")
9
+ digit_space_percent = re.compile(r"(\d) %")
10
+ double_quot_punc = re.compile(r"\"([,\.]+)")
11
+ digit_nbsp_digit = re.compile(r"(\d) (\d)")
12
+
13
+ def punc_norm(text, lang="en"):
14
+ text = text.replace('\r', '') \
15
+ .replace('(', " (") \
16
+ .replace(')', ") ") \
17
+ \
18
+ .replace("( ", "(") \
19
+ .replace(" )", ")") \
20
+ \
21
+ .replace(" :", ':') \
22
+ .replace(" ;", ';') \
23
+ .replace('`', "'") \
24
+ \
25
+ .replace('„', '"') \
26
+ .replace('“', '"') \
27
+ .replace('”', '"') \
28
+ .replace('–', '-') \
29
+ .replace('—', " - ") \
30
+ .replace('´', "'") \
31
+ .replace('‘', "'") \
32
+ .replace('‚', "'") \
33
+ .replace('’', "'") \
34
+ .replace("''", "\"") \
35
+ .replace("´´", '"') \
36
+ .replace('…', "...") \
37
+ .replace(" « ", " \"") \
38
+ .replace("« ", '"') \
39
+ .replace('«', '"') \
40
+ .replace(" » ", "\" ") \
41
+ .replace(" »", '"') \
42
+ .replace('»', '"') \
43
+ .replace(" %", '%') \
44
+ .replace("nº ", "nº ") \
45
+ .replace(" :", ':') \
46
+ .replace(" ºC", " ºC") \
47
+ .replace(" cm", " cm") \
48
+ .replace(" ?", '?') \
49
+ .replace(" !", '!') \
50
+ .replace(" ;", ';') \
51
+ .replace(", ", ", ") \
52
+
53
+
54
+ text = multispace_regex.sub(' ', text)
55
+ text = multidots_regex.sub('.', text)
56
+ text = end_bracket_space_punc_regex.sub(r")\1", text)
57
+ text = digit_space_percent.sub(r"\1%", text)
58
+ text = double_quot_punc.sub(r'\1"', text) # English "quotation," followed by comma, style
59
+ text = digit_nbsp_digit.sub(r"\1.\2", text) # What does it mean?
60
+ return text.strip(' ')
IndicTrans2/inference/normalize_punctuation.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -euo pipefail
2
+
3
+ root=$(dirname $0)
4
+
5
+ lang_map_path=$root/utils.map_token_lang.tsv
6
+
7
+ usage () {
8
+ echo "usage: $0 lang" >&2
9
+ exit 1
10
+ }
11
+
12
+ [ $# -eq 1 ] || usage
13
+
14
+ lang=$1
15
+
16
+ declare -A lang_map
17
+
18
+ while read line; do
19
+ key=$(cut -f1 <<< "$line")
20
+ val=$(cut -f2 <<< "$line")
21
+ lang_map[$key]=$val
22
+ done < $lang_map_path
23
+
24
+ if [ -v "lang_map[$lang]" ]; then
25
+ lang=${lang_map[$lang]}
26
+ elif [ -v "lang_map[${lang:0:3}]" ]; then
27
+ lang=${lang_map[${lang:0:3}]}
28
+ else
29
+ echo "undefined mapping: ${lang}, falling back to: en" >&2
30
+ lang=en
31
+ fi
32
+
33
+ perl $root/normalize-punctuation.perl $lang
IndicTrans2/inference/normalize_regex_inference.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import regex as re
3
+ import sys
4
+ from tqdm import tqdm
5
+ from .indic_num_map import INDIC_NUM_MAP
6
+
7
+
8
+ URL_PATTERN = r'\b(?<![\w/.])(?:(?:https?|ftp)://)?(?:(?:[\w-]+\.)+(?!\.))(?:[\w/\-?#&=%.]+)+(?!\.\w+)\b'
9
+ EMAIL_PATTERN = r'[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}'
10
+ # handles dates, time, percentages, proportion, ratio, etc
11
+ NUMERAL_PATTERN = r"(~?\d+\.?\d*\s?%?\s?-?\s?~?\d+\.?\d*\s?%|~?\d+%|\d+[-\/.,:']\d+[-\/.,:'+]\d+(?:\.\d+)?|\d+[-\/.:'+]\d+(?:\.\d+)?)"
12
+ # handles upi, social media handles and hashtags
13
+ OTHER_PATTERN = r'[A-Za-z0-9]*[#|@]\w+'
14
+
15
+
16
+ def normalize_indic_numerals(line: str):
17
+ """
18
+ Normalize the numerals in Indic languages from native script to Roman script (if present).
19
+
20
+ Args:
21
+ line (str): an input string with Indic numerals to be normalized.
22
+
23
+ Returns:
24
+ str: an input string with the all Indic numerals normalized to Roman script.
25
+ """
26
+ return "".join([INDIC_NUM_MAP.get(c, c) for c in line])
27
+
28
+
29
+ def wrap_with_placeholders(text: str, patterns: list) -> Tuple[str, dict]:
30
+ """
31
+ Wraps substrings with matched patterns in the given text with placeholders and returns
32
+ the modified text along with a mapping of the placeholders to their original value.
33
+
34
+ Args:
35
+ text (str): an input string which needs to be wrapped with the placeholders.
36
+ pattern (list): list of patterns to search for in the input string.
37
+
38
+ Returns:
39
+ Tuple[str, dict]: a tuple containing the modified text and a dictionary mapping
40
+ placeholders to their original values.
41
+ """
42
+ serial_no = 1
43
+
44
+ placeholder_entity_map = dict()
45
+
46
+ for pattern in patterns:
47
+ matches = set(re.findall(pattern, text))
48
+
49
+ # wrap common match with placeholder tags
50
+ for match in matches:
51
+ if pattern==URL_PATTERN :
52
+ #Avoids false positive URL matches for names with initials.
53
+ temp = match.replace(".",'')
54
+ if len(temp)<4:
55
+ continue
56
+ if pattern==NUMERAL_PATTERN :
57
+ #Short numeral patterns do not need placeholder based handling.
58
+ temp = match.replace(" ",'').replace(".",'').replace(":",'')
59
+ if len(temp)<4:
60
+ continue
61
+
62
+ #Set of Translations of "ID" in all the suppported languages have been collated.
63
+ #This has been added to deal with edge cases where placeholders might get translated.
64
+ indic_failure_cases = ['آی ڈی ', 'ꯑꯥꯏꯗꯤ', 'आईडी', 'आई . डी . ', 'ऐटि', 'آئی ڈی ', 'ᱟᱭᱰᱤ ᱾', 'आयडी', 'ऐडि', 'आइडि']
65
+ placeholder = "<ID{}>".format(serial_no)
66
+ alternate_placeholder = "< ID{} >".format(serial_no)
67
+ placeholder_entity_map[placeholder] = match
68
+ placeholder_entity_map[alternate_placeholder] = match
69
+
70
+ for i in indic_failure_cases:
71
+ placeholder_temp = "<{}{}>".format(i,serial_no)
72
+ placeholder_entity_map[placeholder_temp] = match
73
+ placeholder_temp = "< {}{} >".format(i, serial_no)
74
+ placeholder_entity_map[placeholder_temp] = match
75
+ placeholder_temp = "< {} {} >".format(i, serial_no)
76
+ placeholder_entity_map[placeholder_temp] = match
77
+
78
+ text = text.replace(match, placeholder)
79
+ serial_no+=1
80
+
81
+ text = re.sub("\s+", " ", text)
82
+
83
+ #Regex has failure cases in trailing "/" in URLs, so this is a workaround.
84
+ text = text.replace(">/",">")
85
+
86
+ return text, placeholder_entity_map
87
+
88
+
89
+ def normalize(text: str, patterns: list = [EMAIL_PATTERN, URL_PATTERN, NUMERAL_PATTERN, OTHER_PATTERN]) -> Tuple[str, dict]:
90
+ """
91
+ Normalizes and wraps the spans of input string with placeholder tags. It first normalizes
92
+ the Indic numerals in the input string to Roman script. Later, it uses the input string with normalized
93
+ Indic numerals to wrap the spans of text matching the pattern with placeholder tags.
94
+
95
+ Args:
96
+ text (str): input string.
97
+ pattern (list): list of patterns to search for in the input string.
98
+
99
+ Returns:
100
+ Tuple[str, dict]: a tuple containing the modified text and a dictionary mapping
101
+ placeholders to their original values.
102
+ """
103
+ text = normalize_indic_numerals(text.strip("\n"))
104
+ text, placeholder_entity_map = wrap_with_placeholders(text, patterns)
105
+ return text, placeholder_entity_map
IndicTrans2/inference/requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/anoopkunchukuttan/indic_nlp_library
2
+ git+https://github.com/pytorch/fairseq
3
+ sacremoses
4
+ pandas
5
+ mock
6
+ nltk
7
+ sacrebleu
8
+ urduhack[tf]
9
+ mosestokenizer
10
+ ctranslate2
11
+ sentencepiece
IndicTrans2/inference/triton_server/Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ARG BASE_IMAGE=nvcr.io/nvidia/tritonserver:22.12-py3
2
+ FROM ${BASE_IMAGE}
3
+
4
+ # Ensure apt-get won't prompt for selecting options
5
+ ENV DEBIAN_FRONTEND=noninteractive
6
+ ENV PYTHONIOENCODING=utf8
7
+
8
+ WORKDIR /home
9
+
10
+ WORKDIR /home/indicTrans2
11
+ COPY requirements.txt .
12
+ RUN pip install -r requirements.txt
13
+
14
+ COPY download.py .
15
+ RUN python3 download.py
16
+
17
+ COPY . ./inference
18
+
19
+ WORKDIR /home/
20
+ COPY ./triton_server/triton_repo ./triton_repo
21
+
22
+ CMD ["tritonserver", "--model-repository=/home/triton_repo", "--log-verbose=2", "--strict-model-config=false", "--http-port=8000", "--grpc-port=8001", "--metrics-port=8002"]
23
+ EXPOSE 8000
24
+ EXPOSE 8001
25
+ EXPOSE 8002
IndicTrans2/inference/triton_server/README.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Triton server
2
+
3
+ ## Building the image
4
+
5
+ ```
6
+ cd indicTrans2/inference/
7
+ docker build -f triton_server/Dockerfile -t indictrans2_triton .
8
+ ```
9
+
10
+ ## Running the container
11
+
12
+ Place the `en-indic` and `indic-en` checkpoint folders into `indicTrans2/checkpoints` directory
13
+
14
+ Then start the server by:
15
+ ```
16
+ docker run --shm-size=256m --gpus=1 --rm -v ${PWD}/../checkpoints/:/models/checkpoints -p 8000:8000 -t indictrans2_triton
17
+ ```
18
+
19
+ ## Sample client
20
+
21
+ - Do `pip install tritonclient[all] gevent` first.
22
+ - Then `python3 triton_server/client.py`
IndicTrans2/inference/triton_server/azure_ml/README.md ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Deployment on Azure Machine Learning
2
+
3
+ ## Pre-requisites
4
+
5
+ ```
6
+ cd inference/triton_server
7
+ ```
8
+
9
+ Set the environment for AML:
10
+ ```
11
+ export RESOURCE_GROUP=Dhruva-prod
12
+ export WORKSPACE_NAME=dhruva--central-india
13
+ export DOCKER_REGISTRY=dhruvaprod
14
+ ```
15
+
16
+ Also remember to edit the `yml` files accordingly.
17
+
18
+ ## Registering the model
19
+
20
+ ```
21
+ az ml model create --file azure_ml/model.yml --resource-group $RESOURCE_GROUP --workspace-name $WORKSPACE_NAME
22
+ ```
23
+
24
+ ## Pushing the docker image to Container Registry
25
+
26
+ ```
27
+ az acr login --name $DOCKER_REGISTRY
28
+ docker tag indictrans2_triton $DOCKER_REGISTRY.azurecr.io/nmt/triton-indictrans-v2:latest
29
+ docker push $DOCKER_REGISTRY.azurecr.io/nmt/triton-indictrans-v2:latest
30
+ ```
31
+
32
+ ## Creating the execution environment
33
+
34
+ ```
35
+ az ml environment create -f azure_ml/environment.yml -g $RESOURCE_GROUP -w $WORKSPACE_NAME
36
+ ```
37
+
38
+ ## Publishing the endpoint for online inference
39
+
40
+ ```
41
+ az ml online-endpoint create -f azure_ml/endpoint.yml -g $RESOURCE_GROUP -w $WORKSPACE_NAME
42
+ ```
43
+
44
+ Now from the Azure Portal, open the Container Registry, and grant ACR_PULL permission for the above endpoint, so that it is allowed to download the docker image.
45
+
46
+ ## Attaching a deployment
47
+
48
+ ```
49
+ az ml online-deployment create -f azure_ml/deployment.yml --all-traffic -g $RESOURCE_GROUP -w $WORKSPACE_NAME
50
+ ```
51
+
52
+ ## Testing if inference works
53
+
54
+ 1. From Azure ML Studio, go to the "Consume" tab, and get the endpoint domain (without `https://` or trailing `/`) and an authentication key.
55
+ 2. In `client.py`, enable `ENABLE_SSL = True`, and then set the `ENDPOINT_URL` variable as well as `Authorization` value inside `HTTP_HEADERS`.
56
+ 3. Run `python3 client.py`
IndicTrans2/inference/triton_server/azure_ml/deployment.yml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $schema: https://azuremlschemas.azureedge.net/latest/managedOnlineDeployment.schema.json
2
+ name: ai4b-indictransv2--t4-piv--gpu
3
+ endpoint_name: ai4b-indictransv2--t4
4
+ model: azureml:indictrans-v2--models:1
5
+ model_mount_path: /models
6
+ environment: azureml:triton-indictrans-v2-env:1
7
+ instance_type: Standard_NC4as_T4_v3
8
+ instance_count: 1
9
+ request_settings:
10
+ request_timeout_ms: 90000
11
+ max_concurrent_requests_per_instance: 100
12
+ max_queue_wait_ms: 2000
13
+ app_insights_enabled: true
IndicTrans2/inference/triton_server/azure_ml/endpoint.yml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ $schema: https://azuremlschemas.azureedge.net/latest/managedOnlineEndpoint.schema.json
2
+ name: ai4b-indictransv2--t4
3
+ auth_mode: key
IndicTrans2/inference/triton_server/azure_ml/environment.yml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $schema: https://azuremlschemas.azureedge.net/latest/environment.schema.json
2
+ name: triton-indictrans-v2-env
3
+ image: dhruvaprod.azurecr.io/nmt/triton-indictrans-v2:latest
4
+ version: 1
5
+ inference_config:
6
+ liveness_route:
7
+ path: /v2/health/live
8
+ port: 8000
9
+ readiness_route:
10
+ path: /v2/health/ready
11
+ port: 8000
12
+ scoring_route:
13
+ path: /
14
+ port: 8000
IndicTrans2/inference/triton_server/azure_ml/model.yml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ $schema: https://azuremlschemas.azureedge.net/latest/model.schema.json
2
+ name: indictrans-v2--models
3
+ version: 1
4
+ path: ../../../checkpoints
5
+ type: triton_model
IndicTrans2/inference/triton_server/client.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tritonclient.http as http_client
2
+ from tritonclient.utils import *
3
+ import numpy as np
4
+
5
+ ENABLE_SSL = False
6
+ ENDPOINT_URL = 'localhost:8000'
7
+ HTTP_HEADERS = {"Authorization": "Bearer __PASTE_KEY_HERE__"}
8
+
9
+ # Connect to the server
10
+ if ENABLE_SSL:
11
+ import gevent.ssl
12
+ triton_http_client = http_client.InferenceServerClient(
13
+ url=ENDPOINT_URL, verbose=False,
14
+ ssl=True, ssl_context_factory=gevent.ssl._create_default_https_context,
15
+ )
16
+ else:
17
+ triton_http_client = http_client.InferenceServerClient(
18
+ url=ENDPOINT_URL, verbose=False,
19
+ )
20
+
21
+ print("Is server ready - {}".format(triton_http_client.is_server_ready(headers=HTTP_HEADERS)))
22
+
23
+ def get_string_tensor(string_values, tensor_name):
24
+ string_obj = np.array(string_values, dtype="object")
25
+ input_obj = http_client.InferInput(tensor_name, string_obj.shape, np_to_triton_dtype(string_obj.dtype))
26
+ input_obj.set_data_from_numpy(string_obj)
27
+ return input_obj
28
+
29
+ def get_translation_input_for_triton(texts: list, src_lang: str, tgt_lang: str):
30
+ return [
31
+ get_string_tensor([[text] for text in texts], "INPUT_TEXT"),
32
+ get_string_tensor([[src_lang]] * len(texts), "INPUT_LANGUAGE_ID"),
33
+ get_string_tensor([[tgt_lang]] * len(texts), "OUTPUT_LANGUAGE_ID"),
34
+ ]
35
+
36
+ # Prepare input and output tensors
37
+ input_sentences = ["Hello world, I am Ram and I am from Ayodhya.", "How are you Ravan bro?"]
38
+ inputs = get_translation_input_for_triton(input_sentences, "en", "hi")
39
+ output0 = http_client.InferRequestedOutput("OUTPUT_TEXT")
40
+
41
+ # Send request
42
+ response = triton_http_client.infer(
43
+ "nmt",
44
+ model_version='1',
45
+ inputs=inputs,
46
+ outputs=[output0],
47
+ headers=HTTP_HEADERS,
48
+ )#.get_response()
49
+
50
+ # Decode the response
51
+ output_batch = response.as_numpy('OUTPUT_TEXT').tolist()
52
+ for input_sentence, translation in zip(input_sentences, output_batch):
53
+ print()
54
+ print(input_sentence)
55
+ print(translation[0].decode("utf-8"))
IndicTrans2/inference/triton_server/dhruva/ulca_model.json ADDED
The diff for this file is too large to render. See raw diff
 
IndicTrans2/inference/triton_server/triton_repo/nmt/1/model.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import numpy as np
5
+ import triton_python_backend_utils as pb_utils
6
+
7
+ PWD = os.path.dirname(__file__)
8
+
9
+ INFERENCE_MODULE_DIR = "/home/indicTrans2/"
10
+ sys.path.insert(0, INFERENCE_MODULE_DIR)
11
+ from inference.engine import Model, iso_to_flores
12
+ INDIC_LANGUAGES = set(iso_to_flores)
13
+
14
+ ALLOWED_DIRECTION_STRINGS = {"en-indic", "indic-en", "indic-indic"}
15
+ FORCE_PIVOTING = False
16
+ DEFAULT_PIVOT_LANG = "en"
17
+
18
+ class TritonPythonModel:
19
+ def initialize(self, args):
20
+ self.model_config = json.loads(args['model_config'])
21
+ self.model_instance_device_id = json.loads(args['model_instance_device_id'])
22
+ self.output_name = "OUTPUT_TEXT"
23
+ self.output_dtype = pb_utils.triton_string_to_numpy(
24
+ pb_utils.get_output_config_by_name(self.model_config, self.output_name)["data_type"])
25
+
26
+
27
+ # checkpoints_root_dir = os.path.join(PWD, "checkpoints")
28
+ checkpoints_root_dir = "/models/checkpoints"
29
+ checkpoint_folders = [ f.path for f in os.scandir(checkpoints_root_dir) if f.is_dir() ]
30
+ # The assumption is that, each folder name is `<src_direction>-to-<tgt_direction>`
31
+
32
+ if not checkpoint_folders:
33
+ raise RuntimeError(f"No checkpoint folders in: {checkpoints_root_dir}")
34
+
35
+ self.models = {}
36
+ for checkpoint_folder in checkpoint_folders:
37
+ direction_string = os.path.basename(checkpoint_folder)
38
+ assert direction_string in ALLOWED_DIRECTION_STRINGS, f"Checkpoint folder-name `{direction_string}` not allowed"
39
+ self.models[direction_string] = Model(os.path.join(checkpoint_folder, "ct2_fp16_model"), input_lang_code_format="iso", model_type="ctranslate2")
40
+ # self.models[direction_string] = Model(checkpoint_folder, input_lang_code_format="iso", model_type="fairseq")
41
+
42
+ self.pivot_lang = None
43
+ if "en-indic" in self.models and "indic-en" in self.models:
44
+ if "indic-indic" not in self.models:
45
+ self.pivot_lang = DEFAULT_PIVOT_LANG
46
+ elif FORCE_PIVOTING:
47
+ del self.models["indic-indic"]
48
+ self.pivot_lang = DEFAULT_PIVOT_LANG
49
+
50
+ def get_direction_string(self, input_language_id, output_language_id):
51
+ direction_string = None
52
+ if input_language_id == DEFAULT_PIVOT_LANG and output_language_id in INDIC_LANGUAGES:
53
+ direction_string = "en-indic"
54
+ elif input_language_id in INDIC_LANGUAGES:
55
+ if output_language_id == DEFAULT_PIVOT_LANG:
56
+ direction_string = "indic-en"
57
+ elif output_language_id in INDIC_LANGUAGES:
58
+ direction_string = "indic-indic"
59
+ return direction_string
60
+
61
+ def get_model(self, input_language_id, output_language_id):
62
+ direction_string = self.get_direction_string(input_language_id, output_language_id)
63
+
64
+ if direction_string in self.models:
65
+ return self.models[direction_string]
66
+ raise RuntimeError(f"Language-pair not supported: {input_language_id}-{output_language_id}")
67
+
68
+ def execute(self,requests):
69
+ # print("REQ_COUNT", len(requests))
70
+ modelwise_batches = {}
71
+ responses = []
72
+ for request_id, request in enumerate(requests):
73
+ input_text_batch = pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT").as_numpy()
74
+ input_language_id_batch = pb_utils.get_input_tensor_by_name(request, "INPUT_LANGUAGE_ID").as_numpy()
75
+ output_language_id_batch = pb_utils.get_input_tensor_by_name(request, "OUTPUT_LANGUAGE_ID").as_numpy()
76
+
77
+ input_text_batch = [input_text[0].decode("utf-8", "ignore") for input_text in input_text_batch]
78
+ input_language_id_batch = [input_language_id[0].decode("utf-8", "ignore") for input_language_id in input_language_id_batch]
79
+ output_language_id_batch = [output_language_id[0].decode("utf-8", "ignore") for output_language_id in output_language_id_batch]
80
+
81
+ responses.append([['']] * len(input_text_batch))
82
+
83
+ for input_id, (input_text, input_language_id, output_language_id) in enumerate(zip(input_text_batch, input_language_id_batch, output_language_id_batch)):
84
+ direction_string = self.get_direction_string(input_language_id, output_language_id)
85
+ if direction_string not in self.models:
86
+ if direction_string == "indic-indic" and self.pivot_lang:
87
+ pass
88
+ else:
89
+ raise RuntimeError(f"Language-pair not supported: {input_language_id}-{output_language_id}")
90
+
91
+ if direction_string not in modelwise_batches:
92
+ modelwise_batches[direction_string] = {
93
+ "payloads": [],
94
+ "text_id_to_req_id_input_id": [],
95
+ }
96
+
97
+ modelwise_batches[direction_string]["payloads"].append([input_text, input_language_id, output_language_id])
98
+ modelwise_batches[direction_string]["text_id_to_req_id_input_id"].append((request_id, input_id))
99
+
100
+ for direction_string, batch in modelwise_batches.items():
101
+ if direction_string == "indic-indic" and self.pivot_lang:
102
+ model = self.get_model("hi", self.pivot_lang)
103
+ original_langs = []
104
+ for i in range(len(batch["payloads"])):
105
+ original_langs.append(batch["payloads"][i][2])
106
+ batch["payloads"][i][2] = self.pivot_lang
107
+
108
+ pivot_texts = model.paragraphs_batch_translate__multilingual(batch["payloads"])
109
+
110
+ for i in range(len(batch["payloads"])):
111
+ batch["payloads"][i][0] = pivot_texts[i]
112
+ batch["payloads"][i][1] = self.pivot_lang
113
+ batch["payloads"][i][2] = original_langs[i]
114
+
115
+ model = self.get_model(self.pivot_lang, "hi")
116
+ translations = model.paragraphs_batch_translate__multilingual(batch["payloads"])
117
+ else:
118
+ model = self.models[direction_string]
119
+ translations = model.paragraphs_batch_translate__multilingual(batch["payloads"])
120
+ # translations = ["bro"] * len(batch["payloads"])
121
+
122
+ for translation, (request_id, output_id) in zip(translations, batch["text_id_to_req_id_input_id"]):
123
+ responses[request_id][output_id] = [translation]
124
+
125
+ for i in range(len(responses)):
126
+ responses[i] = pb_utils.InferenceResponse(output_tensors=[
127
+ pb_utils.Tensor(
128
+ self.output_name,
129
+ np.array(responses[i], dtype=self.output_dtype),
130
+ )
131
+ ])
132
+ return responses
133
+
134
+ def execute_sequential(self,requests):
135
+ # print("REQ_COUNT", len(requests))
136
+ responses = []
137
+ for request in requests:
138
+ input_text_batch = pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT").as_numpy()
139
+ input_language_id_batch = pb_utils.get_input_tensor_by_name(request, "INPUT_LANGUAGE_ID").as_numpy()
140
+ output_language_id_batch = pb_utils.get_input_tensor_by_name(request, "OUTPUT_LANGUAGE_ID").as_numpy()
141
+
142
+ input_text_batch = [input_text[0].decode("utf-8", "ignore") for input_text in input_text_batch]
143
+ input_language_id_batch = [input_language_id[0].decode("utf-8", "ignore") for input_language_id in input_language_id_batch]
144
+ output_language_id_batch = [output_language_id[0].decode("utf-8", "ignore") for output_language_id in output_language_id_batch]
145
+
146
+ generated_outputs = []
147
+
148
+ for input_text, input_language_id, output_language_id in zip(input_text_batch, input_language_id_batch, output_language_id_batch):
149
+ if self.pivot_lang and (input_language_id != self.pivot_lang and output_language_id != self.pivot_lang):
150
+ model = self.get_model(input_language_id, self.pivot_lang)
151
+ pivot_text = model.translate_paragraph(input_text, input_language_id, self.pivot_lang)
152
+
153
+ model = self.get_model(self.pivot_lang, output_language_id)
154
+ translation = model.translate_paragraph(pivot_text, self.pivot_lang, output_language_id)
155
+ else:
156
+ model = self.get_model(input_language_id, output_language_id)
157
+ translation = model.translate_paragraph(input_text, input_language_id, output_language_id)
158
+ generated_outputs.append([translation])
159
+
160
+ inference_response = pb_utils.InferenceResponse(output_tensors=[
161
+ pb_utils.Tensor(
162
+ self.output_name,
163
+ np.array(generated_outputs, dtype=self.output_dtype),
164
+ )
165
+ ])
166
+ responses.append(inference_response)
167
+ return responses
IndicTrans2/inference/triton_server/triton_repo/nmt/config.pbtxt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ backend: "python"
2
+ max_batch_size: 512
3
+ input [{
4
+ name: "INPUT_TEXT"
5
+ data_type: TYPE_STRING
6
+ dims: 1
7
+ },
8
+ {
9
+ name: "INPUT_LANGUAGE_ID"
10
+ data_type: TYPE_STRING
11
+ dims: 1
12
+ },
13
+ {
14
+ name: "OUTPUT_LANGUAGE_ID"
15
+ data_type: TYPE_STRING
16
+ dims: 1
17
+ }]
18
+
19
+ output {
20
+ name: "OUTPUT_TEXT"
21
+ data_type: TYPE_STRING
22
+ dims: 1
23
+ }
24
+
25
+ dynamic_batching {
26
+
27
+ }
28
+
29
+ instance_group [{
30
+ count: 1
31
+ kind: KIND_GPU
32
+ }]
IndicTrans2/inference/utils.map_token_lang.tsv ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ asm_Beng hi
2
+ ben_Beng hi
3
+ brx_Deva hi
4
+ doi_Deva hi
5
+ gom_Deva hi
6
+ eng_Latn en
7
+ guj_Gujr hi
8
+ hin_Deva hi
9
+ kan_Knda hi
10
+ kas_Arab ar
11
+ kas_Deva hi
12
+ mai_Deva hi
13
+ mar_Deva hi
14
+ mal_Mlym hi
15
+ mni_Beng hi
16
+ mni_Mtei en
17
+ npi_Deva hi
18
+ ory_Orya hi
19
+ pan_Guru hi
20
+ san_Deva hi
21
+ sat_Olck hi
22
+ snd_Arab ar
23
+ snd_Deva hi
24
+ tam_Taml hi
25
+ tel_Telu hi
26
+ urd_Arab ar