gauravchand11 commited on
Commit
4174664
·
verified ·
1 Parent(s): c98f2e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -80
app.py CHANGED
@@ -2,12 +2,17 @@ import streamlit as st
2
  import PyPDF2
3
  import docx
4
  import io
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
6
  import torch
7
  from pathlib import Path
8
  import tempfile
9
  from typing import Union, Tuple
10
  import os
 
 
 
 
 
11
 
12
  # Get Hugging Face token from environment variables
13
  HF_TOKEN = os.environ.get('HF_TOKEN')
@@ -32,43 +37,67 @@ MT5_LANG_CODES = {
32
  @st.cache_resource
33
  def load_models():
34
  """Load and cache the translation, context interpretation, and grammar correction models."""
35
- # Load Gemma model for context interpretation
36
- gemma_tokenizer = AutoTokenizer.from_pretrained(
37
- "google/gemma-2b",
38
- token=HF_TOKEN
39
- )
40
- gemma_model = AutoModelForCausalLM.from_pretrained(
41
- "google/gemma-2b",
42
- device_map="auto",
43
- torch_dtype=torch.float16,
44
- token=HF_TOKEN
45
- )
46
-
47
- # Load NLLB model for translation
48
- nllb_tokenizer = AutoTokenizer.from_pretrained(
49
- "facebook/nllb-200-distilled-600M",
50
- token=HF_TOKEN
51
- )
52
- nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
53
- "facebook/nllb-200-distilled-600M",
54
- device_map="auto",
55
- torch_dtype=torch.float16,
56
- token=HF_TOKEN
57
- )
58
-
59
- # Load MT5 model for grammar correction
60
- mt5_tokenizer = AutoTokenizer.from_pretrained(
61
- "google/mt5-small",
62
- token=HF_TOKEN
63
- )
64
- mt5_model = T5ForConditionalGeneration.from_pretrained(
65
- "google/mt5-small",
66
- device_map="auto",
67
- torch_dtype=torch.float16,
68
- token=HF_TOKEN
69
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- return (gemma_tokenizer, gemma_model), (nllb_tokenizer, nllb_model), (mt5_tokenizer, mt5_model)
 
 
 
 
 
 
72
 
73
  def extract_text_from_file(uploaded_file) -> str:
74
  """Extract text content from uploaded file based on its type."""
@@ -99,43 +128,87 @@ def extract_from_docx(file) -> str:
99
  text += paragraph.text + "\n"
100
  return text.strip()
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def interpret_context(text: str, gemma_tuple: Tuple) -> str:
103
  """Use Gemma model to interpret context and understand regional nuances."""
104
  tokenizer, model = gemma_tuple
105
 
106
- prompt = f"""Analyze the following text for context and cultural nuances,
107
- maintaining the core meaning while identifying any idiomatic expressions or
108
- cultural references: {text}"""
109
 
110
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
111
- outputs = model.generate(
112
- **inputs,
113
- max_length=1024,
114
- temperature=0.3,
115
- pad_token_id=tokenizer.eos_token_id
116
- )
 
 
 
 
 
 
 
 
 
 
 
117
 
118
- interpreted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
119
- return interpreted_text
120
 
 
121
  def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tuple) -> str:
122
  """Translate text using NLLB model."""
123
  tokenizer, model = nllb_tuple
124
 
125
- inputs = tokenizer(text, return_tensors="pt").to(model.device)
126
- forced_bos_token_id = tokenizer.lang_code_to_id[target_lang]
 
127
 
128
- outputs = model.generate(
129
- **inputs,
130
- forced_bos_token_id=forced_bos_token_id,
131
- max_length=1024,
132
- temperature=0.7,
133
- num_beams=5
134
- )
 
 
 
 
 
 
 
 
 
 
135
 
136
- translated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
137
- return translated_text
138
 
 
139
  def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
