gauravchand11 commited on
Commit
5e3207d
Β·
verified Β·
1 Parent(s): bcda6d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +279 -217
app.py CHANGED
@@ -6,205 +6,274 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2Se
6
  import torch
7
  from pathlib import Path
8
  import tempfile
9
- from typing import Union, Tuple
10
  import os
11
  import sys
12
  from datetime import datetime, timezone
13
  import warnings
 
14
 
15
  # Filter warnings
16
  warnings.filterwarnings('ignore', category=UserWarning)
17
 
18
  # Page config
19
  st.set_page_config(
20
- page_title="Document Translation App",
21
  page_icon="🌐",
22
  layout="wide"
23
  )
24
 
25
- # Display system info
26
- st.sidebar.markdown(f"""
27
- ### System Information
28
- **Current UTC Time:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}
29
- **User:** {os.environ.get('USER', 'gauravchand')}
30
- """)
31
-
32
- # Get Hugging Face token
33
- HF_TOKEN = os.environ.get('HF_TOKEN')
34
- if not HF_TOKEN:
35
- st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.")
36
- st.stop()
37
-
38
- # Language configurations
39
- SUPPORTED_LANGUAGES = {
40
- 'English': 'eng_Latn',
41
- 'Hindi': 'hin_Deva',
42
- 'Marathi': 'mar_Deva'
43
- }
44
-
45
- MT5_LANG_CODES = {
46
- 'eng_Latn': 'en',
47
- 'hin_Deva': 'hi',
48
- 'mar_Deva': 'mr'
49
  }
50
 
51
- def get_nllb_lang_token(lang_code: str) -> str:
52
- """Get the correct token format for NLLB model."""
53
- return f"___{lang_code}___"
54
-
55
- def extract_text_from_file(uploaded_file) -> str:
56
- """Extract text content from uploaded file based on its type."""
57
- file_extension = Path(uploaded_file.name).suffix.lower()
58
 
59
- if file_extension == '.pdf':
60
- return extract_from_pdf(uploaded_file)
61
- elif file_extension == '.docx':
62
- return extract_from_docx(uploaded_file)
63
- elif file_extension == '.txt':
64
- return uploaded_file.getvalue().decode('utf-8')
65
- else:
66
- raise ValueError(f"Unsupported file format: {file_extension}")
67
-
68
- def extract_from_pdf(file) -> str:
69
- """Extract text from PDF file."""
70
- pdf_reader = PyPDF2.PdfReader(file)
71
- text = ""
72
- for page in pdf_reader.pages:
73
- text += page.extract_text() + "\n"
74
- return text.strip()
75
-
76
- def extract_from_docx(file) -> str:
77
- """Extract text from DOCX file."""
78
- doc = docx.Document(file)
79
- text = ""
80
- for paragraph in doc.paragraphs:
81
- text += paragraph.text + "\n"
82
- return text.strip()
83
 
84
- def batch_process_text(text: str, max_length: int = 512) -> list:
85
- """Split text into batches for processing."""
86
- words = text.split()
87
- batches = []
88
- current_batch = []
89
- current_length = 0
90
 
91
- for word in words:
92
- if current_length + len(word) + 1 > max_length:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  batches.append(" ".join(current_batch))
94
- current_batch = [word]
95
- current_length = len(word)
96
- else:
97
- current_batch.append(word)
98
- current_length += len(word) + 1
99
-
100
- if current_batch:
101
- batches.append(" ".join(current_batch))
102
 
103
- return batches
104
-
105
- @st.cache_resource
106
- def load_models():
107
- """Load and cache the translation and context interpretation models."""
108
- try:
109
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- # Load Gemma model
112
- gemma_tokenizer = AutoTokenizer.from_pretrained(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  "google/gemma-2b",
114
- token=HF_TOKEN,
115
  trust_remote_code=True
116
  )
117
- gemma_model = AutoModelForCausalLM.from_pretrained(
118
  "google/gemma-2b",
119
- token=HF_TOKEN,
120
  torch_dtype=torch.float16,
121
  device_map="auto" if torch.cuda.is_available() else None,
122
  trust_remote_code=True
123
  )
