import streamlit as st import PyPDF2 import docx import io from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration import torch from pathlib import Path import tempfile from typing import Union, Tuple import os # Get Hugging Face token from environment variables HF_TOKEN = os.environ.get('HF_TOKEN') if not HF_TOKEN: st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.") st.stop() # Define supported languages and their codes SUPPORTED_LANGUAGES = { 'English': 'eng_Latn', 'Hindi': 'hin_Deva', 'Marathi': 'mar_Deva' } # Language codes for MT5 MT5_LANG_CODES = { 'eng_Latn': 'en', 'hin_Deva': 'hi', 'mar_Deva': 'mr' } @st.cache_resource def load_models(): """Load and cache the translation, context interpretation, and grammar correction models.""" # Load Gemma model for context interpretation gemma_tokenizer = AutoTokenizer.from_pretrained( "google/gemma-2b", token=HF_TOKEN ) gemma_model = AutoModelForCausalLM.from_pretrained( "google/gemma-2b", device_map="auto", torch_dtype=torch.float16, token=HF_TOKEN ) # Load NLLB model for translation nllb_tokenizer = AutoTokenizer.from_pretrained( "facebook/nllb-200-distilled-600M", token=HF_TOKEN ) nllb_model = AutoModelForSeq2SeqLM.from_pretrained( "facebook/nllb-200-distilled-600M", device_map="auto", torch_dtype=torch.float16, token=HF_TOKEN ) # Load MT5 model for grammar correction mt5_tokenizer = AutoTokenizer.from_pretrained( "google/mt5-small", token=HF_TOKEN ) mt5_model = T5ForConditionalGeneration.from_pretrained( "google/mt5-small", device_map="auto", torch_dtype=torch.float16, token=HF_TOKEN ) return (gemma_tokenizer, gemma_model), (nllb_tokenizer, nllb_model), (mt5_tokenizer, mt5_model) def extract_text_from_file(uploaded_file) -> str: """Extract text content from uploaded file based on its type.""" file_extension = Path(uploaded_file.name).suffix.lower() if file_extension == '.pdf': return extract_from_pdf(uploaded_file) elif file_extension == '.docx': return extract_from_docx(uploaded_file) elif file_extension == '.txt': return uploaded_file.getvalue().decode('utf-8') else: raise ValueError(f"Unsupported file format: {file_extension}") def extract_from_pdf(file) -> str: """Extract text from PDF file.""" pdf_reader = PyPDF2.PdfReader(file) text = "" for page in pdf_reader.pages: text += page.extract_text() + "\n" return text.strip() def extract_from_docx(file) -> str: """Extract text from DOCX file.""" doc = docx.Document(file) text = "" for paragraph in doc.paragraphs: text += paragraph.text + "\n" return text.strip() def interpret_context(text: str, gemma_tuple: Tuple) -> str: """Use Gemma model to interpret context and understand regional nuances.""" tokenizer, model = gemma_tuple prompt = f"""Analyze the following text for context and cultural nuances, maintaining the core meaning while identifying any idiomatic expressions or cultural references: {text}""" inputs = tokenizer(prompt, return_tensors="pt").to(model.device) outputs = model.generate( **inputs, max_length=1024, temperature=0.3, pad_token_id=tokenizer.eos_token_id ) interpreted_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return interpreted_text def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tuple) -> str: """Translate text using NLLB model.""" tokenizer, model = nllb_tuple inputs = tokenizer(text, return_tensors="pt").to(model.device) forced_bos_token_id = tokenizer.lang_code_to_id[target_lang] outputs = model.generate( **inputs, forced_bos_token_id=forced_bos_token_id, max_length=1024, temperature=0.7, num_beams=5 ) translated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] return translated_text def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str: """ Correct grammar using MT5 model for all supported languages. Uses a text-to-text approach with language-specific prompts. """ tokenizer, model = mt5_tuple lang_code = MT5_LANG_CODES[target_lang] # Language-specific prompts for grammar correction prompts = { 'en': f"grammar: {text}", 'hi': f"व्याकरण सुधार: {text}", 'mr': f"व्याकरण सुधारणा: {text}" } prompt = prompts[lang_code] inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(model.device) outputs = model.generate( **inputs, max_length=512, num_beams=5, temperature=0.7, top_p=0.9, do_sample=True ) corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Clean up any artifacts from the model output corrected_text = corrected_text.replace("grammar:", "").replace("व्याकरण सुधार:", "").replace("व्याकरण सुधारणा:", "").strip() return corrected_text def save_as_docx(text: str) -> io.BytesIO: """Save translated text as a DOCX file.""" doc = docx.Document() doc.add_paragraph(text) docx_buffer = io.BytesIO() doc.save(docx_buffer) docx_buffer.seek(0) return docx_buffer def main(): st.title("Document Translation App") # Load models with st.spinner("Loading models... This may take a few minutes."): try: gemma_tuple, nllb_tuple, mt5_tuple = load_models() except Exception as e: st.error(f"Error loading models: {str(e)}") st.error("Please check if the HF_TOKEN is valid and has the necessary permissions.") st.stop() # File upload uploaded_file = st.file_uploader( "Upload your document (PDF, DOCX, or TXT)", type=['pdf', 'docx', 'txt'] ) # Language selection col1, col2 = st.columns(2) with col1: source_language = st.selectbox( "Source Language", options=list(SUPPORTED_LANGUAGES.keys()), index=0 ) with col2: target_language = st.selectbox( "Target Language", options=list(SUPPORTED_LANGUAGES.keys()), index=1 ) if uploaded_file and st.button("Translate"): try: with st.spinner("Processing document..."): # Extract text text = extract_text_from_file(uploaded_file) st.text_area("Extracted Text:", value=text, height=150) # Interpret context with st.spinner("Interpreting context..."): interpreted_text = interpret_context(text, gemma_tuple) # Translate with st.spinner("Translating..."): translated_text = translate_text( interpreted_text, SUPPORTED_LANGUAGES[source_language], SUPPORTED_LANGUAGES[target_language], nllb_tuple ) # Grammar correction with st.spinner("Correcting grammar..."): corrected_text = correct_grammar( translated_text, SUPPORTED_LANGUAGES[target_language], mt5_tuple ) # Display result st.subheader("Translation Result:") st.text_area("Translated Text:", value=corrected_text, height=150) # Download options st.subheader("Download Translation:") # Text file download text_buffer = io.BytesIO() text_buffer.write(corrected_text.encode()) text_buffer.seek(0) st.download_button( label="Download as TXT", data=text_buffer, file_name="translated_document.txt", mime="text/plain" ) # DOCX file download docx_buffer = save_as_docx(corrected_text) st.download_button( label="Download as DOCX", data=docx_buffer, file_name="translated_document.docx", mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document" ) except Exception as e: st.error(f"An error occurred: {str(e)}") if __name__ == "__main__": main()