|
from __future__ import annotations |
|
|
|
import itertools |
|
import re |
|
from string import punctuation |
|
|
|
import Levenshtein |
|
from errant.alignment import Alignment |
|
from errant.edit import Edit |
|
|
|
|
|
def get_rule_edits(alignment: Alignment) -> list[Edit]: |
|
"""Groups word-level alignment according to merging rules.""" |
|
edits = [] |
|
|
|
alignment_groups = group_alignment(alignment, "new") |
|
for op, group in alignment_groups: |
|
group = list(group) |
|
|
|
if op == "M": |
|
continue |
|
|
|
if op == "T": |
|
for seq in group: |
|
edits.append(Edit(alignment.orig, alignment.cor, seq[1:])) |
|
|
|
else: |
|
processed = process_seq(group, alignment) |
|
|
|
for seq in processed: |
|
edits.append(Edit(alignment.orig, alignment.cor, seq[1:])) |
|
return edits |
|
|
|
|
|
def group_alignment(alignment: Alignment, mode: str = "default") -> list[tuple[str, list[tuple]]]: |
|
""" |
|
Does initial alignment grouping: |
|
1. Make groups of MDM, MIM od MSM. |
|
2. In remaining operations, make groups of Ms, groups of Ts, and D/I/Ss. |
|
Do not group what was on the sides of M[DIS]M: SSMDMS -> [SS, MDM, S], not [MDM, SSS]. |
|
3. Sort groups by the order in which they appear in the alignment. |
|
""" |
|
if mode == "new": |
|
op_groups = [] |
|
|
|
all_ops_seq = "".join([op[0][0] for op in alignment.align_seq]) |
|
|
|
ungrouped_ids = list(range(len(alignment.align_seq))) |
|
for match in re.finditer("M[DIS]M", all_ops_seq): |
|
start, end = match.start(), match.end() |
|
op_groups.append(("MSM", alignment.align_seq[start:end])) |
|
for idx in range(start, end): |
|
ungrouped_ids.remove(idx) |
|
|
|
if ungrouped_ids: |
|
def get_group_type(operation): |
|
return operation if operation in {"M", "T"} else "DIS" |
|
curr_group = [alignment.align_seq[ungrouped_ids[0]]] |
|
last_oper_type = get_group_type(curr_group[0][0][0]) |
|
for i, idx in enumerate(ungrouped_ids[1:], start=1): |
|
operation = alignment.align_seq[idx] |
|
oper_type = get_group_type(operation[0][0]) |
|
if (oper_type == last_oper_type and |
|
(idx - ungrouped_ids[i-1] == 1 or oper_type in {"M", "T"})): |
|
curr_group.append(operation) |
|
else: |
|
op_groups.append((last_oper_type, curr_group)) |
|
curr_group = [operation] |
|
last_oper_type = oper_type |
|
if curr_group: |
|
op_groups.append((last_oper_type, curr_group)) |
|
|
|
op_groups = sorted(op_groups, key=lambda x: x[1][0][1]) |
|
else: |
|
grouped = itertools.groupby(alignment.align_seq, |
|
lambda x: x[0][0] if x[0][0] in {"M", "T"} else False) |
|
op_groups = [(op, list(group)) for op, group in grouped] |
|
return op_groups |
|
|
|
|
|
def process_seq(seq: list[tuple], alignment: Alignment) -> list[tuple]: |
|
"""Applies merging rules to previously formed alignment groups (`seq`).""" |
|
|
|
if len(seq) <= 1: |
|
return seq |
|
|
|
ops = [op[0] for op in seq] |
|
|
|
|
|
combos = list(itertools.combinations(range(0, len(seq)), 2)) |
|
|
|
combos.sort(key=lambda x: x[1] - x[0], reverse=True) |
|
|
|
for start, end in combos: |
|
|
|
if not any(type_ in ops[start:end + 1] for type_ in ["D", "I", "S"]): |
|
continue |
|
|
|
if set(ops[start:end + 1]) == {"D"} or set(ops[start:end + 1]) == {"I"}: |
|
return (process_seq(seq[:start], alignment) |
|
+ merge_edits(seq[start:end + 1]) |
|
+ process_seq(seq[end + 1:], alignment)) |
|
|
|
o = alignment.orig[seq[start][1]:seq[end][2]] |
|
c = alignment.cor[seq[start][3]:seq[end][4]] |
|
if ops[start:end + 1] in [["M", "D", "M"], ["M", "I", "M"], ["M", "S", "M"]]: |
|
|
|
if (o[start + 1].text == "-" or c[start + 1].text == "-") and len(o) != len(c): |
|
return (process_seq(seq[:start], alignment) |
|
+ merge_edits(seq[start:end + 1]) |
|
+ process_seq(seq[end + 1:], alignment)) |
|
|
|
return seq[start + 1: end] |
|
|
|
if o[-1].tag_ == "POS" or c[-1].tag_ == "POS": |
|
return (process_seq(seq[:end - 1], alignment) |
|
+ merge_edits(seq[end - 1:end + 1]) |
|
+ process_seq(seq[end + 1:], alignment)) |
|
|
|
if o[-1].lower == c[-1].lower: |
|
|
|
if (start == 0 and |
|
(len(o) == 1 and c[0].text[0].isupper()) or |
|
(len(c) == 1 and o[0].text[0].isupper())): |
|
return (merge_edits(seq[start:end + 1]) |
|
+ process_seq(seq[end + 1:], alignment)) |
|
|
|
if (len(o) > 1 and is_punct(o[-2])) or \ |
|
(len(c) > 1 and is_punct(c[-2])): |
|
return (process_seq(seq[:end - 1], alignment) |
|
+ merge_edits(seq[end - 1:end + 1]) |
|
+ process_seq(seq[end + 1:], alignment)) |
|
|
|
s_str = re.sub("['-]", "", "".join([tok.lower_ for tok in o])) |
|
t_str = re.sub("['-]", "", "".join([tok.lower_ for tok in c])) |
|
if s_str == t_str or s_str.replace(" ", "") == t_str.replace(" ", ""): |
|
return (process_seq(seq[:start], alignment) |
|
+ merge_edits(seq[start:end + 1]) |
|
+ process_seq(seq[end + 1:], alignment)) |
|
|
|
|
|
pos_set = set([tok.pos for tok in o] + [tok.pos for tok in c]) |
|
if len(o) != len(c) and (len(pos_set) == 1 or pos_set.issubset({"AUX", "PART", "VERB"})): |
|
return (process_seq(seq[:start], alignment) |
|
+ merge_edits(seq[start:end + 1]) |
|
+ process_seq(seq[end + 1:], alignment)) |
|
|
|
if end - start < 2: |
|
|
|
if len(o) == len(c) == 2: |
|
return (process_seq(seq[:start + 1], alignment) |
|
+ process_seq(seq[start + 1:], alignment)) |
|
|
|
if ((ops[start] == "S" and char_cost(o[0].text, c[0].text) > 0.75) or |
|
(ops[end] == "S" and char_cost(o[-1].text, c[-1].text) > 0.75)): |
|
return (process_seq(seq[:start + 1], alignment) |
|
+ process_seq(seq[start + 1:], alignment)) |
|
|
|
if (end == len(seq) - 1 and |
|
((ops[-1] in {"D", "S"} and o[-1].pos == "DET") or |
|
(ops[-1] in {"I", "S"} and c[-1].pos == "DET"))): |
|
return process_seq(seq[:-1], alignment) + [seq[-1]] |
|
return seq |
|
|
|
|
|
def is_punct(token) -> bool: |
|
return token.text in punctuation |
|
|
|
|
|
def char_cost(a: str, b: str) -> float: |
|
"""Calculate the cost of character alignment; i.e. char similarity.""" |
|
|
|
return Levenshtein.ratio(a, b) |
|
|
|
|
|
def merge_edits(seq: list[tuple]) -> list[tuple]: |
|
"""Merge the input alignment sequence to a single edit span.""" |
|
|
|
if seq: |
|
return [("X", seq[0][1], seq[-1][2], seq[0][3], seq[-1][4])] |
|
return seq |
|
|