gauravchand11 commited on
Commit
698647f
·
verified ·
1 Parent(s): 90c759f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -142
app.py CHANGED
@@ -12,53 +12,53 @@ import sys
12
  from datetime import datetime, timezone
13
  import warnings
14
 
15
- # Filter out specific warnings
16
- warnings.filterwarnings('ignore', category=UserWarning, module='transformers.convert_slow_tokenizer')
17
- warnings.filterwarnings('ignore', category=UserWarning, module='transformers.tokenization_utils_base')
18
 
19
- # Custom styling
20
  st.set_page_config(
21
  page_title="Document Translation App",
22
  page_icon="🌐",
23
  layout="wide"
24
  )
25
 
26
- # Display current information in sidebar
27
- current_time = datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')
28
- st.sidebar.markdown("""
29
  ### System Information
30
- **Current UTC Time:** {}
31
- **User:** {}
32
- """.format(current_time, os.environ.get('USER', 'gauravchand')))
33
 
34
- # Get Hugging Face token from environment variables
35
  HF_TOKEN = os.environ.get('HF_TOKEN')
36
  if not HF_TOKEN:
37
  st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.")
38
  st.stop()
39
 
40
- # Define supported languages and their codes
41
  SUPPORTED_LANGUAGES = {
42
  'English': 'eng_Latn',
43
  'Hindi': 'hin_Deva',
44
  'Marathi': 'mar_Deva'
45
  }
46
 
47
- # Language codes for MT5
48
  MT5_LANG_CODES = {
49
  'eng_Latn': 'en',
50
  'hin_Deva': 'hi',
51
  'mar_Deva': 'mr'
52
  }
53
 
 
 
 
 
54
  @st.cache_resource
55
  def load_models():
56
  """Load and cache the translation and context interpretation models."""
57
  try:
58
- # Set device
59
  device = "cuda" if torch.cuda.is_available() else "cpu"
60
 
61
- # Load Gemma model for context interpretation
62
  gemma_tokenizer = AutoTokenizer.from_pretrained(
63
  "google/gemma-2b",
64
  token=HF_TOKEN,
@@ -72,11 +72,11 @@ def load_models():
72
  trust_remote_code=True
73
  )
74
 
75
- # Load NLLB model for translation
76
  nllb_tokenizer = AutoTokenizer.from_pretrained(
77
  "facebook/nllb-200-distilled-600M",
78
  token=HF_TOKEN,
79
- src_lang="eng_Latn",
80
  trust_remote_code=True
81
  )
82
  nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
@@ -87,21 +87,20 @@ def load_models():
87
  trust_remote_code=True
88
  )
89
 
90
- # Load MT5 model for grammar correction
91
  mt5_tokenizer = AutoTokenizer.from_pretrained(
92
- "google/mt5-base", # Changed to base model for better performance
93
  token=HF_TOKEN,
94
  trust_remote_code=True
95
  )
96
  mt5_model = MT5ForConditionalGeneration.from_pretrained(
97
- "google/mt5-base", # Changed to base model for better performance
98
  token=HF_TOKEN,
99
  torch_dtype=torch.float16,
100
  device_map="auto" if torch.cuda.is_available() else None,
101
  trust_remote_code=True
102
  )
103
 
104
- # Move models to device if not using device_map="auto"
105
  if not torch.cuda.is_available():
106
  gemma_model = gemma_model.to(device)
107
  nllb_model = nllb_model.to(device)
@@ -111,90 +110,11 @@ def load_models():
111
 
112
  except Exception as e:
113
  st.error(f"Error loading models: {str(e)}")
114
- st.error("Detailed error information:")
115
  st.error(f"Python version: {sys.version}")
116
  st.error(f"PyTorch version: {torch.__version__}")
117
  raise e
118
 
