Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -52,6 +52,56 @@ def get_nllb_lang_token(lang_code: str) -> str:
|
|
52 |
"""Get the correct token format for NLLB model."""
|
53 |
return f"___{lang_code}___"
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
@st.cache_resource
|
56 |
def load_models():
|
57 |
"""Load and cache the translation and context interpretation models."""
|
@@ -76,7 +126,7 @@ def load_models():
|
|
76 |
nllb_tokenizer = AutoTokenizer.from_pretrained(
|
77 |
"facebook/nllb-200-distilled-600M",
|
78 |
token=HF_TOKEN,
|
79 |
-
use_fast=False,
|
80 |
trust_remote_code=True
|
81 |
)
|
82 |
nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
|
@@ -114,7 +164,34 @@ def load_models():
|
|
114 |
st.error(f"PyTorch version: {torch.__version__}")
|
115 |
raise e
|
116 |
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
@torch.no_grad()
|
120 |
def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tuple) -> str:
|
@@ -124,15 +201,12 @@ def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tu
|
|
124 |
batches = batch_process_text(text)
|
125 |
translated_batches = []
|
126 |
|
127 |
-
# Get target language token
|
128 |
target_lang_token = get_nllb_lang_token(target_lang)
|
129 |
|
130 |
for batch in batches:
|
131 |
-
# Prepare input text
|
132 |
inputs = tokenizer(batch, return_tensors="pt", max_length=512, truncation=True)
|
133 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
134 |
|
135 |
-
# Get target language token ID
|
136 |
target_lang_id = tokenizer.convert_tokens_to_ids(target_lang_token)
|
137 |
|
138 |
outputs = model.generate(
|
@@ -190,6 +264,17 @@ def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
|
|
190 |
|
191 |
return " ".join(corrected_batches)
|
192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
def main():
|
194 |
st.title("π Document Translation App")
|
195 |
|
|
|
52 |
"""Get the correct token format for NLLB model."""
|
53 |
return f"___{lang_code}___"
|
54 |
|
55 |
+
def extract_text_from_file(uploaded_file) -> str:
|
56 |
+
"""Extract text content from uploaded file based on its type."""
|
57 |
+
file_extension = Path(uploaded_file.name).suffix.lower()
|
58 |
+
|
59 |
+
if file_extension == '.pdf':
|
60 |
+
return extract_from_pdf(uploaded_file)
|
61 |
+
elif file_extension == '.docx':
|
62 |
+
return extract_from_docx(uploaded_file)
|
63 |
+
elif file_extension == '.txt':
|
64 |
+
return uploaded_file.getvalue().decode('utf-8')
|
65 |
+
else:
|
66 |
+
raise ValueError(f"Unsupported file format: {file_extension}")
|
67 |
+
|
68 |
+
def extract_from_pdf(file) -> str:
|
69 |
+
"""Extract text from PDF file."""
|
70 |
+
pdf_reader = PyPDF2.PdfReader(file)
|
71 |
+
text = ""
|
72 |
+
for page in pdf_reader.pages:
|
73 |
+
text += page.extract_text() + "\n"
|
74 |
+
return text.strip()
|
75 |
+
|
76 |
+
def extract_from_docx(file) -> str:
|
77 |
+
"""Extract text from DOCX file."""
|
78 |
+
doc = docx.Document(file)
|
79 |
+
text = ""
|
80 |
+
for paragraph in doc.paragraphs:
|
81 |
+
text += paragraph.text + "\n"
|
82 |
+
return text.strip()
|
83 |
+
|
84 |
+
def batch_process_text(text: str, max_length: int = 512) -> list:
|
85 |
+
"""Split text into batches for processing."""
|
86 |
+
words = text.split()
|
87 |
+
batches = []
|
88 |
+
current_batch = []
|
89 |
+
current_length = 0
|
90 |
+
|
91 |
+
for word in words:
|
92 |
+
if current_length + len(word) + 1 > max_length:
|
93 |
+
batches.append(" ".join(current_batch))
|
94 |
+
current_batch = [word]
|
95 |
+
current_length = len(word)
|
96 |
+
else:
|
97 |
+
current_batch.append(word)
|
98 |
+
current_length += len(word) + 1
|
99 |
+
|
100 |
+
if current_batch:
|
101 |
+
batches.append(" ".join(current_batch))
|
102 |
+
|
103 |
+
return batches
|
104 |
+
|
105 |
@st.cache_resource
|
106 |
def load_models():
|
107 |
"""Load and cache the translation and context interpretation models."""
|
|
|
126 |
nllb_tokenizer = AutoTokenizer.from_pretrained(
|
127 |
"facebook/nllb-200-distilled-600M",
|
128 |
token=HF_TOKEN,
|
129 |
+
use_fast=False,
|
130 |
trust_remote_code=True
|
131 |
)
|
132 |
nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
|
|
|
164 |
st.error(f"PyTorch version: {torch.__version__}")
|
165 |
raise e
|
166 |
|
167 |
+
@torch.no_grad()
|
168 |
+
def interpret_context(text: str, gemma_tuple: Tuple) -> str:
|
169 |
+
"""Use Gemma model to interpret context and understand regional nuances."""
|
170 |
+
tokenizer, model = gemma_tuple
|
171 |
+
|
172 |
+
batches = batch_process_text(text)
|
173 |
+
interpreted_batches = []
|
174 |
+
|
175 |
+
for batch in batches:
|
176 |
+
prompt = f"""Analyze and maintain the core meaning of this text: {batch}"""
|
177 |
+
|
178 |
+
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
|
179 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
180 |
+
|
181 |
+
outputs = model.generate(
|
182 |
+
**inputs,
|
183 |
+
max_length=512,
|
184 |
+
do_sample=True,
|
185 |
+
temperature=0.3,
|
186 |
+
pad_token_id=tokenizer.eos_token_id,
|
187 |
+
num_return_sequences=1
|
188 |
+
)
|
189 |
+
|
190 |
+
interpreted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
191 |
+
interpreted_text = interpreted_text.replace(prompt, "").strip()
|
192 |
+
interpreted_batches.append(interpreted_text)
|
193 |
+
|
194 |
+
return " ".join(interpreted_batches)
|
195 |
|
196 |
@torch.no_grad()
|
197 |
def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tuple) -> str:
|
|
|
201 |
batches = batch_process_text(text)
|
202 |
translated_batches = []
|
203 |
|
|
|
204 |
target_lang_token = get_nllb_lang_token(target_lang)
|
205 |
|
206 |
for batch in batches:
|
|
|
207 |
inputs = tokenizer(batch, return_tensors="pt", max_length=512, truncation=True)
|
208 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
209 |
|
|
|
210 |
target_lang_id = tokenizer.convert_tokens_to_ids(target_lang_token)
|
211 |
|
212 |
outputs = model.generate(
|
|
|
264 |
|
265 |
return " ".join(corrected_batches)
|
266 |
|
267 |
+
def save_as_docx(text: str) -> io.BytesIO:
|
268 |
+
"""Save translated text as a DOCX file."""
|
269 |
+
doc = docx.Document()
|
270 |
+
doc.add_paragraph(text)
|
271 |
+
|
272 |
+
docx_buffer = io.BytesIO()
|
273 |
+
doc.save(docx_buffer)
|
274 |
+
docx_buffer.seek(0)
|
275 |
+
|
276 |
+
return docx_buffer
|
277 |
+
|
278 |
def main():
|
279 |
st.title("π Document Translation App")
|
280 |
|