SCBconsulting commited on
Commit
15b6b6b
Β·
verified Β·
1 Parent(s): 8aa4a08

Update utils/translator.py

Browse files
Files changed (1) hide show
  1. utils/translator.py +11 -9
utils/translator.py CHANGED
@@ -1,14 +1,15 @@
1
  # utils/translator.py
2
 
3
  import os
4
- from transformers import MarianMTModel, MarianTokenizer
 
5
 
6
  DEEPL_API_KEY = os.getenv("DEEPL_API_KEY")
7
 
8
- # Fallback model (HF)
9
- model_name = "Helsinki-NLP/opus-mt-en-ROMANCE"
10
- tokenizer = MarianTokenizer.from_pretrained(model_name)
11
- model = MarianMTModel.from_pretrained(model_name)
12
 
13
  def translate_text(text):
14
  if not text.strip():
@@ -22,13 +23,14 @@ def translate_text(text):
22
  data={
23
  "auth_key": DEEPL_API_KEY,
24
  "text": text,
25
- "target_lang": "PT"
26
  },
27
  )
28
  return response.json()["translations"][0]["text"]
29
 
30
  except Exception:
31
- # Use Hugging Face fallback model
32
- inputs = tokenizer.encode(text, return_tensors="pt", truncation=True)
33
- outputs = model.generate(inputs, max_length=512, num_beams=4)
 
34
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
1
  # utils/translator.py
2
 
3
  import os
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
 
7
  DEEPL_API_KEY = os.getenv("DEEPL_API_KEY")
8
 
9
+ # βœ… Better fallback model (Brazilian Portuguese)
10
+ model_name = "unicamp-dl/translation-en-pt-t5"
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
13
 
14
  def translate_text(text):
15
  if not text.strip():
 
23
  data={
24
  "auth_key": DEEPL_API_KEY,
25
  "text": text,
26
+ "target_lang": "PT-BR" # 🟒 Brazil-specific
27
  },
28
  )
29
  return response.json()["translations"][0]["text"]
30
 
31
  except Exception:
32
+ # πŸ” Use HF fallback model
33
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
34
+ with torch.no_grad():
35
+ outputs = model.generate(**inputs, max_length=512, num_beams=4)
36
  return tokenizer.decode(outputs[0], skip_special_tokens=True)