gauravchand11 commited on
Commit
bcda6d5
Β·
verified Β·
1 Parent(s): 698647f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -5
app.py CHANGED
@@ -52,6 +52,56 @@ def get_nllb_lang_token(lang_code: str) -> str:
52
  """Get the correct token format for NLLB model."""
53
  return f"___{lang_code}___"
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  @st.cache_resource
56
  def load_models():
57
  """Load and cache the translation and context interpretation models."""
@@ -76,7 +126,7 @@ def load_models():
76
  nllb_tokenizer = AutoTokenizer.from_pretrained(
77
  "facebook/nllb-200-distilled-600M",
78
  token=HF_TOKEN,
79
- use_fast=False, # Use slow tokenizer for better compatibility
80
  trust_remote_code=True
81
  )
82
  nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
@@ -114,7 +164,34 @@ def load_models():
114
  st.error(f"PyTorch version: {torch.__version__}")
115
  raise e
116
 
117
- # [Previous file handling functions remain the same]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  @torch.no_grad()
120
  def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tuple) -> str:
@@ -124,15 +201,12 @@ def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tu
124
  batches = batch_process_text(text)
125
  translated_batches = []
126
 
127
- # Get target language token
128
  target_lang_token = get_nllb_lang_token(target_lang)
129
 
130
  for batch in batches:
131
- # Prepare input text
132
  inputs = tokenizer(batch, return_tensors="pt", max_length=512, truncation=True)
133
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
134
 
135
- # Get target language token ID
136
  target_lang_id = tokenizer.convert_tokens_to_ids(target_lang_token)
137
 
138
  outputs = model.generate(
@@ -190,6 +264,17 @@ def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
190
 
191
  return " ".join(corrected_batches)
192
 
 
 
 
 
 
 
 
 
 
 
 
193
  def main():
194
  st.title("🌐 Document Translation App")
195
 
 
52
  """Get the correct token format for NLLB model."""
53
  return f"___{lang_code}___"
54
 
55
+ def extract_text_from_file(uploaded_file) -> str:
56
+ """Extract text content from uploaded file based on its type."""
57
+ file_extension = Path(uploaded_file.name).suffix.lower()
58
+
59
+ if file_extension == '.pdf':
60
+ return extract_from_pdf(uploaded_file)
61
+ elif file_extension == '.docx':
62
+ return extract_from_docx(uploaded_file)
63
+ elif file_extension == '.txt':
64
+ return uploaded_file.getvalue().decode('utf-8')
65
+ else:
66
+ raise ValueError(f"Unsupported file format: {file_extension}")
67
+
68
+ def extract_from_pdf(file) -> str:
69
+ """Extract text from PDF file."""
70
+ pdf_reader = PyPDF2.PdfReader(file)
71
+ text = ""
72
+ for page in pdf_reader.pages:
73
+ text += page.extract_text() + "\n"
74
+ return text.strip()
75
+
76
+ def extract_from_docx(file) -> str:
77
+ """Extract text from DOCX file."""
78
+ doc = docx.Document(file)
79
+ text = ""
80
+ for paragraph in doc.paragraphs:
81
+ text += paragraph.text + "\n"
82
+ return text.strip()
83
+
84
+ def batch_process_text(text: str, max_length: int = 512) -> list:
85
+ """Split text into batches for processing."""
86
+ words = text.split()
87
+ batches = []
88
+ current_batch = []
89
+ current_length = 0
90
+
91
+ for word in words:
92
+ if current_length + len(word) + 1 > max_length:
93
+ batches.append(" ".join(current_batch))
94
+ current_batch = [word]
95
+ current_length = len(word)
96
+ else:
97
+ current_batch.append(word)
98
+ current_length += len(word) + 1
99
+
100
+ if current_batch:
101
+ batches.append(" ".join(current_batch))
102
+
103
+ return batches
104
+
105
  @st.cache_resource
106
  def load_models():
107
  """Load and cache the translation and context interpretation models."""
 
126
  nllb_tokenizer = AutoTokenizer.from_pretrained(
127
  "facebook/nllb-200-distilled-600M",
128
  token=HF_TOKEN,
129
+ use_fast=False,
130
  trust_remote_code=True
131
  )
132
  nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
 
164
  st.error(f"PyTorch version: {torch.__version__}")
165
  raise e
166
 
167
+ @torch.no_grad()
168
+ def interpret_context(text: str, gemma_tuple: Tuple) -> str:
169
+ """Use Gemma model to interpret context and understand regional nuances."""
170
+ tokenizer, model = gemma_tuple
171
+
172
+ batches = batch_process_text(text)
173
+ interpreted_batches = []
174
+
175
+ for batch in batches:
176
+ prompt = f"""Analyze and maintain the core meaning of this text: {batch}"""
177
+
178
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
179
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
180
+
181
+ outputs = model.generate(
182
+ **inputs,
183
+ max_length=512,
184
+ do_sample=True,
185
+ temperature=0.3,
186
+ pad_token_id=tokenizer.eos_token_id,
187
+ num_return_sequences=1
188
+ )
189
+
190
+ interpreted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
191
+ interpreted_text = interpreted_text.replace(prompt, "").strip()
192
+ interpreted_batches.append(interpreted_text)
193
+
194
+ return " ".join(interpreted_batches)
195
 
196
  @torch.no_grad()
197
  def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tuple) -> str:
 
201
  batches = batch_process_text(text)
202
  translated_batches = []
203
 
 
204
  target_lang_token = get_nllb_lang_token(target_lang)
205
 
206
  for batch in batches:
 
207
  inputs = tokenizer(batch, return_tensors="pt", max_length=512, truncation=True)
208
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
209
 
 
210
  target_lang_id = tokenizer.convert_tokens_to_ids(target_lang_token)
211
 
212
  outputs = model.generate(
 
264
 
265
  return " ".join(corrected_batches)
266
 
267
+ def save_as_docx(text: str) -> io.BytesIO:
268
+ """Save translated text as a DOCX file."""
269
+ doc = docx.Document()
270
+ doc.add_paragraph(text)
271
+
272
+ docx_buffer = io.BytesIO()
273
+ doc.save(docx_buffer)
274
+ docx_buffer.seek(0)
275
+
276
+ return docx_buffer
277
+
278
  def main():
279
  st.title("🌐 Document Translation App")
280