124
-
125
- # Load NLLB model
126
- nllb_tokenizer = AutoTokenizer.from_pretrained(
 
 
127
  "facebook/nllb-200-distilled-600M",
128
- token=HF_TOKEN,
129
  use_fast=False,
130
  trust_remote_code=True
131
  )
132
- nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
133
  "facebook/nllb-200-distilled-600M",
134
- token=HF_TOKEN,
135
  torch_dtype=torch.float16,
136
  device_map="auto" if torch.cuda.is_available() else None,
137
  trust_remote_code=True
138
  )
139
-
140
- # Load MT5 model
141
- mt5_tokenizer = AutoTokenizer.from_pretrained(
 
 
142
  "google/mt5-base",
143
- token=HF_TOKEN,
144
  trust_remote_code=True
145
  )
146
- mt5_model = MT5ForConditionalGeneration.from_pretrained(
147
  "google/mt5-base",
148
- token=HF_TOKEN,
149
  torch_dtype=torch.float16,
150
  device_map="auto" if torch.cuda.is_available() else None,
151
  trust_remote_code=True
152
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- if not torch.cuda.is_available():
155
- gemma_model = gemma_model.to(device)
156
- nllb_model = nllb_model.to(device)
157
- mt5_model = mt5_model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
- return (gemma_tokenizer, gemma_model), (nllb_tokenizer, nllb_model), (mt5_tokenizer, mt5_model)
160
 
161
- except Exception as e:
162
- st.error(f"Error loading models: {str(e)}")
163
- st.error(f"Python version: {sys.version}")
164
- st.error(f"PyTorch version: {torch.__version__}")
165
- raise e
 
 
 
 
 
 
166
 
167
- @torch.no_grad()
168
- def interpret_context(text: str, gemma_tuple: Tuple) -> str:
169
- """Use Gemma model to interpret context and understand regional nuances."""
170
- tokenizer, model = gemma_tuple
171
-
172
- batches = batch_process_text(text)
173
- interpreted_batches = []
174
-
175
- for batch in batches:
176
- prompt = f"""Analyze and maintain the core meaning of this text: {batch}"""
177
 
178
- inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
179
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
180
 
181
  outputs = model.generate(
182
  **inputs,
183
- max_length=512,
184
  do_sample=True,
185
- temperature=0.3,
186
  pad_token_id=tokenizer.eos_token_id,
187
  num_return_sequences=1
188
  )
189
 
190
- interpreted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
191
- interpreted_text = interpreted_text.replace(prompt, "").strip()
192
- interpreted_batches.append(interpreted_text)
193
-
194
- return " ".join(interpreted_batches)
195
-
196
- @torch.no_grad()
197
- def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tuple) -> str:
198
- """Translate text using NLLB model."""
199
- tokenizer, model = nllb_tuple
200
-
201
- batches = batch_process_text(text)
202
- translated_batches = []
203
 
204
- target_lang_token = get_nllb_lang_token(target_lang)
205
-
206
- for batch in batches:
207
- inputs = tokenizer(batch, return_tensors="pt", max_length=512, truncation=True)
 
 
 
 
208
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
209
 
210
  target_lang_id = tokenizer.convert_tokens_to_ids(target_lang_token)
@@ -212,78 +281,85 @@ def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tu
212
  outputs = model.generate(
213
  **inputs,
214
  forced_bos_token_id=target_lang_id,
215
- max_length=512,
216
  do_sample=True,
217
- temperature=0.7,
218
- num_beams=5,
219
- num_return_sequences=1
 
 
220
  )
221
 
222
- translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
223
- translated_batches.append(translated_text)
224
-
225
- return " ".join(translated_batches)
226
-
227
- @torch.no_grad()
228
- def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
229
- """Correct grammar using MT5 model."""
230
- tokenizer, model = mt5_tuple
231
- lang_code = MT5_LANG_CODES[target_lang]
232
-
233
- prompts = {
234
- 'en': "Fix grammar: ",
235
- 'hi': "ΰ€΅ΰ₯ΰ€―ΰ€Ύΰ€•ΰ€°ΰ€£ ΰ€Έΰ₯ΰ€§ΰ€Ύΰ€°: ",
236
- 'mr': "ΰ€΅ΰ₯ΰ€―ΰ€Ύΰ€•ΰ€°ΰ€£ ΰ€Έΰ₯ΰ€§ΰ€Ύΰ€°: "
237
- }
238
 
