Spaces:
Build error
Build error
import streamlit as st | |
import PyPDF2 | |
import docx | |
import io | |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
from pathlib import Path | |
import tempfile | |
from typing import Union, Tuple | |
import language_tool_python | |
# Initialize language tool for grammar correction | |
language_tool = language_tool_python.LanguageTool('en-US') | |
# Define supported languages and their codes | |
SUPPORTED_LANGUAGES = { | |
'English': 'eng_Latn', | |
'Hindi': 'hin_Deva', | |
'Marathi': 'mar_Deva' | |
} | |
def load_models(): | |
"""Load and cache the translation and context interpretation models.""" | |
# Load Gemma model for context interpretation | |
gemma_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") | |
gemma_model = AutoModelForCausalLM.from_pretrained( | |
"google/gemma-2b", | |
device_map="auto", | |
torch_dtype=torch.float16 | |
) | |
# Load NLLB model for translation | |
nllb_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") | |
nllb_model = AutoModelForSeq2SeqLM.from_pretrained( | |
"facebook/nllb-200-distilled-600M", | |
device_map="auto", | |
torch_dtype=torch.float16 | |
) | |
return (gemma_tokenizer, gemma_model), (nllb_tokenizer, nllb_model) | |
def extract_text_from_file(uploaded_file) -> str: | |
"""Extract text content from uploaded file based on its type.""" | |
file_extension = Path(uploaded_file.name).suffix.lower() | |
if file_extension == '.pdf': | |
return extract_from_pdf(uploaded_file) | |
elif file_extension == '.docx': | |
return extract_from_docx(uploaded_file) | |
elif file_extension == '.txt': | |
return uploaded_file.getvalue().decode('utf-8') | |
else: | |
raise ValueError(f"Unsupported file format: {file_extension}") | |
def extract_from_pdf(file) -> str: | |
"""Extract text from PDF file.""" | |
pdf_reader = PyPDF2.PdfReader(file) | |
text = "" | |
for page in pdf_reader.pages: | |
text += page.extract_text() + "\n" | |
return text.strip() | |
def extract_from_docx(file) -> str: | |
"""Extract text from DOCX file.""" | |
doc = docx.Document(file) | |
text = "" | |
for paragraph in doc.paragraphs: | |
text += paragraph.text + "\n" | |
return text.strip() | |
def interpret_context(text: str, gemma_tuple: Tuple) -> str: | |
"""Use Gemma model to interpret context and understand regional nuances.""" | |
tokenizer, model = gemma_tuple | |
prompt = f"""Analyze the following text for context and cultural nuances, | |
maintaining the core meaning while identifying any idiomatic expressions or | |
cultural references: {text}""" | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
outputs = model.generate( | |
**inputs, | |
max_length=1024, | |
temperature=0.3, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
interpreted_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return interpreted_text | |
def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tuple) -> str: | |
"""Translate text using NLLB model.""" | |
tokenizer, model = nllb_tuple | |
inputs = tokenizer(text, return_tensors="pt").to(model.device) | |
forced_bos_token_id = tokenizer.lang_code_to_id[target_lang] | |
outputs = model.generate( | |
**inputs, | |
forced_bos_token_id=forced_bos_token_id, | |
max_length=1024, | |
temperature=0.7, | |
num_beams=5 | |
) | |
translated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
return translated_text | |
def correct_grammar(text: str, target_lang: str) -> str: | |
"""Correct grammar and ensure tense consistency in the translated text.""" | |
# For English target language, use LanguageTool | |
if target_lang == 'eng_Latn': | |
matches = language_tool.check(text) | |
corrected_text = language_tool.correct(text) | |
return corrected_text | |
# For other languages, return as-is (you may want to add specific grammar | |
# correction for Hindi and Marathi in a production environment) | |
return text | |
def save_as_docx(text: str) -> io.BytesIO: | |
"""Save translated text as a DOCX file.""" | |
doc = docx.Document() | |
doc.add_paragraph(text) | |
docx_buffer = io.BytesIO() | |
doc.save(docx_buffer) | |
docx_buffer.seek(0) | |
return docx_buffer | |
def main(): | |
st.title("Document Translation App") | |
# Load models | |
with st.spinner("Loading models... This may take a few minutes."): | |
gemma_tuple, nllb_tuple = load_models() | |
# File upload | |
uploaded_file = st.file_uploader( | |
"Upload your document (PDF, DOCX, or TXT)", | |
type=['pdf', 'docx', 'txt'] | |
) | |
# Language selection | |
col1, col2 = st.columns(2) | |
with col1: | |
source_language = st.selectbox( | |
"Source Language", | |
options=list(SUPPORTED_LANGUAGES.keys()), | |
index=0 | |
) | |
with col2: | |
target_language = st.selectbox( | |
"Target Language", | |
options=list(SUPPORTED_LANGUAGES.keys()), | |
index=1 | |
) | |
if uploaded_file and st.button("Translate"): | |
try: | |
with st.spinner("Processing document..."): | |
# Extract text | |
text = extract_text_from_file(uploaded_file) | |
st.text_area("Extracted Text:", value=text, height=150) | |
# Interpret context | |
with st.spinner("Interpreting context..."): | |
interpreted_text = interpret_context(text, gemma_tuple) | |
# Translate | |
with st.spinner("Translating..."): | |
translated_text = translate_text( | |
interpreted_text, | |
SUPPORTED_LANGUAGES[source_language], | |
SUPPORTED_LANGUAGES[target_language], | |
nllb_tuple | |
) | |
# Grammar correction | |
with st.spinner("Correcting grammar..."): | |
corrected_text = correct_grammar( | |
translated_text, | |
SUPPORTED_LANGUAGES[target_language] | |
) | |
# Display result | |
st.subheader("Translation Result:") | |
st.text_area("Translated Text:", value=corrected_text, height=150) | |
# Download options | |
st.subheader("Download Translation:") | |
# Text file download | |
text_buffer = io.BytesIO() | |
text_buffer.write(corrected_text.encode()) | |
text_buffer.seek(0) | |
st.download_button( | |
label="Download as TXT", | |
data=text_buffer, | |
file_name="translated_document.txt", | |
mime="text/plain" | |
) | |
# DOCX file download | |
docx_buffer = save_as_docx(corrected_text) | |
st.download_button( | |
label="Download as DOCX", | |
data=docx_buffer, | |
file_name="translated_document.docx", | |
mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document" | |
) | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
if __name__ == "__main__": | |
main() |