try / app.py
gauravchand11's picture
Create app.py
67419d9 verified
raw
history blame
7.52 kB
import streamlit as st
import PyPDF2
import docx
import io
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from pathlib import Path
import tempfile
from typing import Union, Tuple
import language_tool_python
# Initialize language tool for grammar correction
language_tool = language_tool_python.LanguageTool('en-US')
# Define supported languages and their codes
SUPPORTED_LANGUAGES = {
'English': 'eng_Latn',
'Hindi': 'hin_Deva',
'Marathi': 'mar_Deva'
}
@st.cache_resource
def load_models():
"""Load and cache the translation and context interpretation models."""
# Load Gemma model for context interpretation
gemma_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
gemma_model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
device_map="auto",
torch_dtype=torch.float16
)
# Load NLLB model for translation
nllb_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
"facebook/nllb-200-distilled-600M",
device_map="auto",
torch_dtype=torch.float16
)
return (gemma_tokenizer, gemma_model), (nllb_tokenizer, nllb_model)
def extract_text_from_file(uploaded_file) -> str:
"""Extract text content from uploaded file based on its type."""
file_extension = Path(uploaded_file.name).suffix.lower()
if file_extension == '.pdf':
return extract_from_pdf(uploaded_file)
elif file_extension == '.docx':
return extract_from_docx(uploaded_file)
elif file_extension == '.txt':
return uploaded_file.getvalue().decode('utf-8')
else:
raise ValueError(f"Unsupported file format: {file_extension}")
def extract_from_pdf(file) -> str:
"""Extract text from PDF file."""
pdf_reader = PyPDF2.PdfReader(file)
text = ""
for page in pdf_reader.pages:
text += page.extract_text() + "\n"
return text.strip()
def extract_from_docx(file) -> str:
"""Extract text from DOCX file."""
doc = docx.Document(file)
text = ""
for paragraph in doc.paragraphs:
text += paragraph.text + "\n"
return text.strip()
def interpret_context(text: str, gemma_tuple: Tuple) -> str:
"""Use Gemma model to interpret context and understand regional nuances."""
tokenizer, model = gemma_tuple
prompt = f"""Analyze the following text for context and cultural nuances,
maintaining the core meaning while identifying any idiomatic expressions or
cultural references: {text}"""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_length=1024,
temperature=0.3,
pad_token_id=tokenizer.eos_token_id
)
interpreted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return interpreted_text
def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tuple) -> str:
"""Translate text using NLLB model."""
tokenizer, model = nllb_tuple
inputs = tokenizer(text, return_tensors="pt").to(model.device)
forced_bos_token_id = tokenizer.lang_code_to_id[target_lang]
outputs = model.generate(
**inputs,
forced_bos_token_id=forced_bos_token_id,
max_length=1024,
temperature=0.7,
num_beams=5
)
translated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return translated_text
def correct_grammar(text: str, target_lang: str) -> str:
"""Correct grammar and ensure tense consistency in the translated text."""
# For English target language, use LanguageTool
if target_lang == 'eng_Latn':
matches = language_tool.check(text)
corrected_text = language_tool.correct(text)
return corrected_text
# For other languages, return as-is (you may want to add specific grammar
# correction for Hindi and Marathi in a production environment)
return text
def save_as_docx(text: str) -> io.BytesIO:
"""Save translated text as a DOCX file."""
doc = docx.Document()
doc.add_paragraph(text)
docx_buffer = io.BytesIO()
doc.save(docx_buffer)
docx_buffer.seek(0)
return docx_buffer
def main():
st.title("Document Translation App")
# Load models
with st.spinner("Loading models... This may take a few minutes."):
gemma_tuple, nllb_tuple = load_models()
# File upload
uploaded_file = st.file_uploader(
"Upload your document (PDF, DOCX, or TXT)",
type=['pdf', 'docx', 'txt']
)
# Language selection
col1, col2 = st.columns(2)
with col1:
source_language = st.selectbox(
"Source Language",
options=list(SUPPORTED_LANGUAGES.keys()),
index=0
)
with col2:
target_language = st.selectbox(
"Target Language",
options=list(SUPPORTED_LANGUAGES.keys()),
index=1
)
if uploaded_file and st.button("Translate"):
try:
with st.spinner("Processing document..."):
# Extract text
text = extract_text_from_file(uploaded_file)
st.text_area("Extracted Text:", value=text, height=150)
# Interpret context
with st.spinner("Interpreting context..."):
interpreted_text = interpret_context(text, gemma_tuple)
# Translate
with st.spinner("Translating..."):
translated_text = translate_text(
interpreted_text,
SUPPORTED_LANGUAGES[source_language],
SUPPORTED_LANGUAGES[target_language],
nllb_tuple
)
# Grammar correction
with st.spinner("Correcting grammar..."):
corrected_text = correct_grammar(
translated_text,
SUPPORTED_LANGUAGES[target_language]
)
# Display result
st.subheader("Translation Result:")
st.text_area("Translated Text:", value=corrected_text, height=150)
# Download options
st.subheader("Download Translation:")
# Text file download
text_buffer = io.BytesIO()
text_buffer.write(corrected_text.encode())
text_buffer.seek(0)
st.download_button(
label="Download as TXT",
data=text_buffer,
file_name="translated_document.txt",
mime="text/plain"
)
# DOCX file download
docx_buffer = save_as_docx(corrected_text)
st.download_button(
label="Download as DOCX",
data=docx_buffer,
file_name="translated_document.docx",
mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
)
except Exception as e:
st.error(f"An error occurred: {str(e)}")
if __name__ == "__main__":
main()