import streamlit as st import PyPDF2 import docx import io from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, MT5ForConditionalGeneration import torch from pathlib import Path import tempfile from typing import Union, Tuple, List, Dict import os import sys from datetime import datetime, timezone import warnings import json # Filter warnings warnings.filterwarnings('ignore', category=UserWarning) # Page config st.set_page_config( page_title="Enhanced Document Translation App", page_icon="ЁЯМР", layout="wide" ) # Constants and Configurations CONFIG = { "MAX_BATCH_LENGTH": 512, "MIN_BATCH_LENGTH": 50, "TRANSLATION_TEMPERATURE": 0.7, "CONTEXT_TEMPERATURE": 0.3, "NUM_BEAMS": 5, "SUPPORTED_LANGUAGES": { 'English': 'eng_Latn', 'Hindi': 'hin_Deva', 'Marathi': 'mar_Deva' }, "MT5_LANG_CODES": { 'eng_Latn': 'en', 'hin_Deva': 'hi', 'mar_Deva': 'mr' }, "GRAMMAR_PROMPTS": { 'en': "Fix grammar and improve fluency: ", 'hi': "рд╡реНрдпрд╛рдХрд░рдг рдФрд░ рдкреНрд░рд╡рд╛рд╣ рд╕реБрдзрд╛рд░реЗрдВ: ", 'mr': "рд╡реНрдпрд╛рдХрд░рдг рдЖрдгрд┐ рдкреНрд░рд╡рд╛рд╣ рд╕реБрдзрд╛рд░рд╛: " } } class DocumentProcessor: """Handles document processing and text extraction""" @staticmethod def extract_text_from_file(uploaded_file) -> str: file_extension = Path(uploaded_file.name).suffix.lower() extractors = { '.pdf': DocumentProcessor._extract_from_pdf, '.docx': DocumentProcessor._extract_from_docx, '.txt': lambda f: f.getvalue().decode('utf-8') } if file_extension not in extractors: raise ValueError(f"Unsupported file format: {file_extension}") return extractors[file_extension](uploaded_file) @staticmethod def _extract_from_pdf(file) -> str: pdf_reader = PyPDF2.PdfReader(file) return "\n".join(page.extract_text() for page in pdf_reader.pages).strip() @staticmethod def _extract_from_docx(file) -> str: doc = docx.Document(file) return "\n".join(paragraph.text for paragraph in doc.paragraphs).strip() class TextBatcher: """Handles text batching with improved sentence boundary detection""" @staticmethod def batch_process_text(text: str, max_length: int = CONFIG["MAX_BATCH_LENGTH"]) -> List[str]: sentences = TextBatcher._split_into_sentences(text) batches = [] current_batch = [] current_length = 0 for sentence in sentences: sentence_length = len(sentence) if current_length + sentence_length > max_length: if current_batch: batches.append(" ".join(current_batch)) current_batch = [sentence] current_length = sentence_length else: current_batch.append(sentence) current_length += sentence_length if current_batch: batches.append(" ".join(current_batch)) return batches @staticmethod def _split_into_sentences(text: str) -> List[str]: """Split text into sentences with improved boundary detection""" # Basic sentence boundary detection delimiters = ['. ', '! ', '? ', 'ред', 'рее', '\n'] sentences = [] current = text for delimiter in delimiters: parts = current.split(delimiter) current = parts[0] for part in parts[1:]: if len(current.strip()) > 0: sentences.append(current.strip() + delimiter.strip()) current = part if len(current.strip()) > 0: sentences.append(current.strip()) return sentences class ModelManager: """Manages loading and caching of AI models""" @st.cache_resource def load_models(): try: device = "cuda" if torch.cuda.is_available() else "cpu" # Load models with improved error handling models = { "gemma": ModelManager._load_gemma_model(), "nllb": ModelManager._load_nllb_model(), "mt5": ModelManager._load_mt5_model() } # Move models to appropriate device if not torch.cuda.is_available(): for model_tuple in models.values(): model_tuple[1].to(device) return models except Exception as e: st.error(f"Error loading models: {str(e)}") st.error(f"Python version: {sys.version}") st.error(f"PyTorch version: {torch.__version__}") raise e @staticmethod def _load_gemma_model(): tokenizer = AutoTokenizer.from_pretrained( "google/gemma-2b", token=os.environ.get('HF_TOKEN'), trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained( "google/gemma-2b", token=os.environ.get('HF_TOKEN'), torch_dtype=torch.float16, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True ) return (tokenizer, model) @staticmethod def _load_nllb_model(): tokenizer = AutoTokenizer.from_pretrained( "facebook/nllb-200-distilled-600M", token=os.environ.get('HF_TOKEN'), use_fast=False, trust_remote_code=True ) model = AutoModelForSeq2SeqLM.from_pretrained( "facebook/nllb-200-distilled-600M", token=os.environ.get('HF_TOKEN'), torch_dtype=torch.float16, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True ) return (tokenizer, model) @staticmethod def _load_mt5_model(): tokenizer = AutoTokenizer.from_pretrained( "google/mt5-base", token=os.environ.get('HF_TOKEN'), trust_remote_code=True ) model = MT5ForConditionalGeneration.from_pretrained( "google/mt5-base", token=os.environ.get('HF_TOKEN'), torch_dtype=torch.float16, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True ) return (tokenizer, model) class TranslationPipeline: """Manages the translation pipeline with context understanding""" def __init__(self, models: Dict): self.models = models @torch.no_grad() def process_text(self, text: str, source_lang: str, target_lang: str) -> str: # Split text into manageable batches batches = TextBatcher.batch_process_text(text) final_results = [] for batch in batches: # Step 1: Context Understanding context = self._understand_context(batch) # Step 2: Context-aware Translation translated = self._translate_with_context( context, source_lang, target_lang ) # Step 3: Grammar Correction corrected = self._correct_grammar( translated, target_lang ) final_results.append(corrected) return " ".join(final_results) def _understand_context(self, text: str) -> str: """Enhanced context understanding using Gemma model""" tokenizer, model = self.models["gemma"] prompt = f"""Analyze and provide context for translation: Text: {text} Key points to consider: - Main topic and subject matter - Cultural context and nuances - Technical terminology if any - Tone and style of writing Provide a clear and concise interpretation that maintains: 1. Original meaning 2. Cultural context 3. Technical accuracy 4. Tone and style""" inputs = tokenizer(prompt, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True) inputs = {k: v.to(model.device) for k, v in inputs.items()} outputs = model.generate( **inputs, max_length=CONFIG["MAX_BATCH_LENGTH"], do_sample=True, temperature=CONFIG["CONTEXT_TEMPERATURE"], pad_token_id=tokenizer.eos_token_id, num_return_sequences=1 ) context = tokenizer.decode(outputs[0], skip_special_tokens=True) return context.replace(prompt, "").strip() def _translate_with_context(self, text: str, source_lang: str, target_lang: str) -> str: """Enhanced translation using NLLB model with context awareness""" tokenizer, model = self.models["nllb"] source_lang_token = f"___{source_lang}___" target_lang_token = f"___{target_lang}___" inputs = tokenizer(text, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True) inputs = {k: v.to(model.device) for k, v in inputs.items()} target_lang_id = tokenizer.convert_tokens_to_ids(target_lang_token) outputs = model.generate( **inputs, forced_bos_token_id=target_lang_id, max_length=CONFIG["MAX_BATCH_LENGTH"], do_sample=True, temperature=CONFIG["TRANSLATION_TEMPERATURE"], num_beams=CONFIG["NUM_BEAMS"], num_return_sequences=1, length_penalty=1.0, repetition_penalty=1.2 ) return tokenizer.decode(outputs[0], skip_special_tokens=True) def _correct_grammar(self, text: str, target_lang: str) -> str: """Enhanced grammar correction using MT5 model""" tokenizer, model = self.models["mt5"] lang_code = CONFIG["MT5_LANG_CODES"][target_lang] prompt = CONFIG["GRAMMAR_PROMPTS"][lang_code] input_text = f"{prompt}{text}" inputs = tokenizer(input_text, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True) inputs = {k: v.to(model.device) for k, v in inputs.items()} outputs = model.generate( **inputs, max_length=CONFIG["MAX_BATCH_LENGTH"], num_beams=CONFIG["NUM_BEAMS"], length_penalty=1.0, early_stopping=True, no_repeat_ngram_size=2, do_sample=False ) corrected = tokenizer.decode(outputs[0], skip_special_tokens=True) for prefix in CONFIG["GRAMMAR_PROMPTS"].values(): corrected = corrected.replace(prefix, "") return corrected.strip() class DocumentExporter: """Handles document export operations""" @staticmethod def save_as_docx(text: str) -> io.BytesIO: doc = docx.Document() doc.add_paragraph(text) buffer = io.BytesIO() doc.save(buffer) buffer.seek(0) return buffer @staticmethod def save_as_text(text: str) -> io.BytesIO: buffer = io.BytesIO() buffer.write(text.encode()) buffer.seek(0) return buffer def main(): st.title("ЁЯМР Enhanced Document Translation App") # Check for HF_TOKEN if not os.environ.get('HF_TOKEN'): st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.") st.stop() # Display system info st.sidebar.markdown(f""" ### System Information **Current UTC Time:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')} **User:** {os.environ.get('USER', 'unknown')} """) # Load models with st.spinner("Loading models... This may take a few minutes."): try: models = ModelManager.load_models() pipeline = TranslationPipeline(models) except Exception as e: st.error(f"Error initializing translation pipeline: {str(e)}") return # File upload uploaded_file = st.file_uploader( "Upload your document (PDF, DOCX, or TXT)", type=['pdf', 'docx', 'txt'] ) # Language selection col1, col2 = st.columns(2) with col1: source_language = st.selectbox( "Source Language", options=list(CONFIG["SUPPORTED_LANGUAGES"].keys()), index=0 ) with col2: target_language = st.selectbox( "Target Language", options=list(CONFIG["SUPPORTED_LANGUAGES"].keys()), index=1 ) if uploaded_file and st.button("Translate", type="primary"): try: progress_bar = st.progress(0) status_text = st.empty() # Process document status_text.text("Extracting text from document...") text = DocumentProcessor.extract_text_from_file(uploaded_file) progress_bar.progress(20) # Perform translation status_text.text("Translating document with context understanding...") final_text = pipeline.process_text( text, CONFIG["SUPPORTED_LANGUAGES"][source_language], CONFIG["SUPPORTED_LANGUAGES"][target_language] ) progress_bar.progress(90) # Display result st.markdown("### Translation Result") st.text_area( label="Translated Text", value=final_text, height=200, key="translation_result" ) # Download options st.markdown("### Download Options") col1, col2 = st.columns(2) with col1: st.download_button( label="Download as TXT", data=DocumentExporter.save_as_text(final_text), file_name="translated_document.txt", mime="text/plain" ) with col2: st.download_button( label="Download as DOCX", data=DocumentExporter.save_as_docx(final_text), file_name="translated_document.docx", mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document" ) status_text.text("Translation completed successfully!") progress_bar.progress(100) except Exception as e: st.error(f"An error occurred: {str(e)}") if __name__ == "__main__": main()