119
- def extract_text_from_file(uploaded_file) -> str:
120
- """Extract text content from uploaded file based on its type."""
121
- file_extension = Path(uploaded_file.name).suffix.lower()
122
-
123
- if file_extension == '.pdf':
124
- return extract_from_pdf(uploaded_file)
125
- elif file_extension == '.docx':
126
- return extract_from_docx(uploaded_file)
127
- elif file_extension == '.txt':
128
- return uploaded_file.getvalue().decode('utf-8')
129
- else:
130
- raise ValueError(f"Unsupported file format: {file_extension}")
131
-
132
- def extract_from_pdf(file) -> str:
133
- """Extract text from PDF file."""
134
- pdf_reader = PyPDF2.PdfReader(file)
135
- text = ""
136
- for page in pdf_reader.pages:
137
- text += page.extract_text() + "\n"
138
- return text.strip()
139
-
140
- def extract_from_docx(file) -> str:
141
- """Extract text from DOCX file."""
142
- doc = docx.Document(file)
143
- text = ""
144
- for paragraph in doc.paragraphs:
145
- text += paragraph.text + "\n"
146
- return text.strip()
147
-
148
- def batch_process_text(text: str, max_length: int = 512) -> list:
149
- """Split text into batches for processing."""
150
- words = text.split()
151
- batches = []
152
- current_batch = []
153
- current_length = 0
154
-
155
- for word in words:
156
- if current_length + len(word) + 1 > max_length:
157
- batches.append(" ".join(current_batch))
158
- current_batch = [word]
159
- current_length = len(word)
160
- else:
161
- current_batch.append(word)
162
- current_length += len(word) + 1
163
-
164
- if current_batch:
165
- batches.append(" ".join(current_batch))
166
-
167
- return batches
168
-
169
- @torch.no_grad()
170
- def interpret_context(text: str, gemma_tuple: Tuple) -> str:
171
- """Use Gemma model to interpret context and understand regional nuances."""
172
- tokenizer, model = gemma_tuple
173
-
174
- batches = batch_process_text(text)
175
- interpreted_batches = []
176
-
177
- for batch in batches:
178
- prompt = f"""Analyze and maintain the core meaning of this text: {batch}"""
179
-
180
- inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
181
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
182
-
183
- outputs = model.generate(
184
- **inputs,
185
- max_length=512,
186
- do_sample=True,
187
- temperature=0.3,
188
- pad_token_id=tokenizer.eos_token_id,
189
- num_return_sequences=1
190
- )
191
-
192
- interpreted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
193
- # Remove the prompt from the output
194
- interpreted_text = interpreted_text.replace(prompt, "").strip()
195
- interpreted_batches.append(interpreted_text)
196
-
197
- return " ".join(interpreted_batches)
198
 
199
  @torch.no_grad()
200
  def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tuple) -> str:
@@ -204,13 +124,20 @@ def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tu
204
  batches = batch_process_text(text)
205
  translated_batches = []
206
 
 
 
 
207
  for batch in batches:
 
208
  inputs = tokenizer(batch, return_tensors="pt", max_length=512, truncation=True)
209
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
210
 
 
 
 
211
  outputs = model.generate(
212
  **inputs,
213
- forced_bos_token_id=tokenizer.lang_code_to_id[target_lang],
214
  max_length=512,
215
  do_sample=True,
216
  temperature=0.7,
@@ -225,22 +152,20 @@ def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tu
225
 
226
  @torch.no_grad()
227
  def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
228
- """Correct grammar using MT5 model for all supported languages."""
229
  tokenizer, model = mt5_tuple
230
  lang_code = MT5_LANG_CODES[target_lang]
231
 
232
- # Language-specific prompts for grammar correction
233
  prompts = {
234
  'en': "Fix grammar: ",
235
- 'hi': "व्याकरण: ",
236
- 'mr': "व्याकरण: "
237
  }
238
 
239
  batches = batch_process_text(text)
240
  corrected_batches = []
241
 
242
  for batch in batches:
243
- # Prepare input with target language prefix
244
  input_text = f"{prompts[lang_code]}{batch}"
245
  inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
246
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
@@ -251,29 +176,20 @@ def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
251
  num_beams=5,
252
  length_penalty=1.0,
253
  early_stopping=True,
254
- do_sample=False # Disable sampling for more stable output
 
255
  )
256
 
257
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
258
- # Clean up the output
259
  for prefix in prompts.values():
260
  corrected_text = corrected_text.replace(prefix, "")
261
- corrected_text = corrected_text.replace("<extra_id_0>", "").replace("<extra_id_1>", "").strip()
 
 
262
  corrected_batches.append(corrected_text)
263
 
264
  return " ".join(corrected_batches)
265
 
266
- def save_as_docx(text: str) -> io.BytesIO:
267
- """Save translated text as a DOCX file."""
268
- doc = docx.Document()
269
- doc.add_paragraph(text)
270
-
271
- docx_buffer = io.BytesIO()
272
- doc.save(docx_buffer)
273
- docx_buffer.seek(0)
274
-
275
- return docx_buffer
276
-
277
  def main():
278
  st.title("🌐 Document Translation App")
279
 
@@ -283,7 +199,6 @@ def main():
283
  gemma_tuple, nllb_tuple, mt5_tuple = load_models()
284
  except Exception as e:
285
  st.error(f"Error loading models: {str(e)}")
286
- st.error("Please check if the HF_TOKEN is valid and has the necessary permissions.")
287
  return
288
 
289
  # File upload
