try / app.py
gauravchand11's picture
Update app.py
c98f2e3 verified
raw
history blame
9.22 kB
import streamlit as st
import PyPDF2
import docx
import io
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
import torch
from pathlib import Path
import tempfile
from typing import Union, Tuple
import os
# Get Hugging Face token from environment variables
HF_TOKEN = os.environ.get('HF_TOKEN')
if not HF_TOKEN:
st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.")
st.stop()
# Define supported languages and their codes
SUPPORTED_LANGUAGES = {
'English': 'eng_Latn',
'Hindi': 'hin_Deva',
'Marathi': 'mar_Deva'
}
# Language codes for MT5
MT5_LANG_CODES = {
'eng_Latn': 'en',
'hin_Deva': 'hi',
'mar_Deva': 'mr'
}
@st.cache_resource
def load_models():
"""Load and cache the translation, context interpretation, and grammar correction models."""
# Load Gemma model for context interpretation
gemma_tokenizer = AutoTokenizer.from_pretrained(
"google/gemma-2b",
token=HF_TOKEN
)
gemma_model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
device_map="auto",
torch_dtype=torch.float16,
token=HF_TOKEN
)
# Load NLLB model for translation
nllb_tokenizer = AutoTokenizer.from_pretrained(
"facebook/nllb-200-distilled-600M",
token=HF_TOKEN
)
nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
"facebook/nllb-200-distilled-600M",
device_map="auto",
torch_dtype=torch.float16,
token=HF_TOKEN
)
# Load MT5 model for grammar correction
mt5_tokenizer = AutoTokenizer.from_pretrained(
"google/mt5-small",
token=HF_TOKEN
)
mt5_model = T5ForConditionalGeneration.from_pretrained(
"google/mt5-small",
device_map="auto",
torch_dtype=torch.float16,
token=HF_TOKEN
)
return (gemma_tokenizer, gemma_model), (nllb_tokenizer, nllb_model), (mt5_tokenizer, mt5_model)
def extract_text_from_file(uploaded_file) -> str:
"""Extract text content from uploaded file based on its type."""
file_extension = Path(uploaded_file.name).suffix.lower()
if file_extension == '.pdf':
return extract_from_pdf(uploaded_file)
elif file_extension == '.docx':
return extract_from_docx(uploaded_file)
elif file_extension == '.txt':
return uploaded_file.getvalue().decode('utf-8')
else:
raise ValueError(f"Unsupported file format: {file_extension}")
def extract_from_pdf(file) -> str:
"""Extract text from PDF file."""
pdf_reader = PyPDF2.PdfReader(file)
text = ""
for page in pdf_reader.pages:
text += page.extract_text() + "\n"
return text.strip()
def extract_from_docx(file) -> str:
"""Extract text from DOCX file."""
doc = docx.Document(file)
text = ""
for paragraph in doc.paragraphs:
text += paragraph.text + "\n"
return text.strip()
def interpret_context(text: str, gemma_tuple: Tuple) -> str:
"""Use Gemma model to interpret context and understand regional nuances."""
tokenizer, model = gemma_tuple
prompt = f"""Analyze the following text for context and cultural nuances,
maintaining the core meaning while identifying any idiomatic expressions or
cultural references: {text}"""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_length=1024,
temperature=0.3,
pad_token_id=tokenizer.eos_token_id
)
interpreted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return interpreted_text
def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tuple) -> str:
"""Translate text using NLLB model."""
tokenizer, model = nllb_tuple
inputs = tokenizer(text, return_tensors="pt").to(model.device)
forced_bos_token_id = tokenizer.lang_code_to_id[target_lang]
outputs = model.generate(
**inputs,
forced_bos_token_id=forced_bos_token_id,
max_length=1024,
temperature=0.7,
num_beams=5
)
translated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return translated_text
def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
"""
Correct grammar using MT5 model for all supported languages.
Uses a text-to-text approach with language-specific prompts.
"""
tokenizer, model = mt5_tuple
lang_code = MT5_LANG_CODES[target_lang]
# Language-specific prompts for grammar correction
prompts = {
'en': f"grammar: {text}",
'hi': f"व्याकरण सुधार: {text}",
'mr': f"व्याकरण सुधारणा: {text}"
}
prompt = prompts[lang_code]
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(model.device)
outputs = model.generate(
**inputs,
max_length=512,
num_beams=5,
temperature=0.7,
top_p=0.9,
do_sample=True
)
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up any artifacts from the model output
corrected_text = corrected_text.replace("grammar:", "").replace("व्याकरण सुधार:", "").replace("व्याकरण सुधारणा:", "").strip()
return corrected_text
def save_as_docx(text: str) -> io.BytesIO:
"""Save translated text as a DOCX file."""
doc = docx.Document()
doc.add_paragraph(text)
docx_buffer = io.BytesIO()
doc.save(docx_buffer)
docx_buffer.seek(0)
return docx_buffer
def main():
st.title("Document Translation App")
# Load models
with st.spinner("Loading models... This may take a few minutes."):
try:
gemma_tuple, nllb_tuple, mt5_tuple = load_models()
except Exception as e:
st.error(f"Error loading models: {str(e)}")
st.error("Please check if the HF_TOKEN is valid and has the necessary permissions.")
st.stop()
# File upload
uploaded_file = st.file_uploader(
"Upload your document (PDF, DOCX, or TXT)",
type=['pdf', 'docx', 'txt']
)
# Language selection
col1, col2 = st.columns(2)
with col1:
source_language = st.selectbox(
"Source Language",
options=list(SUPPORTED_LANGUAGES.keys()),
index=0
)
with col2:
target_language = st.selectbox(
"Target Language",
options=list(SUPPORTED_LANGUAGES.keys()),
index=1
)
if uploaded_file and st.button("Translate"):
try:
with st.spinner("Processing document..."):
# Extract text
text = extract_text_from_file(uploaded_file)
st.text_area("Extracted Text:", value=text, height=150)
# Interpret context
with st.spinner("Interpreting context..."):
interpreted_text = interpret_context(text, gemma_tuple)
# Translate
with st.spinner("Translating..."):
translated_text = translate_text(
interpreted_text,
SUPPORTED_LANGUAGES[source_language],
SUPPORTED_LANGUAGES[target_language],
nllb_tuple
)
# Grammar correction
with st.spinner("Correcting grammar..."):
corrected_text = correct_grammar(
translated_text,
SUPPORTED_LANGUAGES[target_language],
mt5_tuple
)
# Display result
st.subheader("Translation Result:")
st.text_area("Translated Text:", value=corrected_text, height=150)
# Download options
st.subheader("Download Translation:")
# Text file download
text_buffer = io.BytesIO()
text_buffer.write(corrected_text.encode())
text_buffer.seek(0)
st.download_button(
label="Download as TXT",
data=text_buffer,
file_name="translated_document.txt",
mime="text/plain"
)
# DOCX file download
docx_buffer = save_as_docx(corrected_text)
st.download_button(
label="Download as DOCX",
data=docx_buffer,
file_name="translated_document.docx",
mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
)
except Exception as e:
st.error(f"An error occurred: {str(e)}")
if __name__ == "__main__":
main()