239
- batches = batch_process_text(text)
240
- corrected_batches = []
241
-
242
- for batch in batches:
243
- input_text = f"{prompts[lang_code]}{batch}"
244
- inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
 
 
245
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
246
 
247
  outputs = model.generate(
248
  **inputs,
249
- max_length=512,
250
- num_beams=5,
251
  length_penalty=1.0,
252
  early_stopping=True,
253
  no_repeat_ngram_size=2,
254
  do_sample=False
255
  )
256
 
257
- corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
258
- for prefix in prompts.values():
259
- corrected_text = corrected_text.replace(prefix, "")
260
- corrected_text = (corrected_text.replace("<extra_id_0>", "")
261
- .replace("<extra_id_1>", "")
262
- .strip())
263
- corrected_batches.append(corrected_text)
264
-
265
- return " ".join(corrected_batches)
266
 
267
- def save_as_docx(text: str) -> io.BytesIO:
268
- """Save translated text as a DOCX file."""
269
- doc = docx.Document()
270
- doc.add_paragraph(text)
271
 
272
- docx_buffer = io.BytesIO()
273
- doc.save(docx_buffer)
274
- docx_buffer.seek(0)
 
 
 
 
 
 
 
275
 
276
- return docx_buffer
 
 
 
 
 
277
 
278
  def main():
279
- st.title("🌐 Document Translation App")
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
  # Load models
282
  with st.spinner("Loading models... This may take a few minutes."):
283
  try:
284
- gemma_tuple, nllb_tuple, mt5_tuple = load_models()
 
285
  except Exception as e:
286
- st.error(f"Error loading models: {str(e)}")
287
  return
288
 
289
  # File upload
@@ -297,43 +373,35 @@ def main():
297
  with col1:
298
  source_language = st.selectbox(
299
  "Source Language",
300
- options=list(SUPPORTED_LANGUAGES.keys()),
301
  index=0
302
  )
303
 
304
  with col2:
305
  target_language = st.selectbox(
306
  "Target Language",
307
- options=list(SUPPORTED_LANGUAGES.keys()),
308
  index=1
309
  )
310
 
311
  if uploaded_file and st.button("Translate", type="primary"):
312
  try:
313
  progress_bar = st.progress(0)
 
314
 
315
  # Process document
316
- with st.spinner("Processing document..."):
317
- text = extract_text_from_file(uploaded_file)
318
- progress_bar.progress(25)
319
-
320
- interpreted_text = interpret_context(text, gemma_tuple)
321
- progress_bar.progress(50)
322
-
323
- translated_text = translate_text(
324
- interpreted_text,
325
- SUPPORTED_LANGUAGES[source_language],
326
- SUPPORTED_LANGUAGES[target_language],
327
- nllb_tuple
328
- )
329
- progress_bar.progress(75)
330
-
331
- final_text = correct_grammar(
332
- translated_text,
333
- SUPPORTED_LANGUAGES[target_language],
334
- mt5_tuple
335
- )
336
- progress_bar.progress(90)
337
 
338
  # Display result
339
  st.markdown("### Translation Result")
@@ -349,28 +417,22 @@ def main():
349
  col1, col2 = st.columns(2)
350
 
351
  with col1:
352
- # Text file download
353
- text_buffer = io.BytesIO()
354
- text_buffer.write(final_text.encode())
355
- text_buffer.seek(0)
356
-
357
  st.download_button(
358
  label="Download as TXT",
359
- data=text_buffer,
360
  file_name="translated_document.txt",
361
  mime="text/plain"
362
  )
363
 
364
  with col2:
365
- # DOCX file download
366
- docx_buffer = save_as_docx(final_text)
367
  st.download_button(
368
  label="Download as DOCX",
369
- data=docx_buffer,
370
  file_name="translated_document.docx",
371
  mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
372
  )
