|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from rapidfuzz.distance import Levenshtein |
|
from apted import APTED, Config |
|
from apted.helpers import Tree |
|
from lxml import etree, html |
|
from collections import deque |
|
from .parallel import parallel_process |
|
from tqdm import tqdm |
|
|
|
|
|
class TableTree(Tree): |
|
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children): |
|
self.tag = tag |
|
self.colspan = colspan |
|
self.rowspan = rowspan |
|
self.content = content |
|
self.children = list(children) |
|
|
|
def bracket(self): |
|
"""Show tree using brackets notation""" |
|
if self.tag == 'td': |
|
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \ |
|
(self.tag, self.colspan, self.rowspan, self.content) |
|
else: |
|
result = '"tag": %s' % self.tag |
|
for child in self.children: |
|
result += child.bracket() |
|
return "{{{}}}".format(result) |
|
|
|
|
|
class CustomConfig(Config): |
|
def rename(self, node1, node2): |
|
"""Compares attributes of trees""" |
|
|
|
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): |
|
return 1. |
|
if node1.tag == 'td': |
|
if node1.content or node2.content: |
|
|
|
return Levenshtein.normalized_distance(node1.content, node2.content) |
|
return 0. |
|
|
|
|
|
|
|
class CustomConfig_del_short(Config): |
|
def rename(self, node1, node2): |
|
"""Compares attributes of trees""" |
|
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): |
|
return 1. |
|
if node1.tag == 'td': |
|
if node1.content or node2.content: |
|
|
|
|
|
|
|
node1_content = node1.content |
|
node2_content = node2.content |
|
if len(node1_content) < 3: |
|
node1_content = ['####'] |
|
if len(node2_content) < 3: |
|
node2_content = ['####'] |
|
return Levenshtein.normalized_distance(node1_content, node2_content) |
|
return 0. |
|
|
|
class CustomConfig_del_block(Config): |
|
def rename(self, node1, node2): |
|
"""Compares attributes of trees""" |
|
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): |
|
return 1. |
|
if node1.tag == 'td': |
|
if node1.content or node2.content: |
|
|
|
node1_content = node1.content |
|
node2_content = node2.content |
|
while ' ' in node1_content: |
|
print(node1_content.index(' ')) |
|
node1_content.pop(node1_content.index(' ')) |
|
while ' ' in node2_content: |
|
print(node2_content.index(' ')) |
|
node2_content.pop(node2_content.index(' ')) |
|
return Levenshtein.normalized_distance(node1_content, node2_content) |
|
return 0. |
|
|
|
class TEDS(object): |
|
''' Tree Edit Distance basead Similarity |
|
''' |
|
|
|
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None): |
|
assert isinstance(n_jobs, int) and ( |
|
n_jobs >= 1), 'n_jobs must be an integer greather than 1' |
|
self.structure_only = structure_only |
|
self.n_jobs = n_jobs |
|
self.ignore_nodes = ignore_nodes |
|
self.__tokens__ = [] |
|
|
|
def tokenize(self, node): |
|
''' Tokenizes table cells |
|
''' |
|
self.__tokens__.append('<%s>' % node.tag) |
|
if node.text is not None: |
|
self.__tokens__ += list(node.text) |
|
for n in node.getchildren(): |
|
self.tokenize(n) |
|
if node.tag != 'unk': |
|
self.__tokens__.append('</%s>' % node.tag) |
|
if node.tag != 'td' and node.tail is not None: |
|
self.__tokens__ += list(node.tail) |
|
|
|
def load_html_tree(self, node, parent=None): |
|
''' Converts HTML tree to the format required by apted |
|
''' |
|
global __tokens__ |
|
if node.tag == 'td': |
|
if self.structure_only: |
|
cell = [] |
|
else: |
|
self.__tokens__ = [] |
|
self.tokenize(node) |
|
cell = self.__tokens__[1:-1].copy() |
|
new_node = TableTree(node.tag, |
|
int(node.attrib.get('colspan', '1')), |
|
int(node.attrib.get('rowspan', '1')), |
|
cell, *deque()) |
|
else: |
|
new_node = TableTree(node.tag, None, None, None, *deque()) |
|
if parent is not None: |
|
parent.children.append(new_node) |
|
if node.tag != 'td': |
|
for n in node.getchildren(): |
|
self.load_html_tree(n, new_node) |
|
if parent is None: |
|
return new_node |
|
|
|
def evaluate(self, pred, true): |
|
''' Computes TEDS score between the prediction and the ground truth of a |
|
given sample |
|
''' |
|
if (not pred) or (not true): |
|
return 0.0 |
|
parser = html.HTMLParser(remove_comments=True, encoding='utf-8') |
|
pred = html.fromstring(pred, parser=parser) |
|
true = html.fromstring(true, parser=parser) |
|
if pred.xpath('body/table') and true.xpath('body/table'): |
|
pred = pred.xpath('body/table')[0] |
|
true = true.xpath('body/table')[0] |
|
if self.ignore_nodes: |
|
etree.strip_tags(pred, *self.ignore_nodes) |
|
etree.strip_tags(true, *self.ignore_nodes) |
|
n_nodes_pred = len(pred.xpath(".//*")) |
|
n_nodes_true = len(true.xpath(".//*")) |
|
n_nodes = max(n_nodes_pred, n_nodes_true) |
|
tree_pred = self.load_html_tree(pred) |
|
tree_true = self.load_html_tree(true) |
|
distance = APTED(tree_pred, tree_true, |
|
CustomConfig()).compute_edit_distance() |
|
return 1.0 - (float(distance) / n_nodes) |
|
else: |
|
return 0.0 |
|
|
|
def batch_evaluate(self, pred_json, true_json): |
|
''' Computes TEDS score between the prediction and the ground truth of |
|
a batch of samples |
|
@params pred_json: {'FILENAME': 'HTML CODE', ...} |
|
@params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...} |
|
@output: {'FILENAME': 'TEDS SCORE', ...} |
|
''' |
|
samples = true_json.keys() |
|
if self.n_jobs == 1: |
|
scores = [self.evaluate(pred_json.get( |
|
filename, ''), true_json[filename]['html']) for filename in tqdm(samples)] |
|
else: |
|
inputs = [{'pred': pred_json.get( |
|
filename, ''), 'true': true_json[filename]['html']} for filename in samples] |
|
scores = parallel_process( |
|
inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1) |
|
scores = dict(zip(samples, scores)) |
|
return scores |
|
|
|
def batch_evaluate_html(self, pred_htmls, true_htmls): |
|
''' Computes TEDS score between the prediction and the ground truth of |
|
a batch of samples |
|
''' |
|
if self.n_jobs == 1: |
|
scores = [self.evaluate(pred_html, true_html) for ( |
|
pred_html, true_html) in zip(pred_htmls, true_htmls)] |
|
else: |
|
inputs = [{"pred": pred_html, "true": true_html} for( |
|
pred_html, true_html) in zip(pred_htmls, true_htmls)] |
|
|
|
scores = parallel_process( |
|
inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1) |
|
return scores |
|
|
|
|
|
if __name__ == '__main__': |
|
import json |
|
import pprint |
|
with open('sample_pred.json') as fp: |
|
pred_json = json.load(fp) |
|
with open('sample_gt.json') as fp: |
|
true_json = json.load(fp) |
|
teds = TEDS(n_jobs=4) |
|
scores = teds.batch_evaluate(pred_json, true_json) |
|
pp = pprint.PrettyPrinter() |
|
pp.pprint(scores) |
|
|