File size: 6,207 Bytes
a0b78f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from __future__ import annotations

from collections import defaultdict
from string import punctuation

import Levenshtein
from errant.edit import Edit


def edit_to_tuple(edit: Edit, idx: int = 0) -> tuple[int, int, str, str, int]:
    cor_toks_str = " ".join([tok.text for tok in edit.c_toks])
    return [edit.o_start, edit.o_end, edit.type, cor_toks_str, idx]


def classify(edit: Edit) -> list[Edit]:
    """Classifies an Edit via updating its `type` attribute."""
    # Insertion and deletion
    if ((not edit.o_toks and edit.c_toks) or (edit.o_toks and not edit.c_toks)):
        error_cats = get_one_sided_type(edit.o_toks, edit.c_toks)
    elif edit.o_toks != edit.c_toks:
        error_cats = get_two_sided_type(edit.o_toks, edit.c_toks)
    else:
        error_cats = {"NA": edit.c_toks[0].text}
    new_edit_list = []
    if error_cats:
        for error_cat, correct_str in error_cats.items():
            edit.type = error_cat
            edit_tuple = edit_to_tuple(edit)
            edit_tuple[3] = correct_str
            new_edit_list.append(edit_tuple)
    return new_edit_list


def get_edit_info(toks):
    pos = []
    dep = []
    morph = dict()
    for tok in toks:
        pos.append(tok.tag_)
        dep.append(tok.dep_)
        morphs = str(tok.morph).split('|')
        for m in morphs:
            if len(m.strip()):
                k, v = m.strip().split('=')
                morph[k] = v
    return pos, dep, morph


def get_one_sided_type(o_toks, c_toks):
    """Classifies a zero-to-one or one-to-zero error based on a token list."""
    pos_list, _, _ = get_edit_info(o_toks if o_toks else c_toks)
    if "PUNCT" in pos_list or "SPACE" in pos_list:
        return {"PUNCT": c_toks[0].text if c_toks else ""}
    return {"SPELL": c_toks[0].text if c_toks else ""}


def get_two_sided_type(o_toks, c_toks) -> dict[str, str]:
    """Classifies a one-to-one or one-to-many or many-to-one error based on token lists."""
    # one-to-one cases
    if len(o_toks) == len(c_toks) == 1:
        if (
            all(char in punctuation + " " for char in o_toks[0].text) and 
            all(char in punctuation + " " for char in c_toks[0].text)
        ):
            return {"PUNCT": c_toks[0].text}
        source_w, correct_w = o_toks[0].text, c_toks[0].text
        if source_w != correct_w:
            # if both string are lowercase or both are uppercase,
            # and there is no "ё" in both, then it may be only "SPELL" error type
            if (((source_w.islower() and correct_w.islower()) or
                (source_w.isupper() and correct_w.isupper())) and
                    "ё" not in source_w + correct_w):
                return {"SPELL": correct_w}
            # edits with multiple errors (e.g. SPELL + CASE)
            # Step 1. Make char-level Levenstein table
            char_edits = Levenshtein.editops(source_w, correct_w)
            # Step 2. Classify operations (CASE, YO, SPELL)
            edits_classified = classify_char_edits(char_edits, source_w, correct_w)
            # Step 3. Combine the same-typed errors into minimal string pairs
            separated_edits = get_edit_strings(source_w, correct_w, edits_classified)
            return separated_edits
    # one-to-many and many-to-one cases
    if all(char in punctuation + " " for char in o_toks.text + c_toks.text):
        return {"PUNCT": c_toks.text}
    joint_corr_str = " ".join([tok.text for tok in c_toks])
    joint_corr_str = joint_corr_str.replace("- ", "-").replace(" -", "-")
    return {"SPELL": joint_corr_str}


def classify_char_edits(char_edits, source_w, correct_w):
    """Classifies char-level Levenstein operations into SPELL, YO and CASE."""
    edits_classified = []
    for edit in char_edits:
        if edit[0] == "replace":
            if "ё" in [source_w[edit[1]], correct_w[edit[2]]]:
                edits_classified.append((*edit, "YO"))
            elif source_w[edit[1]].lower() == correct_w[edit[2]].lower():
                edits_classified.append((*edit, "CASE"))
            else:
                if (
                    (source_w[edit[1]].islower() and correct_w[edit[2]].isupper()) or
                    (source_w[edit[1]].isupper() and correct_w[edit[2]].islower())
                ):
                    edits_classified.append((*edit, "CASE"))
                edits_classified.append((*edit, "SPELL"))
        else:
            edits_classified.append((*edit, "SPELL"))
    return edits_classified


def get_edit_strings(source: str, correction: str,
                     edits_classified: list[tuple]) -> dict[str, str]:
    """
    Applies classified (SPELL, YO and CASE) char operations to source word separately.
    Returns a dict mapping error type to source string with corrections of this type only.
    """
    separated_edits = defaultdict(lambda: source)
    shift = 0  # char position shift to consider on deletions and insertions
    for edit in edits_classified:
        edit_type = edit[3]
        curr_src = separated_edits[edit_type]
        if edit_type == "CASE":  # SOURCE letter spelled in CORRECTION case
            if correction[edit[2]].isupper():
                correction_char = source[edit[1]].upper()
            else:
                correction_char = source[edit[1]].lower()
        else:
            if edit[0] == "delete":
                correction_char = ""
            elif edit[0] == "insert":
                correction_char = correction[edit[2]]
            elif source[edit[1]].isupper():
                correction_char = correction[edit[2]].upper()
            else:
                correction_char = correction[edit[2]].lower()
        if edit[0] == "replace":
            separated_edits[edit_type] = curr_src[:edit[1] + shift] + correction_char + \
                curr_src[edit[1]+shift + 1:]
        elif edit[0] == "delete":
            separated_edits[edit_type] = curr_src[:edit[1] + shift] + \
                curr_src[edit[1]+shift + 1:]
            shift -= 1
        elif edit[0] == "insert":
            separated_edits[edit_type] = curr_src[:edit[1] + shift] + correction_char + \
                curr_src[edit[1]+shift:]
            shift += 1
    return dict(separated_edits)