try / app.py
gauravchand11's picture
Update app.py
1337d1b verified
raw
history blame
14.3 kB
import streamlit as st
import PyPDF2
import docx
import io
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, MT5ForConditionalGeneration
import torch
from pathlib import Path
from typing import Union, Tuple, List, Dict
import os
import sys
from datetime import datetime, timezone
import warnings
import re
# Filter warnings
warnings.filterwarnings('ignore', category=UserWarning)
# Page config
st.set_page_config(
page_title="Enhanced Document Translation App",
page_icon="🌐",
layout="wide"
)
# Constants and Configurations
CONFIG = {
"MAX_BATCH_LENGTH": 512,
"MIN_BATCH_LENGTH": 50,
"TRANSLATION_TEMPERATURE": 0.7,
"CONTEXT_TEMPERATURE": 0.3,
"NUM_BEAMS": 5,
"SUPPORTED_LANGUAGES": {
'English': 'eng_Latn',
'Hindi': 'hin_Deva',
'Marathi': 'mar_Deva'
},
"MT5_LANG_CODES": {
'eng_Latn': 'en',
'hin_Deva': 'hi',
'mar_Deva': 'mr'
},
"GRAMMAR_PROMPTS": {
'en': "Fix grammar and improve fluency: ",
'hi': "व्याकरण और प्रवाह सुधारें: ",
'mr': "व्याकरण आणि प्रवाह सुधारा: "
}
}
class DocumentProcessor:
"""Handles document processing and text extraction"""
@staticmethod
def extract_text_from_file(uploaded_file) -> str:
file_extension = Path(uploaded_file.name).suffix.lower()
extractors = {
'.pdf': DocumentProcessor._extract_from_pdf,
'.docx': DocumentProcessor._extract_from_docx,
'.txt': lambda f: f.getvalue().decode('utf-8')
}
if file_extension not in extractors:
raise ValueError(f"Unsupported file format: {file_extension}")
return extractors[file_extension](uploaded_file)
@staticmethod
def _extract_from_pdf(file) -> str:
pdf_reader = PyPDF2.PdfReader(file)
return "\n".join(page.extract_text() for page in pdf_reader.pages).strip()
@staticmethod
def _extract_from_docx(file) -> str:
doc = docx.Document(file)
return "\n".join(paragraph.text for paragraph in doc.paragraphs).strip()
class TextBatcher:
"""Handles text batching with improved sentence boundary detection"""
@staticmethod
def batch_process_text(text: str, max_length: int = CONFIG["MAX_BATCH_LENGTH"]) -> List[str]:
sentences = TextBatcher._split_into_sentences(text)
batches = []
current_batch = []
current_length = 0
for sentence in sentences:
sentence_length = len(sentence)
if current_length + sentence_length > max_length:
if current_batch:
batches.append(" ".join(current_batch))
current_batch = [sentence]
current_length = sentence_length
else:
current_batch.append(sentence)
current_length += sentence_length
if current_batch:
batches.append(" ".join(current_batch))
return batches
@staticmethod
def _split_into_sentences(text: str) -> List[str]:
"""Split text into sentences with improved boundary detection"""
delimiters = ['. ', '! ', '? ', '।', '॥', '\n']
sentences = []
current = text
for delimiter in delimiters:
parts = current.split(delimiter)
current = parts[0]
for part in parts[1:]:
if len(current.strip()) > 0:
sentences.append(current.strip() + delimiter.strip())
current = part
if len(current.strip()) > 0:
sentences.append(current.strip())
return sentences
class ModelManager:
"""Manages loading and caching of AI models"""
@st.cache_resource
def load_models():
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
models = {
"gemma": ModelManager._load_gemma_model(),
"nllb": ModelManager._load_nllb_model(),
"mt5": ModelManager._load_mt5_model()
}
if not torch.cuda.is_available():
for model_tuple in models.values():
model_tuple[1].to(device)
return models
except Exception as e:
st.error(f"Error loading models: {str(e)}")
st.error(f"Python version: {sys.version}")
st.error(f"PyTorch version: {torch.__version__}")
raise e
@staticmethod
def _load_gemma_model():
tokenizer = AutoTokenizer.from_pretrained(
"google/gemma-2b",
token=os.environ.get('HF_TOKEN'),
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
token=os.environ.get('HF_TOKEN'),
torch_dtype=torch.float16,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
return (tokenizer, model)
@staticmethod
def _load_nllb_model():
tokenizer = AutoTokenizer.from_pretrained(
"facebook/nllb-200-distilled-600M",
token=os.environ.get('HF_TOKEN'),
use_fast=False,
trust_remote_code=True
)
model = AutoModelForSeq2SeqLM.from_pretrained(
"facebook/nllb-200-distilled-600M",
token=os.environ.get('HF_TOKEN'),
torch_dtype=torch.float16,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
return (tokenizer, model)
@staticmethod
def _load_mt5_model():
tokenizer = AutoTokenizer.from_pretrained(
"google/mt5-base",
token=os.environ.get('HF_TOKEN'),
trust_remote_code=True
)
model = MT5ForConditionalGeneration.from_pretrained(
"google/mt5-base",
token=os.environ.get('HF_TOKEN'),
torch_dtype=torch.float16,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
return (tokenizer, model)
class TranslationPipeline:
"""Manages the translation pipeline with context understanding"""
def __init__(self, models: Dict):
self.models = models
@torch.no_grad()
def process_text(self, text: str, source_lang: str, target_lang: str) -> str:
batches = TextBatcher.batch_process_text(text)
final_results = []
for batch in batches:
# Step 1: Context Understanding
context = self._understand_context(batch)
# Step 2: Context-aware Translation
translated = self._translate_with_context(
context,
source_lang,
target_lang
)
# Step 3: Grammar Correction
corrected = self._correct_grammar(
translated,
target_lang
)
final_results.append(corrected)
# Clean up the final text
final_text = " ".join(final_results)
return self._clean_text(final_text)
def _understand_context(self, text: str) -> str:
tokenizer, model = self.models["gemma"]
prompt = f"""Analyze and provide context for translation:
Text: {text}
Key points to consider:
- Main topic and subject matter
- Cultural context and nuances
- Technical terminology if any
- Tone and style of writing
Provide a clear and concise interpretation that maintains:
1. Original meaning
2. Cultural context
3. Technical accuracy
4. Tone and style"""
inputs = tokenizer(prompt, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
outputs = model.generate(
**inputs,
max_length=CONFIG["MAX_BATCH_LENGTH"],
do_sample=True,
temperature=CONFIG["CONTEXT_TEMPERATURE"],
pad_token_id=tokenizer.eos_token_id,
num_return_sequences=1
)
context = tokenizer.decode(outputs[0], skip_special_tokens=True)
return context.replace(prompt, "").strip()
def _translate_with_context(self, text: str, source_lang: str, target_lang: str) -> str:
tokenizer, model = self.models["nllb"]
target_lang_token = f"___{target_lang}___"
inputs = tokenizer(text, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
target_lang_id = tokenizer.convert_tokens_to_ids(target_lang_token)
outputs = model.generate(
**inputs,
forced_bos_token_id=target_lang_id,
max_length=CONFIG["MAX_BATCH_LENGTH"],
do_sample=True,
temperature=CONFIG["TRANSLATION_TEMPERATURE"],
num_beams=CONFIG["NUM_BEAMS"],
num_return_sequences=1,
length_penalty=1.0,
repetition_penalty=1.2
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def _correct_grammar(self, text: str, target_lang: str) -> str:
tokenizer, model = self.models["mt5"]
lang_code = CONFIG["MT5_LANG_CODES"][target_lang]
prompt = CONFIG["GRAMMAR_PROMPTS"][lang_code]
input_text = f"{prompt}{text}"
inputs = tokenizer(input_text, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
outputs = model.generate(
**inputs,
max_length=CONFIG["MAX_BATCH_LENGTH"],
num_beams=CONFIG["NUM_BEAMS"],
length_penalty=1.0,
early_stopping=True,
no_repeat_ngram_size=2,
do_sample=False
)
corrected = tokenizer.decode(outputs[0], skip_special_tokens=True)
return self._clean_text(corrected.replace(prompt, "").strip())
def _clean_text(self, text: str) -> str:
"""Clean up the text by removing special tokens and fixing formatting"""
# Remove MT5 special tokens
text = re.sub(r'<extra_id_\d+>', '', text)
# Fix multiple spaces
text = re.sub(r'\s+', ' ', text)
# Fix punctuation spacing
text = re.sub(r'\s+([.,!?।॥])', r'\1', text)
return text.strip()
class DocumentExporter:
"""Handles document export operations"""
@staticmethod
def save_as_docx(text: str) -> io.BytesIO:
doc = docx.Document()
doc.add_paragraph(text)
buffer = io.BytesIO()
doc.save(buffer)
buffer.seek(0)
return buffer
def main():
st.title("🌐 Enhanced Document Translation App")
# Display system info
st.sidebar.markdown(f"""
### System Information
**Current UTC Time:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}
**User:** {os.environ.get('USER', 'gauravchand')}
""")
# Check for HF_TOKEN
if not os.environ.get('HF_TOKEN'):
st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.")
st.stop()
# Load models
with st.spinner("Loading models... This may take a few minutes."):
try:
models = ModelManager.load_models()
pipeline = TranslationPipeline(models)
except Exception as e:
st.error(f"Error initializing translation pipeline: {str(e)}")
return
# File upload
uploaded_file = st.file_uploader(
"Upload your document (PDF, DOCX, or TXT)",
type=['pdf', 'docx', 'txt']
)
# Language selection
col1, col2 = st.columns(2)
with col1:
source_language = st.selectbox(
"Source Language",
options=list(CONFIG["SUPPORTED_LANGUAGES"].keys()),
index=0
)
with col2:
target_language = st.selectbox(
"Target Language",
options=list(CONFIG["SUPPORTED_LANGUAGES"].keys()),
index=1
)
if uploaded_file and st.button("Translate", type="primary"):
try:
progress_bar = st.progress(0)
status_text = st.empty()
# Process document
status_text.text("Extracting text from document...")
text = DocumentProcessor.extract_text_from_file(uploaded_file)
progress_bar.progress(20)
# Perform translation
status_text.text("Translating document with context understanding...")
final_text = pipeline.process_text(
text,
CONFIG["SUPPORTED_LANGUAGES"][source_language],
CONFIG["SUPPORTED_LANGUAGES"][target_language]
)
progress_bar.progress(90)
# Display result
st.markdown("### Translation Result")
st.text_area(
label="Translated Text",
value=final_text,
height=200,
key="translation_result"
)
# Download option
st.markdown("### Download Option")
st.download_button(
label="Download as DOCX",
data=DocumentExporter.save_as_docx(final_text),
file_name="translated_document.docx",
mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
)
status_text.text("Translation completed successfully!")
progress_bar.progress(100)
except Exception as e:
st.error(f"An error occurred: {str(e)}")
if __name__ == "__main__":
main()