140
  """
141
  Correct grammar using MT5 model for all supported languages.
@@ -146,29 +219,40 @@ def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
146
 
147
  # Language-specific prompts for grammar correction
148
  prompts = {
149
- 'en': f"grammar: {text}",
150
- 'hi': f"व्याकरण सुधार: {text}",
151
- 'mr': f"व्याकरण सुधारणा: {text}"
152
  }
153
 
154
- prompt = prompts[lang_code]
155
- inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(model.device)
156
-
157
- outputs = model.generate(
158
- **inputs,
159
- max_length=512,
160
- num_beams=5,
161
- temperature=0.7,
162
- top_p=0.9,
163
- do_sample=True
164
- )
165
-
166
- corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
167
 
168
- # Clean up any artifacts from the model output
169
- corrected_text = corrected_text.replace("grammar:", "").replace("व्याकरण सुधार:", "").replace("व्याकरण सुधारणा:", "").strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
- return corrected_text
172
 
173
  def save_as_docx(text: str) -> io.BytesIO:
174
  """Save translated text as a DOCX file."""
@@ -191,7 +275,7 @@ def main():
191
  except Exception as e:
192
  st.error(f"Error loading models: {str(e)}")
193
  st.error("Please check if the HF_TOKEN is valid and has the necessary permissions.")
194
- st.stop()
195
 
196
  # File upload
197
  uploaded_file = st.file_uploader(
 
2
  import PyPDF2
3
  import docx
4
  import io
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, MT5ForConditionalGeneration
6
  import torch
7
  from pathlib import Path
8
  import tempfile
9
  from typing import Union, Tuple
10
  import os
11
+ from datetime import datetime, timezone
12
+
13
+ # Display current information
14
+ st.sidebar.text(f"Current Time (UTC): {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}")
15
+ st.sidebar.text(f"User: {os.environ.get('USER', 'gauravchand')}")
16
 
17
  # Get Hugging Face token from environment variables
18
  HF_TOKEN = os.environ.get('HF_TOKEN')
 
37
  @st.cache_resource
38
  def load_models():
39
  """Load and cache the translation, context interpretation, and grammar correction models."""
40
+ try:
41
+ # Set device
42
+ device = "cuda" if torch.cuda.is_available() else "cpu"
43
+
44
+ # Load Gemma model for context interpretation
45
+ gemma_tokenizer = AutoTokenizer.from_pretrained(
46
+ "google/gemma-2b",
47
+ token=HF_TOKEN,
48
+ trust_remote_code=True
49
+ )
50
+ gemma_model = AutoModelForCausalLM.from_pretrained(
51
+ "google/gemma-2b",
52
+ token=HF_TOKEN,
53
+ torch_dtype=torch.float16,
54
+ device_map="auto" if torch.cuda.is_available() else None,
55
+ trust_remote_code=True
56
+ )
57
+
58
+ # Load NLLB model for translation
59
+ nllb_tokenizer = AutoTokenizer.from_pretrained(
60
+ "facebook/nllb-200-distilled-600M",
61
+ token=HF_TOKEN,
62
+ trust_remote_code=True
63
+ )
64
+ nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
65
+ "facebook/nllb-200-distilled-600M",
66
+ token=HF_TOKEN,
67
+ torch_dtype=torch.float16,
68
+ device_map="auto" if torch.cuda.is_available() else None,
69
+ trust_remote_code=True
70
+ )
71
+
72
+ # Load MT5 model for grammar correction
73
+ mt5_tokenizer = AutoTokenizer.from_pretrained(
74
+ "google/mt5-small",
75
+ token=HF_TOKEN,
76
+ trust_remote_code=True
77
+ )
78
+ mt5_model = MT5ForConditionalGeneration.from_pretrained(
79
+ "google/mt5-small",
80
+ token=HF_TOKEN,
81
+ torch_dtype=torch.float16,
82
+ device_map="auto" if torch.cuda.is_available() else None,
83
+ trust_remote_code=True
84
+ )
85
+
86
+ # Move models to device if not using device_map="auto"
87
+ if not torch.cuda.is_available():
88
+ gemma_model = gemma_model.to(device)
89
+ nllb_model = nllb_model.to(device)
90
+ mt5_model = mt5_model.to(device)
91
+
92
+ return (gemma_tokenizer, gemma_model), (nllb_tokenizer, nllb_model), (mt5_tokenizer, mt5_model)
93
 
94
+ except Exception as e:
95
+ st.error(f"Error loading models: {str(e)}")
96
+ st.error("Detailed error information:")
97
+ st.error(f"Python version: {sys.version}")
98
+ st.error(f"PyTorch version: {torch.__version__}")
99
+ st.error(f"Transformers version: {transformers.__version__}")
100
+ raise e
101
 
102
  def extract_text_from_file(uploaded_file) -> str:
103
  """Extract text content from uploaded file based on its type."""
 
128
  text += paragraph.text + "\n"
129
  return text.strip()
130
 
131
+ def batch_process_text(text: str, max_length: int = 512) -> list:
132
+ """Split text into batches for processing."""
133
+ words = text.split()
134
+ batches = []
135
+ current_batch = []
136
+ current_length = 0
137
+
138
+ for word in words:
139
+ if current_length + len(word) + 1 > max_length:
140
+ batches.append(" ".join(current_batch))
141
+ current_batch = [word]
142
+ current_length = len(word)
143
+ else:
144
+ current_batch.append(word)
145
+ current_length += len(word) + 1
146
+
147
+ if current_batch:
148
+ batches.append(" ".join(current_batch))
149
+
150
+ return batches
151
+
152
+ @torch.no_grad()
153
  def interpret_context(text: str, gemma_tuple: Tuple) -> str:
154
  """Use Gemma model to interpret context and understand regional nuances."""
155
  tokenizer, model = gemma_tuple
156
 
157
+ # Split text into batches
158
+ batches = batch_process_text(text)
159
+ interpreted_batches = []
160
 
161
+ for batch in batches:
162
+ prompt = f"""Analyze the following text for context and cultural nuances,
163
+ maintaining the core meaning while identifying any idiomatic expressions or
164
+ cultural references: {batch}"""
165
+
166
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
167
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
168
+
169
+ outputs = model.generate(
170
+ **inputs,
171
+ max_length=512,
172
+ temperature=0.3,
173
+ pad_token_id=tokenizer.eos_token_id,
174
+ num_return_sequences=1
175
+ )
176
+
177
+ interpreted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
178
+ interpreted_batches.append(interpreted_text)
179
 
180
+ return " ".join(interpreted_batches)
 
181
 
182
+ @torch.no_grad()
183
  def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tuple) -> str:
184
  """Translate text using NLLB model."""
185
  tokenizer, model = nllb_tuple
186
 
187
+ # Split text into batches
188
+ batches = batch_process_text(text)
189
+ translated_batches = []
190
 
191
+ for batch in batches:
192
+ inputs = tokenizer(batch, return_tensors="pt", max_length=512, truncation=True)
193
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
194
+
195
+ forced_bos_token_id = tokenizer.lang_code_to_id[target_lang]
196
+
197
+ outputs = model.generate(
198
+ **inputs,
199
+ forced_bos_token_id=forced_bos_token_id,
200
+ max_length=512,
201
+ temperature=0.7,
202
+ num_beams=5,
203
+ num_return_sequences=1
204
+ )
205
+
206
+ translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
207
+ translated_batches.append(translated_text)
208
 
209
+ return " ".join(translated_batches)
 
210
 
211
+ @torch.no_grad()
212
  def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
213
  """
