|
from __future__ import annotations |
|
|
|
from collections import defaultdict |
|
from string import punctuation |
|
|
|
import Levenshtein |
|
from errant.edit import Edit |
|
|
|
|
|
def edit_to_tuple(edit: Edit, idx: int = 0) -> tuple[int, int, str, str, int]: |
|
cor_toks_str = " ".join([tok.text for tok in edit.c_toks]) |
|
return [edit.o_start, edit.o_end, edit.type, cor_toks_str, idx] |
|
|
|
|
|
def classify(edit: Edit) -> list[Edit]: |
|
"""Classifies an Edit via updating its `type` attribute.""" |
|
|
|
if ((not edit.o_toks and edit.c_toks) or (edit.o_toks and not edit.c_toks)): |
|
error_cats = get_one_sided_type(edit.o_toks, edit.c_toks) |
|
elif edit.o_toks != edit.c_toks: |
|
error_cats = get_two_sided_type(edit.o_toks, edit.c_toks) |
|
else: |
|
error_cats = {"NA": edit.c_toks[0].text} |
|
new_edit_list = [] |
|
if error_cats: |
|
for error_cat, correct_str in error_cats.items(): |
|
edit.type = error_cat |
|
edit_tuple = edit_to_tuple(edit) |
|
edit_tuple[3] = correct_str |
|
new_edit_list.append(edit_tuple) |
|
return new_edit_list |
|
|
|
|
|
def get_edit_info(toks): |
|
pos = [] |
|
dep = [] |
|
morph = dict() |
|
for tok in toks: |
|
pos.append(tok.tag_) |
|
dep.append(tok.dep_) |
|
morphs = str(tok.morph).split('|') |
|
for m in morphs: |
|
if len(m.strip()): |
|
k, v = m.strip().split('=') |
|
morph[k] = v |
|
return pos, dep, morph |
|
|
|
|
|
def get_one_sided_type(o_toks, c_toks): |
|
"""Classifies a zero-to-one or one-to-zero error based on a token list.""" |
|
pos_list, _, _ = get_edit_info(o_toks if o_toks else c_toks) |
|
if "PUNCT" in pos_list or "SPACE" in pos_list: |
|
return {"PUNCT": c_toks[0].text if c_toks else ""} |
|
return {"SPELL": c_toks[0].text if c_toks else ""} |
|
|
|
|
|
def get_two_sided_type(o_toks, c_toks) -> dict[str, str]: |
|
"""Classifies a one-to-one or one-to-many or many-to-one error based on token lists.""" |
|
|
|
if len(o_toks) == len(c_toks) == 1: |
|
if ( |
|
all(char in punctuation + " " for char in o_toks[0].text) and |
|
all(char in punctuation + " " for char in c_toks[0].text) |
|
): |
|
return {"PUNCT": c_toks[0].text} |
|
source_w, correct_w = o_toks[0].text, c_toks[0].text |
|
if source_w != correct_w: |
|
|
|
|
|
if (((source_w.islower() and correct_w.islower()) or |
|
(source_w.isupper() and correct_w.isupper())) and |
|
"ั" not in source_w + correct_w): |
|
return {"SPELL": correct_w} |
|
|
|
|
|
char_edits = Levenshtein.editops(source_w, correct_w) |
|
|
|
edits_classified = classify_char_edits(char_edits, source_w, correct_w) |
|
|
|
separated_edits = get_edit_strings(source_w, correct_w, edits_classified) |
|
return separated_edits |
|
|
|
if all(char in punctuation + " " for char in o_toks.text + c_toks.text): |
|
return {"PUNCT": c_toks.text} |
|
joint_corr_str = " ".join([tok.text for tok in c_toks]) |
|
joint_corr_str = joint_corr_str.replace("- ", "-").replace(" -", "-") |
|
return {"SPELL": joint_corr_str} |
|
|
|
|
|
def classify_char_edits(char_edits, source_w, correct_w): |
|
"""Classifies char-level Levenstein operations into SPELL, YO and CASE.""" |
|
edits_classified = [] |
|
for edit in char_edits: |
|
if edit[0] == "replace": |
|
if "ั" in [source_w[edit[1]], correct_w[edit[2]]]: |
|
edits_classified.append((*edit, "YO")) |
|
elif source_w[edit[1]].lower() == correct_w[edit[2]].lower(): |
|
edits_classified.append((*edit, "CASE")) |
|
else: |
|
if ( |
|
(source_w[edit[1]].islower() and correct_w[edit[2]].isupper()) or |
|
(source_w[edit[1]].isupper() and correct_w[edit[2]].islower()) |
|
): |
|
edits_classified.append((*edit, "CASE")) |
|
edits_classified.append((*edit, "SPELL")) |
|
else: |
|
edits_classified.append((*edit, "SPELL")) |
|
return edits_classified |
|
|
|
|
|
def get_edit_strings(source: str, correction: str, |
|
edits_classified: list[tuple]) -> dict[str, str]: |
|
""" |
|
Applies classified (SPELL, YO and CASE) char operations to source word separately. |
|
Returns a dict mapping error type to source string with corrections of this type only. |
|
""" |
|
separated_edits = defaultdict(lambda: source) |
|
shift = 0 |
|
for edit in edits_classified: |
|
edit_type = edit[3] |
|
curr_src = separated_edits[edit_type] |
|
if edit_type == "CASE": |
|
if correction[edit[2]].isupper(): |
|
correction_char = source[edit[1]].upper() |
|
else: |
|
correction_char = source[edit[1]].lower() |
|
else: |
|
if edit[0] == "delete": |
|
correction_char = "" |
|
elif edit[0] == "insert": |
|
correction_char = correction[edit[2]] |
|
elif source[edit[1]].isupper(): |
|
correction_char = correction[edit[2]].upper() |
|
else: |
|
correction_char = correction[edit[2]].lower() |
|
if edit[0] == "replace": |
|
separated_edits[edit_type] = curr_src[:edit[1] + shift] + correction_char + \ |
|
curr_src[edit[1]+shift + 1:] |
|
elif edit[0] == "delete": |
|
separated_edits[edit_type] = curr_src[:edit[1] + shift] + \ |
|
curr_src[edit[1]+shift + 1:] |
|
shift -= 1 |
|
elif edit[0] == "insert": |
|
separated_edits[edit_type] = curr_src[:edit[1] + shift] + correction_char + \ |
|
curr_src[edit[1]+shift:] |
|
shift += 1 |
|
return dict(separated_edits) |
|
|