373
 
 
374
  progress_bar.progress(100)
375
 
376
  except Exception as e:
 
6
  import torch
7
  from pathlib import Path
8
  import tempfile
9
+ from typing import Union, Tuple, List, Dict
10
  import os
11
  import sys
12
  from datetime import datetime, timezone
13
  import warnings
14
+ import json
15
 
16
  # Filter warnings
17
  warnings.filterwarnings('ignore', category=UserWarning)
18
 
19
  # Page config
20
  st.set_page_config(
21
+ page_title="Enhanced Document Translation App",
22
  page_icon="🌐",
23
  layout="wide"
24
  )
25
 
26
+ # Constants and Configurations
27
+ CONFIG = {
28
+ "MAX_BATCH_LENGTH": 512,
29
+ "MIN_BATCH_LENGTH": 50,
30
+ "TRANSLATION_TEMPERATURE": 0.7,
31
+ "CONTEXT_TEMPERATURE": 0.3,
32
+ "NUM_BEAMS": 5,
33
+ "SUPPORTED_LANGUAGES": {
34
+ 'English': 'eng_Latn',
35
+ 'Hindi': 'hin_Deva',
36
+ 'Marathi': 'mar_Deva'
37
+ },
38
+ "MT5_LANG_CODES": {
39
+ 'eng_Latn': 'en',
40
+ 'hin_Deva': 'hi',
41
+ 'mar_Deva': 'mr'
42
+ },
43
+ "GRAMMAR_PROMPTS": {
44
+ 'en': "Fix grammar and improve fluency: ",
45
+ 'hi': "ΰ€΅ΰ₯ΰ€―ΰ€Ύΰ€•ΰ€°ΰ€£ ΰ€”ΰ€° ΰ€ͺΰ₯ΰ€°ΰ€΅ΰ€Ύΰ€Ή ΰ€Έΰ₯ΰ€§ΰ€Ύΰ€°ΰ₯‡ΰ€‚: ",
46
+ 'mr': "ΰ€΅ΰ₯ΰ€―ΰ€Ύΰ€•ΰ€°ΰ€£ ΰ€†ΰ€£ΰ€Ώ ΰ€ͺΰ₯ΰ€°ΰ€΅ΰ€Ύΰ€Ή ΰ€Έΰ₯ΰ€§ΰ€Ύΰ€°ΰ€Ύ: "
47
+ }
 
 
48
  }
49
 
50
+ class DocumentProcessor:
51
+ """Handles document processing and text extraction"""
 
 
 
 
 
52
 
53
+ @staticmethod
54
+ def extract_text_from_file(uploaded_file) -> str:
55
+ file_extension = Path(uploaded_file.name).suffix.lower()
56
+
57
+ extractors = {
58
+ '.pdf': DocumentProcessor._extract_from_pdf,
59
+ '.docx': DocumentProcessor._extract_from_docx,
60
+ '.txt': lambda f: f.getvalue().decode('utf-8')
61
+ }
62
+
63
+ if file_extension not in extractors:
64
+ raise ValueError(f"Unsupported file format: {file_extension}")
65
+
66
+ return extractors[file_extension](uploaded_file)
67
+
68
+ @staticmethod
69
+ def _extract_from_pdf(file) -> str:
70
+ pdf_reader = PyPDF2.PdfReader(file)
71
+ return "\n".join(page.extract_text() for page in pdf_reader.pages).strip()
72
+
73
+ @staticmethod
74
+ def _extract_from_docx(file) -> str:
75
+ doc = docx.Document(file)
76
+ return "\n".join(paragraph.text for paragraph in doc.paragraphs).strip()
77
 
78
+ class TextBatcher:
79
+ """Handles text batching with improved sentence boundary detection"""
 
 
 
 
80
 
