Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -2,12 +2,17 @@ import streamlit as st
|
|
2 |
import PyPDF2
|
3 |
import docx
|
4 |
import io
|
5 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM,
|
6 |
import torch
|
7 |
from pathlib import Path
|
8 |
import tempfile
|
9 |
from typing import Union, Tuple
|
10 |
import os
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
# Get Hugging Face token from environment variables
|
13 |
HF_TOKEN = os.environ.get('HF_TOKEN')
|
@@ -32,43 +37,67 @@ MT5_LANG_CODES = {
|
|
32 |
@st.cache_resource
|
33 |
def load_models():
|
34 |
"""Load and cache the translation, context interpretation, and grammar correction models."""
|
35 |
-
|
36 |
-
|
37 |
-
"
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
def extract_text_from_file(uploaded_file) -> str:
|
74 |
"""Extract text content from uploaded file based on its type."""
|
@@ -99,43 +128,87 @@ def extract_from_docx(file) -> str:
|
|
99 |
text += paragraph.text + "\n"
|
100 |
return text.strip()
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
def interpret_context(text: str, gemma_tuple: Tuple) -> str:
|
103 |
"""Use Gemma model to interpret context and understand regional nuances."""
|
104 |
tokenizer, model = gemma_tuple
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
-
|
119 |
-
return interpreted_text
|
120 |
|
|
|
121 |
def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tuple) -> str:
|
122 |
"""Translate text using NLLB model."""
|
123 |
tokenizer, model = nllb_tuple
|
124 |
|
125 |
-
|
126 |
-
|
|
|
127 |
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
-
|
137 |
-
return translated_text
|
138 |
|
|
|
139 |
def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
|
140 |
"""
|
141 |
Correct grammar using MT5 model for all supported languages.
|
@@ -146,29 +219,40 @@ def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
|
|
146 |
|
147 |
# Language-specific prompts for grammar correction
|
148 |
prompts = {
|
149 |
-
'en':
|
150 |
-
'hi':
|
151 |
-
'mr':
|
152 |
}
|
153 |
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
outputs = model.generate(
|
158 |
-
**inputs,
|
159 |
-
max_length=512,
|
160 |
-
num_beams=5,
|
161 |
-
temperature=0.7,
|
162 |
-
top_p=0.9,
|
163 |
-
do_sample=True
|
164 |
-
)
|
165 |
-
|
166 |
-
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
167 |
|
168 |
-
|
169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
-
return
|
172 |
|
173 |
def save_as_docx(text: str) -> io.BytesIO:
|
174 |
"""Save translated text as a DOCX file."""
|
@@ -191,7 +275,7 @@ def main():
|
|
191 |
except Exception as e:
|
192 |
st.error(f"Error loading models: {str(e)}")
|
193 |
st.error("Please check if the HF_TOKEN is valid and has the necessary permissions.")
|
194 |
-
|
195 |
|
196 |
# File upload
|
197 |
uploaded_file = st.file_uploader(
|
|
|
2 |
import PyPDF2
|
3 |
import docx
|
4 |
import io
|
5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, MT5ForConditionalGeneration
|
6 |
import torch
|
7 |
from pathlib import Path
|
8 |
import tempfile
|
9 |
from typing import Union, Tuple
|
10 |
import os
|
11 |
+
from datetime import datetime, timezone
|
12 |
+
|
13 |
+
# Display current information
|
14 |
+
st.sidebar.text(f"Current Time (UTC): {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}")
|
15 |
+
st.sidebar.text(f"User: {os.environ.get('USER', 'gauravchand')}")
|
16 |
|
17 |
# Get Hugging Face token from environment variables
|
18 |
HF_TOKEN = os.environ.get('HF_TOKEN')
|
|
|
37 |
@st.cache_resource
|
38 |
def load_models():
|
39 |
"""Load and cache the translation, context interpretation, and grammar correction models."""
|
40 |
+
try:
|
41 |
+
# Set device
|
42 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
43 |
+
|
44 |
+
# Load Gemma model for context interpretation
|
45 |
+
gemma_tokenizer = AutoTokenizer.from_pretrained(
|
46 |
+
"google/gemma-2b",
|
47 |
+
token=HF_TOKEN,
|
48 |
+
trust_remote_code=True
|
49 |
+
)
|
50 |
+
gemma_model = AutoModelForCausalLM.from_pretrained(
|
51 |
+
"google/gemma-2b",
|
52 |
+
token=HF_TOKEN,
|
53 |
+
torch_dtype=torch.float16,
|
54 |
+
device_map="auto" if torch.cuda.is_available() else None,
|
55 |
+
trust_remote_code=True
|
56 |
+
)
|
57 |
+
|
58 |
+
# Load NLLB model for translation
|
59 |
+
nllb_tokenizer = AutoTokenizer.from_pretrained(
|
60 |
+
"facebook/nllb-200-distilled-600M",
|
61 |
+
token=HF_TOKEN,
|
62 |
+
trust_remote_code=True
|
63 |
+
)
|
64 |
+
nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
|
65 |
+
"facebook/nllb-200-distilled-600M",
|
66 |
+
token=HF_TOKEN,
|
67 |
+
torch_dtype=torch.float16,
|
68 |
+
device_map="auto" if torch.cuda.is_available() else None,
|
69 |
+
trust_remote_code=True
|
70 |
+
)
|
71 |
+
|
72 |
+
# Load MT5 model for grammar correction
|
73 |
+
mt5_tokenizer = AutoTokenizer.from_pretrained(
|
74 |
+
"google/mt5-small",
|
75 |
+
token=HF_TOKEN,
|
76 |
+
trust_remote_code=True
|
77 |
+
)
|
78 |
+
mt5_model = MT5ForConditionalGeneration.from_pretrained(
|
79 |
+
"google/mt5-small",
|
80 |
+
token=HF_TOKEN,
|
81 |
+
torch_dtype=torch.float16,
|
82 |
+
device_map="auto" if torch.cuda.is_available() else None,
|
83 |
+
trust_remote_code=True
|
84 |
+
)
|
85 |
+
|
86 |
+
# Move models to device if not using device_map="auto"
|
87 |
+
if not torch.cuda.is_available():
|
88 |
+
gemma_model = gemma_model.to(device)
|
89 |
+
nllb_model = nllb_model.to(device)
|
90 |
+
mt5_model = mt5_model.to(device)
|
91 |
+
|
92 |
+
return (gemma_tokenizer, gemma_model), (nllb_tokenizer, nllb_model), (mt5_tokenizer, mt5_model)
|
93 |
|
94 |
+
except Exception as e:
|
95 |
+
st.error(f"Error loading models: {str(e)}")
|
96 |
+
st.error("Detailed error information:")
|
97 |
+
st.error(f"Python version: {sys.version}")
|
98 |
+
st.error(f"PyTorch version: {torch.__version__}")
|
99 |
+
st.error(f"Transformers version: {transformers.__version__}")
|
100 |
+
raise e
|
101 |
|
102 |
def extract_text_from_file(uploaded_file) -> str:
|
103 |
"""Extract text content from uploaded file based on its type."""
|
|
|
128 |
text += paragraph.text + "\n"
|
129 |
return text.strip()
|
130 |
|
131 |
+
def batch_process_text(text: str, max_length: int = 512) -> list:
|
132 |
+
"""Split text into batches for processing."""
|
133 |
+
words = text.split()
|
134 |
+
batches = []
|
135 |
+
current_batch = []
|
136 |
+
current_length = 0
|
137 |
+
|
138 |
+
for word in words:
|
139 |
+
if current_length + len(word) + 1 > max_length:
|
140 |
+
batches.append(" ".join(current_batch))
|
141 |
+
current_batch = [word]
|
142 |
+
current_length = len(word)
|
143 |
+
else:
|
144 |
+
current_batch.append(word)
|
145 |
+
current_length += len(word) + 1
|
146 |
+
|
147 |
+
if current_batch:
|
148 |
+
batches.append(" ".join(current_batch))
|
149 |
+
|
150 |
+
return batches
|
151 |
+
|
152 |
+
@torch.no_grad()
|
153 |
def interpret_context(text: str, gemma_tuple: Tuple) -> str:
|
154 |
"""Use Gemma model to interpret context and understand regional nuances."""
|
155 |
tokenizer, model = gemma_tuple
|
156 |
|
157 |
+
# Split text into batches
|
158 |
+
batches = batch_process_text(text)
|
159 |
+
interpreted_batches = []
|
160 |
|
161 |
+
for batch in batches:
|
162 |
+
prompt = f"""Analyze the following text for context and cultural nuances,
|
163 |
+
maintaining the core meaning while identifying any idiomatic expressions or
|
164 |
+
cultural references: {batch}"""
|
165 |
+
|
166 |
+
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
|
167 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
168 |
+
|
169 |
+
outputs = model.generate(
|
170 |
+
**inputs,
|
171 |
+
max_length=512,
|
172 |
+
temperature=0.3,
|
173 |
+
pad_token_id=tokenizer.eos_token_id,
|
174 |
+
num_return_sequences=1
|
175 |
+
)
|
176 |
+
|
177 |
+
interpreted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
178 |
+
interpreted_batches.append(interpreted_text)
|
179 |
|
180 |
+
return " ".join(interpreted_batches)
|
|
|
181 |
|
182 |
+
@torch.no_grad()
|
183 |
def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tuple) -> str:
|
184 |
"""Translate text using NLLB model."""
|
185 |
tokenizer, model = nllb_tuple
|
186 |
|
187 |
+
# Split text into batches
|
188 |
+
batches = batch_process_text(text)
|
189 |
+
translated_batches = []
|
190 |
|
191 |
+
for batch in batches:
|
192 |
+
inputs = tokenizer(batch, return_tensors="pt", max_length=512, truncation=True)
|
193 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
194 |
+
|
195 |
+
forced_bos_token_id = tokenizer.lang_code_to_id[target_lang]
|
196 |
+
|
197 |
+
outputs = model.generate(
|
198 |
+
**inputs,
|
199 |
+
forced_bos_token_id=forced_bos_token_id,
|
200 |
+
max_length=512,
|
201 |
+
temperature=0.7,
|
202 |
+
num_beams=5,
|
203 |
+
num_return_sequences=1
|
204 |
+
)
|
205 |
+
|
206 |
+
translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
207 |
+
translated_batches.append(translated_text)
|
208 |
|
209 |
+
return " ".join(translated_batches)
|
|
|
210 |
|
211 |
+
@torch.no_grad()
|
212 |
def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
|
213 |
"""
|
214 |
Correct grammar using MT5 model for all supported languages.
|
|
|
219 |
|
220 |
# Language-specific prompts for grammar correction
|
221 |
prompts = {
|
222 |
+
'en': "grammar: ",
|
223 |
+
'hi': "व्याकरण सुधार: ",
|
224 |
+
'mr': "व्याकरण सुधारणा: "
|
225 |
}
|
226 |
|
227 |
+
# Split text into batches
|
228 |
+
batches = batch_process_text(text)
|
229 |
+
corrected_batches = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
+
for batch in batches:
|
232 |
+
prompt = prompts[lang_code] + batch
|
233 |
+
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
|
234 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
235 |
+
|
236 |
+
outputs = model.generate(
|
237 |
+
**inputs,
|
238 |
+
max_length=512,
|
239 |
+
num_beams=5,
|
240 |
+
temperature=0.7,
|
241 |
+
top_p=0.9,
|
242 |
+
do_sample=True,
|
243 |
+
num_return_sequences=1
|
244 |
+
)
|
245 |
+
|
246 |
+
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
247 |
+
|
248 |
+
# Clean up any artifacts from the model output
|
249 |
+
for prefix in prompts.values():
|
250 |
+
corrected_text = corrected_text.replace(prefix, "")
|
251 |
+
corrected_text = corrected_text.strip()
|
252 |
+
|
253 |
+
corrected_batches.append(corrected_text)
|
254 |
|
255 |
+
return " ".join(corrected_batches)
|
256 |
|
257 |
def save_as_docx(text: str) -> io.BytesIO:
|
258 |
"""Save translated text as a DOCX file."""
|
|
|
275 |
except Exception as e:
|
276 |
st.error(f"Error loading models: {str(e)}")
|
277 |
st.error("Please check if the HF_TOKEN is valid and has the necessary permissions.")
|
278 |
+
return
|
279 |
|
280 |
# File upload
|
281 |
uploaded_file = st.file_uploader(
|