try / app.py
gauravchand11's picture
Update app.py
5e3207d verified
raw
history blame
14.9 kB
import streamlit as st
import PyPDF2
import docx
import io
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, MT5ForConditionalGeneration
import torch
from pathlib import Path
import tempfile
from typing import Union, Tuple, List, Dict
import os
import sys
from datetime import datetime, timezone
import warnings
import json
# Filter warnings
warnings.filterwarnings('ignore', category=UserWarning)
# Page config
st.set_page_config(
page_title="Enhanced Document Translation App",
page_icon="🌐",
layout="wide"
)
# Constants and Configurations
CONFIG = {
"MAX_BATCH_LENGTH": 512,
"MIN_BATCH_LENGTH": 50,
"TRANSLATION_TEMPERATURE": 0.7,
"CONTEXT_TEMPERATURE": 0.3,
"NUM_BEAMS": 5,
"SUPPORTED_LANGUAGES": {
'English': 'eng_Latn',
'Hindi': 'hin_Deva',
'Marathi': 'mar_Deva'
},
"MT5_LANG_CODES": {
'eng_Latn': 'en',
'hin_Deva': 'hi',
'mar_Deva': 'mr'
},
"GRAMMAR_PROMPTS": {
'en': "Fix grammar and improve fluency: ",
'hi': "व्याकरण और प्रवाह सुधारें: ",
'mr': "व्याकरण आणि प्रवाह सुधारा: "
}
}
class DocumentProcessor:
"""Handles document processing and text extraction"""
@staticmethod
def extract_text_from_file(uploaded_file) -> str:
file_extension = Path(uploaded_file.name).suffix.lower()
extractors = {
'.pdf': DocumentProcessor._extract_from_pdf,
'.docx': DocumentProcessor._extract_from_docx,
'.txt': lambda f: f.getvalue().decode('utf-8')
}
if file_extension not in extractors:
raise ValueError(f"Unsupported file format: {file_extension}")
return extractors[file_extension](uploaded_file)
@staticmethod
def _extract_from_pdf(file) -> str:
pdf_reader = PyPDF2.PdfReader(file)
return "\n".join(page.extract_text() for page in pdf_reader.pages).strip()
@staticmethod
def _extract_from_docx(file) -> str:
doc = docx.Document(file)
return "\n".join(paragraph.text for paragraph in doc.paragraphs).strip()
class TextBatcher:
"""Handles text batching with improved sentence boundary detection"""
@staticmethod
def batch_process_text(text: str, max_length: int = CONFIG["MAX_BATCH_LENGTH"]) -> List[str]:
sentences = TextBatcher._split_into_sentences(text)
batches = []
current_batch = []
current_length = 0
for sentence in sentences:
sentence_length = len(sentence)
if current_length + sentence_length > max_length:
if current_batch:
batches.append(" ".join(current_batch))
current_batch = [sentence]
current_length = sentence_length
else:
current_batch.append(sentence)
current_length += sentence_length
if current_batch:
batches.append(" ".join(current_batch))
return batches
@staticmethod
def _split_into_sentences(text: str) -> List[str]:
"""Split text into sentences with improved boundary detection"""
# Basic sentence boundary detection
delimiters = ['. ', '! ', '? ', '।', '॥', '\n']
sentences = []
current = text
for delimiter in delimiters:
parts = current.split(delimiter)
current = parts[0]
for part in parts[1:]:
if len(current.strip()) > 0:
sentences.append(current.strip() + delimiter.strip())
current = part
if len(current.strip()) > 0:
sentences.append(current.strip())
return sentences
class ModelManager:
"""Manages loading and caching of AI models"""
@st.cache_resource
def load_models():
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load models with improved error handling
models = {
"gemma": ModelManager._load_gemma_model(),
"nllb": ModelManager._load_nllb_model(),
"mt5": ModelManager._load_mt5_model()
}
# Move models to appropriate device
if not torch.cuda.is_available():
for model_tuple in models.values():
model_tuple[1].to(device)
return models
except Exception as e:
st.error(f"Error loading models: {str(e)}")
st.error(f"Python version: {sys.version}")
st.error(f"PyTorch version: {torch.__version__}")
raise e
@staticmethod
def _load_gemma_model():
tokenizer = AutoTokenizer.from_pretrained(
"google/gemma-2b",
token=os.environ.get('HF_TOKEN'),
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
token=os.environ.get('HF_TOKEN'),
torch_dtype=torch.float16,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
return (tokenizer, model)
@staticmethod
def _load_nllb_model():
tokenizer = AutoTokenizer.from_pretrained(
"facebook/nllb-200-distilled-600M",
token=os.environ.get('HF_TOKEN'),
use_fast=False,
trust_remote_code=True
)
model = AutoModelForSeq2SeqLM.from_pretrained(
"facebook/nllb-200-distilled-600M",
token=os.environ.get('HF_TOKEN'),
torch_dtype=torch.float16,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
return (tokenizer, model)
@staticmethod
def _load_mt5_model():
tokenizer = AutoTokenizer.from_pretrained(
"google/mt5-base",
token=os.environ.get('HF_TOKEN'),
trust_remote_code=True
)
model = MT5ForConditionalGeneration.from_pretrained(
"google/mt5-base",
token=os.environ.get('HF_TOKEN'),
torch_dtype=torch.float16,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
return (tokenizer, model)
class TranslationPipeline:
"""Manages the translation pipeline with context understanding"""
def __init__(self, models: Dict):
self.models = models
@torch.no_grad()
def process_text(self, text: str, source_lang: str, target_lang: str) -> str:
# Split text into manageable batches
batches = TextBatcher.batch_process_text(text)
final_results = []
for batch in batches:
# Step 1: Context Understanding
context = self._understand_context(batch)
# Step 2: Context-aware Translation
translated = self._translate_with_context(
context,
source_lang,
target_lang
)
# Step 3: Grammar Correction
corrected = self._correct_grammar(
translated,
target_lang
)
final_results.append(corrected)
return " ".join(final_results)
def _understand_context(self, text: str) -> str:
"""Enhanced context understanding using Gemma model"""
tokenizer, model = self.models["gemma"]
prompt = f"""Analyze and provide context for translation:
Text: {text}
Key points to consider:
- Main topic and subject matter
- Cultural context and nuances
- Technical terminology if any
- Tone and style of writing
Provide a clear and concise interpretation that maintains:
1. Original meaning
2. Cultural context
3. Technical accuracy
4. Tone and style"""
inputs = tokenizer(prompt, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
outputs = model.generate(
**inputs,
max_length=CONFIG["MAX_BATCH_LENGTH"],
do_sample=True,
temperature=CONFIG["CONTEXT_TEMPERATURE"],
pad_token_id=tokenizer.eos_token_id,
num_return_sequences=1
)
context = tokenizer.decode(outputs[0], skip_special_tokens=True)
return context.replace(prompt, "").strip()
def _translate_with_context(self, text: str, source_lang: str, target_lang: str) -> str:
"""Enhanced translation using NLLB model with context awareness"""
tokenizer, model = self.models["nllb"]
source_lang_token = f"___{source_lang}___"
target_lang_token = f"___{target_lang}___"
inputs = tokenizer(text, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
target_lang_id = tokenizer.convert_tokens_to_ids(target_lang_token)
outputs = model.generate(
**inputs,
forced_bos_token_id=target_lang_id,
max_length=CONFIG["MAX_BATCH_LENGTH"],
do_sample=True,
temperature=CONFIG["TRANSLATION_TEMPERATURE"],
num_beams=CONFIG["NUM_BEAMS"],
num_return_sequences=1,
length_penalty=1.0,
repetition_penalty=1.2
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def _correct_grammar(self, text: str, target_lang: str) -> str:
"""Enhanced grammar correction using MT5 model"""
tokenizer, model = self.models["mt5"]
lang_code = CONFIG["MT5_LANG_CODES"][target_lang]
prompt = CONFIG["GRAMMAR_PROMPTS"][lang_code]
input_text = f"{prompt}{text}"
inputs = tokenizer(input_text, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
outputs = model.generate(
**inputs,
max_length=CONFIG["MAX_BATCH_LENGTH"],
num_beams=CONFIG["NUM_BEAMS"],
length_penalty=1.0,
early_stopping=True,
no_repeat_ngram_size=2,
do_sample=False
)
corrected = tokenizer.decode(outputs[0], skip_special_tokens=True)
for prefix in CONFIG["GRAMMAR_PROMPTS"].values():
corrected = corrected.replace(prefix, "")
return corrected.strip()
class DocumentExporter:
"""Handles document export operations"""
@staticmethod
def save_as_docx(text: str) -> io.BytesIO:
doc = docx.Document()
doc.add_paragraph(text)
buffer = io.BytesIO()
doc.save(buffer)
buffer.seek(0)
return buffer
@staticmethod
def save_as_text(text: str) -> io.BytesIO:
buffer = io.BytesIO()
buffer.write(text.encode())
buffer.seek(0)
return buffer
def main():
st.title("🌐 Enhanced Document Translation App")
# Check for HF_TOKEN
if not os.environ.get('HF_TOKEN'):
st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.")
st.stop()
# Display system info
st.sidebar.markdown(f"""
### System Information
**Current UTC Time:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}
**User:** {os.environ.get('USER', 'unknown')}
""")
# Load models
with st.spinner("Loading models... This may take a few minutes."):
try:
models = ModelManager.load_models()
pipeline = TranslationPipeline(models)
except Exception as e:
st.error(f"Error initializing translation pipeline: {str(e)}")
return
# File upload
uploaded_file = st.file_uploader(
"Upload your document (PDF, DOCX, or TXT)",
type=['pdf', 'docx', 'txt']
)
# Language selection
col1, col2 = st.columns(2)
with col1:
source_language = st.selectbox(
"Source Language",
options=list(CONFIG["SUPPORTED_LANGUAGES"].keys()),
index=0
)
with col2:
target_language = st.selectbox(
"Target Language",
options=list(CONFIG["SUPPORTED_LANGUAGES"].keys()),
index=1
)
if uploaded_file and st.button("Translate", type="primary"):
try:
progress_bar = st.progress(0)
status_text = st.empty()
# Process document
status_text.text("Extracting text from document...")
text = DocumentProcessor.extract_text_from_file(uploaded_file)
progress_bar.progress(20)
# Perform translation
status_text.text("Translating document with context understanding...")
final_text = pipeline.process_text(
text,
CONFIG["SUPPORTED_LANGUAGES"][source_language],
CONFIG["SUPPORTED_LANGUAGES"][target_language]
)
progress_bar.progress(90)
# Display result
st.markdown("### Translation Result")
st.text_area(
label="Translated Text",
value=final_text,
height=200,
key="translation_result"
)
# Download options
st.markdown("### Download Options")
col1, col2 = st.columns(2)
with col1:
st.download_button(
label="Download as TXT",
data=DocumentExporter.save_as_text(final_text),
file_name="translated_document.txt",
mime="text/plain"
)
with col2:
st.download_button(
label="Download as DOCX",
data=DocumentExporter.save_as_docx(final_text),
file_name="translated_document.docx",
mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
)
status_text.text("Translation completed successfully!")
progress_bar.progress(100)
except Exception as e:
st.error(f"An error occurred: {str(e)}")
if __name__ == "__main__":
main()