81
+ @staticmethod
82
+ def batch_process_text(text: str, max_length: int = CONFIG["MAX_BATCH_LENGTH"]) -> List[str]:
83
+ sentences = TextBatcher._split_into_sentences(text)
84
+ batches = []
85
+ current_batch = []
86
+ current_length = 0
87
+
88
+ for sentence in sentences:
89
+ sentence_length = len(sentence)
90
+
91
+ if current_length + sentence_length > max_length:
92
+ if current_batch:
93
+ batches.append(" ".join(current_batch))
94
+ current_batch = [sentence]
95
+ current_length = sentence_length
96
+ else:
97
+ current_batch.append(sentence)
98
+ current_length += sentence_length
99
+
100
+ if current_batch:
101
  batches.append(" ".join(current_batch))
102
+
103
+ return batches
 
 
 
 
 
 
104
 
105
+ @staticmethod
106
+ def _split_into_sentences(text: str) -> List[str]:
107
+ """Split text into sentences with improved boundary detection"""
108
+ # Basic sentence boundary detection
109
+ delimiters = ['. ', '! ', '? ', 'ΰ₯€', 'ΰ₯₯', '\n']
110
+ sentences = []
111
+ current = text
112
+
113
+ for delimiter in delimiters:
114
+ parts = current.split(delimiter)
115
+ current = parts[0]
116
+ for part in parts[1:]:
117
+ if len(current.strip()) > 0:
118
+ sentences.append(current.strip() + delimiter.strip())
119
+ current = part
120
+
121
+ if len(current.strip()) > 0:
122
+ sentences.append(current.strip())
123
 
124
+ return sentences
125
+
126
+ class ModelManager:
127
+ """Manages loading and caching of AI models"""
128
+
129
+ @st.cache_resource
130
+ def load_models():
131
+ try:
132
+ device = "cuda" if torch.cuda.is_available() else "cpu"
133
+
134
+ # Load models with improved error handling
135
+ models = {
136
+ "gemma": ModelManager._load_gemma_model(),
137
+ "nllb": ModelManager._load_nllb_model(),
138
+ "mt5": ModelManager._load_mt5_model()
139
+ }
140
+
141
+ # Move models to appropriate device
142
+ if not torch.cuda.is_available():
143
+ for model_tuple in models.values():
144
+ model_tuple[1].to(device)
145
+
146
+ return models
147
+
148
+ except Exception as e:
149
+ st.error(f"Error loading models: {str(e)}")
150
+ st.error(f"Python version: {sys.version}")
151
+ st.error(f"PyTorch version: {torch.__version__}")
152
+ raise e
153
+
154
+ @staticmethod
155
+ def _load_gemma_model():
156
+ tokenizer = AutoTokenizer.from_pretrained(
157
  "google/gemma-2b",
158
+ token=os.environ.get('HF_TOKEN'),
159
  trust_remote_code=True
160
  )
161
+ model = AutoModelForCausalLM.from_pretrained(
162
  "google/gemma-2b",
163
+ token=os.environ.get('HF_TOKEN'),
164
  torch_dtype=torch.float16,
165
  device_map="auto" if torch.cuda.is_available() else None,
166
  trust_remote_code=True
167
  )
168
+ return (tokenizer, model)
169
+
170
+ @staticmethod
171
+ def _load_nllb_model():
172
+ tokenizer = AutoTokenizer.from_pretrained(
173
  "facebook/nllb-200-distilled-600M",
174
+ token=os.environ.get('HF_TOKEN'),
175
  use_fast=False,
176
  trust_remote_code=True
177
  )
178
+ model = AutoModelForSeq2SeqLM.from_pretrained(
179
  "facebook/nllb-200-distilled-600M",
180
+ token=os.environ.get('HF_TOKEN'),
181
  torch_dtype=torch.float16,
182
  device_map="auto" if torch.cuda.is_available() else None,
183
  trust_remote_code=True
184
  )
185
+ return (tokenizer, model)
186
+
187
+ @staticmethod
188
+ def _load_mt5_model():
189
+ tokenizer = AutoTokenizer.from_pretrained(
190
  "google/mt5-base",
191
+ token=os.environ.get('HF_TOKEN'),
192
  trust_remote_code=True
193
  )
