Spaces:
Build error
Build error
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import streamlit as st | |
from PyPDF2 import PdfReader | |
import docx | |
import os | |
import re | |
# Load NLLB model and tokenizer | |
def load_translation_model(): | |
model_name = "facebook/nllb-200-distilled-600M" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
return tokenizer, model | |
# Initialize model | |
def initialize_models(): | |
tokenizer, model = load_translation_model() | |
return {"nllb": (tokenizer, model)} | |
# Function to extract text from different file types | |
def extract_text(file): | |
ext = os.path.splitext(file.name)[1].lower() | |
if ext == ".pdf": | |
reader = PdfReader(file) | |
text = "" | |
for page in reader.pages: | |
text += page.extract_text() + "\n" | |
return text | |
elif ext == ".docx": | |
doc = docx.Document(file) | |
text = "" | |
for para in doc.paragraphs: | |
text += para.text + "\n" | |
return text | |
elif ext == ".txt": | |
return file.read().decode("utf-8") | |
else: | |
raise ValueError("Unsupported file format. Please upload PDF, DOCX, or TXT files.") | |
# Translation function | |
def translate_text(text, src_lang, tgt_lang, models): | |
if src_lang == tgt_lang: | |
return text | |
# Language codes for NLLB | |
lang_map = {"en": "eng_Latn", "hi": "hin_Deva", "mr": "mar_Deva"} | |
if src_lang not in lang_map or tgt_lang not in lang_map: | |
return "Error: Unsupported language combination" | |
tgt_lang_code = lang_map[tgt_lang] | |
tokenizer, model = models["nllb"] | |
# Preprocess for idioms | |
preprocessed_text = preprocess_idioms(text, src_lang, tgt_lang) | |
# Split text into manageable chunks | |
sentences = preprocessed_text.split("\n") | |
translated_text = "" | |
for sentence in sentences: | |
if sentence.strip(): | |
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
# Use lang_code_to_id instead of get_lang_id | |
translated = model.generate( | |
**inputs, | |
forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang_code], | |
max_length=512 | |
) | |
translated_sentence = tokenizer.decode(translated[0], skip_special_tokens=True) | |
translated_text += translated_sentence + "\n" | |
return translated_text | |
# Function to save text as a file | |
def save_text_to_file(text, original_filename, prefix="translated"): | |
output_filename = f"{prefix}_{os.path.basename(original_filename)}.txt" | |
with open(output_filename, "w", encoding="utf-8") as f: | |
f.write(text) | |
return output_filename | |
# Main processing function | |
def process_document(file, source_lang, target_lang, models): | |
try: | |
# Extract text from uploaded file | |
text = extract_text(file) | |
# Translate the text | |
translated_text = translate_text(text, source_lang, target_lang, models) | |
# Save the result (success or error) to a file | |
if translated_text.startswith("Error:"): | |
output_file = save_text_to_file(translated_text, file.name, prefix="error") | |
else: | |
output_file = save_text_to_file(translated_text, file.name) | |
return output_file, translated_text | |
except Exception as e: | |
# Save error message to a file | |
error_message = f"Error: {str(e)}" | |
output_file = save_text_to_file(error_message, file.name, prefix="error") | |
return output_file, error_message | |
# Streamlit interface | |
def main(): | |
st.title("Document Translator (NLLB-200)") | |
st.write("Upload a document (PDF, DOCX, or TXT) and select source and target languages (English, Hindi, Marathi).") | |
# Initialize models | |
models = initialize_models() | |
# File uploader | |
uploaded_file = st.file_uploader("Upload Document", type=["pdf", "docx", "txt"]) | |
# Language selection | |
col1, col2 = st.columns(2) | |
with col1: | |
source_lang = st.selectbox("Source Language", ["en", "hi", "mr"], index=0) | |
with col2: | |
target_lang = st.selectbox("Target Language", ["en", "hi", "mr"], index=1) | |
if uploaded_file is not None and st.button("Translate"): | |
with st.spinner("Translating..."): | |
output_file, result_text = process_document(uploaded_file, source_lang, target_lang, models) | |
# Display result | |
st.text_area("Translated Text", result_text, height=300) | |
# Provide download button | |
with open(output_file, "rb") as file: | |
st.download_button( | |
label="Download Translated Document", | |
data=file, | |
file_name=os.path.basename(output_file), | |
mime="text/plain" | |
) | |
if __name__ == "__main__": | |
main() |