gauravchand11 commited on
Commit
2ea2438
·
verified ·
1 Parent(s): ed75acb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -189
app.py CHANGED
@@ -1,96 +1,96 @@
1
- from transformers import (
2
- AutoTokenizer,
3
- AutoModelForSeq2SeqLM,
4
- BertTokenizer,
5
- BertModel,
6
- AutoModelForTokenClassification
7
- )
8
  import streamlit as st
9
  from PyPDF2 import PdfReader
10
  import docx
11
  import os
12
  import re
13
- import torch
14
- import numpy as np
15
- from datetime import datetime, timezone
16
 
17
- # Load models and tokenizers
18
  @st.cache_resource
19
- def load_models():
20
- try:
21
- # BERT model for context understanding
22
- context_tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
23
- context_model = BertModel.from_pretrained('bert-base-multilingual-cased')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- # NLLB model for translation
26
- nllb_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
27
- nllb_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
28
 
29
- # Grammar correction model
30
- grammar_tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
31
- grammar_model = AutoModelForTokenClassification.from_pretrained(
32
- 'bert-base-cased',
33
- num_labels=3 # Assuming 3 labels: keep, delete, replace
34
- )
35
 
36
- return {
37
- "context": (context_tokenizer, context_model),
38
- "nllb": (nllb_tokenizer, nllb_model),
39
- "grammar": (grammar_tokenizer, grammar_model)
40
- }
41
- except Exception as e:
42
- st.error(f"Error loading models: {str(e)}")
43
- raise e
44
-
45
- def get_bert_embeddings(text, models):
46
- """Get contextual embeddings from BERT"""
47
- tokenizer, model = models["context"]
48
-
49
- # Split text into smaller chunks
50
- max_length = 512
51
- chunks = [text[i:i + max_length] for i in range(0, len(text), max_length)]
52
- contextual_embeddings = []
53
-
54
- for chunk in chunks:
55
- inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=512)
56
- with torch.no_grad():
57
- outputs = model(**inputs)
58
- embeddings = outputs.last_hidden_state.mean(dim=1)
59
- contextual_embeddings.append(embeddings)
60
-
61
- # Combine embeddings from all chunks
62
- combined_embedding = torch.cat(contextual_embeddings, dim=0).mean(dim=0)
63
- return combined_embedding
64
-
65
- def apply_grammar_correction(text, models):
66
- """Basic grammar correction using BERT"""
67
- tokenizer, model = models["grammar"]
68
-
69
- sentences = re.split('([.!?।]+)', text)
70
- corrected_sentences = []
71
-
72
- for sentence in sentences:
73
- if sentence.strip():
74
- # Basic tokenization and prediction
75
- inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
76
- with torch.no_grad():
77
- outputs = model(**inputs)
78
- predictions = torch.argmax(outputs.logits, dim=2)
79
-
80
- tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
81
- corrected_tokens = []
82
-
83
- for token, pred in zip(tokens, predictions[0]):
84
- if pred == 0 or token in ['[CLS]', '[SEP]', '[PAD]']:
85
- if token not in ['[CLS]', '[SEP]', '[PAD]']:
86
- corrected_tokens.append(token)
87
-
88
- corrected_text = tokenizer.convert_tokens_to_string(corrected_tokens)
89
- if corrected_text.strip():
90
- corrected_sentences.append(corrected_text)
91
 
92
- return " ".join(corrected_sentences)
93
 
 
94
  def extract_text(file):
95
  ext = os.path.splitext(file.name)[1].lower()
96
 
@@ -114,10 +114,12 @@ def extract_text(file):
114
  else:
115
  raise ValueError("Unsupported file format. Please upload PDF, DOCX, or TXT files.")
116
 
 
117
  def translate_text(text, src_lang, tgt_lang, models):
118
  if src_lang == tgt_lang:
119
  return text
120
 
 
121
  lang_map = {"en": "eng_Latn", "hi": "hin_Deva", "mr": "mar_Deva"}
122
 
123
  if src_lang not in lang_map or tgt_lang not in lang_map:
@@ -126,81 +128,61 @@ def translate_text(text, src_lang, tgt_lang, models):
126
  tgt_lang_code = lang_map[tgt_lang]
127
  tokenizer, model = models["nllb"]
128
 