194
+ model = MT5ForConditionalGeneration.from_pretrained(
195
  "google/mt5-base",
196
+ token=os.environ.get('HF_TOKEN'),
197
  torch_dtype=torch.float16,
198
  device_map="auto" if torch.cuda.is_available() else None,
199
  trust_remote_code=True
200
  )
201
+ return (tokenizer, model)
202
+
203
+ class TranslationPipeline:
204
+ """Manages the translation pipeline with context understanding"""
205
+
206
+ def __init__(self, models: Dict):
207
+ self.models = models
208
+
209
+ @torch.no_grad()
210
+ def process_text(self, text: str, source_lang: str, target_lang: str) -> str:
211
+ # Split text into manageable batches
212
+ batches = TextBatcher.batch_process_text(text)
213
+ final_results = []
214
 
215
+ for batch in batches:
216
+ # Step 1: Context Understanding
217
+ context = self._understand_context(batch)
218
+
219
+ # Step 2: Context-aware Translation
220
+ translated = self._translate_with_context(
221
+ context,
222
+ source_lang,
223
+ target_lang
224
+ )
225
+
226
+ # Step 3: Grammar Correction
227
+ corrected = self._correct_grammar(
228
+ translated,
229
+ target_lang
230
+ )
231
+
232
+ final_results.append(corrected)
233
 
234
+ return " ".join(final_results)
235
 
236
+ def _understand_context(self, text: str) -> str:
237
+ """Enhanced context understanding using Gemma model"""
238
+ tokenizer, model = self.models["gemma"]
239
+
240
+ prompt = f"""Analyze and provide context for translation:
241
+ Text: {text}
242
+ Key points to consider:
243
+ - Main topic and subject matter
244
+ - Cultural context and nuances
245
+ - Technical terminology if any
246
+ - Tone and style of writing
247
 
248
+ Provide a clear and concise interpretation that maintains:
249
+ 1. Original meaning
250
+ 2. Cultural context
251
+ 3. Technical accuracy
252
+ 4. Tone and style"""
 
 
 
 
 
253
 
254
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True)
255
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
256
 
257
  outputs = model.generate(
258
  **inputs,
259
+ max_length=CONFIG["MAX_BATCH_LENGTH"],
260
  do_sample=True,
261
+ temperature=CONFIG["CONTEXT_TEMPERATURE"],
262
  pad_token_id=tokenizer.eos_token_id,
263
  num_return_sequences=1
264
  )
265
 
266
+ context = tokenizer.decode(outputs[0], skip_special_tokens=True)
267
+ return context.replace(prompt, "").strip()
 
 
 
 
 
 
 
 
 
 
 
268
 
269
+ def _translate_with_context(self, text: str, source_lang: str, target_lang: str) -> str:
270
+ """Enhanced translation using NLLB model with context awareness"""
271
+ tokenizer, model = self.models["nllb"]
272
+
273
+ source_lang_token = f"___{source_lang}___"
274
+ target_lang_token = f"___{target_lang}___"
275
+
276
+ inputs = tokenizer(text, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True)
277
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
278
 
279
  target_lang_id = tokenizer.convert_tokens_to_ids(target_lang_token)
 
281
  outputs = model.generate(
282
  **inputs,
283
  forced_bos_token_id=target_lang_id,
284
+ max_length=CONFIG["MAX_BATCH_LENGTH"],
285
  do_sample=True,
286
+ temperature=CONFIG["TRANSLATION_TEMPERATURE"],
287
+ num_beams=CONFIG["NUM_BEAMS"],
288
+ num_return_sequences=1,
289
+ length_penalty=1.0,
290
+ repetition_penalty=1.2
291
  )
292
 
293
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
+ def _correct_grammar(self, text: str, target_lang: str) -> str:
296
+ """Enhanced grammar correction using MT5 model"""
297
+ tokenizer, model = self.models["mt5"]
298
+ lang_code = CONFIG["MT5_LANG_CODES"][target_lang]
299
+ prompt = CONFIG["GRAMMAR_PROMPTS"][lang_code]
300
+
301
+ input_text = f"{prompt}{text}"
302
+ inputs = tokenizer(input_text, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True)
303
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
304
 
