Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
import nltk | |
from transformers import T5ForConditionalGeneration, T5Tokenizer, ElectraTokenizer, ElectraForTokenClassification | |
import torch.nn as nn | |
from tqdm import tqdm | |
import numpy as np | |
from huggingface_hub import hf_hub_download | |
import re | |
import difflib | |
nltk.download('punkt') | |
class T5WithGED(nn.Module): | |
def __init__(self, model_path="Zlovoblachko/REAEC_GEC_2step_test", ged_model_path="Zlovoblachko/4tag-electra-grammar-error-detection"): | |
super().__init__() | |
self.t5 = T5ForConditionalGeneration.from_pretrained(model_path) | |
self.t5_tokenizer = T5Tokenizer.from_pretrained(model_path) | |
self.has_ged = False | |
try: | |
self.ged_encoder = self.t5.encoder | |
self.gate = nn.Linear(2 * self.t5.config.d_model, 1) | |
try: | |
ged_components_path = hf_hub_download( | |
repo_id=model_path, | |
filename="ged_components.pt" | |
) | |
ged_components = torch.load(ged_components_path, map_location=torch.device('cpu')) | |
self.ged_encoder.load_state_dict(ged_components["ged_encoder"]) | |
self.gate.load_state_dict(ged_components["gate"]) | |
self.has_ged = True | |
except Exception as e: | |
print(f"Could not load GED components: {e}") | |
except Exception as e: | |
print(f"Error setting up GED integration: {e}") | |
self.ged_model = None | |
self.ged_tokenizer = None | |
try: | |
self.ged_tokenizer = ElectraTokenizer.from_pretrained(ged_model_path) | |
self.ged_model = ElectraForTokenClassification.from_pretrained(ged_model_path) | |
self.ged_model.eval() | |
except Exception as e: | |
print(f"Could not load GED model: {e}") | |
def get_ged_predictions(self, text): | |
"""Get GED predictions for a sentence.""" | |
if self.ged_model is None or self.ged_tokenizer is None: | |
return None | |
inputs = self.ged_tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
with torch.no_grad(): | |
outputs = self.ged_model(**inputs) | |
logits = outputs.logits | |
predictions = torch.argmax(logits, dim=2) | |
token_predictions = predictions[0].cpu().numpy().tolist() | |
tokens = self.ged_tokenizer.convert_ids_to_tokens(inputs.input_ids[0]) | |
input_tokens = self.ged_tokenizer.convert_ids_to_tokens(inputs.input_ids[0]) | |
token_pred_pairs = [] | |
for i, (token, pred) in enumerate(zip(tokens, token_predictions)): | |
if token.startswith("##") or token in ["[CLS]", "[SEP]", "[PAD]"]: | |
continue | |
if pred == 0: | |
tag = "C" | |
elif pred == 1: | |
tag = "R" | |
elif pred == 2: | |
tag = "M" | |
elif pred == 3: | |
tag = "U" | |
else: | |
tag = "C" | |
token_pred_pairs.append((token, tag, i)) | |
ged_tags = [pair[1] for pair in token_pred_pairs] | |
error_spans = [] | |
current_span = None | |
for i, (token, tag, token_idx) in enumerate(token_pred_pairs): | |
if tag in ["R", "M", "U"]: | |
if current_span is None: | |
current_span = { | |
"start_idx": i, | |
"error_type": tag, | |
"tokens": [token], | |
"token_indices": [token_idx] | |
} | |
elif current_span["error_type"] == tag: | |
current_span["tokens"].append(token) | |
current_span["token_indices"].append(token_idx) | |
else: | |
error_spans.append(current_span) | |
current_span = { | |
"start_idx": i, | |
"error_type": tag, | |
"tokens": [token], | |
"token_indices": [token_idx] | |
} | |
else: | |
if current_span is not None: | |
error_spans.append(current_span) | |
current_span = None | |
if current_span is not None: | |
error_spans.append(current_span) | |
formatted_spans = [] | |
for span in error_spans: | |
span_tokens = span["tokens"] | |
span_text = " ".join(span_tokens) | |
error_type = span["error_type"] | |
formatted_spans.append({ | |
"text": span_text, | |
"type": error_type, | |
"tokens": span_tokens, | |
"token_indices": span["token_indices"] | |
}) | |
return " ".join(ged_tags), formatted_spans, input_tokens | |
def correct(self, text, use_ged=True, max_length=128): | |
"""Correct grammatical errors in text.""" | |
inputs = self.t5_tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length) | |
ged_tags = None | |
error_spans = None | |
if self.has_ged and use_ged and self.ged_model is not None: | |
ged_info = self.get_ged_predictions(text) | |
if ged_info is not None: | |
ged_tags, error_spans, input_tokens = ged_info | |
if ged_tags is None: | |
output_ids = self.t5.generate(input_ids=inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_length=max_length) | |
corrected_text = self.t5_tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
return corrected_text, None, None | |
ged_inputs = self.t5_tokenizer(ged_tags, return_tensors="pt", truncation=True, max_length=max_length) | |
src_encoder_outputs = self.t5.encoder(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, return_dict=True) | |
ged_encoder_outputs = self.ged_encoder(input_ids=ged_inputs.input_ids, attention_mask=ged_inputs.attention_mask, return_dict=True) | |
src_hidden_states = src_encoder_outputs.last_hidden_state | |
ged_hidden_states = ged_encoder_outputs.last_hidden_state | |
min_len = min(src_hidden_states.size(1), ged_hidden_states.size(1)) | |
combined = torch.cat([src_hidden_states[:, :min_len, :], ged_hidden_states[:, :min_len, :]], dim=2) | |
gate_scores = torch.sigmoid(self.gate(combined)) | |
# formula: λ*src_hidden + (1-λ)*ged_hidden | |
combined_hidden = (gate_scores * src_hidden_states[:, :min_len, :] + (1 - gate_scores) * ged_hidden_states[:, :min_len, :]) | |
src_encoder_outputs.last_hidden_state = combined_hidden | |
output_ids = self.t5.generate(encoder_outputs=src_encoder_outputs, max_length=max_length) | |
else: | |
# debug: use usual t5 | |
output_ids = self.t5.generate(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, max_length=max_length) | |
corrected_text = self.t5_tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
return corrected_text, ged_tags, error_spans | |
def find_differences(source, corrected): | |
"""Find differences between source and corrected text.""" | |
diff = difflib.ndiff(source.split(), corrected.split()) | |
changes = [] | |
for i, s in enumerate(diff): | |
if s.startswith('- '): | |
changes.append({"type": "deletion", "text": s[2:], "position": i}) | |
elif s.startswith('+ '): | |
changes.append({"type": "addition", "text": s[2:], "position": i}) | |
return changes | |
def process_text(text, model): | |
"""Process input text by splitting into sentences and applying the model.""" | |
if not text.strip(): | |
return "Please enter some text." | |
try: | |
sentences = nltk.sent_tokenize(text) | |
except LookupError: | |
nltk.download('punkt_tab') | |
sentences = nltk.sent_tokenize(text) | |
results = [] | |
for sentence in sentences: | |
corrected, ged_tags, error_spans = model.correct(sentence) | |
# Create result dictionary | |
result = { | |
"original": sentence, | |
"corrected": corrected, | |
"ged_tags": ged_tags, | |
"error_spans": error_spans} | |
results.append(result) | |
# Generate HTML output with highlighted errors | |
html_output = "<div style='font-family: Arial, sans-serif;'>" | |
for i, result in enumerate(results): | |
html_output += f"<div style='margin-bottom: 20px; padding: 15px; border-radius: 5px; background-color: #f8f9fa;'>" | |
# Original sentence with error spans highlighted | |
original = result["original"] | |
error_spans = result["error_spans"] | |
if error_spans: | |
# Convert the original sentence to HTML with highlighted spans | |
html_output += "<p><strong>Original sentence:</strong></p>" | |
# Sort spans by token index for proper display | |
if error_spans: | |
error_spans.sort(key=lambda x: x["token_indices"][0]) | |
# Create a visualization of the original text with error spans | |
marked_original = original | |
replacements = [] | |
for span in error_spans: | |
error_type = span["type"] | |
span_text = span["text"] | |
# Set color based on error type | |
if error_type == "R": | |
color = "#FFCCCC" # Light red for replacement | |
label = "Replace" | |
elif error_type == "M": | |
color = "#CCFFCC" # Light green for missing | |
label = "Missing" | |
elif error_type == "U": | |
color = "#CCCCFF" # Light blue for unnecessary | |
label = "Unnecessary" | |
# Find the span in the original text | |
pattern = re.escape(span_text.replace(" ", r"\s+")) | |
matches = list(re.finditer(pattern, marked_original, re.IGNORECASE)) | |
for match in matches: | |
replacements.append(( | |
match.start(), | |
match.end(), | |
f"<span style='background-color: {color}; padding: 2px; border-radius: 3px;' title='{label}'>{match.group(0)}</span>" | |
)) | |
# Apply replacements from end to start to avoid index shifting | |
replacements.sort(key=lambda x: x[0], reverse=True) | |
for start, end, replacement in replacements: | |
marked_original = marked_original[:start] + replacement + marked_original[end:] | |
html_output += f"<p>{marked_original}</p>" | |
else: | |
html_output += f"<p><strong>Original sentence:</strong> {original}</p>" | |
# Corrected sentence | |
html_output += f"<p><strong>Corrected:</strong> {result['corrected']}</p>" | |
# Find differences for additional visualization | |
changes = find_differences(original, result["corrected"]) | |
if changes: | |
html_output += "<p><strong>Changes:</strong></p><ul>" | |
for change in changes: | |
if change["type"] == "deletion": | |
html_output += f"<li>Removed: <span style='color: red;'>{change['text']}</span></li>" | |
else: | |
html_output += f"<li>Added: <span style='color: green;'>{change['text']}</span></li>" | |
html_output += "</ul>" | |
html_output += "</div>" | |
html_output += "</div>" | |
return html_output | |
def create_gradio_app(): | |
model = T5WithGED("Zlovoblachko/REAEC_GEC_2step_test", "Zlovoblachko/4tag-electra-grammar-error-detection") | |
iface = gr.Interface( | |
fn=lambda text: process_text(text, model), | |
inputs=gr.Textbox( | |
lines=5, | |
placeholder="Enter text to correct grammatical errors...", | |
label="Input Text" | |
), | |
outputs=gr.HTML(label="Corrected Text"), | |
title="Grammar Error Correction with Detection", | |
description=""" | |
This app corrects grammatical errors in text using an ensemble of models: | |
1. An ELECTRA-based Grammatical Error Detection (GED) model identifies error spans | |
2. A T5-based Grammatical Error Correction (GEC) model corrects the errors | |
Enter your text and see the corrections with highlighted error spans: | |
- <span style='background-color: #FFCCCC; padding: 2px;'>Red</span>: Replacement needed | |
- <span style='background-color: #CCFFCC; padding: 2px;'>Green</span>: Missing word | |
- <span style='background-color: #CCCCFF; padding: 2px;'>Blue</span>: Unnecessary word | |
""", | |
examples=[ | |
["First of all, we can see increasing tendency of overweighting during the hole period."], | |
["Food products were mostly transportaded by the road."], | |
["I have went to the store yesterday. She dont like to study for exams."], | |
["The company have announced a new policy. I am living in London since 2010."], | |
["He didnt studied for the test. They was at the party last night."], | |
["The chart illustrates the number in percents of overweight children in Canada throughout a 20-years period from 1985 to 2005, while the table demonstrates the percentage of children doing sport exercises regulary over the period from 1990 to 2005. Overall, it can be seen that despite the fact that the number of boys and girls performing exercises has grown considerably by the end of the period, percent of overweight children has increased too. According to the graph, boys are more likely to have extra weight in period of 2000-2005, a quater of them had problems with weight in 2005. Girls were going ahead of boys in 1985-1990, then they maintained the same level in 1995, but then the number of outweight boys went up more rapidly. The table allows to see that interest in physical activity has grown by more than 25% both within boys and girls by 2005."] | |
], | |
allow_flagging="never" | |
) | |
return iface | |
iface = create_gradio_app() | |
iface.launch() |