Update README.md
#9
by
VarunGumma
- opened
README.md
CHANGED
@@ -64,17 +64,22 @@ Please refer to the [github repository](https://github.com/AI4Bharat/IndicTrans2
|
|
64 |
|
65 |
```python
|
66 |
import torch
|
67 |
-
from transformers import
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
|
|
|
74 |
model_name = "ai4bharat/indictrans2-en-indic-1B"
|
75 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
76 |
|
77 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
ip = IndicProcessor(inference=True)
|
80 |
|
@@ -85,16 +90,12 @@ input_sentences = [
|
|
85 |
"My friend has invited me to his birthday party, and I will give him a gift.",
|
86 |
]
|
87 |
|
88 |
-
src_lang, tgt_lang = "eng_Latn", "hin_Deva"
|
89 |
-
|
90 |
batch = ip.preprocess_batch(
|
91 |
input_sentences,
|
92 |
src_lang=src_lang,
|
93 |
tgt_lang=tgt_lang,
|
94 |
)
|
95 |
|
96 |
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
97 |
-
|
98 |
# Tokenize the sentences and generate input encodings
|
99 |
inputs = tokenizer(
|
100 |
batch,
|
@@ -131,7 +132,10 @@ for input_sentence, translation in zip(input_sentences, translations):
|
|
131 |
print(f"{tgt_lang}: {translation}")
|
132 |
```
|
133 |
|
134 |
-
|
|
|
|
|
|
|
135 |
|
136 |
|
137 |
### Citation
|
|
|
64 |
|
65 |
```python
|
66 |
import torch
|
67 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
68 |
+
from IndicTransToolkit import IndicProcessor
|
69 |
+
# recommended to run this on a gpu with flash_attn installed
|
70 |
+
# don't set attn_implemetation if you don't have flash_attn
|
71 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
72 |
|
73 |
+
src_lang, tgt_lang = "eng_Latn", "hin_Deva"
|
74 |
model_name = "ai4bharat/indictrans2-en-indic-1B"
|
75 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
76 |
|
77 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
78 |
+
model_name,
|
79 |
+
trust_remote_code=True,
|
80 |
+
torch_dtype=torch.float16, # performance might slightly vary for bfloat16
|
81 |
+
attn_implementation="flash_attention_2"
|
82 |
+
).to(DEVICE)
|
83 |
|
84 |
ip = IndicProcessor(inference=True)
|
85 |
|
|
|
90 |
"My friend has invited me to his birthday party, and I will give him a gift.",
|
91 |
]
|
92 |
|
|
|
|
|
93 |
batch = ip.preprocess_batch(
|
94 |
input_sentences,
|
95 |
src_lang=src_lang,
|
96 |
tgt_lang=tgt_lang,
|
97 |
)
|
98 |
|
|
|
|
|
99 |
# Tokenize the sentences and generate input encodings
|
100 |
inputs = tokenizer(
|
101 |
batch,
|
|
|
132 |
print(f"{tgt_lang}: {translation}")
|
133 |
```
|
134 |
|
135 |
+
### 📢 Long Context IT2 Models
|
136 |
+
- New RoPE based IndicTrans2 models which are capable of handling sequence lengths **upto 2048 tokens** are available [here](https://huggingface.co/collections/prajdabre/indictrans2-rope-6742ddac669a05db0804db35)
|
137 |
+
- These models can be used by just changing the `model_name` parameter. Please read the model card of the RoPE-IT2 models for more information about the generation.
|
138 |
+
- It is recommended to run these models with `flash_attention_2` for efficient generation.
|
139 |
|
140 |
|
141 |
### Citation
|