|
from typing import NamedTuple |
|
from argparse import ArgumentParser |
|
|
|
from tqdm import tqdm |
|
import logging |
|
|
|
import numpy as np |
|
import torch as T |
|
from torch.nn import functional as F |
|
|
|
import diac_utils as du |
|
|
|
_x = [ |
|
'a' |
|
] |
|
|
|
|
|
logger = logging.getLogger(__file__) |
|
logger.setLevel(logging.INFO) |
|
|
|
def logln(*texts: str): |
|
|
|
print(*texts) |
|
|
|
|
|
|
|
|
|
|
|
|
|
class PartialDiacMetrics(NamedTuple): |
|
diff_total: float |
|
worse_total: float |
|
diff_relative: float |
|
der_total: float |
|
selectivity: float |
|
hidden_der: float |
|
partial_der: float |
|
reader_error: float |
|
|
|
def load_data(path: str): |
|
if path.endswith('.txt'): |
|
with open(path, 'r', encoding='utf-8') as fin: |
|
return fin.readlines() |
|
else: |
|
return T.load(path) |
|
|
|
def parse_data( |
|
data, |
|
logits: bool = False, |
|
side=None, |
|
): |
|
if logits: |
|
ld = data['line_data'] |
|
diac_logits = T.tensor(ld[f'diac_logits_{side}']) |
|
|
|
diac_pred: T.Tensor = diac_logits.argmax(dim=-1) |
|
diac_gt : T.Tensor = ld['diac_gt'] |
|
|
|
return diac_pred, diac_gt, diac_logits |
|
if isinstance(data, dict): |
|
ld = data.get('line_data_fix', data['line_data']) |
|
if side is None: |
|
diac_pred: T.Tensor = ld['diac_pred'] |
|
else: |
|
diac_pred: T.Tensor = ld[f'diac_logits_{side}'].argmax(axis=-1) |
|
diac_gt : T.Tensor = ld['diac_gt'] |
|
return diac_pred, diac_gt |
|
elif isinstance(data, list): |
|
data_indices = [ |
|
du.diac_ids_of_line(du.strip_tatweel(du.normalize_spaces(line))) |
|
for line in data |
|
] |
|
max_len = max(map(len, data_indices)) |
|
out = np.full((len(data), max_len), fill_value=du.DIAC_PAD_IDX) |
|
for i_line, line_indices in enumerate(data_indices): |
|
out[i_line][:len(line_indices)] = line_indices |
|
return out, None |
|
elif isinstance(data, (T.Tensor, np.ndarray)): |
|
return data, None |
|
else: |
|
raise NotImplementedError |
|
|
|
def make_mask_hard( |
|
pred_c: T.Tensor, |
|
pred_m: T.Tensor, |
|
): |
|
selection = (pred_c != pred_m) |
|
return selection |
|
|
|
def make_mask_logits( |
|
pred_c: T.Tensor, |
|
pred_m: T.Tensor, |
|
threshold: float = 0.1, |
|
version: str = '2', |
|
) -> T.BoolTensor: |
|
logger.warning(f"{version=}, {threshold=}") |
|
pred_c = T.softmax(T.tensor(pred_c), dim=-1) |
|
pred_m = T.softmax(T.tensor(pred_m), dim=-1) |
|
|
|
if version == 'hard': |
|
selection = pred_c.argmax(-1) != pred_m.argmax(-1) |
|
elif version == '0': |
|
selection = pred_c.max(dim=-1).values > pred_m.max(dim=-1).values |
|
selection = selection & (pred_m.max(dim=-1).values > threshold) |
|
elif version == '1': |
|
pred_c_conf = pred_c.max(dim=-1).values |
|
pred_m_conf = pred_m.max(dim=-1).values |
|
selection = (pred_c_conf - pred_m_conf) > threshold |
|
elif version == '1.1': |
|
pred_c_conf = pred_c.max(dim=-1).values |
|
pred_m_conf = pred_m.max(dim=-1).values |
|
selection = (pred_c_conf - pred_m_conf).abs() > threshold |
|
elif version.startswith('2'): |
|
if version == '2': |
|
max_c = pred_c.argmax(dim=-1, keepdims=True) |
|
selection = T.gather(pred_c - pred_m, dim=-1, index=max_c) > threshold |
|
elif version == '2.1': |
|
max_c = pred_m.argmax(dim=-1, keepdims=True) |
|
selection = T.gather(pred_c - pred_m, dim=-1, index=max_c) > threshold |
|
elif version == '2.abs': |
|
max_c = pred_c.argmax(dim=-1, keepdims=True) |
|
selection = T.gather(pred_c - pred_m, dim=-1, index=max_c).abs() > threshold |
|
elif version == '2.1.abs': |
|
max_c = pred_m.argmax(dim=-1, keepdims=True) |
|
selection = T.gather(pred_c - pred_m, dim=-1, index=max_c).abs() > threshold |
|
elif version == '3': |
|
selection = (pred_c - pred_m).max(dim=-1).values > threshold |
|
elif version == '4': |
|
selection_hard = (pred_c.argmax(-1) != pred_m.argmax(-1)) |
|
|
|
selection_logits = T.gather(pred_c - pred_m, dim=-1, index=pred_c.argmax(-1, keepdims=True)) > threshold |
|
selection = selection_hard & selection_logits.squeeze() |
|
|
|
return selection.squeeze() |
|
|
|
def analysis_summary( |
|
pred_c : T.LongTensor, |
|
pred_m : T.LongTensor, |
|
labels : T.LongTensor, |
|
padding_mask: T.BoolTensor, |
|
*, |
|
selection : T.Tensor = None, |
|
random: bool = False, |
|
logits: tuple = None |
|
): |
|
|
|
|
|
|
|
padding_mask = T.tensor(padding_mask) |
|
|
|
nonpad_mask = ~padding_mask |
|
num_chars = nonpad_mask.sum() |
|
|
|
if logits is not None: |
|
logits = tuple(map(T.tensor, logits)) |
|
|
|
pred_c = (T.softmax(logits[0], dim=-1) + T.softmax(logits[1], dim=-1)).argmax(-1) |
|
pred_c = T.tensor(pred_c)[nonpad_mask] |
|
pred_m = T.tensor(pred_m)[nonpad_mask] |
|
labels = T.tensor(labels)[nonpad_mask] |
|
|
|
|
|
ctxt_match = (pred_c == labels).float() |
|
base_match = (pred_m == labels).float() |
|
|
|
selection = T.tensor(selection)[nonpad_mask] |
|
if random: |
|
selection = pred_c.new_empty(pred_c.shape).bernoulli_(p=selection.float().mean()).to(bool) |
|
unselected = ~selection |
|
|
|
assert num_chars > 0 |
|
assert selection.sum() > 0 |
|
base_accuracy = base_match[unselected].sum() / unselected.sum() |
|
ctxt_accuracy = ctxt_match[selection].sum() / selection.sum() |
|
correct_total = ctxt_match.sum() / num_chars |
|
der_total = 1 - correct_total |
|
|
|
cmp = (ctxt_match - base_match)[selection] |
|
diff = T.sum(cmp) |
|
diff_total = diff / num_chars |
|
diff_relative = diff / selection.sum() |
|
|
|
selectivity = selection.sum() / num_chars |
|
worse_total = base_match[selection].sum() / num_chars |
|
|
|
hidden_der = 1.0 - base_accuracy |
|
partial_der = 1.0 - ctxt_accuracy |
|
reader_error = selectivity * partial_der + (1 - selectivity) * hidden_der |
|
|
|
return PartialDiacMetrics( |
|
diff_total = round(diff_total.item() * 100, 2), |
|
worse_total = round(worse_total.item() * 100, 2), |
|
diff_relative = round(diff_relative.item() * 100, 2), |
|
der_total = round(der_total.item() * 100, 2), |
|
selectivity = round(selectivity.item() * 100, 2), |
|
hidden_der = round(hidden_der.item() * 100, 2), |
|
partial_der = round(partial_der.item() * 100, 2), |
|
reader_error = round(reader_error.item() * 100, 2) |
|
) |
|
|
|
|
|
def relative_improvement_soft( |
|
pred_c : T.Tensor, |
|
pred_m : T.Tensor, |
|
labels : T.LongTensor, |
|
padding_mask: T.Tensor, |
|
): |
|
|
|
|
|
padding_mask = T.tensor(padding_mask) |
|
nonpad_mask = 1 - padding_mask.float() |
|
num_chars = nonpad_mask.sum() |
|
|
|
pred_c = T.tensor(pred_c)[~padding_mask] |
|
pred_m = T.tensor(pred_m)[~padding_mask] |
|
|
|
labels = T.tensor(labels)[~padding_mask] |
|
|
|
|
|
ctxt_match = T.gather(pred_c, dim=1, index=labels) |
|
base_match = T.gather(pred_m, dim=1, index=labels) |
|
selection = (pred_c.argmax(-1) != pred_m.argmax(-1)) |
|
|
|
better = T.sum(ctxt_match - base_match) / num_chars |
|
selectivity = selection.sum() / num_chars |
|
worse = base_match[selection].sum() / num_chars |
|
return better, worse, selectivity |
|
|
|
def relative_improvement_masked_soft( |
|
pred_c: T.Tensor, |
|
pred_m: T.Tensor, |
|
ground_truth: T.LongTensor, |
|
padding_mask: T.Tensor, |
|
): |
|
raise NotImplementedError |
|
|
|
|
|
|
|
nonpad_mask = 1 - padding_mask |
|
|
|
selection_mask = pred_c.argmax(3) != pred_m.argmax(3) |
|
|
|
probs = F.softmax(pred_c.clone(), dim=-1) |
|
probs_gt = T.gather(probs, dim=-1, index=ground_truth.unsqueeze(-1)).squeeze(-1) |
|
|
|
result = probs_gt[selection_mask & nonpad_mask].mean() |
|
return result |
|
|
|
def coverage_confidence( |
|
pred_c: T.Tensor, |
|
pred_m: T.Tensor, |
|
padding_mask: T.Tensor, |
|
|
|
): |
|
raise NotImplementedError |
|
|
|
|
|
|
|
pred_c_id = pred_c.argmax(3) |
|
pred_m_id = pred_m.argmax(3) |
|
selected = pred_c_id[pred_c_id != pred_m_id] |
|
nonpad_mask = 1 - padding_mask |
|
result = selected.sum() / nonpad_mask.sum() |
|
return result |
|
|
|
def cli(): |
|
parser = ArgumentParser('Compare diacritics from base/ctxt systems with partial diac metrics.') |
|
parser.add_argument('-m', '--model-output-base', help="Path to tensor.pt dump files of base diacs.") |
|
parser.add_argument('-c', '--model-output-ctxt', help="Path to tensor.pt dump files of ctxt diacs.") |
|
parser.add_argument('--gt', default=None, help="Path to tensor.pt for gt only.") |
|
parser.add_argument('--mode', choices=['hard', 'logits'], default='hard') |
|
args = parser.parse_args() |
|
|
|
model_output_base = parse_data( |
|
load_data(args.model_output_base), |
|
|
|
logits=True, |
|
side='base', |
|
) |
|
model_output_ctxt = parse_data( |
|
load_data(args.model_output_ctxt), |
|
|
|
logits=True, |
|
side='ctxt', |
|
) |
|
|
|
diacs_pred = model_output_base |
|
|
|
logln(f"{model_output_base[0].shape=} , {model_output_ctxt[0].shape=}") |
|
|
|
assert len(model_output_base[0]) == len(model_output_ctxt[0]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
xc = model_output_ctxt |
|
xm = model_output_base |
|
|
|
|
|
|
|
|
|
|
|
|
|
if xm[1] is not None: |
|
ground_truth = xm[1] |
|
elif xc[1] is not None: |
|
ground_truth = xc[1] |
|
assert ground_truth is not None |
|
|
|
if args.mode == 'hard': |
|
selection = make_mask_hard(xc[0], xm[0]) |
|
elif args.mode == 'logits': |
|
selection = make_mask_logits(xc[2], xm[2]) |
|
|
|
metrics = analysis_summary( |
|
xc[0], xm[0], ground_truth, ground_truth == -1, |
|
selection=selection, |
|
logits=(xc[2], xm[2]) |
|
) |
|
logln("Actual Totals:", metrics) |
|
metrics = analysis_summary( |
|
xc[0], xm[0], ground_truth, ground_truth == -1, random=True, |
|
selection=selection, |
|
logits=(xc[2], xm[2]) |
|
) |
|
logln("Random Marked Chars:", metrics) |
|
|