129
- try:
130
- # Get contextual embeddings
131
- context_embedding = get_bert_embeddings(text, models)
132
-
133
- # Split into chunks for translation
134
- chunks = []
135
- current_chunk = ""
136
-
137
- for sentence in re.split('([.!?।]+)', text):
138
- if sentence.strip():
139
- if len(current_chunk) + len(sentence) < 450:
140
- current_chunk += sentence
141
- else:
142
- if current_chunk:
143
- chunks.append(current_chunk)
144
- current_chunk = sentence
145
-
146
- if current_chunk:
147
- chunks.append(current_chunk)
148
-
149
- translated_text = ""
150
-
151
- for chunk in chunks:
152
- if chunk.strip():
153
- inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=512)
154
-
155
- # Use context embedding to modify attention
156
- attention_mask = inputs['attention_mask'].float()
157
- context_weight = 0.1 * torch.sigmoid(context_embedding.mean())
158
- attention_mask = attention_mask * (1 + context_weight)
159
-
160
- # Get target language token ID
161
- tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang_code)
162
-
163
- with torch.no_grad():
164
- translated = model.generate(
165
- input_ids=inputs['input_ids'],
166
- attention_mask=attention_mask,
167
- forced_bos_token_id=tgt_lang_id,
168
- max_length=512,
169
- num_beams=5,
170
- length_penalty=1.0,
171
- no_repeat_ngram_size=3,
172
- do_sample=True,
173
- temperature=0.7
174
- )
175
- translated_chunk = tokenizer.decode(translated[0], skip_special_tokens=True)
176
- translated_text += translated_chunk + " "
177
-
178
- # Apply basic grammar correction
179
- corrected_text = apply_grammar_correction(translated_text.strip(), models)
180
-
181
- return corrected_text
182
 
183
- except Exception as e:
184
- st.error(f"Translation error: {str(e)}")
185
- return f"Error during translation: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
 
187
  def save_text_to_file(text, original_filename, prefix="translated"):
188
- timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
189
- output_filename = f"{prefix}_{timestamp}_{os.path.basename(original_filename)}.txt"
190
  with open(output_filename, "w", encoding="utf-8") as f:
191
  f.write(text)
192
  return output_filename
193
 
 
194
  def process_document(file, source_lang, target_lang, models):
195
  try:
196
  # Extract text from uploaded file
197
  text = extract_text(file)
198
 
199
- # Add debugging information
200
- st.sidebar.write("Processing document...")
201
- st.sidebar.write(f"Source language: {source_lang}")
202
- st.sidebar.write(f"Target language: {target_lang}")
203
-
204
  # Translate the text
205
  translated_text = translate_text(text, source_lang, target_lang, models)
206
 
@@ -211,56 +193,44 @@ def process_document(file, source_lang, target_lang, models):
211
  output_file = save_text_to_file(translated_text, file.name)
212
 
213
  return output_file, translated_text
214
-
215
  except Exception as e:
216
  error_message = f"Error: {str(e)}"
217
- st.error(error_message)
218
  output_file = save_text_to_file(error_message, file.name, prefix="error")
219
  return output_file, error_message
220
 
 
221
  def main():
222
- st.title("Advanced Document Translator")
223
- st.write("Upload a document (PDF, DOCX, or TXT) and select source and target languages.")
224
-
225
- # Display current user and timestamp
226
- st.sidebar.write(f"Current User: {os.getenv('USER', 'gauravchand')}")
227
- st.sidebar.write(f"UTC Time: {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}")
228
-
229
- try:
230
- # Initialize models with error handling
231
- with st.spinner("Loading models..."):
232
- models = load_models()
233
- st.success("Models loaded successfully!")
234
-
235
- # File uploader
236
- uploaded_file = st.file_uploader("Upload Document", type=["pdf", "docx", "txt"])
237
-
238
- # Language selection
239
- col1, col2 = st.columns(2)
240
- with col1:
241
- source_lang = st.selectbox("Source Language", ["en", "hi", "mr"], index=0)
242
- with col2:
243
- target_lang = st.selectbox("Target Language", ["en", "hi", "mr"], index=1)
244
-
245
- if uploaded_file is not None and st.button("Translate"):
246
- with st.spinner("Processing and Translating..."):
247
- output_file, result_text = process_document(uploaded_file, source_lang, target_lang, models)
248
-
249
- # Display result
250
- st.text_area("Translated Text", result_text, height=300)
251
-
252
- # Provide download button
253
- with open(output_file, "rb") as file:
254
- st.download_button(
255
- label="Download Translated Document",
256
- data=file,
257
- file_name=os.path.basename(output_file),
258
- mime="text/plain"
259
- )
260
-
261
- except Exception as e:
262
- st.error(f"Application error: {str(e)}")
263
- st.warning("Please try refreshing the page or contact support.")
264
 
265
- if __name__ == "__main__":
266
  main()
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
 
 
 
 
 
2
  import streamlit as st
