Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -12,53 +12,53 @@ import sys
|
|
12 |
from datetime import datetime, timezone
|
13 |
import warnings
|
14 |
|
15 |
-
# Filter
|
16 |
-
warnings.filterwarnings('ignore', category=UserWarning
|
17 |
-
warnings.filterwarnings('ignore', category=UserWarning, module='transformers.tokenization_utils_base')
|
18 |
|
19 |
-
#
|
20 |
st.set_page_config(
|
21 |
page_title="Document Translation App",
|
22 |
page_icon="🌐",
|
23 |
layout="wide"
|
24 |
)
|
25 |
|
26 |
-
# Display
|
27 |
-
|
28 |
-
st.sidebar.markdown("""
|
29 |
### System Information
|
30 |
-
**Current UTC Time:** {}
|
31 |
-
**User:** {}
|
32 |
-
"""
|
33 |
|
34 |
-
# Get Hugging Face token
|
35 |
HF_TOKEN = os.environ.get('HF_TOKEN')
|
36 |
if not HF_TOKEN:
|
37 |
st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.")
|
38 |
st.stop()
|
39 |
|
40 |
-
#
|
41 |
SUPPORTED_LANGUAGES = {
|
42 |
'English': 'eng_Latn',
|
43 |
'Hindi': 'hin_Deva',
|
44 |
'Marathi': 'mar_Deva'
|
45 |
}
|
46 |
|
47 |
-
# Language codes for MT5
|
48 |
MT5_LANG_CODES = {
|
49 |
'eng_Latn': 'en',
|
50 |
'hin_Deva': 'hi',
|
51 |
'mar_Deva': 'mr'
|
52 |
}
|
53 |
|
|
|
|
|
|
|
|
|
54 |
@st.cache_resource
|
55 |
def load_models():
|
56 |
"""Load and cache the translation and context interpretation models."""
|
57 |
try:
|
58 |
-
# Set device
|
59 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
60 |
|
61 |
-
# Load Gemma model
|
62 |
gemma_tokenizer = AutoTokenizer.from_pretrained(
|
63 |
"google/gemma-2b",
|
64 |
token=HF_TOKEN,
|
@@ -72,11 +72,11 @@ def load_models():
|
|
72 |
trust_remote_code=True
|
73 |
)
|
74 |
|
75 |
-
# Load NLLB model
|
76 |
nllb_tokenizer = AutoTokenizer.from_pretrained(
|
77 |
"facebook/nllb-200-distilled-600M",
|
78 |
token=HF_TOKEN,
|
79 |
-
|
80 |
trust_remote_code=True
|
81 |
)
|
82 |
nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
|
@@ -87,21 +87,20 @@ def load_models():
|
|
87 |
trust_remote_code=True
|
88 |
)
|
89 |
|
90 |
-
# Load MT5 model
|
91 |
mt5_tokenizer = AutoTokenizer.from_pretrained(
|
92 |
-
"google/mt5-base",
|
93 |
token=HF_TOKEN,
|
94 |
trust_remote_code=True
|
95 |
)
|
96 |
mt5_model = MT5ForConditionalGeneration.from_pretrained(
|
97 |
-
"google/mt5-base",
|
98 |
token=HF_TOKEN,
|
99 |
torch_dtype=torch.float16,
|
100 |
device_map="auto" if torch.cuda.is_available() else None,
|
101 |
trust_remote_code=True
|
102 |
)
|
103 |
|
104 |
-
# Move models to device if not using device_map="auto"
|
105 |
if not torch.cuda.is_available():
|
106 |
gemma_model = gemma_model.to(device)
|
107 |
nllb_model = nllb_model.to(device)
|
@@ -111,90 +110,11 @@ def load_models():
|
|
111 |
|
112 |
except Exception as e:
|
113 |
st.error(f"Error loading models: {str(e)}")
|
114 |
-
st.error("Detailed error information:")
|
115 |
st.error(f"Python version: {sys.version}")
|
116 |
st.error(f"PyTorch version: {torch.__version__}")
|
117 |
raise e
|
118 |
|
119 |
-
|
120 |
-
"""Extract text content from uploaded file based on its type."""
|
121 |
-
file_extension = Path(uploaded_file.name).suffix.lower()
|
122 |
-
|
123 |
-
if file_extension == '.pdf':
|
124 |
-
return extract_from_pdf(uploaded_file)
|
125 |
-
elif file_extension == '.docx':
|
126 |
-
return extract_from_docx(uploaded_file)
|
127 |
-
elif file_extension == '.txt':
|
128 |
-
return uploaded_file.getvalue().decode('utf-8')
|
129 |
-
else:
|
130 |
-
raise ValueError(f"Unsupported file format: {file_extension}")
|
131 |
-
|
132 |
-
def extract_from_pdf(file) -> str:
|
133 |
-
"""Extract text from PDF file."""
|
134 |
-
pdf_reader = PyPDF2.PdfReader(file)
|
135 |
-
text = ""
|
136 |
-
for page in pdf_reader.pages:
|
137 |
-
text += page.extract_text() + "\n"
|
138 |
-
return text.strip()
|
139 |
-
|
140 |
-
def extract_from_docx(file) -> str:
|
141 |
-
"""Extract text from DOCX file."""
|
142 |
-
doc = docx.Document(file)
|
143 |
-
text = ""
|
144 |
-
for paragraph in doc.paragraphs:
|
145 |
-
text += paragraph.text + "\n"
|
146 |
-
return text.strip()
|
147 |
-
|
148 |
-
def batch_process_text(text: str, max_length: int = 512) -> list:
|
149 |
-
"""Split text into batches for processing."""
|
150 |
-
words = text.split()
|
151 |
-
batches = []
|
152 |
-
current_batch = []
|
153 |
-
current_length = 0
|
154 |
-
|
155 |
-
for word in words:
|
156 |
-
if current_length + len(word) + 1 > max_length:
|
157 |
-
batches.append(" ".join(current_batch))
|
158 |
-
current_batch = [word]
|
159 |
-
current_length = len(word)
|
160 |
-
else:
|
161 |
-
current_batch.append(word)
|
162 |
-
current_length += len(word) + 1
|
163 |
-
|
164 |
-
if current_batch:
|
165 |
-
batches.append(" ".join(current_batch))
|
166 |
-
|
167 |
-
return batches
|
168 |
-
|
169 |
-
@torch.no_grad()
|
170 |
-
def interpret_context(text: str, gemma_tuple: Tuple) -> str:
|
171 |
-
"""Use Gemma model to interpret context and understand regional nuances."""
|
172 |
-
tokenizer, model = gemma_tuple
|
173 |
-
|
174 |
-
batches = batch_process_text(text)
|
175 |
-
interpreted_batches = []
|
176 |
-
|
177 |
-
for batch in batches:
|
178 |
-
prompt = f"""Analyze and maintain the core meaning of this text: {batch}"""
|
179 |
-
|
180 |
-
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
|
181 |
-
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
182 |
-
|
183 |
-
outputs = model.generate(
|
184 |
-
**inputs,
|
185 |
-
max_length=512,
|
186 |
-
do_sample=True,
|
187 |
-
temperature=0.3,
|
188 |
-
pad_token_id=tokenizer.eos_token_id,
|
189 |
-
num_return_sequences=1
|
190 |
-
)
|
191 |
-
|
192 |
-
interpreted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
193 |
-
# Remove the prompt from the output
|
194 |
-
interpreted_text = interpreted_text.replace(prompt, "").strip()
|
195 |
-
interpreted_batches.append(interpreted_text)
|
196 |
-
|
197 |
-
return " ".join(interpreted_batches)
|
198 |
|
199 |
@torch.no_grad()
|
200 |
def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tuple) -> str:
|
@@ -204,13 +124,20 @@ def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tu
|
|
204 |
batches = batch_process_text(text)
|
205 |
translated_batches = []
|
206 |
|
|
|
|
|
|
|
207 |
for batch in batches:
|
|
|
208 |
inputs = tokenizer(batch, return_tensors="pt", max_length=512, truncation=True)
|
209 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
210 |
|
|
|
|
|
|
|
211 |
outputs = model.generate(
|
212 |
**inputs,
|
213 |
-
forced_bos_token_id=
|
214 |
max_length=512,
|
215 |
do_sample=True,
|
216 |
temperature=0.7,
|
@@ -225,22 +152,20 @@ def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tu
|
|
225 |
|
226 |
@torch.no_grad()
|
227 |
def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
|
228 |
-
"""Correct grammar using MT5 model
|
229 |
tokenizer, model = mt5_tuple
|
230 |
lang_code = MT5_LANG_CODES[target_lang]
|
231 |
|
232 |
-
# Language-specific prompts for grammar correction
|
233 |
prompts = {
|
234 |
'en': "Fix grammar: ",
|
235 |
-
'hi': "
|
236 |
-
'mr': "
|
237 |
}
|
238 |
|
239 |
batches = batch_process_text(text)
|
240 |
corrected_batches = []
|
241 |
|
242 |
for batch in batches:
|
243 |
-
# Prepare input with target language prefix
|
244 |
input_text = f"{prompts[lang_code]}{batch}"
|
245 |
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
|
246 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
@@ -251,29 +176,20 @@ def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
|
|
251 |
num_beams=5,
|
252 |
length_penalty=1.0,
|
253 |
early_stopping=True,
|
254 |
-
|
|
|
255 |
)
|
256 |
|
257 |
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
258 |
-
# Clean up the output
|
259 |
for prefix in prompts.values():
|
260 |
corrected_text = corrected_text.replace(prefix, "")
|
261 |
-
corrected_text = corrected_text.replace("<extra_id_0>", "")
|
|
|
|
|
262 |
corrected_batches.append(corrected_text)
|
263 |
|
264 |
return " ".join(corrected_batches)
|
265 |
|
266 |
-
def save_as_docx(text: str) -> io.BytesIO:
|
267 |
-
"""Save translated text as a DOCX file."""
|
268 |
-
doc = docx.Document()
|
269 |
-
doc.add_paragraph(text)
|
270 |
-
|
271 |
-
docx_buffer = io.BytesIO()
|
272 |
-
doc.save(docx_buffer)
|
273 |
-
docx_buffer.seek(0)
|
274 |
-
|
275 |
-
return docx_buffer
|
276 |
-
|
277 |
def main():
|
278 |
st.title("🌐 Document Translation App")
|
279 |
|
@@ -283,7 +199,6 @@ def main():
|
|
283 |
gemma_tuple, nllb_tuple, mt5_tuple = load_models()
|
284 |
except Exception as e:
|
285 |
st.error(f"Error loading models: {str(e)}")
|
286 |
-
st.error("Please check if the HF_TOKEN is valid and has the necessary permissions.")
|
287 |
return
|
288 |
|
289 |
# File upload
|
@@ -312,39 +227,34 @@ def main():
|
|
312 |
try:
|
313 |
progress_bar = st.progress(0)
|
314 |
|
315 |
-
#
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
with st.spinner("Interpreting context..."):
|
321 |
interpreted_text = interpret_context(text, gemma_tuple)
|
322 |
-
|
323 |
-
|
324 |
-
# Translate
|
325 |
-
with st.spinner("Translating..."):
|
326 |
translated_text = translate_text(
|
327 |
interpreted_text,
|
328 |
SUPPORTED_LANGUAGES[source_language],
|
329 |
SUPPORTED_LANGUAGES[target_language],
|
330 |
nllb_tuple
|
331 |
)
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
with st.spinner("Correcting grammar..."):
|
336 |
-
corrected_text = correct_grammar(
|
337 |
translated_text,
|
338 |
SUPPORTED_LANGUAGES[target_language],
|
339 |
mt5_tuple
|
340 |
)
|
341 |
-
|
342 |
|
343 |
# Display result
|
344 |
st.markdown("### Translation Result")
|
345 |
st.text_area(
|
346 |
label="Translated Text",
|
347 |
-
value=
|
348 |
height=200,
|
349 |
key="translation_result"
|
350 |
)
|
@@ -356,7 +266,7 @@ def main():
|
|
356 |
with col1:
|
357 |
# Text file download
|
358 |
text_buffer = io.BytesIO()
|
359 |
-
text_buffer.write(
|
360 |
text_buffer.seek(0)
|
361 |
|
362 |
st.download_button(
|
@@ -368,7 +278,7 @@ def main():
|
|
368 |
|
369 |
with col2:
|
370 |
# DOCX file download
|
371 |
-
docx_buffer = save_as_docx(
|
372 |
st.download_button(
|
373 |
label="Download as DOCX",
|
374 |
data=docx_buffer,
|
@@ -380,6 +290,6 @@ def main():
|
|
380 |
|
381 |
except Exception as e:
|
382 |
st.error(f"An error occurred: {str(e)}")
|
383 |
-
|
384 |
if __name__ == "__main__":
|
385 |
main()
|
|
|
12 |
from datetime import datetime, timezone
|
13 |
import warnings
|
14 |
|
15 |
+
# Filter warnings
|
16 |
+
warnings.filterwarnings('ignore', category=UserWarning)
|
|
|
17 |
|
18 |
+
# Page config
|
19 |
st.set_page_config(
|
20 |
page_title="Document Translation App",
|
21 |
page_icon="🌐",
|
22 |
layout="wide"
|
23 |
)
|
24 |
|
25 |
+
# Display system info
|
26 |
+
st.sidebar.markdown(f"""
|
|
|
27 |
### System Information
|
28 |
+
**Current UTC Time:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}
|
29 |
+
**User:** {os.environ.get('USER', 'gauravchand')}
|
30 |
+
""")
|
31 |
|
32 |
+
# Get Hugging Face token
|
33 |
HF_TOKEN = os.environ.get('HF_TOKEN')
|
34 |
if not HF_TOKEN:
|
35 |
st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.")
|
36 |
st.stop()
|
37 |
|
38 |
+
# Language configurations
|
39 |
SUPPORTED_LANGUAGES = {
|
40 |
'English': 'eng_Latn',
|
41 |
'Hindi': 'hin_Deva',
|
42 |
'Marathi': 'mar_Deva'
|
43 |
}
|
44 |
|
|
|
45 |
MT5_LANG_CODES = {
|
46 |
'eng_Latn': 'en',
|
47 |
'hin_Deva': 'hi',
|
48 |
'mar_Deva': 'mr'
|
49 |
}
|
50 |
|
51 |
+
def get_nllb_lang_token(lang_code: str) -> str:
|
52 |
+
"""Get the correct token format for NLLB model."""
|
53 |
+
return f"___{lang_code}___"
|
54 |
+
|
55 |
@st.cache_resource
|
56 |
def load_models():
|
57 |
"""Load and cache the translation and context interpretation models."""
|
58 |
try:
|
|
|
59 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
60 |
|
61 |
+
# Load Gemma model
|
62 |
gemma_tokenizer = AutoTokenizer.from_pretrained(
|
63 |
"google/gemma-2b",
|
64 |
token=HF_TOKEN,
|
|
|
72 |
trust_remote_code=True
|
73 |
)
|
74 |
|
75 |
+
# Load NLLB model
|
76 |
nllb_tokenizer = AutoTokenizer.from_pretrained(
|
77 |
"facebook/nllb-200-distilled-600M",
|
78 |
token=HF_TOKEN,
|
79 |
+
use_fast=False, # Use slow tokenizer for better compatibility
|
80 |
trust_remote_code=True
|
81 |
)
|
82 |
nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
|
|
|
87 |
trust_remote_code=True
|
88 |
)
|
89 |
|
90 |
+
# Load MT5 model
|
91 |
mt5_tokenizer = AutoTokenizer.from_pretrained(
|
92 |
+
"google/mt5-base",
|
93 |
token=HF_TOKEN,
|
94 |
trust_remote_code=True
|
95 |
)
|
96 |
mt5_model = MT5ForConditionalGeneration.from_pretrained(
|
97 |
+
"google/mt5-base",
|
98 |
token=HF_TOKEN,
|
99 |
torch_dtype=torch.float16,
|
100 |
device_map="auto" if torch.cuda.is_available() else None,
|
101 |
trust_remote_code=True
|
102 |
)
|
103 |
|
|
|
104 |
if not torch.cuda.is_available():
|
105 |
gemma_model = gemma_model.to(device)
|
106 |
nllb_model = nllb_model.to(device)
|
|
|
110 |
|
111 |
except Exception as e:
|
112 |
st.error(f"Error loading models: {str(e)}")
|
|
|
113 |
st.error(f"Python version: {sys.version}")
|
114 |
st.error(f"PyTorch version: {torch.__version__}")
|
115 |
raise e
|
116 |
|
117 |
+
# [Previous file handling functions remain the same]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
@torch.no_grad()
|
120 |
def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tuple) -> str:
|
|
|
124 |
batches = batch_process_text(text)
|
125 |
translated_batches = []
|
126 |
|
127 |
+
# Get target language token
|
128 |
+
target_lang_token = get_nllb_lang_token(target_lang)
|
129 |
+
|
130 |
for batch in batches:
|
131 |
+
# Prepare input text
|
132 |
inputs = tokenizer(batch, return_tensors="pt", max_length=512, truncation=True)
|
133 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
134 |
|
135 |
+
# Get target language token ID
|
136 |
+
target_lang_id = tokenizer.convert_tokens_to_ids(target_lang_token)
|
137 |
+
|
138 |
outputs = model.generate(
|
139 |
**inputs,
|
140 |
+
forced_bos_token_id=target_lang_id,
|
141 |
max_length=512,
|
142 |
do_sample=True,
|
143 |
temperature=0.7,
|
|
|
152 |
|
153 |
@torch.no_grad()
|
154 |
def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
|
155 |
+
"""Correct grammar using MT5 model."""
|
156 |
tokenizer, model = mt5_tuple
|
157 |
lang_code = MT5_LANG_CODES[target_lang]
|
158 |
|
|
|
159 |
prompts = {
|
160 |
'en': "Fix grammar: ",
|
161 |
+
'hi': "व्याकरण सुधार: ",
|
162 |
+
'mr': "व्याकरण सुधार: "
|
163 |
}
|
164 |
|
165 |
batches = batch_process_text(text)
|
166 |
corrected_batches = []
|
167 |
|
168 |
for batch in batches:
|
|
|
169 |
input_text = f"{prompts[lang_code]}{batch}"
|
170 |
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
|
171 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
|
|
176 |
num_beams=5,
|
177 |
length_penalty=1.0,
|
178 |
early_stopping=True,
|
179 |
+
no_repeat_ngram_size=2,
|
180 |
+
do_sample=False
|
181 |
)
|
182 |
|
183 |
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
184 |
for prefix in prompts.values():
|
185 |
corrected_text = corrected_text.replace(prefix, "")
|
186 |
+
corrected_text = (corrected_text.replace("<extra_id_0>", "")
|
187 |
+
.replace("<extra_id_1>", "")
|
188 |
+
.strip())
|
189 |
corrected_batches.append(corrected_text)
|
190 |
|
191 |
return " ".join(corrected_batches)
|
192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
def main():
|
194 |
st.title("🌐 Document Translation App")
|
195 |
|
|
|
199 |
gemma_tuple, nllb_tuple, mt5_tuple = load_models()
|
200 |
except Exception as e:
|
201 |
st.error(f"Error loading models: {str(e)}")
|
|
|
202 |
return
|
203 |
|
204 |
# File upload
|
|
|
227 |
try:
|
228 |
progress_bar = st.progress(0)
|
229 |
|
230 |
+
# Process document
|
231 |
+
with st.spinner("Processing document..."):
|
232 |
+
text = extract_text_from_file(uploaded_file)
|
233 |
+
progress_bar.progress(25)
|
234 |
+
|
|
|
235 |
interpreted_text = interpret_context(text, gemma_tuple)
|
236 |
+
progress_bar.progress(50)
|
237 |
+
|
|
|
|
|
238 |
translated_text = translate_text(
|
239 |
interpreted_text,
|
240 |
SUPPORTED_LANGUAGES[source_language],
|
241 |
SUPPORTED_LANGUAGES[target_language],
|
242 |
nllb_tuple
|
243 |
)
|
244 |
+
progress_bar.progress(75)
|
245 |
+
|
246 |
+
final_text = correct_grammar(
|
|
|
|
|
247 |
translated_text,
|
248 |
SUPPORTED_LANGUAGES[target_language],
|
249 |
mt5_tuple
|
250 |
)
|
251 |
+
progress_bar.progress(90)
|
252 |
|
253 |
# Display result
|
254 |
st.markdown("### Translation Result")
|
255 |
st.text_area(
|
256 |
label="Translated Text",
|
257 |
+
value=final_text,
|
258 |
height=200,
|
259 |
key="translation_result"
|
260 |
)
|
|
|
266 |
with col1:
|
267 |
# Text file download
|
268 |
text_buffer = io.BytesIO()
|
269 |
+
text_buffer.write(final_text.encode())
|
270 |
text_buffer.seek(0)
|
271 |
|
272 |
st.download_button(
|
|
|
278 |
|
279 |
with col2:
|
280 |
# DOCX file download
|
281 |
+
docx_buffer = save_as_docx(final_text)
|
282 |
st.download_button(
|
283 |
label="Download as DOCX",
|
284 |
data=docx_buffer,
|
|
|
290 |
|
291 |
except Exception as e:
|
292 |
st.error(f"An error occurred: {str(e)}")
|
293 |
+
|
294 |
if __name__ == "__main__":
|
295 |
main()
|