Spaces:
Build error
Build error
import streamlit as st | |
import PyPDF2 | |
import docx | |
import io | |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, MT5ForConditionalGeneration | |
import torch | |
from pathlib import Path | |
import tempfile | |
from typing import Union, Tuple, List, Dict | |
import os | |
import sys | |
from datetime import datetime, timezone | |
import warnings | |
import json | |
# Filter warnings | |
warnings.filterwarnings('ignore', category=UserWarning) | |
# Page config | |
st.set_page_config( | |
page_title="Enhanced Document Translation App", | |
page_icon="🌐", | |
layout="wide" | |
) | |
# Constants and Configurations | |
CONFIG = { | |
"MAX_BATCH_LENGTH": 512, | |
"MIN_BATCH_LENGTH": 50, | |
"TRANSLATION_TEMPERATURE": 0.7, | |
"CONTEXT_TEMPERATURE": 0.3, | |
"NUM_BEAMS": 5, | |
"SUPPORTED_LANGUAGES": { | |
'English': 'eng_Latn', | |
'Hindi': 'hin_Deva', | |
'Marathi': 'mar_Deva' | |
}, | |
"MT5_LANG_CODES": { | |
'eng_Latn': 'en', | |
'hin_Deva': 'hi', | |
'mar_Deva': 'mr' | |
}, | |
"GRAMMAR_PROMPTS": { | |
'en': "Fix grammar and improve fluency: ", | |
'hi': "व्याकरण और प्रवाह सुधारें: ", | |
'mr': "व्याकरण आणि प्रवाह सुधारा: " | |
} | |
} | |
class DocumentProcessor: | |
"""Handles document processing and text extraction""" | |
def extract_text_from_file(uploaded_file) -> str: | |
file_extension = Path(uploaded_file.name).suffix.lower() | |
extractors = { | |
'.pdf': DocumentProcessor._extract_from_pdf, | |
'.docx': DocumentProcessor._extract_from_docx, | |
'.txt': lambda f: f.getvalue().decode('utf-8') | |
} | |
if file_extension not in extractors: | |
raise ValueError(f"Unsupported file format: {file_extension}") | |
return extractors[file_extension](uploaded_file) | |
def _extract_from_pdf(file) -> str: | |
pdf_reader = PyPDF2.PdfReader(file) | |
return "\n".join(page.extract_text() for page in pdf_reader.pages).strip() | |
def _extract_from_docx(file) -> str: | |
doc = docx.Document(file) | |
return "\n".join(paragraph.text for paragraph in doc.paragraphs).strip() | |
class TextBatcher: | |
"""Handles text batching with improved sentence boundary detection""" | |
def batch_process_text(text: str, max_length: int = CONFIG["MAX_BATCH_LENGTH"]) -> List[str]: | |
sentences = TextBatcher._split_into_sentences(text) | |
batches = [] | |
current_batch = [] | |
current_length = 0 | |
for sentence in sentences: | |
sentence_length = len(sentence) | |
if current_length + sentence_length > max_length: | |
if current_batch: | |
batches.append(" ".join(current_batch)) | |
current_batch = [sentence] | |
current_length = sentence_length | |
else: | |
current_batch.append(sentence) | |
current_length += sentence_length | |
if current_batch: | |
batches.append(" ".join(current_batch)) | |
return batches | |
def _split_into_sentences(text: str) -> List[str]: | |
"""Split text into sentences with improved boundary detection""" | |
# Basic sentence boundary detection | |
delimiters = ['. ', '! ', '? ', '।', '॥', '\n'] | |
sentences = [] | |
current = text | |
for delimiter in delimiters: | |
parts = current.split(delimiter) | |
current = parts[0] | |
for part in parts[1:]: | |
if len(current.strip()) > 0: | |
sentences.append(current.strip() + delimiter.strip()) | |
current = part | |
if len(current.strip()) > 0: | |
sentences.append(current.strip()) | |
return sentences | |
class ModelManager: | |
"""Manages loading and caching of AI models""" | |
def load_models(): | |
try: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load models with improved error handling | |
models = { | |
"gemma": ModelManager._load_gemma_model(), | |
"nllb": ModelManager._load_nllb_model(), | |
"mt5": ModelManager._load_mt5_model() | |
} | |
# Move models to appropriate device | |
if not torch.cuda.is_available(): | |
for model_tuple in models.values(): | |
model_tuple[1].to(device) | |
return models | |
except Exception as e: | |
st.error(f"Error loading models: {str(e)}") | |
st.error(f"Python version: {sys.version}") | |
st.error(f"PyTorch version: {torch.__version__}") | |
raise e | |
def _load_gemma_model(): | |
tokenizer = AutoTokenizer.from_pretrained( | |
"google/gemma-2b", | |
token=os.environ.get('HF_TOKEN'), | |
trust_remote_code=True | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
"google/gemma-2b", | |
token=os.environ.get('HF_TOKEN'), | |
torch_dtype=torch.float16, | |
device_map="auto" if torch.cuda.is_available() else None, | |
trust_remote_code=True | |
) | |
return (tokenizer, model) | |
def _load_nllb_model(): | |
tokenizer = AutoTokenizer.from_pretrained( | |
"facebook/nllb-200-distilled-600M", | |
token=os.environ.get('HF_TOKEN'), | |
use_fast=False, | |
trust_remote_code=True | |
) | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
"facebook/nllb-200-distilled-600M", | |
token=os.environ.get('HF_TOKEN'), | |
torch_dtype=torch.float16, | |
device_map="auto" if torch.cuda.is_available() else None, | |
trust_remote_code=True | |
) | |
return (tokenizer, model) | |
def _load_mt5_model(): | |
tokenizer = AutoTokenizer.from_pretrained( | |
"google/mt5-base", | |
token=os.environ.get('HF_TOKEN'), | |
trust_remote_code=True | |
) | |
model = MT5ForConditionalGeneration.from_pretrained( | |
"google/mt5-base", | |
token=os.environ.get('HF_TOKEN'), | |
torch_dtype=torch.float16, | |
device_map="auto" if torch.cuda.is_available() else None, | |
trust_remote_code=True | |
) | |
return (tokenizer, model) | |
class TranslationPipeline: | |
"""Manages the translation pipeline with context understanding""" | |
def __init__(self, models: Dict): | |
self.models = models | |
def process_text(self, text: str, source_lang: str, target_lang: str) -> str: | |
# Split text into manageable batches | |
batches = TextBatcher.batch_process_text(text) | |
final_results = [] | |
for batch in batches: | |
# Step 1: Context Understanding | |
context = self._understand_context(batch) | |
# Step 2: Context-aware Translation | |
translated = self._translate_with_context( | |
context, | |
source_lang, | |
target_lang | |
) | |
# Step 3: Grammar Correction | |
corrected = self._correct_grammar( | |
translated, | |
target_lang | |
) | |
final_results.append(corrected) | |
return " ".join(final_results) | |
def _understand_context(self, text: str) -> str: | |
"""Enhanced context understanding using Gemma model""" | |
tokenizer, model = self.models["gemma"] | |
prompt = f"""Analyze and provide context for translation: | |
Text: {text} | |
Key points to consider: | |
- Main topic and subject matter | |
- Cultural context and nuances | |
- Technical terminology if any | |
- Tone and style of writing | |
Provide a clear and concise interpretation that maintains: | |
1. Original meaning | |
2. Cultural context | |
3. Technical accuracy | |
4. Tone and style""" | |
inputs = tokenizer(prompt, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True) | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
outputs = model.generate( | |
**inputs, | |
max_length=CONFIG["MAX_BATCH_LENGTH"], | |
do_sample=True, | |
temperature=CONFIG["CONTEXT_TEMPERATURE"], | |
pad_token_id=tokenizer.eos_token_id, | |
num_return_sequences=1 | |
) | |
context = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return context.replace(prompt, "").strip() | |
def _translate_with_context(self, text: str, source_lang: str, target_lang: str) -> str: | |
"""Enhanced translation using NLLB model with context awareness""" | |
tokenizer, model = self.models["nllb"] | |
source_lang_token = f"___{source_lang}___" | |
target_lang_token = f"___{target_lang}___" | |
inputs = tokenizer(text, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True) | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
target_lang_id = tokenizer.convert_tokens_to_ids(target_lang_token) | |
outputs = model.generate( | |
**inputs, | |
forced_bos_token_id=target_lang_id, | |
max_length=CONFIG["MAX_BATCH_LENGTH"], | |
do_sample=True, | |
temperature=CONFIG["TRANSLATION_TEMPERATURE"], | |
num_beams=CONFIG["NUM_BEAMS"], | |
num_return_sequences=1, | |
length_penalty=1.0, | |
repetition_penalty=1.2 | |
) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
def _correct_grammar(self, text: str, target_lang: str) -> str: | |
"""Enhanced grammar correction using MT5 model""" | |
tokenizer, model = self.models["mt5"] | |
lang_code = CONFIG["MT5_LANG_CODES"][target_lang] | |
prompt = CONFIG["GRAMMAR_PROMPTS"][lang_code] | |
input_text = f"{prompt}{text}" | |
inputs = tokenizer(input_text, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True) | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
outputs = model.generate( | |
**inputs, | |
max_length=CONFIG["MAX_BATCH_LENGTH"], | |
num_beams=CONFIG["NUM_BEAMS"], | |
length_penalty=1.0, | |
early_stopping=True, | |
no_repeat_ngram_size=2, | |
do_sample=False | |
) | |
corrected = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
for prefix in CONFIG["GRAMMAR_PROMPTS"].values(): | |
corrected = corrected.replace(prefix, "") | |
return corrected.strip() | |
class DocumentExporter: | |
"""Handles document export operations""" | |
def save_as_docx(text: str) -> io.BytesIO: | |
doc = docx.Document() | |
doc.add_paragraph(text) | |
buffer = io.BytesIO() | |
doc.save(buffer) | |
buffer.seek(0) | |
return buffer | |
def save_as_text(text: str) -> io.BytesIO: | |
buffer = io.BytesIO() | |
buffer.write(text.encode()) | |
buffer.seek(0) | |
return buffer | |
def main(): | |
st.title("🌐 Enhanced Document Translation App") | |
# Check for HF_TOKEN | |
if not os.environ.get('HF_TOKEN'): | |
st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.") | |
st.stop() | |
# Display system info | |
st.sidebar.markdown(f""" | |
### System Information | |
**Current UTC Time:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')} | |
**User:** {os.environ.get('USER', 'unknown')} | |
""") | |
# Load models | |
with st.spinner("Loading models... This may take a few minutes."): | |
try: | |
models = ModelManager.load_models() | |
pipeline = TranslationPipeline(models) | |
except Exception as e: | |
st.error(f"Error initializing translation pipeline: {str(e)}") | |
return | |
# File upload | |
uploaded_file = st.file_uploader( | |
"Upload your document (PDF, DOCX, or TXT)", | |
type=['pdf', 'docx', 'txt'] | |
) | |
# Language selection | |
col1, col2 = st.columns(2) | |
with col1: | |
source_language = st.selectbox( | |
"Source Language", | |
options=list(CONFIG["SUPPORTED_LANGUAGES"].keys()), | |
index=0 | |
) | |
with col2: | |
target_language = st.selectbox( | |
"Target Language", | |
options=list(CONFIG["SUPPORTED_LANGUAGES"].keys()), | |
index=1 | |
) | |
if uploaded_file and st.button("Translate", type="primary"): | |
try: | |
progress_bar = st.progress(0) | |
status_text = st.empty() | |
# Process document | |
status_text.text("Extracting text from document...") | |
text = DocumentProcessor.extract_text_from_file(uploaded_file) | |
progress_bar.progress(20) | |
# Perform translation | |
status_text.text("Translating document with context understanding...") | |
final_text = pipeline.process_text( | |
text, | |
CONFIG["SUPPORTED_LANGUAGES"][source_language], | |
CONFIG["SUPPORTED_LANGUAGES"][target_language] | |
) | |
progress_bar.progress(90) | |
# Display result | |
st.markdown("### Translation Result") | |
st.text_area( | |
label="Translated Text", | |
value=final_text, | |
height=200, | |
key="translation_result" | |
) | |
# Download options | |
st.markdown("### Download Options") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.download_button( | |
label="Download as TXT", | |
data=DocumentExporter.save_as_text(final_text), | |
file_name="translated_document.txt", | |
mime="text/plain" | |
) | |
with col2: | |
st.download_button( | |
label="Download as DOCX", | |
data=DocumentExporter.save_as_docx(final_text), | |
file_name="translated_document.docx", | |
mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document" | |
) | |
status_text.text("Translation completed successfully!") | |
progress_bar.progress(100) | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
if __name__ == "__main__": | |
main() |