3
  from PyPDF2 import PdfReader
4
  import docx
5
  import os
6
  import re
 
 
 
7
 
8
+ # Load NLLB model and tokenizer
9
  @st.cache_resource
10
+ def load_translation_model():
11
+ model_name = "facebook/nllb-200-distilled-600M"
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
14
+ return tokenizer, model
15
+
16
+ # Initialize model
17
+ @st.cache_resource
18
+ def initialize_models():
19
+ tokenizer, model = load_translation_model()
20
+ return {"nllb": (tokenizer, model)}
21
+
22
+ # Enhanced idiom mapping with more comprehensive translations
23
+ def preprocess_idioms(text, src_lang, tgt_lang):
24
+ if src_lang == "en" and tgt_lang == "hi":
25
+ idiom_map = {
26
+ # Basic phrases
27
+ "no piece of cake": "कोई आसान काम नहीं",
28
+ "piece of cake": "बहुत आसान काम",
29
+ "bite the bullet": "दांतों तले उंगली दबाना",
30
+ "tackle it head-on": "सीधे मुकाबला करना",
31
+ "fell into place": "सब कुछ ठीक हो गया",
32
+ "see the light at the end of the tunnel": "मुश्किलों के अंत में उम्मीद की किरण दिखना",
33
+ "with a little perseverance": "थोड़े से धैर्य से",
34
+
35
+ # Additional common idioms
36
+ "break a leg": "बहुत बहुत शुभकामनाएं",
37
+ "hit the nail on the head": "बिल्कुल सही बात कहना",
38
+ "once in a blue moon": "बहुत कम, कभी-कभार",
39
+ "under the weather": "तबीयत ठीक नहीं",
40
+ "cost an arm and a leg": "बहुत महंगा",
41
+ "beating around the bush": "इधर-उधर की बात करना",
42
+ "call it a day": "काम समाप्त करना",
43
+ "burn the midnight oil": "रात-रात भर जागकर काम करना",
44
+ "get the ball rolling": "शुरुआत करना",
45
+ "pull yourself together": "खुद को संभालो",
46
+ "shoot yourself in the foot": "अपना ही नुकसान करना",
47
+ "take it with a grain of salt": "संदेह से लेना",
48
+ "the last straw": "सहनशीलता की आखिरी सीमा",
49
+ "time flies": "समय पंख लगाकर उड़ता है",
50
+ "wrap your head around": "समझने की कोशिश करना",
51
+ "cut corners": "काम में छोटा रास्ता अपनाना",
52
+ "back to square one": "फिर से शुरू से",
53
+ "blessing in disguise": "छिपा हुआ वरदान",
54
+ "cry over spilled milk": "बीती बात पर पछताना",
55
+ "keep your chin up": "हिम्मत रखना",
56
+
57
+ # Work-related idioms
58
+ "think outside the box": "नए तरीके से सोचना",
59
+ "raise the bar": "मानक ऊंचा करना",
60
+ "learning curve": "सीखने की प्रक्रिया",
61
+ "up and running": "चालू और कार्यरत",
62
+ "back to the drawing board": "फिर से योजना बनाना",
63
+
64
+ # Project-related phrases
65
+ "running into issues": "समस्याओं का सामना करना",
66
+ "iron out the bugs": "खामियां दूर करना",
67
+ "in the pipeline": "विचाराधीन",
68
+ "moving forward": "आगे बढ़ते हुए",
69
+ "touch base": "संपर्क में रहना",
70
+
71
+ # Technical phrases
72
+ "user-friendly": "उपयोगकर्ता के अनुकूल",
73
+ "cutting-edge": "अत्याधुनिक",
74
+ "state of the art": "अत्याधुनिक तकनीक",
75
+ "proof of concept": "व्यवहार्यता का प्रमाण",
76
+ "game changer": "खेल बदलने वाला"
77
+ }
78
 
79
+ # Sort idioms by length (longest first) to handle overlapping phrases
80
+ sorted_idioms = sorted(idiom_map.keys(), key=len, reverse=True)
 
81
 
82
+ # Create a single regex pattern for all idioms
83
+ pattern = '|'.join(map(re.escape, sorted_idioms))
 
 
 
 
84
 
