REALEC_GEC_test / app.py
Zlovoblachko's picture
Update app.py
504d04a verified
import torch
import gradio as gr
import nltk
from transformers import T5ForConditionalGeneration, T5Tokenizer, ElectraTokenizer, ElectraForTokenClassification
import torch.nn as nn
from tqdm import tqdm
import numpy as np
from huggingface_hub import hf_hub_download
import re
import difflib
nltk.download('punkt')
class T5WithGED(nn.Module):
def __init__(self, model_path="Zlovoblachko/REAEC_GEC_2step_test", ged_model_path="Zlovoblachko/4tag-electra-grammar-error-detection"):
super().__init__()
self.t5 = T5ForConditionalGeneration.from_pretrained(model_path)
self.t5_tokenizer = T5Tokenizer.from_pretrained(model_path)
self.has_ged = False
try:
self.ged_encoder = self.t5.encoder
self.gate = nn.Linear(2 * self.t5.config.d_model, 1)
try:
ged_components_path = hf_hub_download(
repo_id=model_path,
filename="ged_components.pt"
)
ged_components = torch.load(ged_components_path, map_location=torch.device('cpu'))
self.ged_encoder.load_state_dict(ged_components["ged_encoder"])
self.gate.load_state_dict(ged_components["gate"])
self.has_ged = True
except Exception as e:
print(f"Could not load GED components: {e}")
except Exception as e:
print(f"Error setting up GED integration: {e}")
self.ged_model = None
self.ged_tokenizer = None
try:
self.ged_tokenizer = ElectraTokenizer.from_pretrained(ged_model_path)
self.ged_model = ElectraForTokenClassification.from_pretrained(ged_model_path)
self.ged_model.eval()
except Exception as e:
print(f"Could not load GED model: {e}")
def get_ged_predictions(self, text):
"""Get GED predictions for a sentence."""
if self.ged_model is None or self.ged_tokenizer is None:
return None
inputs = self.ged_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = self.ged_model(**inputs)
logits = outputs.logits
predictions = torch.argmax(logits, dim=2)
token_predictions = predictions[0].cpu().numpy().tolist()
tokens = self.ged_tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
input_tokens = self.ged_tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
token_pred_pairs = []
for i, (token, pred) in enumerate(zip(tokens, token_predictions)):
if token.startswith("##") or token in ["[CLS]", "[SEP]", "[PAD]"]:
continue
if pred == 0:
tag = "C"
elif pred == 1:
tag = "R"
elif pred == 2:
tag = "M"
elif pred == 3:
tag = "U"
else:
tag = "C"
token_pred_pairs.append((token, tag, i))
ged_tags = [pair[1] for pair in token_pred_pairs]
error_spans = []
current_span = None
for i, (token, tag, token_idx) in enumerate(token_pred_pairs):
if tag in ["R", "M", "U"]:
if current_span is None:
current_span = {
"start_idx": i,
"error_type": tag,
"tokens": [token],
"token_indices": [token_idx]
}
elif current_span["error_type"] == tag:
current_span["tokens"].append(token)
current_span["token_indices"].append(token_idx)
else:
error_spans.append(current_span)
current_span = {
"start_idx": i,
"error_type": tag,
"tokens": [token],
"token_indices": [token_idx]
}
else:
if current_span is not None:
error_spans.append(current_span)
current_span = None
if current_span is not None:
error_spans.append(current_span)
formatted_spans = []
for span in error_spans:
span_tokens = span["tokens"]
span_text = " ".join(span_tokens)
error_type = span["error_type"]
formatted_spans.append({
"text": span_text,
"type": error_type,
"tokens": span_tokens,
"token_indices": span["token_indices"]
})
return " ".join(ged_tags), formatted_spans, input_tokens
def correct(self, text, use_ged=True, max_length=128):
"""Correct grammatical errors in text."""
inputs = self.t5_tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length)
ged_tags = None
error_spans = None
if self.has_ged and use_ged and self.ged_model is not None:
ged_info = self.get_ged_predictions(text)
if ged_info is not None:
ged_tags, error_spans, input_tokens = ged_info
if ged_tags is None:
output_ids = self.t5.generate(input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_length=max_length)
corrected_text = self.t5_tokenizer.decode(output_ids[0], skip_special_tokens=True)
return corrected_text, None, None
ged_inputs = self.t5_tokenizer(ged_tags, return_tensors="pt", truncation=True, max_length=max_length)
src_encoder_outputs = self.t5.encoder(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, return_dict=True)
ged_encoder_outputs = self.ged_encoder(input_ids=ged_inputs.input_ids, attention_mask=ged_inputs.attention_mask, return_dict=True)
src_hidden_states = src_encoder_outputs.last_hidden_state
ged_hidden_states = ged_encoder_outputs.last_hidden_state
min_len = min(src_hidden_states.size(1), ged_hidden_states.size(1))
combined = torch.cat([src_hidden_states[:, :min_len, :], ged_hidden_states[:, :min_len, :]], dim=2)
gate_scores = torch.sigmoid(self.gate(combined))
# formula: λ*src_hidden + (1-λ)*ged_hidden
combined_hidden = (gate_scores * src_hidden_states[:, :min_len, :] + (1 - gate_scores) * ged_hidden_states[:, :min_len, :])
src_encoder_outputs.last_hidden_state = combined_hidden
output_ids = self.t5.generate(encoder_outputs=src_encoder_outputs, max_length=max_length)
else:
# debug: use usual t5
output_ids = self.t5.generate(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, max_length=max_length)
corrected_text = self.t5_tokenizer.decode(output_ids[0], skip_special_tokens=True)
return corrected_text, ged_tags, error_spans
def find_differences(source, corrected):
"""Find differences between source and corrected text."""
diff = difflib.ndiff(source.split(), corrected.split())
changes = []
for i, s in enumerate(diff):
if s.startswith('- '):
changes.append({"type": "deletion", "text": s[2:], "position": i})
elif s.startswith('+ '):
changes.append({"type": "addition", "text": s[2:], "position": i})
return changes
def process_text(text, model):
"""Process input text by splitting into sentences and applying the model."""
if not text.strip():
return "Please enter some text."
try:
sentences = nltk.sent_tokenize(text)
except LookupError:
nltk.download('punkt_tab')
sentences = nltk.sent_tokenize(text)
results = []
for sentence in sentences:
corrected, ged_tags, error_spans = model.correct(sentence)
# Create result dictionary
result = {
"original": sentence,
"corrected": corrected,
"ged_tags": ged_tags,
"error_spans": error_spans}
results.append(result)
# Generate HTML output with highlighted errors
html_output = "<div style='font-family: Arial, sans-serif;'>"
for i, result in enumerate(results):
html_output += f"<div style='margin-bottom: 20px; padding: 15px; border-radius: 5px; background-color: #f8f9fa;'>"
# Original sentence with error spans highlighted
original = result["original"]
error_spans = result["error_spans"]
if error_spans:
# Convert the original sentence to HTML with highlighted spans
html_output += "<p><strong>Original sentence:</strong></p>"
# Sort spans by token index for proper display
if error_spans:
error_spans.sort(key=lambda x: x["token_indices"][0])
# Create a visualization of the original text with error spans
marked_original = original
replacements = []
for span in error_spans:
error_type = span["type"]
span_text = span["text"]
# Set color based on error type
if error_type == "R":
color = "#FFCCCC" # Light red for replacement
label = "Replace"
elif error_type == "M":
color = "#CCFFCC" # Light green for missing
label = "Missing"
elif error_type == "U":
color = "#CCCCFF" # Light blue for unnecessary
label = "Unnecessary"
# Find the span in the original text
pattern = re.escape(span_text.replace(" ", r"\s+"))
matches = list(re.finditer(pattern, marked_original, re.IGNORECASE))
for match in matches:
replacements.append((
match.start(),
match.end(),
f"<span style='background-color: {color}; padding: 2px; border-radius: 3px;' title='{label}'>{match.group(0)}</span>"
))
# Apply replacements from end to start to avoid index shifting
replacements.sort(key=lambda x: x[0], reverse=True)
for start, end, replacement in replacements:
marked_original = marked_original[:start] + replacement + marked_original[end:]
html_output += f"<p>{marked_original}</p>"
else:
html_output += f"<p><strong>Original sentence:</strong> {original}</p>"
# Corrected sentence
html_output += f"<p><strong>Corrected:</strong> {result['corrected']}</p>"
# Find differences for additional visualization
changes = find_differences(original, result["corrected"])
if changes:
html_output += "<p><strong>Changes:</strong></p><ul>"
for change in changes:
if change["type"] == "deletion":
html_output += f"<li>Removed: <span style='color: red;'>{change['text']}</span></li>"
else:
html_output += f"<li>Added: <span style='color: green;'>{change['text']}</span></li>"
html_output += "</ul>"
html_output += "</div>"
html_output += "</div>"
return html_output
def create_gradio_app():
model = T5WithGED("Zlovoblachko/REAEC_GEC_2step_test", "Zlovoblachko/4tag-electra-grammar-error-detection")
iface = gr.Interface(
fn=lambda text: process_text(text, model),
inputs=gr.Textbox(
lines=5,
placeholder="Enter text to correct grammatical errors...",
label="Input Text"
),
outputs=gr.HTML(label="Corrected Text"),
title="Grammar Error Correction with Detection",
description="""
This app corrects grammatical errors in text using an ensemble of models:
1. An ELECTRA-based Grammatical Error Detection (GED) model identifies error spans
2. A T5-based Grammatical Error Correction (GEC) model corrects the errors
Enter your text and see the corrections with highlighted error spans:
- <span style='background-color: #FFCCCC; padding: 2px;'>Red</span>: Replacement needed
- <span style='background-color: #CCFFCC; padding: 2px;'>Green</span>: Missing word
- <span style='background-color: #CCCCFF; padding: 2px;'>Blue</span>: Unnecessary word
""",
examples=[
["First of all, we can see increasing tendency of overweighting during the hole period."],
["Food products were mostly transportaded by the road."],
["I have went to the store yesterday. She dont like to study for exams."],
["The company have announced a new policy. I am living in London since 2010."],
["He didnt studied for the test. They was at the party last night."],
["The chart illustrates the number in percents of overweight children in Canada throughout a 20-years period from 1985 to 2005, while the table demonstrates the percentage of children doing sport exercises regulary over the period from 1990 to 2005. Overall, it can be seen that despite the fact that the number of boys and girls performing exercises has grown considerably by the end of the period, percent of overweight children has increased too. According to the graph, boys are more likely to have extra weight in period of 2000-2005, a quater of them had problems with weight in 2005. Girls were going ahead of boys in 1985-1990, then they maintained the same level in 1995, but then the number of outweight boys went up more rapidly. The table allows to see that interest in physical activity has grown by more than 25% both within boys and girls by 2005."]
],
allow_flagging="never"
)
return iface
iface = create_gradio_app()
iface.launch()