synclm-demo / utils /translator.py
SCBconsulting's picture
Update utils/translator.py
6f26572 verified
# utils/translator.py
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from docx import Document
# ========== Model Loading (Cached Once) ==========
def load_model_and_tokenizer(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
return tokenizer, model
# English β†’ Portuguese
tokenizer_en_pt, model_en_pt = load_model_and_tokenizer("unicamp-dl/translation-en-pt-t5")
# Portuguese β†’ English
tokenizer_pt_en, model_pt_en = load_model_and_tokenizer("unicamp-dl/translation-pt-en-t5")
# ========== Preprocessing ==========
def clean_text(text: str) -> str:
return text.replace("\n", " ").replace(" ", " ").strip()
def chunk_text(text: str, max_chunk_chars: int = 500):
"""
Split long text into chunks based on character count.
"""
words = text.split()
chunks, current_chunk = [], ""
for word in words:
if len(current_chunk) + len(word) + 1 <= max_chunk_chars:
current_chunk += " " + word
else:
chunks.append(current_chunk.strip())
current_chunk = word
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
# ========== Translation Core Logic ==========
def translate_chunks(chunks, tokenizer, model):
translated = []
for chunk in chunks:
inputs = tokenizer(chunk, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model.generate(**inputs, max_length=512, num_beams=4)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
translated.append(decoded)
return " ".join(translated)
def translate_to_portuguese(text: str) -> str:
if not text.strip():
return "No input provided."
chunks = chunk_text(clean_text(text))
return translate_chunks(chunks, tokenizer_en_pt, model_en_pt)
def translate_to_english(text: str) -> str:
if not text.strip():
return "No input provided."
chunks = chunk_text(clean_text(text))
return translate_chunks(chunks, tokenizer_pt_en, model_pt_en)
def translate_text(text: str, direction: str = "en-pt") -> str:
"""
direction = 'en-pt' or 'pt-en'
"""
if direction == "en-pt":
return translate_to_portuguese(text)
elif direction == "pt-en":
return translate_to_english(text)
else:
return "Unsupported translation direction."
# ========== Bilingual View ==========
def bilingual_clauses(text: str) -> str:
"""
Create bilingual clause-by-clause output (EN + PT).
"""
if not text.strip():
return "No input provided."
clauses_en = chunk_text(clean_text(text), max_chunk_chars=300)
bilingual_output = []
for clause in clauses_en:
translated = translate_to_portuguese(clause)
bilingual_output.append(f"πŸ“˜ EN: {clause}\nπŸ“— PT: {translated}\n" + "-" * 60)
return "\n\n".join(bilingual_output)
# ========== Export to DOCX ==========
def export_to_word(text: str, filename: str = "translated_contract.docx") -> str:
"""
Export text (bilingual or full) to Word DOCX.
"""
doc = Document()
doc.add_heading("Legal Translation Output", level=1)
for para in text.split("\n\n"):
doc.add_paragraph(para)
doc.save(filename)
return filename