|
import torch |
|
from transformers import MarianMTModel, AutoTokenizer |
|
import ctranslate2 |
|
from colorize import align_words |
|
import logging |
|
|
|
|
|
logger = logging.getLogger() |
|
logger.setLevel(logging.INFO) |
|
file_handler = logging.FileHandler('app.log', mode='a') |
|
file_handler.setLevel(logging.INFO) |
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
file_handler.setFormatter(formatter) |
|
logger.addHandler(file_handler) |
|
|
|
model_to_ar = MarianMTModel.from_pretrained("./he_ar/", output_attentions=True) |
|
model_from_ar = MarianMTModel.from_pretrained("./ar_he/", output_attentions=True) |
|
model_to_ar_ct2 = ctranslate2.Translator("./he_ar_ct2/") |
|
model_from_ar_ct2 = ctranslate2.Translator("./ar_he_ct2/") |
|
|
|
tokenizer_to_ar = AutoTokenizer.from_pretrained("./he_ar/") |
|
tokenizer_from_ar = AutoTokenizer.from_pretrained("./ar_he/") |
|
print("Done loading models") |
|
|
|
dialect_map = { |
|
"Palestinian": "P", |
|
"Syrian": "S", |
|
"Lebanese": "L", |
|
"Egyptian": "E", |
|
"פלסטיני": "P", |
|
"סורי": "S", |
|
"לבנוני": "L", |
|
"מצרי": "E" |
|
} |
|
|
|
|
|
def translate(text, ct_model, hf_model, tokenizer, to_arabic=True, |
|
threshold=None, layer=2, head=6): |
|
|
|
logger.info(f"Translating: {text}") |
|
inp_tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(text)) |
|
out_tokens = ct_model.translate_batch([inp_tokens])[0].hypotheses[0] |
|
out_string = tokenizer.convert_tokens_to_string(out_tokens) |
|
|
|
encoder_input_ids = torch.tensor(tokenizer.convert_tokens_to_ids(inp_tokens)).unsqueeze(0) |
|
decoder_input_ids = torch.tensor(tokenizer.convert_tokens_to_ids(["<pad>"] + out_tokens + |
|
['</s>'])).unsqueeze(0) |
|
|
|
colorization_output = hf_model(input_ids=encoder_input_ids, |
|
decoder_input_ids=decoder_input_ids) |
|
|
|
if not threshold: |
|
if len(inp_tokens) < 10: |
|
threshold = 0.05 |
|
elif len(inp_tokens) < 20: |
|
threshold = 0.10 |
|
else: |
|
threshold = 0.05 |
|
|
|
srchtml, tgthtml = align_words(colorization_output, |
|
tokenizer, |
|
encoder_input_ids, |
|
decoder_input_ids, |
|
threshold, |
|
skip_first_src=to_arabic, |
|
skip_second_src=False, |
|
layer=layer, |
|
head=head) |
|
|
|
html = f"<div style='direction: rtl'>{srchtml}<br><br>{tgthtml}</div>" |
|
|
|
arabic = out_string if is_arabic(out_string) else text |
|
return html, arabic |
|
|
|
|
|
|
|
|
|
|
|
def is_arabic(text): |
|
|
|
text = text.replace(" ", "") |
|
arabic_chars = 0 |
|
for c in text: |
|
if "\u0600" <= c <= "\u06FF": |
|
arabic_chars += 1 |
|
|
|
return arabic_chars / len(text) > 0.5 |
|
|
|
def run_translate(text, dialect=None): |
|
if not text: |
|
return "" |
|
if is_arabic(text): |
|
return translate(text, model_from_ar_ct2, model_from_ar, tokenizer_from_ar, |
|
to_arabic=False, threshold=None, layer=2, head=1) |
|
else: |
|
if dialect in dialect_map: |
|
dialect = dialect_map[dialect] |
|
|
|
text = f"{dialect} {text}" if dialect else text |
|
return translate(text, model_to_ar_ct2, model_to_ar, tokenizer_to_ar, |
|
to_arabic=True, threshold=None, layer=2, head=6) |