Spaces:
Build error
Build error
import streamlit as st | |
import PyPDF2 | |
import docx | |
import io | |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration | |
import torch | |
from pathlib import Path | |
import tempfile | |
from typing import Union, Tuple | |
import os | |
# Get Hugging Face token from environment variables | |
HF_TOKEN = os.environ.get('HF_TOKEN') | |
if not HF_TOKEN: | |
st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.") | |
st.stop() | |
# Define supported languages and their codes | |
SUPPORTED_LANGUAGES = { | |
'English': 'eng_Latn', | |
'Hindi': 'hin_Deva', | |
'Marathi': 'mar_Deva' | |
} | |
# Language codes for MT5 | |
MT5_LANG_CODES = { | |
'eng_Latn': 'en', | |
'hin_Deva': 'hi', | |
'mar_Deva': 'mr' | |
} | |
def load_models(): | |
"""Load and cache the translation, context interpretation, and grammar correction models.""" | |
# Load Gemma model for context interpretation | |
gemma_tokenizer = AutoTokenizer.from_pretrained( | |
"google/gemma-2b", | |
token=HF_TOKEN | |
) | |
gemma_model = AutoModelForCausalLM.from_pretrained( | |
"google/gemma-2b", | |
device_map="auto", | |
torch_dtype=torch.float16, | |
token=HF_TOKEN | |
) | |
# Load NLLB model for translation | |
nllb_tokenizer = AutoTokenizer.from_pretrained( | |
"facebook/nllb-200-distilled-600M", | |
token=HF_TOKEN | |
) | |
nllb_model = AutoModelForSeq2SeqLM.from_pretrained( | |
"facebook/nllb-200-distilled-600M", | |
device_map="auto", | |
torch_dtype=torch.float16, | |
token=HF_TOKEN | |
) | |
# Load MT5 model for grammar correction | |
mt5_tokenizer = AutoTokenizer.from_pretrained( | |
"google/mt5-small", | |
token=HF_TOKEN | |
) | |
mt5_model = T5ForConditionalGeneration.from_pretrained( | |
"google/mt5-small", | |
device_map="auto", | |
torch_dtype=torch.float16, | |
token=HF_TOKEN | |
) | |
return (gemma_tokenizer, gemma_model), (nllb_tokenizer, nllb_model), (mt5_tokenizer, mt5_model) | |
def extract_text_from_file(uploaded_file) -> str: | |
"""Extract text content from uploaded file based on its type.""" | |
file_extension = Path(uploaded_file.name).suffix.lower() | |
if file_extension == '.pdf': | |
return extract_from_pdf(uploaded_file) | |
elif file_extension == '.docx': | |
return extract_from_docx(uploaded_file) | |
elif file_extension == '.txt': | |
return uploaded_file.getvalue().decode('utf-8') | |
else: | |
raise ValueError(f"Unsupported file format: {file_extension}") | |
def extract_from_pdf(file) -> str: | |
"""Extract text from PDF file.""" | |
pdf_reader = PyPDF2.PdfReader(file) | |
text = "" | |
for page in pdf_reader.pages: | |
text += page.extract_text() + "\n" | |
return text.strip() | |
def extract_from_docx(file) -> str: | |
"""Extract text from DOCX file.""" | |
doc = docx.Document(file) | |
text = "" | |
for paragraph in doc.paragraphs: | |
text += paragraph.text + "\n" | |
return text.strip() | |
def interpret_context(text: str, gemma_tuple: Tuple) -> str: | |
"""Use Gemma model to interpret context and understand regional nuances.""" | |
tokenizer, model = gemma_tuple | |
prompt = f"""Analyze the following text for context and cultural nuances, | |
maintaining the core meaning while identifying any idiomatic expressions or | |
cultural references: {text}""" | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
outputs = model.generate( | |
**inputs, | |
max_length=1024, | |
temperature=0.3, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
interpreted_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return interpreted_text | |
def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tuple) -> str: | |
"""Translate text using NLLB model.""" | |
tokenizer, model = nllb_tuple | |
inputs = tokenizer(text, return_tensors="pt").to(model.device) | |
forced_bos_token_id = tokenizer.lang_code_to_id[target_lang] | |
outputs = model.generate( | |
**inputs, | |
forced_bos_token_id=forced_bos_token_id, | |
max_length=1024, | |
temperature=0.7, | |
num_beams=5 | |
) | |
translated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
return translated_text | |
def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str: | |
""" | |
Correct grammar using MT5 model for all supported languages. | |
Uses a text-to-text approach with language-specific prompts. | |
""" | |
tokenizer, model = mt5_tuple | |
lang_code = MT5_LANG_CODES[target_lang] | |
# Language-specific prompts for grammar correction | |
prompts = { | |
'en': f"grammar: {text}", | |
'hi': f"व्याकरण सुधार: {text}", | |
'mr': f"व्याकरण सुधारणा: {text}" | |
} | |
prompt = prompts[lang_code] | |
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(model.device) | |
outputs = model.generate( | |
**inputs, | |
max_length=512, | |
num_beams=5, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Clean up any artifacts from the model output | |
corrected_text = corrected_text.replace("grammar:", "").replace("व्याकरण सुधार:", "").replace("व्याकरण सुधारणा:", "").strip() | |
return corrected_text | |
def save_as_docx(text: str) -> io.BytesIO: | |
"""Save translated text as a DOCX file.""" | |
doc = docx.Document() | |
doc.add_paragraph(text) | |
docx_buffer = io.BytesIO() | |
doc.save(docx_buffer) | |
docx_buffer.seek(0) | |
return docx_buffer | |
def main(): | |
st.title("Document Translation App") | |
# Load models | |
with st.spinner("Loading models... This may take a few minutes."): | |
try: | |
gemma_tuple, nllb_tuple, mt5_tuple = load_models() | |
except Exception as e: | |
st.error(f"Error loading models: {str(e)}") | |
st.error("Please check if the HF_TOKEN is valid and has the necessary permissions.") | |
st.stop() | |
# File upload | |
uploaded_file = st.file_uploader( | |
"Upload your document (PDF, DOCX, or TXT)", | |
type=['pdf', 'docx', 'txt'] | |
) | |
# Language selection | |
col1, col2 = st.columns(2) | |
with col1: | |
source_language = st.selectbox( | |
"Source Language", | |
options=list(SUPPORTED_LANGUAGES.keys()), | |
index=0 | |
) | |
with col2: | |
target_language = st.selectbox( | |
"Target Language", | |
options=list(SUPPORTED_LANGUAGES.keys()), | |
index=1 | |
) | |
if uploaded_file and st.button("Translate"): | |
try: | |
with st.spinner("Processing document..."): | |
# Extract text | |
text = extract_text_from_file(uploaded_file) | |
st.text_area("Extracted Text:", value=text, height=150) | |
# Interpret context | |
with st.spinner("Interpreting context..."): | |
interpreted_text = interpret_context(text, gemma_tuple) | |
# Translate | |
with st.spinner("Translating..."): | |
translated_text = translate_text( | |
interpreted_text, | |
SUPPORTED_LANGUAGES[source_language], | |
SUPPORTED_LANGUAGES[target_language], | |
nllb_tuple | |
) | |
# Grammar correction | |
with st.spinner("Correcting grammar..."): | |
corrected_text = correct_grammar( | |
translated_text, | |
SUPPORTED_LANGUAGES[target_language], | |
mt5_tuple | |
) | |
# Display result | |
st.subheader("Translation Result:") | |
st.text_area("Translated Text:", value=corrected_text, height=150) | |
# Download options | |
st.subheader("Download Translation:") | |
# Text file download | |
text_buffer = io.BytesIO() | |
text_buffer.write(corrected_text.encode()) | |
text_buffer.seek(0) | |
st.download_button( | |
label="Download as TXT", | |
data=text_buffer, | |
file_name="translated_document.txt", | |
mime="text/plain" | |
) | |
# DOCX file download | |
docx_buffer = save_as_docx(corrected_text) | |
st.download_button( | |
label="Download as DOCX", | |
data=docx_buffer, | |
file_name="translated_document.docx", | |
mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document" | |
) | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
if __name__ == "__main__": | |
main() |