@@ -312,39 +227,34 @@ def main():
312
  try:
313
  progress_bar = st.progress(0)
314
 
315
- # Extract text
316
- text = extract_text_from_file(uploaded_file)
317
- progress_bar.progress(20)
318
-
319
- # Interpret context
320
- with st.spinner("Interpreting context..."):
321
  interpreted_text = interpret_context(text, gemma_tuple)
322
- progress_bar.progress(40)
323
-
324
- # Translate
325
- with st.spinner("Translating..."):
326
  translated_text = translate_text(
327
  interpreted_text,
328
  SUPPORTED_LANGUAGES[source_language],
329
  SUPPORTED_LANGUAGES[target_language],
330
  nllb_tuple
331
  )
332
- progress_bar.progress(70)
333
-
334
- # Grammar correction
335
- with st.spinner("Correcting grammar..."):
336
- corrected_text = correct_grammar(
337
  translated_text,
338
  SUPPORTED_LANGUAGES[target_language],
339
  mt5_tuple
340
  )
341
- progress_bar.progress(90)
342
 
343
  # Display result
344
  st.markdown("### Translation Result")
345
  st.text_area(
346
  label="Translated Text",
347
- value=corrected_text,
348
  height=200,
349
  key="translation_result"
350
  )
@@ -356,7 +266,7 @@ def main():
356
  with col1:
357
  # Text file download
358
  text_buffer = io.BytesIO()
359
- text_buffer.write(corrected_text.encode())
360
  text_buffer.seek(0)
361
 
362
  st.download_button(
@@ -368,7 +278,7 @@ def main():
368
 
369
  with col2:
370
  # DOCX file download
371
- docx_buffer = save_as_docx(corrected_text)
372
  st.download_button(
373
  label="Download as DOCX",
374
  data=docx_buffer,
@@ -380,6 +290,6 @@ def main():
380
 
381
  except Exception as e:
382
  st.error(f"An error occurred: {str(e)}")
383
-
384
  if __name__ == "__main__":
385
  main()
 
12
  from datetime import datetime, timezone
13
  import warnings
14
 
15
+ # Filter warnings
16
+ warnings.filterwarnings('ignore', category=UserWarning)
 
17
 
18
+ # Page config
19
  st.set_page_config(
20
  page_title="Document Translation App",
21
  page_icon="🌐",
22
  layout="wide"
23
  )
24
 
25
+ # Display system info
26
+ st.sidebar.markdown(f"""
 
27
  ### System Information
28
+ **Current UTC Time:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}
29
+ **User:** {os.environ.get('USER', 'gauravchand')}
30
+ """)
31
 
32
+ # Get Hugging Face token
33
  HF_TOKEN = os.environ.get('HF_TOKEN')
34
  if not HF_TOKEN:
35
  st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.")
36
  st.stop()
37
 
38
+ # Language configurations
39
  SUPPORTED_LANGUAGES = {
40
  'English': 'eng_Latn',
41
  'Hindi': 'hin_Deva',
42
  'Marathi': 'mar_Deva'
43
  }
44
 
 
45
  MT5_LANG_CODES = {
46
  'eng_Latn': 'en',
47
  'hin_Deva': 'hi',
48
  'mar_Deva': 'mr'
49
  }
50
 
51
+ def get_nllb_lang_token(lang_code: str) -> str:
52
+ """Get the correct token format for NLLB model."""
53
+ return f"___{lang_code}___"
54
+
55
  @st.cache_resource
56
  def load_models():
57
  """Load and cache the translation and context interpretation models."""
58
  try:
 
59
  device = "cuda" if torch.cuda.is_available() else "cpu"
60
 
61
+ # Load Gemma model
62
  gemma_tokenizer = AutoTokenizer.from_pretrained(
63
  "google/gemma-2b",
64
  token=HF_TOKEN,
 
72
  trust_remote_code=True
73
  )
74
 
75
+ # Load NLLB model
76
  nllb_tokenizer = AutoTokenizer.from_pretrained(
77
  "facebook/nllb-200-distilled-600M",
78
  token=HF_TOKEN,
79
+ use_fast=False, # Use slow tokenizer for better compatibility
80
  trust_remote_code=True
81
  )
82
  nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
 
87
  trust_remote_code=True
88
  )
89
 
90
+ # Load MT5 model
91
  mt5_tokenizer = AutoTokenizer.from_pretrained(
92
+ "google/mt5-base",
93
  token=HF_TOKEN,
94
  trust_remote_code=True
95
  )
96
  mt5_model = MT5ForConditionalGeneration.from_pretrained(
97
+ "google/mt5-base",
98
  token=HF_TOKEN,
99
  torch_dtype=torch.float16,
100
  device_map="auto" if torch.cuda.is_available() else None,
101
  trust_remote_code=True
102
  )
103
 
 
104
  if not torch.cuda.is_available():
105
  gemma_model = gemma_model.to(device)
106
  nllb_model = nllb_model.to(device)
 
110
 
111
  except Exception as e:
112
  st.error(f"Error loading models: {str(e)}")
 
113
  st.error(f"Python version: {sys.version}")
114
  st.error(f"PyTorch version: {torch.__version__}")
115
  raise e
116
 
117
+ # [Previous file handling functions remain the same]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  @torch.no_grad()
120
  def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tuple) -> str:
 
