cointegrated
commited on
Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
This is a version of NLLB fine-tuned to translate sentences between eng and azj languages.
|
2 |
+
|
3 |
+
Example inference code (with the correct NLLB preprocessing!):
|
4 |
+
|
5 |
+
```Python
|
6 |
+
from transformers import NllbTokenizer, AutoModelForSeq2SeqLM, AutoConfig
|
7 |
+
# this code is adapted from the Stopes repo of the NLLB team
|
8 |
+
# https://github.com/facebookresearch/stopes/blob/main/stopes/pipelines/monolingual/monolingual_line_processor.py#L214
|
9 |
+
|
10 |
+
import re
|
11 |
+
import sys
|
12 |
+
import typing as tp
|
13 |
+
import unicodedata
|
14 |
+
from sacremoses import MosesPunctNormalizer
|
15 |
+
|
16 |
+
|
17 |
+
mpn = MosesPunctNormalizer(lang="en")
|
18 |
+
mpn.substitutions = [
|
19 |
+
(re.compile(r), sub) for r, sub in mpn.substitutions
|
20 |
+
]
|
21 |
+
|
22 |
+
|
23 |
+
def get_non_printing_char_replacer(replace_by: str = " ") -> tp.Callable[[str], str]:
|
24 |
+
non_printable_map = {
|
25 |
+
ord(c): replace_by
|
26 |
+
for c in (chr(i) for i in range(sys.maxunicode + 1))
|
27 |
+
# same as \p{C} in perl
|
28 |
+
# see https://www.unicode.org/reports/tr44/#General_Category_Values
|
29 |
+
if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"}
|
30 |
+
}
|
31 |
+
|
32 |
+
def replace_non_printing_char(line) -> str:
|
33 |
+
return line.translate(non_printable_map)
|
34 |
+
|
35 |
+
return replace_non_printing_char
|
36 |
+
|
37 |
+
replace_nonprint = get_non_printing_char_replacer(" ")
|
38 |
+
|
39 |
+
def preproc(text):
|
40 |
+
clean = mpn.normalize(text)
|
41 |
+
clean = replace_nonprint(clean)
|
42 |
+
# replace 𝓕𝔯𝔞𝔫𝔠𝔢𝔰𝔠𝔞 by Francesca
|
43 |
+
clean = unicodedata.normalize("NFKC", clean)
|
44 |
+
return clean
|
45 |
+
|
46 |
+
# loading the model
|
47 |
+
model_name = "slone/nllb-600M-azj-eng-v1"
|
48 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).cuda()
|
49 |
+
tokenizer = NllbTokenizer.from_pretrained(model_name)
|
50 |
+
|
51 |
+
def translate(text, src_lang='eng_Latn', tgt_lang='azj_Latn', a=32, b=3, max_input_length=1024, num_beams=4, **kwargs):
|
52 |
+
tokenizer.src_lang = src_lang
|
53 |
+
tokenizer.tgt_lang = tgt_lang
|
54 |
+
if isinstance(text, str):
|
55 |
+
text = [text]
|
56 |
+
text = [preproc(t) for t in text]
|
57 |
+
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=max_input_length)
|
58 |
+
result = model.generate(
|
59 |
+
**inputs.to(model.device),
|
60 |
+
forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
|
61 |
+
max_new_tokens=int(a + b * inputs.input_ids.shape[1]),
|
62 |
+
num_beams=num_beams,
|
63 |
+
**kwargs
|
64 |
+
)
|
65 |
+
return tokenizer.batch_decode(result, skip_special_tokens=True)
|
66 |
+
|
67 |
+
# Example of translating a couple of texts:
|
68 |
+
texts = translate(["To be, or not to be, that is the question.", "Hello, how are you?"], src_lang='eng_Latn', tgt_lang='azj_Latn')
|
69 |
+
print(texts)
|
70 |
+
# ['Olmaq və ya olmamaq sualdır.', 'Salam, necə var?']
|
71 |
+
```
|
72 |
+
|
73 |
+
If you want to translate too many sentences, you will need to put them in small batches
|
74 |
+
(batch size can be chosen as the largest that fits into your GPU memory).
|
75 |
+
An efficient way would be to batch them by similar length, like below:
|
76 |
+
|
77 |
+
```Python
|
78 |
+
def batched_translate(texts, batch_size=16, **kwargs):
|
79 |
+
"""Translate texts in batches of similar length"""
|
80 |
+
idxs, texts2 = zip(*sorted(enumerate(texts), key=lambda p: len(p[1]), reverse=True))
|
81 |
+
results = []
|
82 |
+
for i in trange(0, len(texts2), batch_size):
|
83 |
+
results.extend(translate(texts2[i: i+batch_size], **kwargs))
|
84 |
+
return [p for i, p in sorted(zip(idxs, results))]
|
85 |
+
```
|
86 |
+
|