try / app.py
gauravchand11's picture
Update app.py
bcda6d5 verified
raw
history blame
12.2 kB
import streamlit as st
import PyPDF2
import docx
import io
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, MT5ForConditionalGeneration
import torch
from pathlib import Path
import tempfile
from typing import Union, Tuple
import os
import sys
from datetime import datetime, timezone
import warnings
# Filter warnings
warnings.filterwarnings('ignore', category=UserWarning)
# Page config
st.set_page_config(
page_title="Document Translation App",
page_icon="🌐",
layout="wide"
)
# Display system info
st.sidebar.markdown(f"""
### System Information
**Current UTC Time:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}
**User:** {os.environ.get('USER', 'gauravchand')}
""")
# Get Hugging Face token
HF_TOKEN = os.environ.get('HF_TOKEN')
if not HF_TOKEN:
st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.")
st.stop()
# Language configurations
SUPPORTED_LANGUAGES = {
'English': 'eng_Latn',
'Hindi': 'hin_Deva',
'Marathi': 'mar_Deva'
}
MT5_LANG_CODES = {
'eng_Latn': 'en',
'hin_Deva': 'hi',
'mar_Deva': 'mr'
}
def get_nllb_lang_token(lang_code: str) -> str:
"""Get the correct token format for NLLB model."""
return f"___{lang_code}___"
def extract_text_from_file(uploaded_file) -> str:
"""Extract text content from uploaded file based on its type."""
file_extension = Path(uploaded_file.name).suffix.lower()
if file_extension == '.pdf':
return extract_from_pdf(uploaded_file)
elif file_extension == '.docx':
return extract_from_docx(uploaded_file)
elif file_extension == '.txt':
return uploaded_file.getvalue().decode('utf-8')
else:
raise ValueError(f"Unsupported file format: {file_extension}")
def extract_from_pdf(file) -> str:
"""Extract text from PDF file."""
pdf_reader = PyPDF2.PdfReader(file)
text = ""
for page in pdf_reader.pages:
text += page.extract_text() + "\n"
return text.strip()
def extract_from_docx(file) -> str:
"""Extract text from DOCX file."""
doc = docx.Document(file)
text = ""
for paragraph in doc.paragraphs:
text += paragraph.text + "\n"
return text.strip()
def batch_process_text(text: str, max_length: int = 512) -> list:
"""Split text into batches for processing."""
words = text.split()
batches = []
current_batch = []
current_length = 0
for word in words:
if current_length + len(word) + 1 > max_length:
batches.append(" ".join(current_batch))
current_batch = [word]
current_length = len(word)
else:
current_batch.append(word)
current_length += len(word) + 1
if current_batch:
batches.append(" ".join(current_batch))
return batches
@st.cache_resource
def load_models():
"""Load and cache the translation and context interpretation models."""
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Gemma model
gemma_tokenizer = AutoTokenizer.from_pretrained(
"google/gemma-2b",
token=HF_TOKEN,
trust_remote_code=True
)
gemma_model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
token=HF_TOKEN,
torch_dtype=torch.float16,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
# Load NLLB model
nllb_tokenizer = AutoTokenizer.from_pretrained(
"facebook/nllb-200-distilled-600M",
token=HF_TOKEN,
use_fast=False,
trust_remote_code=True
)
nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
"facebook/nllb-200-distilled-600M",
token=HF_TOKEN,
torch_dtype=torch.float16,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
# Load MT5 model
mt5_tokenizer = AutoTokenizer.from_pretrained(
"google/mt5-base",
token=HF_TOKEN,
trust_remote_code=True
)
mt5_model = MT5ForConditionalGeneration.from_pretrained(
"google/mt5-base",
token=HF_TOKEN,
torch_dtype=torch.float16,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
if not torch.cuda.is_available():
gemma_model = gemma_model.to(device)
nllb_model = nllb_model.to(device)
mt5_model = mt5_model.to(device)
return (gemma_tokenizer, gemma_model), (nllb_tokenizer, nllb_model), (mt5_tokenizer, mt5_model)
except Exception as e:
st.error(f"Error loading models: {str(e)}")
st.error(f"Python version: {sys.version}")
st.error(f"PyTorch version: {torch.__version__}")
raise e
@torch.no_grad()
def interpret_context(text: str, gemma_tuple: Tuple) -> str:
"""Use Gemma model to interpret context and understand regional nuances."""
tokenizer, model = gemma_tuple
batches = batch_process_text(text)
interpreted_batches = []
for batch in batches:
prompt = f"""Analyze and maintain the core meaning of this text: {batch}"""
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
outputs = model.generate(
**inputs,
max_length=512,
do_sample=True,
temperature=0.3,
pad_token_id=tokenizer.eos_token_id,
num_return_sequences=1
)
interpreted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
interpreted_text = interpreted_text.replace(prompt, "").strip()
interpreted_batches.append(interpreted_text)
return " ".join(interpreted_batches)
@torch.no_grad()
def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tuple) -> str:
"""Translate text using NLLB model."""
tokenizer, model = nllb_tuple
batches = batch_process_text(text)
translated_batches = []
target_lang_token = get_nllb_lang_token(target_lang)
for batch in batches:
inputs = tokenizer(batch, return_tensors="pt", max_length=512, truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
target_lang_id = tokenizer.convert_tokens_to_ids(target_lang_token)
outputs = model.generate(
**inputs,
forced_bos_token_id=target_lang_id,
max_length=512,
do_sample=True,
temperature=0.7,
num_beams=5,
num_return_sequences=1
)
translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
translated_batches.append(translated_text)
return " ".join(translated_batches)
@torch.no_grad()
def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
"""Correct grammar using MT5 model."""
tokenizer, model = mt5_tuple
lang_code = MT5_LANG_CODES[target_lang]
prompts = {
'en': "Fix grammar: ",
'hi': "ΰ€΅ΰ₯ΰ€―ΰ€Ύΰ€•ΰ€°ΰ€£ ΰ€Έΰ₯ΰ€§ΰ€Ύΰ€°: ",
'mr': "ΰ€΅ΰ₯ΰ€―ΰ€Ύΰ€•ΰ€°ΰ€£ ΰ€Έΰ₯ΰ€§ΰ€Ύΰ€°: "
}
batches = batch_process_text(text)
corrected_batches = []
for batch in batches:
input_text = f"{prompts[lang_code]}{batch}"
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
outputs = model.generate(
**inputs,
max_length=512,
num_beams=5,
length_penalty=1.0,
early_stopping=True,
no_repeat_ngram_size=2,
do_sample=False
)
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
for prefix in prompts.values():
corrected_text = corrected_text.replace(prefix, "")
corrected_text = (corrected_text.replace("<extra_id_0>", "")
.replace("<extra_id_1>", "")
.strip())
corrected_batches.append(corrected_text)
return " ".join(corrected_batches)
def save_as_docx(text: str) -> io.BytesIO:
"""Save translated text as a DOCX file."""
doc = docx.Document()
doc.add_paragraph(text)
docx_buffer = io.BytesIO()
doc.save(docx_buffer)
docx_buffer.seek(0)
return docx_buffer
def main():
st.title("🌐 Document Translation App")
# Load models
with st.spinner("Loading models... This may take a few minutes."):
try:
gemma_tuple, nllb_tuple, mt5_tuple = load_models()
except Exception as e:
st.error(f"Error loading models: {str(e)}")
return
# File upload
uploaded_file = st.file_uploader(
"Upload your document (PDF, DOCX, or TXT)",
type=['pdf', 'docx', 'txt']
)
# Language selection
col1, col2 = st.columns(2)
with col1:
source_language = st.selectbox(
"Source Language",
options=list(SUPPORTED_LANGUAGES.keys()),
index=0
)
with col2:
target_language = st.selectbox(
"Target Language",
options=list(SUPPORTED_LANGUAGES.keys()),
index=1
)
if uploaded_file and st.button("Translate", type="primary"):
try:
progress_bar = st.progress(0)
# Process document
with st.spinner("Processing document..."):
text = extract_text_from_file(uploaded_file)
progress_bar.progress(25)
interpreted_text = interpret_context(text, gemma_tuple)
progress_bar.progress(50)
translated_text = translate_text(
interpreted_text,
SUPPORTED_LANGUAGES[source_language],
SUPPORTED_LANGUAGES[target_language],
nllb_tuple
)
progress_bar.progress(75)
final_text = correct_grammar(
translated_text,
SUPPORTED_LANGUAGES[target_language],
mt5_tuple
)
progress_bar.progress(90)
# Display result
st.markdown("### Translation Result")
st.text_area(
label="Translated Text",
value=final_text,
height=200,
key="translation_result"
)
# Download options
st.markdown("### Download Options")
col1, col2 = st.columns(2)
with col1:
# Text file download
text_buffer = io.BytesIO()
text_buffer.write(final_text.encode())
text_buffer.seek(0)
st.download_button(
label="Download as TXT",
data=text_buffer,
file_name="translated_document.txt",
mime="text/plain"
)
with col2:
# DOCX file download
docx_buffer = save_as_docx(final_text)
st.download_button(
label="Download as DOCX",
data=docx_buffer,
file_name="translated_document.docx",
mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
)
progress_bar.progress(100)
except Exception as e:
st.error(f"An error occurred: {str(e)}")
if __name__ == "__main__":
main()