124
  batches = batch_process_text(text)
125
  translated_batches = []
126
 
127
+ # Get target language token
128
+ target_lang_token = get_nllb_lang_token(target_lang)
129
+
130
  for batch in batches:
131
+ # Prepare input text
132
  inputs = tokenizer(batch, return_tensors="pt", max_length=512, truncation=True)
133
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
134
 
135
+ # Get target language token ID
136
+ target_lang_id = tokenizer.convert_tokens_to_ids(target_lang_token)
137
+
138
  outputs = model.generate(
139
  **inputs,
140
+ forced_bos_token_id=target_lang_id,
141
  max_length=512,
142
  do_sample=True,
143
  temperature=0.7,
 
152
 
153
  @torch.no_grad()
154
  def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
155
+ """Correct grammar using MT5 model."""
156
  tokenizer, model = mt5_tuple
157
  lang_code = MT5_LANG_CODES[target_lang]
158
 
 
159
  prompts = {
160
  'en': "Fix grammar: ",
161
+ 'hi': "व्याकरण सुधार: ",
162
+ 'mr': "व्याकरण सुधार: "
163
  }
164
 
165
  batches = batch_process_text(text)
166
  corrected_batches = []
167
 
168
  for batch in batches:
 
169
  input_text = f"{prompts[lang_code]}{batch}"
170
  inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
171
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
 
176
  num_beams=5,
177
  length_penalty=1.0,
178
  early_stopping=True,
179
+ no_repeat_ngram_size=2,
180
+ do_sample=False
181
  )
182
 
183
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
184
  for prefix in prompts.values():
185
  corrected_text = corrected_text.replace(prefix, "")
186
+ corrected_text = (corrected_text.replace("<extra_id_0>", "")
187
+ .replace("<extra_id_1>", "")
188
+ .strip())
189
  corrected_batches.append(corrected_text)
190
 
191
  return " ".join(corrected_batches)
192
 
 
 
 
 
 
 
 
 
 
 
 
193
  def main():
194
  st.title("🌐 Document Translation App")
195
 
 
199
  gemma_tuple, nllb_tuple, mt5_tuple = load_models()
200
  except Exception as e:
201
  st.error(f"Error loading models: {str(e)}")
 
202
  return
203
 
204
  # File upload
 
227
  try:
228
  progress_bar = st.progress(0)
229
 
230
+ # Process document
231
+ with st.spinner("Processing document..."):
232
+ text = extract_text_from_file(uploaded_file)
233
+ progress_bar.progress(25)
234
+
 
235
  interpreted_text = interpret_context(text, gemma_tuple)
236
+ progress_bar.progress(50)
237
+
 
 
238
  translated_text = translate_text(
239
  interpreted_text,
240
  SUPPORTED_LANGUAGES[source_language],
241
  SUPPORTED_LANGUAGES[target_language],
242
  nllb_tuple
243
  )
244
+ progress_bar.progress(75)
245
+
246
+ final_text = correct_grammar(
 
 
247
  translated_text,
248
  SUPPORTED_LANGUAGES[target_language],
249
  mt5_tuple
250
  )
251
+ progress_bar.progress(90)
252
 
253
  # Display result
254
  st.markdown("### Translation Result")
255
  st.text_area(
256
  label="Translated Text",
257
+ value=final_text,
258
  height=200,
259
  key="translation_result"
260
  )
 
266
  with col1:
267
  # Text file download
268
  text_buffer = io.BytesIO()
269
+ text_buffer.write(final_text.encode())
270
  text_buffer.seek(0)
271
 
272
  st.download_button(
 
278
 
279
  with col2:
280
  # DOCX file download
281
+ docx_buffer = save_as_docx(final_text)
282
  st.download_button(
283
  label="Download as DOCX",
284
  data=docx_buffer,
 
290
 
291
  except Exception as e:
292
  st.error(f"An error occurred: {str(e)}")
293
+
294
  if __name__ == "__main__":
295
  main()