Spaces:
Build error
Build error
import streamlit as st | |
import PyPDF2 | |
import docx | |
import io | |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, MT5ForConditionalGeneration | |
import torch | |
from pathlib import Path | |
import tempfile | |
from typing import Union, Tuple | |
import os | |
import sys | |
from datetime import datetime, timezone | |
import warnings | |
# Filter warnings | |
warnings.filterwarnings('ignore', category=UserWarning) | |
# Page config | |
st.set_page_config( | |
page_title="Document Translation App", | |
page_icon="π", | |
layout="wide" | |
) | |
# Display system info | |
st.sidebar.markdown(f""" | |
### System Information | |
**Current UTC Time:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')} | |
**User:** {os.environ.get('USER', 'gauravchand')} | |
""") | |
# Get Hugging Face token | |
HF_TOKEN = os.environ.get('HF_TOKEN') | |
if not HF_TOKEN: | |
st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.") | |
st.stop() | |
# Language configurations | |
SUPPORTED_LANGUAGES = { | |
'English': 'eng_Latn', | |
'Hindi': 'hin_Deva', | |
'Marathi': 'mar_Deva' | |
} | |
MT5_LANG_CODES = { | |
'eng_Latn': 'en', | |
'hin_Deva': 'hi', | |
'mar_Deva': 'mr' | |
} | |
def get_nllb_lang_token(lang_code: str) -> str: | |
"""Get the correct token format for NLLB model.""" | |
return f"___{lang_code}___" | |
def extract_text_from_file(uploaded_file) -> str: | |
"""Extract text content from uploaded file based on its type.""" | |
file_extension = Path(uploaded_file.name).suffix.lower() | |
if file_extension == '.pdf': | |
return extract_from_pdf(uploaded_file) | |
elif file_extension == '.docx': | |
return extract_from_docx(uploaded_file) | |
elif file_extension == '.txt': | |
return uploaded_file.getvalue().decode('utf-8') | |
else: | |
raise ValueError(f"Unsupported file format: {file_extension}") | |
def extract_from_pdf(file) -> str: | |
"""Extract text from PDF file.""" | |
pdf_reader = PyPDF2.PdfReader(file) | |
text = "" | |
for page in pdf_reader.pages: | |
text += page.extract_text() + "\n" | |
return text.strip() | |
def extract_from_docx(file) -> str: | |
"""Extract text from DOCX file.""" | |
doc = docx.Document(file) | |
text = "" | |
for paragraph in doc.paragraphs: | |
text += paragraph.text + "\n" | |
return text.strip() | |
def batch_process_text(text: str, max_length: int = 512) -> list: | |
"""Split text into batches for processing.""" | |
words = text.split() | |
batches = [] | |
current_batch = [] | |
current_length = 0 | |
for word in words: | |
if current_length + len(word) + 1 > max_length: | |
batches.append(" ".join(current_batch)) | |
current_batch = [word] | |
current_length = len(word) | |
else: | |
current_batch.append(word) | |
current_length += len(word) + 1 | |
if current_batch: | |
batches.append(" ".join(current_batch)) | |
return batches | |
def load_models(): | |
"""Load and cache the translation and context interpretation models.""" | |
try: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load Gemma model | |
gemma_tokenizer = AutoTokenizer.from_pretrained( | |
"google/gemma-2b", | |
token=HF_TOKEN, | |
trust_remote_code=True | |
) | |
gemma_model = AutoModelForCausalLM.from_pretrained( | |
"google/gemma-2b", | |
token=HF_TOKEN, | |
torch_dtype=torch.float16, | |
device_map="auto" if torch.cuda.is_available() else None, | |
trust_remote_code=True | |
) | |
# Load NLLB model | |
nllb_tokenizer = AutoTokenizer.from_pretrained( | |
"facebook/nllb-200-distilled-600M", | |
token=HF_TOKEN, | |
use_fast=False, | |
trust_remote_code=True | |
) | |
nllb_model = AutoModelForSeq2SeqLM.from_pretrained( | |
"facebook/nllb-200-distilled-600M", | |
token=HF_TOKEN, | |
torch_dtype=torch.float16, | |
device_map="auto" if torch.cuda.is_available() else None, | |
trust_remote_code=True | |
) | |
# Load MT5 model | |
mt5_tokenizer = AutoTokenizer.from_pretrained( | |
"google/mt5-base", | |
token=HF_TOKEN, | |
trust_remote_code=True | |
) | |
mt5_model = MT5ForConditionalGeneration.from_pretrained( | |
"google/mt5-base", | |
token=HF_TOKEN, | |
torch_dtype=torch.float16, | |
device_map="auto" if torch.cuda.is_available() else None, | |
trust_remote_code=True | |
) | |
if not torch.cuda.is_available(): | |
gemma_model = gemma_model.to(device) | |
nllb_model = nllb_model.to(device) | |
mt5_model = mt5_model.to(device) | |
return (gemma_tokenizer, gemma_model), (nllb_tokenizer, nllb_model), (mt5_tokenizer, mt5_model) | |
except Exception as e: | |
st.error(f"Error loading models: {str(e)}") | |
st.error(f"Python version: {sys.version}") | |
st.error(f"PyTorch version: {torch.__version__}") | |
raise e | |
def interpret_context(text: str, gemma_tuple: Tuple) -> str: | |
"""Use Gemma model to interpret context and understand regional nuances.""" | |
tokenizer, model = gemma_tuple | |
batches = batch_process_text(text) | |
interpreted_batches = [] | |
for batch in batches: | |
prompt = f"""Analyze and maintain the core meaning of this text: {batch}""" | |
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True) | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
outputs = model.generate( | |
**inputs, | |
max_length=512, | |
do_sample=True, | |
temperature=0.3, | |
pad_token_id=tokenizer.eos_token_id, | |
num_return_sequences=1 | |
) | |
interpreted_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
interpreted_text = interpreted_text.replace(prompt, "").strip() | |
interpreted_batches.append(interpreted_text) | |
return " ".join(interpreted_batches) | |
def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tuple) -> str: | |
"""Translate text using NLLB model.""" | |
tokenizer, model = nllb_tuple | |
batches = batch_process_text(text) | |
translated_batches = [] | |
target_lang_token = get_nllb_lang_token(target_lang) | |
for batch in batches: | |
inputs = tokenizer(batch, return_tensors="pt", max_length=512, truncation=True) | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
target_lang_id = tokenizer.convert_tokens_to_ids(target_lang_token) | |
outputs = model.generate( | |
**inputs, | |
forced_bos_token_id=target_lang_id, | |
max_length=512, | |
do_sample=True, | |
temperature=0.7, | |
num_beams=5, | |
num_return_sequences=1 | |
) | |
translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
translated_batches.append(translated_text) | |
return " ".join(translated_batches) | |
def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str: | |
"""Correct grammar using MT5 model.""" | |
tokenizer, model = mt5_tuple | |
lang_code = MT5_LANG_CODES[target_lang] | |
prompts = { | |
'en': "Fix grammar: ", | |
'hi': "ΰ€΅ΰ₯ΰ€―ΰ€Ύΰ€ΰ€°ΰ€£ ΰ€Έΰ₯ΰ€§ΰ€Ύΰ€°: ", | |
'mr': "ΰ€΅ΰ₯ΰ€―ΰ€Ύΰ€ΰ€°ΰ€£ ΰ€Έΰ₯ΰ€§ΰ€Ύΰ€°: " | |
} | |
batches = batch_process_text(text) | |
corrected_batches = [] | |
for batch in batches: | |
input_text = f"{prompts[lang_code]}{batch}" | |
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
outputs = model.generate( | |
**inputs, | |
max_length=512, | |
num_beams=5, | |
length_penalty=1.0, | |
early_stopping=True, | |
no_repeat_ngram_size=2, | |
do_sample=False | |
) | |
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
for prefix in prompts.values(): | |
corrected_text = corrected_text.replace(prefix, "") | |
corrected_text = (corrected_text.replace("<extra_id_0>", "") | |
.replace("<extra_id_1>", "") | |
.strip()) | |
corrected_batches.append(corrected_text) | |
return " ".join(corrected_batches) | |
def save_as_docx(text: str) -> io.BytesIO: | |
"""Save translated text as a DOCX file.""" | |
doc = docx.Document() | |
doc.add_paragraph(text) | |
docx_buffer = io.BytesIO() | |
doc.save(docx_buffer) | |
docx_buffer.seek(0) | |
return docx_buffer | |
def main(): | |
st.title("π Document Translation App") | |
# Load models | |
with st.spinner("Loading models... This may take a few minutes."): | |
try: | |
gemma_tuple, nllb_tuple, mt5_tuple = load_models() | |
except Exception as e: | |
st.error(f"Error loading models: {str(e)}") | |
return | |
# File upload | |
uploaded_file = st.file_uploader( | |
"Upload your document (PDF, DOCX, or TXT)", | |
type=['pdf', 'docx', 'txt'] | |
) | |
# Language selection | |
col1, col2 = st.columns(2) | |
with col1: | |
source_language = st.selectbox( | |
"Source Language", | |
options=list(SUPPORTED_LANGUAGES.keys()), | |
index=0 | |
) | |
with col2: | |
target_language = st.selectbox( | |
"Target Language", | |
options=list(SUPPORTED_LANGUAGES.keys()), | |
index=1 | |
) | |
if uploaded_file and st.button("Translate", type="primary"): | |
try: | |
progress_bar = st.progress(0) | |
# Process document | |
with st.spinner("Processing document..."): | |
text = extract_text_from_file(uploaded_file) | |
progress_bar.progress(25) | |
interpreted_text = interpret_context(text, gemma_tuple) | |
progress_bar.progress(50) | |
translated_text = translate_text( | |
interpreted_text, | |
SUPPORTED_LANGUAGES[source_language], | |
SUPPORTED_LANGUAGES[target_language], | |
nllb_tuple | |
) | |
progress_bar.progress(75) | |
final_text = correct_grammar( | |
translated_text, | |
SUPPORTED_LANGUAGES[target_language], | |
mt5_tuple | |
) | |
progress_bar.progress(90) | |
# Display result | |
st.markdown("### Translation Result") | |
st.text_area( | |
label="Translated Text", | |
value=final_text, | |
height=200, | |
key="translation_result" | |
) | |
# Download options | |
st.markdown("### Download Options") | |
col1, col2 = st.columns(2) | |
with col1: | |
# Text file download | |
text_buffer = io.BytesIO() | |
text_buffer.write(final_text.encode()) | |
text_buffer.seek(0) | |
st.download_button( | |
label="Download as TXT", | |
data=text_buffer, | |
file_name="translated_document.txt", | |
mime="text/plain" | |
) | |
with col2: | |
# DOCX file download | |
docx_buffer = save_as_docx(final_text) | |
st.download_button( | |
label="Download as DOCX", | |
data=docx_buffer, | |
file_name="translated_document.docx", | |
mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document" | |
) | |
progress_bar.progress(100) | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
if __name__ == "__main__": | |
main() |