305
  outputs = model.generate(
306
  **inputs,
307
+ max_length=CONFIG["MAX_BATCH_LENGTH"],
308
+ num_beams=CONFIG["NUM_BEAMS"],
309
  length_penalty=1.0,
310
  early_stopping=True,
311
  no_repeat_ngram_size=2,
312
  do_sample=False
313
  )
314
 
315
+ corrected = tokenizer.decode(outputs[0], skip_special_tokens=True)
316
+ for prefix in CONFIG["GRAMMAR_PROMPTS"].values():
317
+ corrected = corrected.replace(prefix, "")
318
+ return corrected.strip()
 
 
 
 
 
319
 
320
+ class DocumentExporter:
321
+ """Handles document export operations"""
 
 
322
 
323
+ @staticmethod
324
+ def save_as_docx(text: str) -> io.BytesIO:
325
+ doc = docx.Document()
326
+ doc.add_paragraph(text)
327
+
328
+ buffer = io.BytesIO()
329
+ doc.save(buffer)
330
+ buffer.seek(0)
331
+
332
+ return buffer
333
 
334
+ @staticmethod
335
+ def save_as_text(text: str) -> io.BytesIO:
336
+ buffer = io.BytesIO()
337
+ buffer.write(text.encode())
338
+ buffer.seek(0)
339
+ return buffer
340
 
341
  def main():
342
+ st.title("🌐 Enhanced Document Translation App")
343
+
344
+ # Check for HF_TOKEN
345
+ if not os.environ.get('HF_TOKEN'):
346
+ st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.")
347
+ st.stop()
348
+
349
+ # Display system info
350
+ st.sidebar.markdown(f"""
351
+ ### System Information
352
+ **Current UTC Time:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}
353
+ **User:** {os.environ.get('USER', 'unknown')}
354
+ """)
355
 
356
  # Load models
357
  with st.spinner("Loading models... This may take a few minutes."):
358
  try:
359
+ models = ModelManager.load_models()
360
+ pipeline = TranslationPipeline(models)
361
  except Exception as e:
362
+ st.error(f"Error initializing translation pipeline: {str(e)}")
363
  return
364
 
365
  # File upload
 
373
  with col1:
374
  source_language = st.selectbox(
375
  "Source Language",
376
+ options=list(CONFIG["SUPPORTED_LANGUAGES"].keys()),
377
  index=0
378
  )
379
 
380
  with col2:
381
  target_language = st.selectbox(
382
  "Target Language",
383
+ options=list(CONFIG["SUPPORTED_LANGUAGES"].keys()),
384
  index=1
385
  )
386
 
387
  if uploaded_file and st.button("Translate", type="primary"):
388
  try:
389
  progress_bar = st.progress(0)
390
+ status_text = st.empty()
391
 
392
  # Process document
393
+ status_text.text("Extracting text from document...")
394
+ text = DocumentProcessor.extract_text_from_file(uploaded_file)
395
+ progress_bar.progress(20)
396
+
397
+ # Perform translation
398
+ status_text.text("Translating document with context understanding...")
399
+ final_text = pipeline.process_text(
400
+ text,
401
+ CONFIG["SUPPORTED_LANGUAGES"][source_language],
402
+ CONFIG["SUPPORTED_LANGUAGES"][target_language]
403
+ )
404
+ progress_bar.progress(90)
 
 
 
 
 
 
 
 
 
405
 
406
  # Display result
407
  st.markdown("### Translation Result")
 
417
  col1, col2 = st.columns(2)
418
 
419
  with col1:
 
 
 
 
 
420
  st.download_button(
421
  label="Download as TXT",
422
+ data=DocumentExporter.save_as_text(final_text),
423
  file_name="translated_document.txt",
424
  mime="text/plain"
425
  )
426
 
427
  with col2:
 
 
428
  st.download_button(
429
  label="Download as DOCX",
430
+ data=DocumentExporter.save_as_docx(final_text),
431
  file_name="translated_document.docx",
432
  mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
433
  )
434
 
435
+ status_text.text("Translation completed successfully!")
436
  progress_bar.progress(100)
437
 
438
  except Exception as e: