gauravchand11 commited on
Commit
67419d9
·
verified ·
1 Parent(s): 5d2f328

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +220 -0
app.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import PyPDF2
3
+ import docx
4
+ import io
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
6
+ import torch
7
+ from pathlib import Path
8
+ import tempfile
9
+ from typing import Union, Tuple
10
+ import language_tool_python
11
+
12
+ # Initialize language tool for grammar correction
13
+ language_tool = language_tool_python.LanguageTool('en-US')
14
+
15
+ # Define supported languages and their codes
16
+ SUPPORTED_LANGUAGES = {
17
+ 'English': 'eng_Latn',
18
+ 'Hindi': 'hin_Deva',
19
+ 'Marathi': 'mar_Deva'
20
+ }
21
+
22
+ @st.cache_resource
23
+ def load_models():
24
+ """Load and cache the translation and context interpretation models."""
25
+ # Load Gemma model for context interpretation
26
+ gemma_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
27
+ gemma_model = AutoModelForCausalLM.from_pretrained(
28
+ "google/gemma-2b",
29
+ device_map="auto",
30
+ torch_dtype=torch.float16
31
+ )
32
+
33
+ # Load NLLB model for translation
34
+ nllb_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
35
+ nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
36
+ "facebook/nllb-200-distilled-600M",
37
+ device_map="auto",
38
+ torch_dtype=torch.float16
39
+ )
40
+
41
+ return (gemma_tokenizer, gemma_model), (nllb_tokenizer, nllb_model)
42
+
43
+ def extract_text_from_file(uploaded_file) -> str:
44
+ """Extract text content from uploaded file based on its type."""
45
+ file_extension = Path(uploaded_file.name).suffix.lower()
46
+
47
+ if file_extension == '.pdf':
48
+ return extract_from_pdf(uploaded_file)
49
+ elif file_extension == '.docx':
50
+ return extract_from_docx(uploaded_file)
51
+ elif file_extension == '.txt':
52
+ return uploaded_file.getvalue().decode('utf-8')
53
+ else:
54
+ raise ValueError(f"Unsupported file format: {file_extension}")
55
+
56
+ def extract_from_pdf(file) -> str:
57
+ """Extract text from PDF file."""
58
+ pdf_reader = PyPDF2.PdfReader(file)
59
+ text = ""
60
+ for page in pdf_reader.pages:
61
+ text += page.extract_text() + "\n"
62
+ return text.strip()
63
+
64
+ def extract_from_docx(file) -> str:
65
+ """Extract text from DOCX file."""
66
+ doc = docx.Document(file)
67
+ text = ""
68
+ for paragraph in doc.paragraphs:
69
+ text += paragraph.text + "\n"
70
+ return text.strip()
71
+
72
+ def interpret_context(text: str, gemma_tuple: Tuple) -> str:
73
+ """Use Gemma model to interpret context and understand regional nuances."""
74
+ tokenizer, model = gemma_tuple
75
+
76
+ prompt = f"""Analyze the following text for context and cultural nuances,
77
+ maintaining the core meaning while identifying any idiomatic expressions or
78
+ cultural references: {text}"""
79
+
80
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
81
+ outputs = model.generate(
82
+ **inputs,
83
+ max_length=1024,
84
+ temperature=0.3,
85
+ pad_token_id=tokenizer.eos_token_id
86
+ )
87
+
88
+ interpreted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
89
+ return interpreted_text
90
+
91
+ def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tuple) -> str:
92
+ """Translate text using NLLB model."""
93
+ tokenizer, model = nllb_tuple
94
+
95
+ inputs = tokenizer(text, return_tensors="pt").to(model.device)
96
+ forced_bos_token_id = tokenizer.lang_code_to_id[target_lang]
97
+
98
+ outputs = model.generate(
99
+ **inputs,
100
+ forced_bos_token_id=forced_bos_token_id,
101
+ max_length=1024,
102
+ temperature=0.7,
103
+ num_beams=5
104
+ )
105
+
106
+ translated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
107
+ return translated_text
108
+
109
+ def correct_grammar(text: str, target_lang: str) -> str:
110
+ """Correct grammar and ensure tense consistency in the translated text."""
111
+ # For English target language, use LanguageTool
112
+ if target_lang == 'eng_Latn':
113
+ matches = language_tool.check(text)
114
+ corrected_text = language_tool.correct(text)
115
+ return corrected_text
116
+
117
+ # For other languages, return as-is (you may want to add specific grammar
118
+ # correction for Hindi and Marathi in a production environment)
119
+ return text
120
+
121
+ def save_as_docx(text: str) -> io.BytesIO:
122
+ """Save translated text as a DOCX file."""
123
+ doc = docx.Document()
124
+ doc.add_paragraph(text)
125
+
126
+ docx_buffer = io.BytesIO()
127
+ doc.save(docx_buffer)
128
+ docx_buffer.seek(0)
129
+
130
+ return docx_buffer
131
+
132
+ def main():
133
+ st.title("Document Translation App")
134
+
135
+ # Load models
136
+ with st.spinner("Loading models... This may take a few minutes."):
137
+ gemma_tuple, nllb_tuple = load_models()
138
+
139
+ # File upload
140
+ uploaded_file = st.file_uploader(
141
+ "Upload your document (PDF, DOCX, or TXT)",
142
+ type=['pdf', 'docx', 'txt']
143
+ )
144
+
145
+ # Language selection
146
+ col1, col2 = st.columns(2)
147
+ with col1:
148
+ source_language = st.selectbox(
149
+ "Source Language",
150
+ options=list(SUPPORTED_LANGUAGES.keys()),
151
+ index=0
152
+ )
153
+
154
+ with col2:
155
+ target_language = st.selectbox(
156
+ "Target Language",
157
+ options=list(SUPPORTED_LANGUAGES.keys()),
158
+ index=1
159
+ )
160
+
161
+ if uploaded_file and st.button("Translate"):
162
+ try:
163
+ with st.spinner("Processing document..."):
164
+ # Extract text
165
+ text = extract_text_from_file(uploaded_file)
166
+ st.text_area("Extracted Text:", value=text, height=150)
167
+
168
+ # Interpret context
169
+ with st.spinner("Interpreting context..."):
170
+ interpreted_text = interpret_context(text, gemma_tuple)
171
+
172
+ # Translate
173
+ with st.spinner("Translating..."):
174
+ translated_text = translate_text(
175
+ interpreted_text,
176
+ SUPPORTED_LANGUAGES[source_language],
177
+ SUPPORTED_LANGUAGES[target_language],
178
+ nllb_tuple
179
+ )
180
+
181
+ # Grammar correction
182
+ with st.spinner("Correcting grammar..."):
183
+ corrected_text = correct_grammar(
184
+ translated_text,
185
+ SUPPORTED_LANGUAGES[target_language]
186
+ )
187
+
188
+ # Display result
189
+ st.subheader("Translation Result:")
190
+ st.text_area("Translated Text:", value=corrected_text, height=150)
191
+
192
+ # Download options
193
+ st.subheader("Download Translation:")
194
+
195
+ # Text file download
196
+ text_buffer = io.BytesIO()
197
+ text_buffer.write(corrected_text.encode())
198
+ text_buffer.seek(0)
199
+
200
+ st.download_button(
201
+ label="Download as TXT",
202
+ data=text_buffer,
203
+ file_name="translated_document.txt",
204
+ mime="text/plain"
205
+ )
206
+
207
+ # DOCX file download
208
+ docx_buffer = save_as_docx(corrected_text)
209
+ st.download_button(
210
+ label="Download as DOCX",
211
+ data=docx_buffer,
212
+ file_name="translated_document.docx",
213
+ mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
214
+ )
215
+
216
+ except Exception as e:
217
+ st.error(f"An error occurred: {str(e)}")
218
+
219
+ if __name__ == "__main__":
220
+ main()