85
+ def replace_idiom(match):
86
+ return idiom_map[match.group(0).lower()]
87
+
88
+ # Replace all idioms in one pass, case-insensitive
89
+ text = re.sub(pattern, replace_idiom, text, flags=re.IGNORECASE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ return text
92
 
93
+ # Function to extract text from different file types
94
  def extract_text(file):
95
  ext = os.path.splitext(file.name)[1].lower()
96
 
 
114
  else:
115
  raise ValueError("Unsupported file format. Please upload PDF, DOCX, or TXT files.")
116
 
117
+ # Translation function with improved chunking and fixed tokenizer issue
118
  def translate_text(text, src_lang, tgt_lang, models):
119
  if src_lang == tgt_lang:
120
  return text
121
 
122
+ # Language codes for NLLB
123
  lang_map = {"en": "eng_Latn", "hi": "hin_Deva", "mr": "mar_Deva"}
124
 
125
  if src_lang not in lang_map or tgt_lang not in lang_map:
 
128
  tgt_lang_code = lang_map[tgt_lang]
129
  tokenizer, model = models["nllb"]
130
 
131
+ # Preprocess for idioms
132
+ preprocessed_text = preprocess_idioms(text, src_lang, tgt_lang)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ # Improved chunking: Split by sentences while preserving context
135
+ chunks = []
136
+ current_chunk = ""
137
+
138
+ for sentence in re.split('([.!?।]+)', preprocessed_text):
139
+ if sentence.strip():
140
+ if len(current_chunk) + len(sentence) < 450: # Leave room for tokenization
141
+ current_chunk += sentence
142
+ else:
143
+ if current_chunk:
144
+ chunks.append(current_chunk)
145
+ current_chunk = sentence
146
+
147
+ if current_chunk:
148
+ chunks.append(current_chunk)
149
+
150
+ translated_text = ""
151
+
152
+ for chunk in chunks:
153
+ if chunk.strip():
154
+ # Add target language token to the beginning of the input
155
+ inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=512)
156
+
157
+ # Get the token ID for the target language
158
+ tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang_code)
159
+
160
+ translated = model.generate(
161
+ **inputs,
162
+ forced_bos_token_id=tgt_lang_id, # Fixed: Using convert_tokens_to_ids instead of lang_code_to_id
163
+ max_length=512,
164
+ num_beams=5,
165
+ length_penalty=1.0,
166
+ no_repeat_ngram_size=3
167
+ )
168
+ translated_chunk = tokenizer.decode(translated[0], skip_special_tokens=True)
169
+ translated_text += translated_chunk + " "
170
+
171
+ return translated_text.strip()
172
 
173
+ # Function to save text as a file
174
  def save_text_to_file(text, original_filename, prefix="translated"):
175
+ output_filename = f"{prefix}_{os.path.basename(original_filename)}.txt"
 
176
  with open(output_filename, "w", encoding="utf-8") as f:
177
  f.write(text)
178
  return output_filename
179
 
180
+ # Main processing function
181
  def process_document(file, source_lang, target_lang, models):
182
  try:
183
  # Extract text from uploaded file
184
  text = extract_text(file)
185
 
 
 
 
 
 
186
  # Translate the text
187
  translated_text = translate_text(text, source_lang, target_lang, models)
188
 
 
193
  output_file = save_text_to_file(translated_text, file.name)
194
 
195
  return output_file, translated_text
 
196
  except Exception as e:
197
  error_message = f"Error: {str(e)}"
 
198
  output_file = save_text_to_file(error_message, file.name, prefix="error")
199
  return output_file, error_message
200
 
201
+ # Streamlit interface
202
  def main():
203
+ st.title("Document Translator (NLLB-200)")
204
+ st.write("Upload a document (PDF, DOCX, or TXT) and select source and target languages (English, Hindi, Marathi).")
205
+
206
+ # Initialize models
207
+ models = initialize_models()
208
+
209
+ # File uploader
210
+ uploaded_file = st.file_uploader("Upload Document", type=["pdf", "docx", "txt"])
211
+
212
+ # Language selection
213
+ col1, col2 = st.columns(2)
214
+ with col1:
215
+ source_lang = st.selectbox("Source Language", ["en", "hi", "mr"], index=0)
216
+ with col2:
217
+ target_lang = st.selectbox("Target Language", ["en", "hi", "mr"], index=1)
218
+
219
+ if uploaded_file is not None and st.button("Translate"):
220
+ with st.spinner("Translating..."):
221
+ output_file, result_text = process_document(uploaded_file, source_lang, target_lang, models)
222
+
223
+ # Display result
224
+ st.text_area("Translated Text", result_text, height=300)
225
+
226
+ # Provide download button
227
+ with open(output_file, "rb") as file:
228
+ st.download_button(
229
+ label="Download Translated Document",
230
+ data=file,
231
+ file_name=os.path.basename(output_file),
232
+ mime="text/plain"
233
+ )
 
 
 
 
 
 
 
 
 
 
 
234
 
235
+ if _name_ == "_main_":
236
  main()