Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -6,205 +6,274 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2Se
|
|
6 |
import torch
|
7 |
from pathlib import Path
|
8 |
import tempfile
|
9 |
-
from typing import Union, Tuple
|
10 |
import os
|
11 |
import sys
|
12 |
from datetime import datetime, timezone
|
13 |
import warnings
|
|
|
14 |
|
15 |
# Filter warnings
|
16 |
warnings.filterwarnings('ignore', category=UserWarning)
|
17 |
|
18 |
# Page config
|
19 |
st.set_page_config(
|
20 |
-
page_title="Document Translation App",
|
21 |
page_icon="π",
|
22 |
layout="wide"
|
23 |
)
|
24 |
|
25 |
-
#
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
""
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
'hin_Deva': 'hi',
|
48 |
-
'mar_Deva': 'mr'
|
49 |
}
|
50 |
|
51 |
-
|
52 |
-
"""
|
53 |
-
return f"___{lang_code}___"
|
54 |
-
|
55 |
-
def extract_text_from_file(uploaded_file) -> str:
|
56 |
-
"""Extract text content from uploaded file based on its type."""
|
57 |
-
file_extension = Path(uploaded_file.name).suffix.lower()
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
|
84 |
-
|
85 |
-
"""
|
86 |
-
words = text.split()
|
87 |
-
batches = []
|
88 |
-
current_batch = []
|
89 |
-
current_length = 0
|
90 |
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
batches.append(" ".join(current_batch))
|
94 |
-
|
95 |
-
|
96 |
-
else:
|
97 |
-
current_batch.append(word)
|
98 |
-
current_length += len(word) + 1
|
99 |
-
|
100 |
-
if current_batch:
|
101 |
-
batches.append(" ".join(current_batch))
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
"google/gemma-2b",
|
114 |
-
token=HF_TOKEN,
|
115 |
trust_remote_code=True
|
116 |
)
|
117 |
-
|
118 |
"google/gemma-2b",
|
119 |
-
token=HF_TOKEN,
|
120 |
torch_dtype=torch.float16,
|
121 |
device_map="auto" if torch.cuda.is_available() else None,
|
122 |
trust_remote_code=True
|
123 |
)
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
|
|
127 |
"facebook/nllb-200-distilled-600M",
|
128 |
-
token=HF_TOKEN,
|
129 |
use_fast=False,
|
130 |
trust_remote_code=True
|
131 |
)
|
132 |
-
|
133 |
"facebook/nllb-200-distilled-600M",
|
134 |
-
token=HF_TOKEN,
|
135 |
torch_dtype=torch.float16,
|
136 |
device_map="auto" if torch.cuda.is_available() else None,
|
137 |
trust_remote_code=True
|
138 |
)
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
|
|
142 |
"google/mt5-base",
|
143 |
-
token=HF_TOKEN,
|
144 |
trust_remote_code=True
|
145 |
)
|
146 |
-
|
147 |
"google/mt5-base",
|
148 |
-
token=HF_TOKEN,
|
149 |
torch_dtype=torch.float16,
|
150 |
device_map="auto" if torch.cuda.is_available() else None,
|
151 |
trust_remote_code=True
|
152 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
-
return
|
160 |
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
batches = batch_process_text(text)
|
173 |
-
interpreted_batches = []
|
174 |
-
|
175 |
-
for batch in batches:
|
176 |
-
prompt = f"""Analyze and maintain the core meaning of this text: {batch}"""
|
177 |
|
178 |
-
inputs = tokenizer(prompt, return_tensors="pt", max_length=
|
179 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
180 |
|
181 |
outputs = model.generate(
|
182 |
**inputs,
|
183 |
-
max_length=
|
184 |
do_sample=True,
|
185 |
-
temperature=
|
186 |
pad_token_id=tokenizer.eos_token_id,
|
187 |
num_return_sequences=1
|
188 |
)
|
189 |
|
190 |
-
|
191 |
-
|
192 |
-
interpreted_batches.append(interpreted_text)
|
193 |
-
|
194 |
-
return " ".join(interpreted_batches)
|
195 |
-
|
196 |
-
@torch.no_grad()
|
197 |
-
def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tuple) -> str:
|
198 |
-
"""Translate text using NLLB model."""
|
199 |
-
tokenizer, model = nllb_tuple
|
200 |
-
|
201 |
-
batches = batch_process_text(text)
|
202 |
-
translated_batches = []
|
203 |
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
|
|
|
|
|
|
|
|
208 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
209 |
|
210 |
target_lang_id = tokenizer.convert_tokens_to_ids(target_lang_token)
|
@@ -212,78 +281,85 @@ def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tu
|
|
212 |
outputs = model.generate(
|
213 |
**inputs,
|
214 |
forced_bos_token_id=target_lang_id,
|
215 |
-
max_length=
|
216 |
do_sample=True,
|
217 |
-
temperature=
|
218 |
-
num_beams=
|
219 |
-
num_return_sequences=1
|
|
|
|
|
220 |
)
|
221 |
|
222 |
-
|
223 |
-
translated_batches.append(translated_text)
|
224 |
-
|
225 |
-
return " ".join(translated_batches)
|
226 |
-
|
227 |
-
@torch.no_grad()
|
228 |
-
def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
|
229 |
-
"""Correct grammar using MT5 model."""
|
230 |
-
tokenizer, model = mt5_tuple
|
231 |
-
lang_code = MT5_LANG_CODES[target_lang]
|
232 |
-
|
233 |
-
prompts = {
|
234 |
-
'en': "Fix grammar: ",
|
235 |
-
'hi': "ΰ€΅ΰ₯ΰ€―ΰ€Ύΰ€ΰ€°ΰ€£ ΰ€Έΰ₯ΰ€§ΰ€Ύΰ€°: ",
|
236 |
-
'mr': "ΰ€΅ΰ₯ΰ€―ΰ€Ύΰ€ΰ€°ΰ€£ ΰ€Έΰ₯ΰ€§ΰ€Ύΰ€°: "
|
237 |
-
}
|
238 |
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
|
|
|
|
245 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
246 |
|
247 |
outputs = model.generate(
|
248 |
**inputs,
|
249 |
-
max_length=
|
250 |
-
num_beams=
|
251 |
length_penalty=1.0,
|
252 |
early_stopping=True,
|
253 |
no_repeat_ngram_size=2,
|
254 |
do_sample=False
|
255 |
)
|
256 |
|
257 |
-
|
258 |
-
for prefix in
|
259 |
-
|
260 |
-
|
261 |
-
.replace("<extra_id_1>", "")
|
262 |
-
.strip())
|
263 |
-
corrected_batches.append(corrected_text)
|
264 |
-
|
265 |
-
return " ".join(corrected_batches)
|
266 |
|
267 |
-
|
268 |
-
"""
|
269 |
-
doc = docx.Document()
|
270 |
-
doc.add_paragraph(text)
|
271 |
|
272 |
-
|
273 |
-
|
274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
|
276 |
-
|
|
|
|
|
|
|
|
|
|
|
277 |
|
278 |
def main():
|
279 |
-
st.title("π Document Translation App")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
|
281 |
# Load models
|
282 |
with st.spinner("Loading models... This may take a few minutes."):
|
283 |
try:
|
284 |
-
|
|
|
285 |
except Exception as e:
|
286 |
-
st.error(f"Error
|
287 |
return
|
288 |
|
289 |
# File upload
|
@@ -297,43 +373,35 @@ def main():
|
|
297 |
with col1:
|
298 |
source_language = st.selectbox(
|
299 |
"Source Language",
|
300 |
-
options=list(SUPPORTED_LANGUAGES.keys()),
|
301 |
index=0
|
302 |
)
|
303 |
|
304 |
with col2:
|
305 |
target_language = st.selectbox(
|
306 |
"Target Language",
|
307 |
-
options=list(SUPPORTED_LANGUAGES.keys()),
|
308 |
index=1
|
309 |
)
|
310 |
|
311 |
if uploaded_file and st.button("Translate", type="primary"):
|
312 |
try:
|
313 |
progress_bar = st.progress(0)
|
|
|
314 |
|
315 |
# Process document
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
)
|
329 |
-
progress_bar.progress(75)
|
330 |
-
|
331 |
-
final_text = correct_grammar(
|
332 |
-
translated_text,
|
333 |
-
SUPPORTED_LANGUAGES[target_language],
|
334 |
-
mt5_tuple
|
335 |
-
)
|
336 |
-
progress_bar.progress(90)
|
337 |
|
338 |
# Display result
|
339 |
st.markdown("### Translation Result")
|
@@ -349,28 +417,22 @@ def main():
|
|
349 |
col1, col2 = st.columns(2)
|
350 |
|
351 |
with col1:
|
352 |
-
# Text file download
|
353 |
-
text_buffer = io.BytesIO()
|
354 |
-
text_buffer.write(final_text.encode())
|
355 |
-
text_buffer.seek(0)
|
356 |
-
|
357 |
st.download_button(
|
358 |
label="Download as TXT",
|
359 |
-
data=
|
360 |
file_name="translated_document.txt",
|
361 |
mime="text/plain"
|
362 |
)
|
363 |
|
364 |
with col2:
|
365 |
-
# DOCX file download
|
366 |
-
docx_buffer = save_as_docx(final_text)
|
367 |
st.download_button(
|
368 |
label="Download as DOCX",
|
369 |
-
data=
|
370 |
file_name="translated_document.docx",
|
371 |
mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
372 |
)
|
373 |
|
|
|
374 |
progress_bar.progress(100)
|
375 |
|
376 |
except Exception as e:
|
|
|
6 |
import torch
|
7 |
from pathlib import Path
|
8 |
import tempfile
|
9 |
+
from typing import Union, Tuple, List, Dict
|
10 |
import os
|
11 |
import sys
|
12 |
from datetime import datetime, timezone
|
13 |
import warnings
|
14 |
+
import json
|
15 |
|
16 |
# Filter warnings
|
17 |
warnings.filterwarnings('ignore', category=UserWarning)
|
18 |
|
19 |
# Page config
|
20 |
st.set_page_config(
|
21 |
+
page_title="Enhanced Document Translation App",
|
22 |
page_icon="π",
|
23 |
layout="wide"
|
24 |
)
|
25 |
|
26 |
+
# Constants and Configurations
|
27 |
+
CONFIG = {
|
28 |
+
"MAX_BATCH_LENGTH": 512,
|
29 |
+
"MIN_BATCH_LENGTH": 50,
|
30 |
+
"TRANSLATION_TEMPERATURE": 0.7,
|
31 |
+
"CONTEXT_TEMPERATURE": 0.3,
|
32 |
+
"NUM_BEAMS": 5,
|
33 |
+
"SUPPORTED_LANGUAGES": {
|
34 |
+
'English': 'eng_Latn',
|
35 |
+
'Hindi': 'hin_Deva',
|
36 |
+
'Marathi': 'mar_Deva'
|
37 |
+
},
|
38 |
+
"MT5_LANG_CODES": {
|
39 |
+
'eng_Latn': 'en',
|
40 |
+
'hin_Deva': 'hi',
|
41 |
+
'mar_Deva': 'mr'
|
42 |
+
},
|
43 |
+
"GRAMMAR_PROMPTS": {
|
44 |
+
'en': "Fix grammar and improve fluency: ",
|
45 |
+
'hi': "ΰ€΅ΰ₯ΰ€―ΰ€Ύΰ€ΰ€°ΰ€£ ΰ€ΰ€° ΰ€ͺΰ₯ΰ€°ΰ€΅ΰ€Ύΰ€Ή ΰ€Έΰ₯ΰ€§ΰ€Ύΰ€°ΰ₯ΰ€: ",
|
46 |
+
'mr': "ΰ€΅ΰ₯ΰ€―ΰ€Ύΰ€ΰ€°ΰ€£ ΰ€ΰ€£ΰ€Ώ ΰ€ͺΰ₯ΰ€°ΰ€΅ΰ€Ύΰ€Ή ΰ€Έΰ₯ΰ€§ΰ€Ύΰ€°ΰ€Ύ: "
|
47 |
+
}
|
|
|
|
|
48 |
}
|
49 |
|
50 |
+
class DocumentProcessor:
|
51 |
+
"""Handles document processing and text extraction"""
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
+
@staticmethod
|
54 |
+
def extract_text_from_file(uploaded_file) -> str:
|
55 |
+
file_extension = Path(uploaded_file.name).suffix.lower()
|
56 |
+
|
57 |
+
extractors = {
|
58 |
+
'.pdf': DocumentProcessor._extract_from_pdf,
|
59 |
+
'.docx': DocumentProcessor._extract_from_docx,
|
60 |
+
'.txt': lambda f: f.getvalue().decode('utf-8')
|
61 |
+
}
|
62 |
+
|
63 |
+
if file_extension not in extractors:
|
64 |
+
raise ValueError(f"Unsupported file format: {file_extension}")
|
65 |
+
|
66 |
+
return extractors[file_extension](uploaded_file)
|
67 |
+
|
68 |
+
@staticmethod
|
69 |
+
def _extract_from_pdf(file) -> str:
|
70 |
+
pdf_reader = PyPDF2.PdfReader(file)
|
71 |
+
return "\n".join(page.extract_text() for page in pdf_reader.pages).strip()
|
72 |
+
|
73 |
+
@staticmethod
|
74 |
+
def _extract_from_docx(file) -> str:
|
75 |
+
doc = docx.Document(file)
|
76 |
+
return "\n".join(paragraph.text for paragraph in doc.paragraphs).strip()
|
77 |
|
78 |
+
class TextBatcher:
|
79 |
+
"""Handles text batching with improved sentence boundary detection"""
|
|
|
|
|
|
|
|
|
80 |
|
81 |
+
@staticmethod
|
82 |
+
def batch_process_text(text: str, max_length: int = CONFIG["MAX_BATCH_LENGTH"]) -> List[str]:
|
83 |
+
sentences = TextBatcher._split_into_sentences(text)
|
84 |
+
batches = []
|
85 |
+
current_batch = []
|
86 |
+
current_length = 0
|
87 |
+
|
88 |
+
for sentence in sentences:
|
89 |
+
sentence_length = len(sentence)
|
90 |
+
|
91 |
+
if current_length + sentence_length > max_length:
|
92 |
+
if current_batch:
|
93 |
+
batches.append(" ".join(current_batch))
|
94 |
+
current_batch = [sentence]
|
95 |
+
current_length = sentence_length
|
96 |
+
else:
|
97 |
+
current_batch.append(sentence)
|
98 |
+
current_length += sentence_length
|
99 |
+
|
100 |
+
if current_batch:
|
101 |
batches.append(" ".join(current_batch))
|
102 |
+
|
103 |
+
return batches
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
+
@staticmethod
|
106 |
+
def _split_into_sentences(text: str) -> List[str]:
|
107 |
+
"""Split text into sentences with improved boundary detection"""
|
108 |
+
# Basic sentence boundary detection
|
109 |
+
delimiters = ['. ', '! ', '? ', 'ΰ₯€', 'ΰ₯₯', '\n']
|
110 |
+
sentences = []
|
111 |
+
current = text
|
112 |
+
|
113 |
+
for delimiter in delimiters:
|
114 |
+
parts = current.split(delimiter)
|
115 |
+
current = parts[0]
|
116 |
+
for part in parts[1:]:
|
117 |
+
if len(current.strip()) > 0:
|
118 |
+
sentences.append(current.strip() + delimiter.strip())
|
119 |
+
current = part
|
120 |
+
|
121 |
+
if len(current.strip()) > 0:
|
122 |
+
sentences.append(current.strip())
|
123 |
|
124 |
+
return sentences
|
125 |
+
|
126 |
+
class ModelManager:
|
127 |
+
"""Manages loading and caching of AI models"""
|
128 |
+
|
129 |
+
@st.cache_resource
|
130 |
+
def load_models():
|
131 |
+
try:
|
132 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
133 |
+
|
134 |
+
# Load models with improved error handling
|
135 |
+
models = {
|
136 |
+
"gemma": ModelManager._load_gemma_model(),
|
137 |
+
"nllb": ModelManager._load_nllb_model(),
|
138 |
+
"mt5": ModelManager._load_mt5_model()
|
139 |
+
}
|
140 |
+
|
141 |
+
# Move models to appropriate device
|
142 |
+
if not torch.cuda.is_available():
|
143 |
+
for model_tuple in models.values():
|
144 |
+
model_tuple[1].to(device)
|
145 |
+
|
146 |
+
return models
|
147 |
+
|
148 |
+
except Exception as e:
|
149 |
+
st.error(f"Error loading models: {str(e)}")
|
150 |
+
st.error(f"Python version: {sys.version}")
|
151 |
+
st.error(f"PyTorch version: {torch.__version__}")
|
152 |
+
raise e
|
153 |
+
|
154 |
+
@staticmethod
|
155 |
+
def _load_gemma_model():
|
156 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
157 |
"google/gemma-2b",
|
158 |
+
token=os.environ.get('HF_TOKEN'),
|
159 |
trust_remote_code=True
|
160 |
)
|
161 |
+
model = AutoModelForCausalLM.from_pretrained(
|
162 |
"google/gemma-2b",
|
163 |
+
token=os.environ.get('HF_TOKEN'),
|
164 |
torch_dtype=torch.float16,
|
165 |
device_map="auto" if torch.cuda.is_available() else None,
|
166 |
trust_remote_code=True
|
167 |
)
|
168 |
+
return (tokenizer, model)
|
169 |
+
|
170 |
+
@staticmethod
|
171 |
+
def _load_nllb_model():
|
172 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
173 |
"facebook/nllb-200-distilled-600M",
|
174 |
+
token=os.environ.get('HF_TOKEN'),
|
175 |
use_fast=False,
|
176 |
trust_remote_code=True
|
177 |
)
|
178 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
179 |
"facebook/nllb-200-distilled-600M",
|
180 |
+
token=os.environ.get('HF_TOKEN'),
|
181 |
torch_dtype=torch.float16,
|
182 |
device_map="auto" if torch.cuda.is_available() else None,
|
183 |
trust_remote_code=True
|
184 |
)
|
185 |
+
return (tokenizer, model)
|
186 |
+
|
187 |
+
@staticmethod
|
188 |
+
def _load_mt5_model():
|
189 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
190 |
"google/mt5-base",
|
191 |
+
token=os.environ.get('HF_TOKEN'),
|
192 |
trust_remote_code=True
|
193 |
)
|
194 |
+
model = MT5ForConditionalGeneration.from_pretrained(
|
195 |
"google/mt5-base",
|
196 |
+
token=os.environ.get('HF_TOKEN'),
|
197 |
torch_dtype=torch.float16,
|
198 |
device_map="auto" if torch.cuda.is_available() else None,
|
199 |
trust_remote_code=True
|
200 |
)
|
201 |
+
return (tokenizer, model)
|
202 |
+
|
203 |
+
class TranslationPipeline:
|
204 |
+
"""Manages the translation pipeline with context understanding"""
|
205 |
+
|
206 |
+
def __init__(self, models: Dict):
|
207 |
+
self.models = models
|
208 |
+
|
209 |
+
@torch.no_grad()
|
210 |
+
def process_text(self, text: str, source_lang: str, target_lang: str) -> str:
|
211 |
+
# Split text into manageable batches
|
212 |
+
batches = TextBatcher.batch_process_text(text)
|
213 |
+
final_results = []
|
214 |
|
215 |
+
for batch in batches:
|
216 |
+
# Step 1: Context Understanding
|
217 |
+
context = self._understand_context(batch)
|
218 |
+
|
219 |
+
# Step 2: Context-aware Translation
|
220 |
+
translated = self._translate_with_context(
|
221 |
+
context,
|
222 |
+
source_lang,
|
223 |
+
target_lang
|
224 |
+
)
|
225 |
+
|
226 |
+
# Step 3: Grammar Correction
|
227 |
+
corrected = self._correct_grammar(
|
228 |
+
translated,
|
229 |
+
target_lang
|
230 |
+
)
|
231 |
+
|
232 |
+
final_results.append(corrected)
|
233 |
|
234 |
+
return " ".join(final_results)
|
235 |
|
236 |
+
def _understand_context(self, text: str) -> str:
|
237 |
+
"""Enhanced context understanding using Gemma model"""
|
238 |
+
tokenizer, model = self.models["gemma"]
|
239 |
+
|
240 |
+
prompt = f"""Analyze and provide context for translation:
|
241 |
+
Text: {text}
|
242 |
+
Key points to consider:
|
243 |
+
- Main topic and subject matter
|
244 |
+
- Cultural context and nuances
|
245 |
+
- Technical terminology if any
|
246 |
+
- Tone and style of writing
|
247 |
|
248 |
+
Provide a clear and concise interpretation that maintains:
|
249 |
+
1. Original meaning
|
250 |
+
2. Cultural context
|
251 |
+
3. Technical accuracy
|
252 |
+
4. Tone and style"""
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
+
inputs = tokenizer(prompt, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True)
|
255 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
256 |
|
257 |
outputs = model.generate(
|
258 |
**inputs,
|
259 |
+
max_length=CONFIG["MAX_BATCH_LENGTH"],
|
260 |
do_sample=True,
|
261 |
+
temperature=CONFIG["CONTEXT_TEMPERATURE"],
|
262 |
pad_token_id=tokenizer.eos_token_id,
|
263 |
num_return_sequences=1
|
264 |
)
|
265 |
|
266 |
+
context = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
267 |
+
return context.replace(prompt, "").strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
|
269 |
+
def _translate_with_context(self, text: str, source_lang: str, target_lang: str) -> str:
|
270 |
+
"""Enhanced translation using NLLB model with context awareness"""
|
271 |
+
tokenizer, model = self.models["nllb"]
|
272 |
+
|
273 |
+
source_lang_token = f"___{source_lang}___"
|
274 |
+
target_lang_token = f"___{target_lang}___"
|
275 |
+
|
276 |
+
inputs = tokenizer(text, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True)
|
277 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
278 |
|
279 |
target_lang_id = tokenizer.convert_tokens_to_ids(target_lang_token)
|
|
|
281 |
outputs = model.generate(
|
282 |
**inputs,
|
283 |
forced_bos_token_id=target_lang_id,
|
284 |
+
max_length=CONFIG["MAX_BATCH_LENGTH"],
|
285 |
do_sample=True,
|
286 |
+
temperature=CONFIG["TRANSLATION_TEMPERATURE"],
|
287 |
+
num_beams=CONFIG["NUM_BEAMS"],
|
288 |
+
num_return_sequences=1,
|
289 |
+
length_penalty=1.0,
|
290 |
+
repetition_penalty=1.2
|
291 |
)
|
292 |
|
293 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
|
295 |
+
def _correct_grammar(self, text: str, target_lang: str) -> str:
|
296 |
+
"""Enhanced grammar correction using MT5 model"""
|
297 |
+
tokenizer, model = self.models["mt5"]
|
298 |
+
lang_code = CONFIG["MT5_LANG_CODES"][target_lang]
|
299 |
+
prompt = CONFIG["GRAMMAR_PROMPTS"][lang_code]
|
300 |
+
|
301 |
+
input_text = f"{prompt}{text}"
|
302 |
+
inputs = tokenizer(input_text, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True)
|
303 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
304 |
|
305 |
outputs = model.generate(
|
306 |
**inputs,
|
307 |
+
max_length=CONFIG["MAX_BATCH_LENGTH"],
|
308 |
+
num_beams=CONFIG["NUM_BEAMS"],
|
309 |
length_penalty=1.0,
|
310 |
early_stopping=True,
|
311 |
no_repeat_ngram_size=2,
|
312 |
do_sample=False
|
313 |
)
|
314 |
|
315 |
+
corrected = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
316 |
+
for prefix in CONFIG["GRAMMAR_PROMPTS"].values():
|
317 |
+
corrected = corrected.replace(prefix, "")
|
318 |
+
return corrected.strip()
|
|
|
|
|
|
|
|
|
|
|
319 |
|
320 |
+
class DocumentExporter:
|
321 |
+
"""Handles document export operations"""
|
|
|
|
|
322 |
|
323 |
+
@staticmethod
|
324 |
+
def save_as_docx(text: str) -> io.BytesIO:
|
325 |
+
doc = docx.Document()
|
326 |
+
doc.add_paragraph(text)
|
327 |
+
|
328 |
+
buffer = io.BytesIO()
|
329 |
+
doc.save(buffer)
|
330 |
+
buffer.seek(0)
|
331 |
+
|
332 |
+
return buffer
|
333 |
|
334 |
+
@staticmethod
|
335 |
+
def save_as_text(text: str) -> io.BytesIO:
|
336 |
+
buffer = io.BytesIO()
|
337 |
+
buffer.write(text.encode())
|
338 |
+
buffer.seek(0)
|
339 |
+
return buffer
|
340 |
|
341 |
def main():
|
342 |
+
st.title("π Enhanced Document Translation App")
|
343 |
+
|
344 |
+
# Check for HF_TOKEN
|
345 |
+
if not os.environ.get('HF_TOKEN'):
|
346 |
+
st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.")
|
347 |
+
st.stop()
|
348 |
+
|
349 |
+
# Display system info
|
350 |
+
st.sidebar.markdown(f"""
|
351 |
+
### System Information
|
352 |
+
**Current UTC Time:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}
|
353 |
+
**User:** {os.environ.get('USER', 'unknown')}
|
354 |
+
""")
|
355 |
|
356 |
# Load models
|
357 |
with st.spinner("Loading models... This may take a few minutes."):
|
358 |
try:
|
359 |
+
models = ModelManager.load_models()
|
360 |
+
pipeline = TranslationPipeline(models)
|
361 |
except Exception as e:
|
362 |
+
st.error(f"Error initializing translation pipeline: {str(e)}")
|
363 |
return
|
364 |
|
365 |
# File upload
|
|
|
373 |
with col1:
|
374 |
source_language = st.selectbox(
|
375 |
"Source Language",
|
376 |
+
options=list(CONFIG["SUPPORTED_LANGUAGES"].keys()),
|
377 |
index=0
|
378 |
)
|
379 |
|
380 |
with col2:
|
381 |
target_language = st.selectbox(
|
382 |
"Target Language",
|
383 |
+
options=list(CONFIG["SUPPORTED_LANGUAGES"].keys()),
|
384 |
index=1
|
385 |
)
|
386 |
|
387 |
if uploaded_file and st.button("Translate", type="primary"):
|
388 |
try:
|
389 |
progress_bar = st.progress(0)
|
390 |
+
status_text = st.empty()
|
391 |
|
392 |
# Process document
|
393 |
+
status_text.text("Extracting text from document...")
|
394 |
+
text = DocumentProcessor.extract_text_from_file(uploaded_file)
|
395 |
+
progress_bar.progress(20)
|
396 |
+
|
397 |
+
# Perform translation
|
398 |
+
status_text.text("Translating document with context understanding...")
|
399 |
+
final_text = pipeline.process_text(
|
400 |
+
text,
|
401 |
+
CONFIG["SUPPORTED_LANGUAGES"][source_language],
|
402 |
+
CONFIG["SUPPORTED_LANGUAGES"][target_language]
|
403 |
+
)
|
404 |
+
progress_bar.progress(90)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
|
406 |
# Display result
|
407 |
st.markdown("### Translation Result")
|
|
|
417 |
col1, col2 = st.columns(2)
|
418 |
|
419 |
with col1:
|
|
|
|
|
|
|
|
|
|
|
420 |
st.download_button(
|
421 |
label="Download as TXT",
|
422 |
+
data=DocumentExporter.save_as_text(final_text),
|
423 |
file_name="translated_document.txt",
|
424 |
mime="text/plain"
|
425 |
)
|
426 |
|
427 |
with col2:
|
|
|
|
|
428 |
st.download_button(
|
429 |
label="Download as DOCX",
|
430 |
+
data=DocumentExporter.save_as_docx(final_text),
|
431 |
file_name="translated_document.docx",
|
432 |
mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
433 |
)
|
434 |
|
435 |
+
status_text.text("Translation completed successfully!")
|
436 |
progress_bar.progress(100)
|
437 |
|
438 |
except Exception as e:
|