gauravchand11 commited on
Commit
f88b938
·
verified ·
1 Parent(s): 97c654f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -15
app.py CHANGED
@@ -8,12 +8,13 @@ from pathlib import Path
8
  import tempfile
9
  from typing import Union, Tuple
10
  import os
11
- from datetime import datetime, timezone
12
  import sys
 
13
 
14
- # Display current information
15
- st.sidebar.text(f"Current Time (UTC): {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}")
16
- st.sidebar.text(f"User: {os.environ.get('USER', 'gauravchand')}")
 
17
 
18
  # Get Hugging Face token from environment variables
19
  HF_TOKEN = os.environ.get('HF_TOKEN')
@@ -37,7 +38,7 @@ MT5_LANG_CODES = {
37
 
38
  @st.cache_resource
39
  def load_models():
40
- """Load and cache the translation, context interpretation, and grammar correction models."""
41
  try:
42
  # Set device
43
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -60,7 +61,6 @@ def load_models():
60
  nllb_tokenizer = AutoTokenizer.from_pretrained(
61
  "facebook/nllb-200-distilled-600M",
62
  token=HF_TOKEN,
63
- src_lang="eng_Latn", # Default source language
64
  trust_remote_code=True
65
  )
66
  nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
@@ -98,7 +98,6 @@ def load_models():
98
  st.error("Detailed error information:")
99
  st.error(f"Python version: {sys.version}")
100
  st.error(f"PyTorch version: {torch.__version__}")
101
- st.error(f"Transformers version: {transformers.__version__}")
102
  raise e
103
 
104
  def extract_text_from_file(uploaded_file) -> str:
@@ -171,6 +170,7 @@ def interpret_context(text: str, gemma_tuple: Tuple) -> str:
171
  outputs = model.generate(
172
  **inputs,
173
  max_length=512,
 
174
  temperature=0.3,
175
  pad_token_id=tokenizer.eos_token_id,
176
  num_return_sequences=1
@@ -191,18 +191,20 @@ def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tu
191
  translated_batches = []
192
 
193
  for batch in batches:
194
- # Set the source language for the tokenizer
195
- tokenizer.src_lang = source_lang
196
-
197
- # Prepare the input text
198
  inputs = tokenizer(batch, return_tensors="pt", max_length=512, truncation=True)
199
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
200
 
201
- # Generate translation with forced target language
 
 
 
 
202
  outputs = model.generate(
203
  **inputs,
204
- forced_bos_token_id=tokenizer.get_lang_id(target_lang),
205
  max_length=512,
 
206
  temperature=0.7,
207
  num_beams=5,
208
  num_return_sequences=1
@@ -242,9 +244,9 @@ def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
242
  **inputs,
243
  max_length=512,
244
  num_beams=5,
 
245
  temperature=0.7,
246
  top_p=0.9,
247
- do_sample=True,
248
  num_return_sequences=1
249
  )
250
 
@@ -309,7 +311,6 @@ def main():
309
  with st.spinner("Processing document..."):
310
  # Extract text
311
  text = extract_text_from_file(uploaded_file)
312
- st.text_area("Extracted Text:", value=text, height=150)
313
 
314
  # Interpret context
315
  with st.spinner("Interpreting context..."):
 
8
  import tempfile
9
  from typing import Union, Tuple
10
  import os
 
11
  import sys
12
+ from datetime import datetime, timezone
13
 
14
+ # Display current information in sidebar
15
+ st.sidebar.text(f"Current Date and Time (UTC):")
16
+ st.sidebar.text(datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S'))
17
+ st.sidebar.text(f"Current User's Login: {os.environ.get('USER', 'gauravchand')}")
18
 
19
  # Get Hugging Face token from environment variables
20
  HF_TOKEN = os.environ.get('HF_TOKEN')
 
38
 
39
  @st.cache_resource
40
  def load_models():
41
+ """Load and cache the translation and context interpretation models."""
42
  try:
43
  # Set device
44
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
61
  nllb_tokenizer = AutoTokenizer.from_pretrained(
62
  "facebook/nllb-200-distilled-600M",
63
  token=HF_TOKEN,
 
64
  trust_remote_code=True
65
  )
66
  nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
 
98
  st.error("Detailed error information:")
99
  st.error(f"Python version: {sys.version}")
100
  st.error(f"PyTorch version: {torch.__version__}")
 
101
  raise e
102
 
103
  def extract_text_from_file(uploaded_file) -> str:
 
170
  outputs = model.generate(
171
  **inputs,
172
  max_length=512,
173
+ do_sample=True,
174
  temperature=0.3,
175
  pad_token_id=tokenizer.eos_token_id,
176
  num_return_sequences=1
 
191
  translated_batches = []
192
 
193
  for batch in batches:
194
+ # Prepare the input text with source language token
 
 
 
195
  inputs = tokenizer(batch, return_tensors="pt", max_length=512, truncation=True)
196
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
197
 
198
+ # Get target language token ID
199
+ target_lang_token = f"___{target_lang}___"
200
+ target_lang_id = tokenizer.convert_tokens_to_ids(target_lang_token)
201
+
202
+ # Generate translation
203
  outputs = model.generate(
204
  **inputs,
205
+ forced_bos_token_id=target_lang_id,
206
  max_length=512,
207
+ do_sample=True,
208
  temperature=0.7,
209
  num_beams=5,
210
  num_return_sequences=1
 
244
  **inputs,
245
  max_length=512,
246
  num_beams=5,
247
+ do_sample=True,
248
  temperature=0.7,
249
  top_p=0.9,
 
250
  num_return_sequences=1
251
  )
252
 
 
311
  with st.spinner("Processing document..."):
312
  # Extract text
313
  text = extract_text_from_file(uploaded_file)
 
314
 
315
  # Interpret context
316
  with st.spinner("Interpreting context..."):