Zlovoblachko commited on
Commit
3acb933
·
1 Parent(s): 93ddda0

added files

Browse files
Files changed (2) hide show
  1. app.py +292 -0
  2. requirements.txt +80 -0
app.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import nltk
4
+ from transformers import T5ForConditionalGeneration, T5Tokenizer, ElectraTokenizer, ElectraForTokenClassification
5
+ import torch.nn as nn
6
+ from tqdm import tqdm
7
+ import numpy as np
8
+ from huggingface_hub import hf_hub_download
9
+ import re
10
+ import difflib
11
+
12
+ nltk.download('punkt')
13
+
14
+ class T5WithGED(nn.Module):
15
+ def __init__(self, model_path="Zlovoblachko/REAEC_GEC_2step_test", ged_model_path="Zlovoblachko/4tag-electra-grammar-error-detection"):
16
+ super().__init__()
17
+ self.t5 = T5ForConditionalGeneration.from_pretrained(model_path)
18
+ self.t5_tokenizer = T5Tokenizer.from_pretrained(model_path)
19
+ self.has_ged = False
20
+ try:
21
+ self.ged_encoder = self.t5.encoder
22
+ self.gate = nn.Linear(2 * self.t5.config.d_model, 1)
23
+ try:
24
+ ged_components_path = hf_hub_download(
25
+ repo_id=model_path,
26
+ filename="ged_components.pt"
27
+ )
28
+ ged_components = torch.load(ged_components_path, map_location=torch.device('cpu'))
29
+ self.ged_encoder.load_state_dict(ged_components["ged_encoder"])
30
+ self.gate.load_state_dict(ged_components["gate"])
31
+ self.has_ged = True
32
+ except Exception as e:
33
+ print(f"Could not load GED components: {e}")
34
+ except Exception as e:
35
+ print(f"Error setting up GED integration: {e}")
36
+ self.ged_model = None
37
+ self.ged_tokenizer = None
38
+ try:
39
+ self.ged_tokenizer = ElectraTokenizer.from_pretrained(ged_model_path)
40
+ self.ged_model = ElectraForTokenClassification.from_pretrained(ged_model_path)
41
+ self.ged_model.eval()
42
+ except Exception as e:
43
+ print(f"Could not load GED model: {e}")
44
+
45
+ def get_ged_predictions(self, text):
46
+ """Get GED predictions for a sentence."""
47
+ if self.ged_model is None or self.ged_tokenizer is None:
48
+ return None
49
+ inputs = self.ged_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
50
+ with torch.no_grad():
51
+ outputs = self.ged_model(**inputs)
52
+ logits = outputs.logits
53
+ predictions = torch.argmax(logits, dim=2)
54
+ token_predictions = predictions[0].cpu().numpy().tolist()
55
+ tokens = self.ged_tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
56
+ input_tokens = self.ged_tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
57
+ token_pred_pairs = []
58
+ for i, (token, pred) in enumerate(zip(tokens, token_predictions)):
59
+ if token.startswith("##") or token in ["[CLS]", "[SEP]", "[PAD]"]:
60
+ continue
61
+ if pred == 0:
62
+ tag = "C"
63
+ elif pred == 1:
64
+ tag = "R"
65
+ elif pred == 2:
66
+ tag = "M"
67
+ elif pred == 3:
68
+ tag = "U"
69
+ else:
70
+ tag = "C"
71
+ token_pred_pairs.append((token, tag, i))
72
+ ged_tags = [pair[1] for pair in token_pred_pairs]
73
+ error_spans = []
74
+ current_span = None
75
+ for i, (token, tag, token_idx) in enumerate(token_pred_pairs):
76
+ if tag in ["R", "M", "U"]:
77
+ if current_span is None:
78
+ current_span = {
79
+ "start_idx": i,
80
+ "error_type": tag,
81
+ "tokens": [token],
82
+ "token_indices": [token_idx]
83
+ }
84
+ elif current_span["error_type"] == tag:
85
+ current_span["tokens"].append(token)
86
+ current_span["token_indices"].append(token_idx)
87
+ else:
88
+ error_spans.append(current_span)
89
+ current_span = {
90
+ "start_idx": i,
91
+ "error_type": tag,
92
+ "tokens": [token],
93
+ "token_indices": [token_idx]
94
+ }
95
+ else:
96
+ if current_span is not None:
97
+ error_spans.append(current_span)
98
+ current_span = None
99
+ if current_span is not None:
100
+ error_spans.append(current_span)
101
+ formatted_spans = []
102
+ for span in error_spans:
103
+ span_tokens = span["tokens"]
104
+ span_text = " ".join(span_tokens)
105
+ error_type = span["error_type"]
106
+ formatted_spans.append({
107
+ "text": span_text,
108
+ "type": error_type,
109
+ "tokens": span_tokens,
110
+ "token_indices": span["token_indices"]
111
+ })
112
+ return " ".join(ged_tags), formatted_spans, input_tokens
113
+
114
+ def correct(self, text, use_ged=True, max_length=128):
115
+ """Correct grammatical errors in text."""
116
+ inputs = self.t5_tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length)
117
+ ged_tags = None
118
+ error_spans = None
119
+ if self.has_ged and use_ged and self.ged_model is not None:
120
+ ged_info = self.get_ged_predictions(text)
121
+ if ged_info is not None:
122
+ ged_tags, error_spans, input_tokens = ged_info
123
+ if ged_tags is None:
124
+ output_ids = self.t5.generate(input_ids=inputs.input_ids,
125
+ attention_mask=inputs.attention_mask,
126
+ max_length=max_length)
127
+ corrected_text = self.t5_tokenizer.decode(output_ids[0], skip_special_tokens=True)
128
+ return corrected_text, None, None
129
+ ged_inputs = self.t5_tokenizer(ged_tags, return_tensors="pt", truncation=True, max_length=max_length)
130
+ src_encoder_outputs = self.t5.encoder(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, return_dict=True)
131
+ ged_encoder_outputs = self.ged_encoder(input_ids=ged_inputs.input_ids, attention_mask=ged_inputs.attention_mask, return_dict=True)
132
+ src_hidden_states = src_encoder_outputs.last_hidden_state
133
+ ged_hidden_states = ged_encoder_outputs.last_hidden_state
134
+ min_len = min(src_hidden_states.size(1), ged_hidden_states.size(1))
135
+ combined = torch.cat([src_hidden_states[:, :min_len, :], ged_hidden_states[:, :min_len, :]], dim=2)
136
+ gate_scores = torch.sigmoid(self.gate(combined))
137
+ # formula: λ*src_hidden + (1-λ)*ged_hidden
138
+ combined_hidden = (gate_scores * src_hidden_states[:, :min_len, :] + (1 - gate_scores) * ged_hidden_states[:, :min_len, :])
139
+ src_encoder_outputs.last_hidden_state = combined_hidden
140
+ output_ids = self.t5.generate(encoder_outputs=src_encoder_outputs, max_length=max_length)
141
+ else:
142
+ # debug: use usual t5
143
+ output_ids = self.t5.generate(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, max_length=max_length)
144
+ corrected_text = self.t5_tokenizer.decode(output_ids[0], skip_special_tokens=True)
145
+ return corrected_text, ged_tags, error_spans
146
+
147
+ def find_differences(source, corrected):
148
+ """Find differences between source and corrected text."""
149
+ diff = difflib.ndiff(source.split(), corrected.split())
150
+ changes = []
151
+ for i, s in enumerate(diff):
152
+ if s.startswith('- '):
153
+ changes.append({"type": "deletion", "text": s[2:], "position": i})
154
+ elif s.startswith('+ '):
155
+ changes.append({"type": "addition", "text": s[2:], "position": i})
156
+ return changes
157
+
158
+ def process_text(text, model):
159
+ """Process input text by splitting into sentences and applying the model."""
160
+ if not text.strip():
161
+ return "Please enter some text."
162
+
163
+ try:
164
+ sentences = nltk.sent_tokenize(text)
165
+ except LookupError:
166
+ nltk.download('punkt_tab')
167
+ sentences = nltk.sent_tokenize(text)
168
+
169
+ results = []
170
+ for sentence in sentences:
171
+ corrected, ged_tags, error_spans = model.correct(sentence)
172
+
173
+ # Create result dictionary
174
+ result = {
175
+ "original": sentence,
176
+ "corrected": corrected,
177
+ "ged_tags": ged_tags,
178
+ "error_spans": error_spans}
179
+ results.append(result)
180
+
181
+ # Generate HTML output with highlighted errors
182
+ html_output = "<div style='font-family: Arial, sans-serif;'>"
183
+
184
+ for i, result in enumerate(results):
185
+ html_output += f"<div style='margin-bottom: 20px; padding: 15px; border-radius: 5px; background-color: #f8f9fa;'>"
186
+
187
+ # Original sentence with error spans highlighted
188
+ original = result["original"]
189
+ error_spans = result["error_spans"]
190
+
191
+ if error_spans:
192
+ # Convert the original sentence to HTML with highlighted spans
193
+ html_output += "<p><strong>Original sentence:</strong></p>"
194
+
195
+ # Sort spans by token index for proper display
196
+ if error_spans:
197
+ error_spans.sort(key=lambda x: x["token_indices"][0])
198
+
199
+ # Create a visualization of the original text with error spans
200
+ marked_original = original
201
+ replacements = []
202
+
203
+ for span in error_spans:
204
+ error_type = span["type"]
205
+ span_text = span["text"]
206
+
207
+ # Set color based on error type
208
+ if error_type == "R":
209
+ color = "#FFCCCC" # Light red for replacement
210
+ label = "Replace"
211
+ elif error_type == "M":
212
+ color = "#CCFFCC" # Light green for missing
213
+ label = "Missing"
214
+ elif error_type == "U":
215
+ color = "#CCCCFF" # Light blue for unnecessary
216
+ label = "Unnecessary"
217
+
218
+ # Find the span in the original text
219
+ pattern = re.escape(span_text.replace(" ", r"\s+"))
220
+ matches = list(re.finditer(pattern, marked_original, re.IGNORECASE))
221
+
222
+ for match in matches:
223
+ replacements.append((
224
+ match.start(),
225
+ match.end(),
226
+ f"<span style='background-color: {color}; padding: 2px; border-radius: 3px;' title='{label}'>{match.group(0)}</span>"
227
+ ))
228
+
229
+ # Apply replacements from end to start to avoid index shifting
230
+ replacements.sort(key=lambda x: x[0], reverse=True)
231
+ for start, end, replacement in replacements:
232
+ marked_original = marked_original[:start] + replacement + marked_original[end:]
233
+
234
+ html_output += f"<p>{marked_original}</p>"
235
+ else:
236
+ html_output += f"<p><strong>Original sentence:</strong> {original}</p>"
237
+
238
+ # Corrected sentence
239
+ html_output += f"<p><strong>Corrected:</strong> {result['corrected']}</p>"
240
+
241
+ # Find differences for additional visualization
242
+ changes = find_differences(original, result["corrected"])
243
+ if changes:
244
+ html_output += "<p><strong>Changes:</strong></p><ul>"
245
+ for change in changes:
246
+ if change["type"] == "deletion":
247
+ html_output += f"<li>Removed: <span style='color: red;'>{change['text']}</span></li>"
248
+ else:
249
+ html_output += f"<li>Added: <span style='color: green;'>{change['text']}</span></li>"
250
+ html_output += "</ul>"
251
+
252
+ html_output += "</div>"
253
+
254
+ html_output += "</div>"
255
+ return html_output
256
+
257
+ def create_gradio_app():
258
+ model = T5WithGED("Zlovoblachko/REAEC_GEC_2step_test", "Zlovoblachko/4tag-electra-grammar-error-detection")
259
+ iface = gr.Interface(
260
+ fn=lambda text: process_text(text, model),
261
+ inputs=gr.Textbox(
262
+ lines=5,
263
+ placeholder="Enter text to correct grammatical errors...",
264
+ label="Input Text"
265
+ ),
266
+ outputs=gr.HTML(label="Corrected Text"),
267
+ title="Grammar Error Correction with Detection",
268
+ description="""
269
+ This app corrects grammatical errors in text using an ensemble of models:
270
+ 1. An ELECTRA-based Grammatical Error Detection (GED) model identifies error spans
271
+ 2. A T5-based Grammatical Error Correction (GEC) model corrects the errors
272
+
273
+ Enter your text and see the corrections with highlighted error spans:
274
+ - <span style='background-color: #FFCCCC; padding: 2px;'>Red</span>: Replacement needed
275
+ - <span style='background-color: #CCFFCC; padding: 2px;'>Green</span>: Missing word
276
+ - <span style='background-color: #CCCCFF; padding: 2px;'>Blue</span>: Unnecessary word
277
+ """,
278
+ examples=[
279
+ ["First of all, we can see increasing tendency of overweighting during the hole period."],
280
+ ["Food products were mostly transportaded by the road."],
281
+ ["I have went to the store yesterday. She dont like to study for exams."],
282
+ ["The company have announced a new policy. I am living in London since 2010."],
283
+ ["He didnt studied for the test. They was at the party last night."],
284
+ ["The chart illustrates the number in percents of overweight children in Canada throughout a 20-years period from 1985 to 2005, while the table demonstrates the percentage of children doing sport exercises regulary over the period from 1990 to 2005. Overall, it can be seen that despite the fact that the number of boys and girls performing exercises has grown considerably by the end of the period, percent of overweight children has increased too. According to the graph, boys are more likely to have extra weight in period of 2000-2005, a quater of them had problems with weight in 2005. Girls were going ahead of boys in 1985-1990, then they maintained the same level in 1995, but then the number of outweight boys went up more rapidly. The table allows to see that interest in physical activity has grown by more than 25% both within boys and girls by 2005."]
285
+ ],
286
+ allow_flagging="never"
287
+ )
288
+
289
+ return iface
290
+
291
+ iface = create_gradio_app()
292
+ iface.launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ annotated-types==0.7.0
3
+ anyio==4.9.0
4
+ certifi==2025.1.31
5
+ charset-normalizer==3.4.1
6
+ click==8.1.8
7
+ fastapi==0.115.12
8
+ ffmpy==0.5.0
9
+ filelock==3.18.0
10
+ fsspec==2025.3.2
11
+ gradio==5.25.0
12
+ gradio_client==1.8.0
13
+ groovy==0.1.2
14
+ h11==0.14.0
15
+ httpcore==1.0.8
16
+ httpx==0.28.1
17
+ huggingface-hub==0.30.2
18
+ idna==3.10
19
+ Jinja2==3.1.6
20
+ joblib==1.4.2
21
+ markdown-it-py==3.0.0
22
+ MarkupSafe==3.0.2
23
+ mdurl==0.1.2
24
+ mpmath==1.3.0
25
+ networkx==3.4.2
26
+ nltk==3.9.1
27
+ numpy==2.2.4
28
+ nvidia-cublas-cu12==12.4.5.8
29
+ nvidia-cuda-cupti-cu12==12.4.127
30
+ nvidia-cuda-nvrtc-cu12==12.4.127
31
+ nvidia-cuda-runtime-cu12==12.4.127
32
+ nvidia-cudnn-cu12==9.1.0.70
33
+ nvidia-cufft-cu12==11.2.1.3
34
+ nvidia-curand-cu12==10.3.5.147
35
+ nvidia-cusolver-cu12==11.6.1.9
36
+ nvidia-cusparse-cu12==12.3.1.170
37
+ nvidia-cusparselt-cu12==0.6.2
38
+ nvidia-nccl-cu12==2.21.5
39
+ nvidia-nvjitlink-cu12==12.4.127
40
+ nvidia-nvtx-cu12==12.4.127
41
+ orjson==3.10.16
42
+ packaging==24.2
43
+ pandas==2.2.3
44
+ pillow==11.2.1
45
+ pydantic==2.11.3
46
+ pydantic_core==2.33.1
47
+ pydub==0.25.1
48
+ Pygments==2.19.1
49
+ python-dateutil==2.9.0.post0
50
+ python-multipart==0.0.20
51
+ pytz==2025.2
52
+ PyYAML==6.0.2
53
+ regex==2024.11.6
54
+ requests==2.32.3
55
+ rich==14.0.0
56
+ ruff==0.11.5
57
+ safehttpx==0.1.6
58
+ safetensors==0.5.3
59
+ semantic-version==2.10.0
60
+ sentencepiece==0.2.0
61
+ setuptools==75.8.0
62
+ shellingham==1.5.4
63
+ six==1.17.0
64
+ sniffio==1.3.1
65
+ starlette==0.46.2
66
+ sympy==1.13.1
67
+ tokenizers==0.21.1
68
+ tomlkit==0.13.2
69
+ torch==2.6.0
70
+ tqdm==4.67.1
71
+ transformers==4.51.3
72
+ triton==3.2.0
73
+ typer==0.15.2
74
+ typing-inspection==0.4.0
75
+ typing_extensions==4.13.2
76
+ tzdata==2025.2
77
+ urllib3==2.4.0
78
+ uvicorn==0.34.1
79
+ websockets==15.0.1
80
+ wheel==0.45.1