214
  Correct grammar using MT5 model for all supported languages.
 
219
 
220
  # Language-specific prompts for grammar correction
221
  prompts = {
222
+ 'en': "grammar: ",
223
+ 'hi': "व्याकरण सुधार: ",
224
+ 'mr': "व्याकरण सुधारणा: "
225
  }
226
 
227
+ # Split text into batches
228
+ batches = batch_process_text(text)
229
+ corrected_batches = []
 
 
 
 
 
 
 
 
 
 
230
 
231
+ for batch in batches:
232
+ prompt = prompts[lang_code] + batch
233
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
234
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
235
+
236
+ outputs = model.generate(
237
+ **inputs,
238
+ max_length=512,
239
+ num_beams=5,
240
+ temperature=0.7,
241
+ top_p=0.9,
242
+ do_sample=True,
243
+ num_return_sequences=1
244
+ )
245
+
246
+ corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
247
+
248
+ # Clean up any artifacts from the model output
249
+ for prefix in prompts.values():
250
+ corrected_text = corrected_text.replace(prefix, "")
251
+ corrected_text = corrected_text.strip()
252
+
253
+ corrected_batches.append(corrected_text)
254
 
255
+ return " ".join(corrected_batches)
256
 
257
  def save_as_docx(text: str) -> io.BytesIO:
258
  """Save translated text as a DOCX file."""
 
275
  except Exception as e:
276
  st.error(f"Error loading models: {str(e)}")
277
  st.error("Please check if the HF_TOKEN is valid and has the necessary permissions.")
278
+ return
279
 
280
  # File upload
281
  uploaded_file = st.file_uploader(