pcback commited on
Commit
f4c516d
1 Parent(s): 2468c64

First commit

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.traineddata filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ import PIL
4
+ import numpy as np
5
+ import torch
6
+ from collections import defaultdict
7
+
8
+ import cv2
9
+ from doctr.io import DocumentFile
10
+ from doctr.models import ocr_predictor
11
+ from doctr.utils.visualization import visualize_page
12
+
13
+ import pytesseract
14
+ from pytesseract import Output
15
+
16
+ from bs4 import BeautifulSoup as bs
17
+
18
+ import sys, json
19
+
20
+ import postprocess
21
+
22
+
23
+ ocr_predictor = ocr_predictor('db_resnet50', 'crnn_vgg16_bn', pretrained=True)
24
+ structure_model = torch.hub.load('ultralytics/yolov5', 'custom', 'weights/structure_wts.pt', force_reload=True)
25
+ imgsz = 640
26
+
27
+ structure_class_names = [
28
+ 'table', 'table column', 'table row', 'table column header',
29
+ 'table projected row header', 'table spanning cell', 'no object'
30
+ ]
31
+ structure_class_map = {k: v for v, k in enumerate(structure_class_names)}
32
+ structure_class_thresholds = {
33
+ "table": 0.5,
34
+ "table column": 0.5,
35
+ "table row": 0.5,
36
+ "table column header": 0.25,
37
+ "table projected row header": 0.25,
38
+ "table spanning cell": 0.25,
39
+ "no object": 10
40
+ }
41
+
42
+
43
+ def table_structure(filename):
44
+ image = cv2.imread(filename)
45
+ pred = structure_model(image, size=imgsz)
46
+ pred = pred.xywhn[0]
47
+ result = pred.cpu().numpy()
48
+ return result
49
+
50
+
51
+ def ocr(filename):
52
+ doc = DocumentFile.from_images(filename)
53
+ result = ocr_predictor(doc).export()
54
+ result = result['pages'][0]
55
+ H, W = result['dimensions']
56
+ ocr_res = []
57
+ for block in result['blocks']:
58
+ for line in block['lines']:
59
+ for word in line['words']:
60
+ bbox = word['geometry']
61
+ word_info = {
62
+ 'bbox': [int(bbox[0][0] * W), int(bbox[0][1] * H), int(bbox[1][0] * W), int(bbox[1][1] * H)],
63
+ 'text': word['value']
64
+ }
65
+ ocr_res.append(word_info)
66
+ return ocr_res
67
+
68
+
69
+ def convert_stucture(page_tokens, filename, structure_result):
70
+ image = cv2.imread(filename)
71
+ width = image.shape[1]
72
+ height = image.shape[0]
73
+ # print(width, height)
74
+
75
+ bboxes = []
76
+ scores = []
77
+ labels = []
78
+ for i, result in enumerate(structure_result):
79
+ class_id = int(result[5])
80
+ score = float(result[4])
81
+ min_x = result[0]
82
+ min_y = result[1]
83
+ w = result[2]
84
+ h = result[3]
85
+
86
+ x1 = int((min_x-w/2)*width)
87
+ y1 = int((min_y-h/2)*height)
88
+ x2 = int((min_x+w/2)*width)
89
+ y2 = int((min_y+h/2)*height)
90
+ # print(x1, y1, x2, y2)
91
+
92
+ bboxes.append([x1, y1, x2, y2])
93
+ scores.append(score)
94
+ labels.append(class_id)
95
+
96
+ table_objects = []
97
+ for bbox, score, label in zip(bboxes, scores, labels):
98
+ table_objects.append({'bbox': bbox, 'score': score, 'label': label})
99
+ # print('table_objects:', table_objects)
100
+
101
+ table = {'objects': table_objects, 'page_num': 0}
102
+
103
+ table_class_objects = [obj for obj in table_objects if obj['label'] == structure_class_map['table']]
104
+ if len(table_class_objects) > 1:
105
+ table_class_objects = sorted(table_class_objects, key=lambda x: x['score'], reverse=True)
106
+ try:
107
+ table_bbox = list(table_class_objects[0]['bbox'])
108
+ except:
109
+ table_bbox = (0,0,1000,1000)
110
+ # print('table_class_objects:', table_class_objects)
111
+ # print('table_bbox:', table_bbox)
112
+
113
+ tokens_in_table = [token for token in page_tokens if postprocess.iob(token['bbox'], table_bbox) >= 0.5]
114
+ # print('tokens_in_table:', tokens_in_table)
115
+
116
+ table_structures, cells, confidence_score = postprocess.objects_to_cells(table, table_objects, tokens_in_table, structure_class_names, structure_class_thresholds)
117
+
118
+ return table_structures, cells, confidence_score
119
+
120
+
121
+ def visualize_cells(filename, cells, ax):
122
+ image = cv2.imread(filename)
123
+ for i, cell in enumerate(cells):
124
+ bbox = cell['bbox']
125
+ x1 = int(bbox[0])
126
+ y1 = int(bbox[1])
127
+ x2 = int(bbox[2])
128
+ y2 = int(bbox[3])
129
+ cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0))
130
+ ax.image(image)
131
+
132
+
133
+ def pytess(cell_pil_img):
134
+ return ' '.join(pytesseract.image_to_data(cell_pil_img, output_type=Output.DICT, config='-c tessedit_char_blacklist=œ˜â€œï¬â™Ã©œ¢!|”?«“¥ --tessdata-dir tessdata --oem 3 --psm 6')['text']).strip()
135
+
136
+
137
+ def resize(pil_img, size=1800):
138
+ length_x, width_y = pil_img.size
139
+ factor = max(1, size / length_x)
140
+ size = int(factor * length_x), int(factor * width_y)
141
+ pil_img = pil_img.resize(size, PIL.Image.ANTIALIAS)
142
+ return pil_img, factor
143
+
144
+
145
+ def image_smoothening(img):
146
+ ret1, th1 = cv2.threshold(img, 180, 255, cv2.THRESH_BINARY)
147
+ ret2, th2 = cv2.threshold(th1, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
148
+ blur = cv2.GaussianBlur(th2, (1, 1), 0)
149
+ ret3, th3 = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
150
+ return th3
151
+
152
+
153
+ def remove_noise_and_smooth(pil_img):
154
+ img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
155
+ filtered = cv2.adaptiveThreshold(img.astype(np.uint8), 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 41, 3)
156
+ kernel = np.ones((1, 1), np.uint8)
157
+ opening = cv2.morphologyEx(filtered, cv2.MORPH_OPEN, kernel)
158
+ closing = cv2.morphologyEx(opening, cv2.MORPH_CLOSE, kernel)
159
+ img = image_smoothening(img)
160
+ or_image = cv2.bitwise_or(img, closing)
161
+ pil_img = PIL.Image.fromarray(or_image)
162
+ return pil_img
163
+
164
+
165
+ def extract_text_from_cells(filename, cells):
166
+ pil_img = PIL.Image.open(filename)
167
+ pil_img, factor = resize(pil_img)
168
+ #pil_img = remove_noise_and_smooth(pil_img)
169
+ #display(pil_img)
170
+ for cell in cells:
171
+ bbox = [x * factor for x in cell['bbox']]
172
+ cell_pil_img = pil_img.crop(bbox)
173
+ #cell_pil_img = remove_noise_and_smooth(cell_pil_img)
174
+ #cell_pil_img = tess_prep(cell_pil_img)
175
+ cell['text'] = pytess(cell_pil_img)
176
+ return cells
177
+
178
+
179
+ def cells_to_html(cells):
180
+ n_cols = max(cell['column_nums'][-1] for cell in cells) + 1
181
+ n_rows = max(cell['row_nums'][-1] for cell in cells) + 1
182
+ html_code = ''
183
+ for r in range(n_rows):
184
+ r_cells = [cell for cell in cells if cell['row_nums'][0] == r]
185
+ r_cells.sort(key=lambda x: x['column_nums'][0])
186
+ r_html = ''
187
+ for cell in r_cells:
188
+ rowspan = cell['row_nums'][-1] - cell['row_nums'][0] + 1
189
+ colspan = cell['column_nums'][-1] - cell['column_nums'][0] + 1
190
+ r_html += f'<td rowspan="{rowspan}" colspan="{colspan}">{cell["text"]}</td>'
191
+ html_code += f'<tr>{r_html}</tr>'
192
+ html_code = '''<html>
193
+ <head>
194
+ <meta charset="UTF-8">
195
+ <style>
196
+ table, th, td {
197
+ border: 1px solid black;
198
+ font-size: 10px;
199
+ }
200
+ </style>
201
+ </head>
202
+ <body>
203
+ <table frame="hsides" rules="groups" width="100%%">
204
+ %s
205
+ </table>
206
+ </body>
207
+ </html>''' % html_code
208
+ soup = bs(html_code)
209
+ html_code = soup.prettify()
210
+ return html_code
211
+
212
+
213
+ def main():
214
+
215
+ st.set_page_config(layout="wide")
216
+ st.title("Table Structure Recognition Demo")
217
+ st.write('\n')
218
+
219
+ cols = st.beta_columns((1, 1, 1))
220
+ cols[0].subheader("Input page")
221
+ cols[1].subheader("Structure output")
222
+ cols[2].subheader("HTML output")
223
+
224
+ st.sidebar.title("Image upload")
225
+ st.set_option('deprecation.showfileUploaderEncoding', False)
226
+ filename = st.sidebar.file_uploader("Upload files", type=['png', 'jpeg', 'jpg'])
227
+
228
+ cols[0].image(cv2.imread(filename))
229
+
230
+ ocr_res = ocr(filename)
231
+ structure_result = table_structure(filename)
232
+ table_structures, cells, confidence_score = convert_stucture(ocr_res, filename, structure_result)
233
+ visualize_cells(filename, cells, cols[1])
234
+
235
+ cells = extract_text_from_cells(filename, cells)
236
+ html_code = cells_to_html(cells)
237
+
238
+ cols[2].html(html_code)
239
+
packages.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ffmpeg
2
+ libsm6
3
+ libxext6
4
+ libgl1
5
+ tesseract-ocr-eng
6
+ python3-opencv
postprocess.py ADDED
@@ -0,0 +1,895 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2021 Microsoft Corporation
3
+ """
4
+ from collections import defaultdict
5
+
6
+ from fitz import Rect
7
+
8
+
9
+ def apply_threshold(objects, threshold):
10
+ """
11
+ Filter out objects below a certain score.
12
+ """
13
+ return [obj for obj in objects if obj['score'] >= threshold]
14
+
15
+
16
+ def apply_class_thresholds(bboxes, labels, scores, class_names, class_thresholds):
17
+ """
18
+ Filter out bounding boxes whose confidence is below the confidence threshold for
19
+ its associated class label.
20
+ """
21
+ # Apply class-specific thresholds
22
+ indices_above_threshold = [idx for idx, (score, label) in enumerate(zip(scores, labels))
23
+ if score >= class_thresholds[
24
+ class_names[label]
25
+ ]
26
+ ]
27
+ bboxes = [bboxes[idx] for idx in indices_above_threshold]
28
+ scores = [scores[idx] for idx in indices_above_threshold]
29
+ labels = [labels[idx] for idx in indices_above_threshold]
30
+
31
+ return bboxes, scores, labels
32
+
33
+
34
+ def iou(bbox1, bbox2):
35
+ """
36
+ Compute the intersection-over-union of two bounding boxes.
37
+ """
38
+ intersection = Rect(bbox1).intersect(bbox2)
39
+ union = Rect(bbox1).include_rect(bbox2)
40
+
41
+ union_area = union.get_area() # getArea()
42
+ if union_area > 0:
43
+ return intersection.get_area() / union.get_area() # .getArea()
44
+
45
+ return 0
46
+
47
+
48
+ def iob(bbox1, bbox2):
49
+ """
50
+ Compute the intersection area over box area, for bbox1.
51
+ """
52
+ intersection = Rect(bbox1).intersect(bbox2)
53
+
54
+ bbox1_area = Rect(bbox1).get_area() # .getArea()
55
+ if bbox1_area > 0:
56
+ return intersection.get_area() / bbox1_area # getArea()
57
+
58
+ return 0
59
+
60
+
61
+ def objects_to_cells(table, objects_in_table, tokens_in_table, class_map, class_thresholds):
62
+ """
63
+ Process the bounding boxes produced by the table structure recognition model
64
+ and the token/word/span bounding boxes into table cells.
65
+
66
+ Also return a confidence score based on how well the text was able to be
67
+ uniquely slotted into the cells detected by the table model.
68
+ """
69
+
70
+ table_structures = objects_to_table_structures(table, objects_in_table, tokens_in_table, class_map,
71
+ class_thresholds)
72
+
73
+ # Check for a valid table
74
+ if len(table_structures['columns']) < 1 or len(table_structures['rows']) < 1:
75
+ cells = []#None
76
+ confidence_score = 0
77
+ else:
78
+ cells, confidence_score = table_structure_to_cells(table_structures, tokens_in_table, table['bbox'])
79
+
80
+ return table_structures, cells, confidence_score
81
+
82
+
83
+ def objects_to_table_structures(table_object, objects_in_table, tokens_in_table, class_names, class_thresholds):
84
+ """
85
+ Process the bounding boxes produced by the table structure recognition model into
86
+ a *consistent* set of table structures (rows, columns, supercells, headers).
87
+ This entails resolving conflicts/overlaps, and ensuring the boxes meet certain alignment
88
+ conditions (for example: rows should all have the same width, etc.).
89
+ """
90
+
91
+ page_num = table_object['page_num']
92
+
93
+ table_structures = {}
94
+
95
+ columns = [obj for obj in objects_in_table if class_names[obj['label']] == 'table column']
96
+ rows = [obj for obj in objects_in_table if class_names[obj['label']] == 'table row']
97
+ headers = [obj for obj in objects_in_table if class_names[obj['label']] == 'table column header']
98
+ supercells = [obj for obj in objects_in_table if class_names[obj['label']] == 'table spanning cell']
99
+ for obj in supercells:
100
+ obj['subheader'] = False
101
+ subheaders = [obj for obj in objects_in_table if class_names[obj['label']] == 'table projected row header']
102
+ for obj in subheaders:
103
+ obj['subheader'] = True
104
+ supercells += subheaders
105
+ for obj in rows:
106
+ obj['header'] = False
107
+ for header_obj in headers:
108
+ if iob(obj['bbox'], header_obj['bbox']) >= 0.5:
109
+ obj['header'] = True
110
+
111
+ for row in rows:
112
+ row['page'] = page_num
113
+
114
+ for column in columns:
115
+ column['page'] = page_num
116
+
117
+ #Refine table structures
118
+ rows = refine_rows(rows, tokens_in_table, class_thresholds['table row'])
119
+ columns = refine_columns(columns, tokens_in_table, class_thresholds['table column'])
120
+
121
+ # Shrink table bbox to just the total height of the rows
122
+ # and the total width of the columns
123
+ row_rect = Rect()
124
+ for obj in rows:
125
+ row_rect.include_rect(obj['bbox'])
126
+ column_rect = Rect()
127
+ for obj in columns:
128
+ column_rect.include_rect(obj['bbox'])
129
+ table_object['row_column_bbox'] = [column_rect[0], row_rect[1], column_rect[2], row_rect[3]]
130
+ table_object['bbox'] = table_object['row_column_bbox']
131
+
132
+ # Process the rows and columns into a complete segmented table
133
+ columns = align_columns(columns, table_object['row_column_bbox'])
134
+ rows = align_rows(rows, table_object['row_column_bbox'])
135
+
136
+ table_structures['rows'] = rows
137
+ table_structures['columns'] = columns
138
+ table_structures['headers'] = headers
139
+ table_structures['supercells'] = supercells
140
+
141
+ if len(rows) > 0 and len(columns) > 1:
142
+ table_structures = refine_table_structures(table_object['bbox'], table_structures, tokens_in_table, class_thresholds)
143
+
144
+ return table_structures
145
+
146
+
147
+ def refine_rows(rows, page_spans, score_threshold):
148
+ """
149
+ Apply operations to the detected rows, such as
150
+ thresholding, NMS, and alignment.
151
+ """
152
+
153
+ #MODIFY
154
+ rows = [obj for obj in rows if obj['score'] >= score_threshold or obj['header']]
155
+ ###
156
+
157
+ rows = nms_by_containment(rows, page_spans, overlap_threshold=0.5)
158
+ # remove_objects_without_content(page_spans, rows) # TODO
159
+ if len(rows) > 1:
160
+ rows = sort_objects_top_to_bottom(rows)
161
+
162
+ return rows
163
+
164
+
165
+ def refine_columns(columns, page_spans, score_threshold):
166
+ """
167
+ Apply operations to the detected columns, such as
168
+ thresholding, NMS, and alignment.
169
+ """
170
+
171
+ #MODIFY
172
+ columns = [obj for obj in columns if obj['score'] >= score_threshold]
173
+ ###
174
+
175
+ columns = nms_by_containment(columns, page_spans, overlap_threshold=0.5)
176
+ # remove_objects_without_content(page_spans, columns) # TODO
177
+ if len(columns) > 1:
178
+ columns = sort_objects_left_to_right(columns)
179
+
180
+ return columns
181
+
182
+
183
+ def nms_by_containment(container_objects, package_objects, overlap_threshold=0.5):
184
+ """
185
+ Non-maxima suppression (NMS) of objects based on shared containment of other objects.
186
+ """
187
+ container_objects = sort_objects_by_score(container_objects)
188
+ num_objects = len(container_objects)
189
+ suppression = [False for obj in container_objects]
190
+
191
+ packages_by_container, _, _ = slot_into_containers(container_objects, package_objects, overlap_threshold=overlap_threshold,
192
+ unique_assignment=True, forced_assignment=False)
193
+
194
+ for object2_num in range(1, num_objects):
195
+ object2_packages = set(packages_by_container[object2_num])
196
+ if len(object2_packages) == 0:
197
+ suppression[object2_num] = True
198
+ for object1_num in range(object2_num):
199
+ if not suppression[object1_num]:
200
+ object1_packages = set(packages_by_container[object1_num])
201
+ if len(object2_packages.intersection(object1_packages)) > 0:
202
+ suppression[object2_num] = True
203
+
204
+ final_objects = [obj for idx, obj in enumerate(container_objects) if not suppression[idx]]
205
+ return final_objects
206
+
207
+
208
+ def slot_into_containers(container_objects, package_objects, overlap_threshold=0.5,
209
+ unique_assignment=True, forced_assignment=False):
210
+ """
211
+ Slot a collection of objects into the container they occupy most (the container which holds the largest fraction of the object).
212
+ """
213
+ best_match_scores = []
214
+
215
+ container_assignments = [[] for container in container_objects]
216
+ package_assignments = [[] for package in package_objects]
217
+
218
+ if len(container_objects) == 0 or len(package_objects) == 0:
219
+ return container_assignments, package_assignments, best_match_scores
220
+
221
+ match_scores = defaultdict(dict)
222
+ for package_num, package in enumerate(package_objects):
223
+ match_scores = []
224
+ package_rect = Rect(package['bbox'])
225
+ package_area = package_rect.get_area() # getArea()
226
+ for container_num, container in enumerate(container_objects):
227
+ container_rect = Rect(container['bbox'])
228
+ intersect_area = container_rect.intersect(package['bbox']).get_area() # getArea()
229
+ overlap_fraction = intersect_area / package_area
230
+ match_scores.append({'container': container, 'container_num': container_num, 'score': overlap_fraction})
231
+
232
+ sorted_match_scores = sort_objects_by_score(match_scores)
233
+
234
+ best_match_score = sorted_match_scores[0]
235
+ best_match_scores.append(best_match_score['score'])
236
+ if forced_assignment or best_match_score['score'] >= overlap_threshold:
237
+ container_assignments[best_match_score['container_num']].append(package_num)
238
+ package_assignments[package_num].append(best_match_score['container_num'])
239
+
240
+ if not unique_assignment: # slot package into all eligible slots
241
+ for match_score in sorted_match_scores[1:]:
242
+ if match_score['score'] >= overlap_threshold:
243
+ container_assignments[match_score['container_num']].append(package_num)
244
+ package_assignments[package_num].append(match_score['container_num'])
245
+ else:
246
+ break
247
+
248
+ return container_assignments, package_assignments, best_match_scores
249
+
250
+
251
+ def sort_objects_by_score(objects, reverse=True):
252
+ """
253
+ Put any set of objects in order from high score to low score.
254
+ """
255
+ if reverse:
256
+ sign = -1
257
+ else:
258
+ sign = 1
259
+ return sorted(objects, key=lambda k: sign*k['score'])
260
+
261
+
262
+ def remove_objects_without_content(page_spans, objects):
263
+ """
264
+ Remove any objects (these can be rows, columns, supercells, etc.) that don't
265
+ have any text associated with them.
266
+ """
267
+ for obj in objects[:]:
268
+ object_text, _ = extract_text_inside_bbox(page_spans, obj['bbox'])
269
+ if len(object_text.strip()) == 0:
270
+ objects.remove(obj)
271
+
272
+
273
+ def extract_text_inside_bbox(spans, bbox):
274
+ """
275
+ Extract the text inside a bounding box.
276
+ """
277
+ bbox_spans = get_bbox_span_subset(spans, bbox)
278
+ bbox_text = extract_text_from_spans(bbox_spans, remove_integer_superscripts=True)
279
+
280
+ return bbox_text, bbox_spans
281
+
282
+
283
+ def get_bbox_span_subset(spans, bbox, threshold=0.5):
284
+ """
285
+ Reduce the set of spans to those that fall within a bounding box.
286
+
287
+ threshold: the fraction of the span that must overlap with the bbox.
288
+ """
289
+ span_subset = []
290
+ for span in spans:
291
+ if overlaps(span['bbox'], bbox, threshold):
292
+ span_subset.append(span)
293
+ return span_subset
294
+
295
+
296
+ def overlaps(bbox1, bbox2, threshold=0.5):
297
+ """
298
+ Test if more than "threshold" fraction of bbox1 overlaps with bbox2.
299
+ """
300
+ rect1 = Rect(list(bbox1))
301
+ area1 = rect1.get_area() # .getArea()
302
+ if area1 == 0:
303
+ return False
304
+ return rect1.intersect(list(bbox2)).get_area()/area1 >= threshold # getArea()
305
+
306
+
307
+ def extract_text_from_spans(spans, join_with_space=True, remove_integer_superscripts=True):
308
+ """
309
+ Convert a collection of page tokens/words/spans into a single text string.
310
+ """
311
+
312
+ if join_with_space:
313
+ join_char = " "
314
+ else:
315
+ join_char = ""
316
+ spans_copy = spans[:]
317
+
318
+ if remove_integer_superscripts:
319
+ for span in spans:
320
+ flags = span['flags']
321
+ if flags & 2**0: # superscript flag
322
+ if is_int(span['text']):
323
+ spans_copy.remove(span)
324
+ else:
325
+ span['superscript'] = True
326
+
327
+ if len(spans_copy) == 0:
328
+ return ""
329
+
330
+ spans_copy.sort(key=lambda span: span['span_num'])
331
+ spans_copy.sort(key=lambda span: span['line_num'])
332
+ spans_copy.sort(key=lambda span: span['block_num'])
333
+
334
+ # Force the span at the end of every line within a block to have exactly one space
335
+ # unless the line ends with a space or ends with a non-space followed by a hyphen
336
+ line_texts = []
337
+ line_span_texts = [spans_copy[0]['text']]
338
+ for span1, span2 in zip(spans_copy[:-1], spans_copy[1:]):
339
+ if not span1['block_num'] == span2['block_num'] or not span1['line_num'] == span2['line_num']:
340
+ line_text = join_char.join(line_span_texts).strip()
341
+ if (len(line_text) > 0
342
+ and not line_text[-1] == ' '
343
+ and not (len(line_text) > 1 and line_text[-1] == "-" and not line_text[-2] == ' ')):
344
+ if not join_with_space:
345
+ line_text += ' '
346
+ line_texts.append(line_text)
347
+ line_span_texts = [span2['text']]
348
+ else:
349
+ line_span_texts.append(span2['text'])
350
+ line_text = join_char.join(line_span_texts)
351
+ line_texts.append(line_text)
352
+
353
+ return join_char.join(line_texts).strip()
354
+
355
+
356
+ def sort_objects_left_to_right(objs):
357
+ """
358
+ Put the objects in order from left to right.
359
+ """
360
+ return sorted(objs, key=lambda k: k['bbox'][0] + k['bbox'][2])
361
+
362
+
363
+ def sort_objects_top_to_bottom(objs):
364
+ """
365
+ Put the objects in order from top to bottom.
366
+ """
367
+ return sorted(objs, key=lambda k: k['bbox'][1] + k['bbox'][3])
368
+
369
+
370
+ def align_columns(columns, bbox):
371
+ """
372
+ For every column, align the top and bottom boundaries to the final
373
+ table bounding box.
374
+ """
375
+ try:
376
+ for column in columns:
377
+ column['bbox'][1] = bbox[1]
378
+ column['bbox'][3] = bbox[3]
379
+ except Exception as err:
380
+ print("Could not align columns: {}".format(err))
381
+ pass
382
+
383
+ return columns
384
+
385
+
386
+ def align_rows(rows, bbox):
387
+ """
388
+ For every row, align the left and right boundaries to the final
389
+ table bounding box.
390
+ """
391
+ try:
392
+ for row in rows:
393
+ row['bbox'][0] = bbox[0]
394
+ row['bbox'][2] = bbox[2]
395
+ except Exception as err:
396
+ print("Could not align rows: {}".format(err))
397
+ pass
398
+
399
+ return rows
400
+
401
+
402
+ def refine_table_structures(table_bbox, table_structures, page_spans, class_thresholds):
403
+ """
404
+ Apply operations to the detected table structure objects such as
405
+ thresholding, NMS, and alignment.
406
+ """
407
+ rows = table_structures["rows"]
408
+ columns = table_structures['columns']
409
+
410
+ #columns = fill_column_gaps(columns, table_bbox)
411
+ #rows = fill_row_gaps(rows, table_bbox)
412
+
413
+ # Process the headers
414
+ headers = table_structures['headers']
415
+ headers = apply_threshold(headers, class_thresholds["table column header"])
416
+ headers = nms(headers)
417
+ headers = align_headers(headers, rows)
418
+
419
+ # Process supercells
420
+ supercells = [elem for elem in table_structures['supercells'] if not elem['subheader']]
421
+ subheaders = [elem for elem in table_structures['supercells'] if elem['subheader']]
422
+ supercells = apply_threshold(supercells, class_thresholds["table spanning cell"])
423
+ subheaders = apply_threshold(subheaders, class_thresholds["table projected row header"])
424
+ supercells += subheaders
425
+ # Align before NMS for supercells because alignment brings them into agreement
426
+ # with rows and columns first; if supercells still overlap after this operation,
427
+ # the threshold for NMS can basically be lowered to just above 0
428
+ supercells = align_supercells(supercells, rows, columns)
429
+ supercells = nms_supercells(supercells)
430
+
431
+ header_supercell_tree(supercells)
432
+
433
+ table_structures['columns'] = columns
434
+ table_structures['rows'] = rows
435
+ table_structures['supercells'] = supercells
436
+ table_structures['headers'] = headers
437
+
438
+ return table_structures
439
+
440
+
441
+ def nms(objects, match_criteria="object2_overlap", match_threshold=0.05, keep_metric="score", keep_higher=True):
442
+ """
443
+ A customizable version of non-maxima suppression (NMS).
444
+
445
+ Default behavior: If a lower-confidence object overlaps more than 5% of its area
446
+ with a higher-confidence object, remove the lower-confidence object.
447
+
448
+ objects: set of dicts; each object dict must have a 'bbox' and a 'score' field
449
+ match_criteria: how to measure how much two objects "overlap"
450
+ match_threshold: the cutoff for determining that overlap requires suppression of one object
451
+ keep_metric: which metric to use to determine the object to keep
452
+ keep_higher: if True, keep the object with the higher metric; otherwise, keep the lower
453
+ """
454
+ if len(objects) == 0:
455
+ return []
456
+
457
+ if keep_metric=="score":
458
+ objects = sort_objects_by_score(objects, reverse=keep_higher)
459
+ elif keep_metric=="area":
460
+ objects = sort_objects_by_area(objects, reverse=keep_higher)
461
+
462
+ num_objects = len(objects)
463
+ suppression = [False for obj in objects]
464
+
465
+ for object2_num in range(1, num_objects):
466
+ object2_rect = Rect(objects[object2_num]['bbox'])
467
+ object2_area = object2_rect.get_area() # .getArea()
468
+ for object1_num in range(object2_num):
469
+ if not suppression[object1_num]:
470
+ object1_rect = Rect(objects[object1_num]['bbox'])
471
+ object1_area = object1_rect.get_area() # .getArea()
472
+ intersect_area = object1_rect.intersect(object2_rect).get_area() # .getArea()
473
+ try:
474
+ if match_criteria=="object1_overlap":
475
+ metric = intersect_area / object1_area
476
+ elif match_criteria=="object2_overlap":
477
+ metric = intersect_area / object2_area
478
+ elif match_criteria=="iou":
479
+ metric = intersect_area / (object1_area + object2_area - intersect_area)
480
+ if metric >= match_threshold:
481
+ suppression[object2_num] = True
482
+ break
483
+ except Exception:
484
+ # Intended to recover from divide-by-zero
485
+ pass
486
+
487
+ return [obj for idx, obj in enumerate(objects) if not suppression[idx]]
488
+
489
+
490
+ def align_headers(headers, rows):
491
+ """
492
+ Adjust the header boundary to be the convex hull of the rows it intersects
493
+ at least 50% of the height of.
494
+
495
+ For now, we are not supporting tables with multiple headers, so we need to
496
+ eliminate anything besides the top-most header.
497
+ """
498
+
499
+ aligned_headers = []
500
+
501
+ for row in rows:
502
+ row['header'] = False
503
+
504
+ header_row_nums = []
505
+ for header in headers:
506
+ for row_num, row in enumerate(rows):
507
+ row_height = row['bbox'][3] - row['bbox'][1]
508
+ min_row_overlap = max(row['bbox'][1], header['bbox'][1])
509
+ max_row_overlap = min(row['bbox'][3], header['bbox'][3])
510
+ overlap_height = max_row_overlap - min_row_overlap
511
+ if overlap_height / row_height >= 0.5:
512
+ header_row_nums.append(row_num)
513
+
514
+ if len(header_row_nums) == 0:
515
+ return aligned_headers
516
+
517
+ header_rect = Rect()
518
+ if header_row_nums[0] > 0:
519
+ header_row_nums = list(range(header_row_nums[0]+1)) + header_row_nums
520
+
521
+ last_row_num = -1
522
+ for row_num in header_row_nums:
523
+ if row_num == last_row_num + 1:
524
+ row = rows[row_num]
525
+ row['header'] = True
526
+ header_rect = header_rect.include_rect(row['bbox'])
527
+ last_row_num = row_num
528
+ else:
529
+ # Break as soon as a non-header row is encountered.
530
+ # This ignores any subsequent rows in the table labeled as a header.
531
+ # Having more than 1 header is not supported currently.
532
+ break
533
+
534
+ header = {'bbox': list(header_rect)}
535
+ aligned_headers.append(header)
536
+
537
+ return aligned_headers
538
+
539
+
540
+ def align_supercells(supercells, rows, columns):
541
+ """
542
+ For each supercell, align it to the rows it intersects 50% of the height of,
543
+ and the columns it intersects 50% of the width of.
544
+ Eliminate supercells for which there are no rows and columns it intersects 50% with.
545
+ """
546
+ aligned_supercells = []
547
+
548
+ for supercell in supercells:
549
+ supercell['header'] = False
550
+ row_bbox_rect = None
551
+ col_bbox_rect = None
552
+ intersecting_header_rows = set()
553
+ intersecting_data_rows = set()
554
+ for row_num, row in enumerate(rows):
555
+ row_height = row['bbox'][3] - row['bbox'][1]
556
+ supercell_height = supercell['bbox'][3] - supercell['bbox'][1]
557
+ min_row_overlap = max(row['bbox'][1], supercell['bbox'][1])
558
+ max_row_overlap = min(row['bbox'][3], supercell['bbox'][3])
559
+ overlap_height = max_row_overlap - min_row_overlap
560
+ if 'span' in supercell:
561
+ overlap_fraction = max(overlap_height/row_height,
562
+ overlap_height/supercell_height)
563
+ else:
564
+ overlap_fraction = overlap_height / row_height
565
+ if overlap_fraction >= 0.5:
566
+ if 'header' in row and row['header']:
567
+ intersecting_header_rows.add(row_num)
568
+ else:
569
+ intersecting_data_rows.add(row_num)
570
+
571
+ # Supercell cannot span across the header boundary; eliminate whichever
572
+ # group of rows is the smallest
573
+ supercell['header'] = False
574
+ if len(intersecting_data_rows) > 0 and len(intersecting_header_rows) > 0:
575
+ if len(intersecting_data_rows) > len(intersecting_header_rows):
576
+ intersecting_header_rows = set()
577
+ else:
578
+ intersecting_data_rows = set()
579
+ if len(intersecting_header_rows) > 0:
580
+ supercell['header'] = True
581
+ elif 'span' in supercell:
582
+ continue # Require span supercell to be in the header
583
+ intersecting_rows = intersecting_data_rows.union(intersecting_header_rows)
584
+ # Determine vertical span of aligned supercell
585
+ for row_num in intersecting_rows:
586
+ if row_bbox_rect is None:
587
+ row_bbox_rect = Rect(rows[row_num]['bbox'])
588
+ else:
589
+ row_bbox_rect = row_bbox_rect.include_rect(rows[row_num]['bbox'])
590
+ if row_bbox_rect is None:
591
+ continue
592
+
593
+ intersecting_cols = []
594
+ for col_num, col in enumerate(columns):
595
+ col_width = col['bbox'][2] - col['bbox'][0]
596
+ supercell_width = supercell['bbox'][2] - supercell['bbox'][0]
597
+ min_col_overlap = max(col['bbox'][0], supercell['bbox'][0])
598
+ max_col_overlap = min(col['bbox'][2], supercell['bbox'][2])
599
+ overlap_width = max_col_overlap - min_col_overlap
600
+ if 'span' in supercell:
601
+ overlap_fraction = max(overlap_width/col_width,
602
+ overlap_width/supercell_width)
603
+ # Multiply by 2 effectively lowers the threshold to 0.25
604
+ if supercell['header']:
605
+ overlap_fraction = overlap_fraction * 2
606
+ else:
607
+ overlap_fraction = overlap_width / col_width
608
+ if overlap_fraction >= 0.5:
609
+ intersecting_cols.append(col_num)
610
+ if col_bbox_rect is None:
611
+ col_bbox_rect = Rect(col['bbox'])
612
+ else:
613
+ col_bbox_rect = col_bbox_rect.include_rect(col['bbox'])
614
+ if col_bbox_rect is None:
615
+ continue
616
+
617
+ supercell_bbox = list(row_bbox_rect.intersect(col_bbox_rect))
618
+ supercell['bbox'] = supercell_bbox
619
+
620
+ # Only a true supercell if it joins across multiple rows or columns
621
+ if (len(intersecting_rows) > 0 and len(intersecting_cols) > 0
622
+ and (len(intersecting_rows) > 1 or len(intersecting_cols) > 1)):
623
+ supercell['row_numbers'] = list(intersecting_rows)
624
+ supercell['column_numbers'] = intersecting_cols
625
+ aligned_supercells.append(supercell)
626
+
627
+ # A span supercell in the header means there must be supercells above it in the header
628
+ if 'span' in supercell and supercell['header'] and len(supercell['column_numbers']) > 1:
629
+ for row_num in range(0, min(supercell['row_numbers'])):
630
+ new_supercell = {'row_numbers': [row_num], 'column_numbers': supercell['column_numbers'],
631
+ 'score': supercell['score'], 'propagated': True}
632
+ new_supercell_columns = [columns[idx] for idx in supercell['column_numbers']]
633
+ new_supercell_rows = [rows[idx] for idx in supercell['row_numbers']]
634
+ bbox = [min([column['bbox'][0] for column in new_supercell_columns]),
635
+ min([row['bbox'][1] for row in new_supercell_rows]),
636
+ max([column['bbox'][2] for column in new_supercell_columns]),
637
+ max([row['bbox'][3] for row in new_supercell_rows])]
638
+ new_supercell['bbox'] = bbox
639
+ aligned_supercells.append(new_supercell)
640
+
641
+ return aligned_supercells
642
+
643
+
644
+ def nms_supercells(supercells):
645
+ """
646
+ A NMS scheme for supercells that first attempts to shrink supercells to
647
+ resolve overlap.
648
+ If two supercells overlap the same (sub)cell, shrink the lower confidence
649
+ supercell to resolve the overlap. If shrunk supercell is empty, remove it.
650
+ """
651
+
652
+ supercells = sort_objects_by_score(supercells)
653
+ num_supercells = len(supercells)
654
+ suppression = [False for supercell in supercells]
655
+
656
+ for supercell2_num in range(1, num_supercells):
657
+ supercell2 = supercells[supercell2_num]
658
+ for supercell1_num in range(supercell2_num):
659
+ supercell1 = supercells[supercell1_num]
660
+ remove_supercell_overlap(supercell1, supercell2)
661
+ if ((len(supercell2['row_numbers']) < 2 and len(supercell2['column_numbers']) < 2)
662
+ or len(supercell2['row_numbers']) == 0 or len(supercell2['column_numbers']) == 0):
663
+ suppression[supercell2_num] = True
664
+
665
+ return [obj for idx, obj in enumerate(supercells) if not suppression[idx]]
666
+
667
+
668
+ def header_supercell_tree(supercells):
669
+ """
670
+ Make sure no supercell in the header is below more than one supercell in any row above it.
671
+ The cells in the header form a tree, but a supercell with more than one supercell in a row
672
+ above it means that some cell has more than one parent, which is not allowed. Eliminate
673
+ any supercell that would cause this to be violated.
674
+ """
675
+ header_supercells = [supercell for supercell in supercells if 'header' in supercell and supercell['header']]
676
+ header_supercells = sort_objects_by_score(header_supercells)
677
+
678
+ for header_supercell in header_supercells[:]:
679
+ ancestors_by_row = defaultdict(int)
680
+ min_row = min(header_supercell['row_numbers'])
681
+ for header_supercell2 in header_supercells:
682
+ max_row2 = max(header_supercell2['row_numbers'])
683
+ if max_row2 < min_row:
684
+ if (set(header_supercell['column_numbers']).issubset(
685
+ set(header_supercell2['column_numbers']))):
686
+ for row2 in header_supercell2['row_numbers']:
687
+ ancestors_by_row[row2] += 1
688
+ for row in range(0, min_row):
689
+ if not ancestors_by_row[row] == 1:
690
+ supercells.remove(header_supercell)
691
+ break
692
+
693
+
694
+ def table_structure_to_cells(table_structures, table_spans, table_bbox):
695
+ """
696
+ Assuming the row, column, supercell, and header bounding boxes have
697
+ been refined into a set of consistent table structures, process these
698
+ table structures into table cells. This is a universal representation
699
+ format for the table, which can later be exported to Pandas or CSV formats.
700
+ Classify the cells as header/access cells or data cells
701
+ based on if they intersect with the header bounding box.
702
+ """
703
+ columns = table_structures['columns']
704
+ rows = table_structures['rows']
705
+ supercells = table_structures['supercells']
706
+ cells = []
707
+ subcells = []
708
+
709
+ # Identify complete cells and subcells
710
+ for column_num, column in enumerate(columns):
711
+ for row_num, row in enumerate(rows):
712
+ column_rect = Rect(list(column['bbox']))
713
+ row_rect = Rect(list(row['bbox']))
714
+ cell_rect = row_rect.intersect(column_rect)
715
+ header = 'header' in row and row['header']
716
+ cell = {'bbox': list(cell_rect), 'column_nums': [column_num], 'row_nums': [row_num],
717
+ 'header': header}
718
+
719
+ cell['subcell'] = False
720
+ for supercell in supercells:
721
+ supercell_rect = Rect(list(supercell['bbox']))
722
+ if (supercell_rect.intersect(cell_rect).get_area() # .getArea()
723
+ / cell_rect.get_area()) > 0.5: # getArea()
724
+ cell['subcell'] = True
725
+ break
726
+
727
+ if cell['subcell']:
728
+ subcells.append(cell)
729
+ else:
730
+ #cell_text = extract_text_inside_bbox(table_spans, cell['bbox'])
731
+ #cell['cell_text'] = cell_text
732
+ cell['subheader'] = False
733
+ cells.append(cell)
734
+
735
+ for supercell in supercells:
736
+ supercell_rect = Rect(list(supercell['bbox']))
737
+ cell_columns = set()
738
+ cell_rows = set()
739
+ cell_rect = None
740
+ header = True
741
+ for subcell in subcells:
742
+ subcell_rect = Rect(list(subcell['bbox']))
743
+ subcell_rect_area = subcell_rect.get_area() # .getArea()
744
+ if (subcell_rect.intersect(supercell_rect).get_area() # .getArea()
745
+ / subcell_rect_area) > 0.5:
746
+ if cell_rect is None:
747
+ cell_rect = Rect(list(subcell['bbox']))
748
+ else:
749
+ cell_rect.include_rect(Rect(list(subcell['bbox'])))
750
+ cell_rows = cell_rows.union(set(subcell['row_nums']))
751
+ cell_columns = cell_columns.union(set(subcell['column_nums']))
752
+ # By convention here, all subcells must be classified
753
+ # as header cells for a supercell to be classified as a header cell;
754
+ # otherwise, this could lead to a non-rectangular header region
755
+ header = header and 'header' in subcell and subcell['header']
756
+ if len(cell_rows) > 0 and len(cell_columns) > 0:
757
+ cell = {'bbox': list(cell_rect), 'column_nums': list(cell_columns), 'row_nums': list(cell_rows),
758
+ 'header': header, 'subheader': supercell['subheader']}
759
+ cells.append(cell)
760
+
761
+ # Compute a confidence score based on how well the page tokens
762
+ # slot into the cells reported by the model
763
+ _, _, cell_match_scores = slot_into_containers(cells, table_spans)
764
+ try:
765
+ mean_match_score = sum(cell_match_scores) / len(cell_match_scores)
766
+ min_match_score = min(cell_match_scores)
767
+ confidence_score = (mean_match_score + min_match_score)/2
768
+ except:
769
+ confidence_score = 0
770
+
771
+ # Dilate rows and columns before final extraction
772
+ #dilated_columns = fill_column_gaps(columns, table_bbox)
773
+ dilated_columns = columns
774
+ #dilated_rows = fill_row_gaps(rows, table_bbox)
775
+ dilated_rows = rows
776
+ for cell in cells:
777
+ column_rect = Rect()
778
+ for column_num in cell['column_nums']:
779
+ column_rect.include_rect(list(dilated_columns[column_num]['bbox']))
780
+ row_rect = Rect()
781
+ for row_num in cell['row_nums']:
782
+ row_rect.include_rect(list(dilated_rows[row_num]['bbox']))
783
+ cell_rect = column_rect.intersect(row_rect)
784
+ cell['bbox'] = list(cell_rect)
785
+
786
+ span_nums_by_cell, _, _ = slot_into_containers(cells, table_spans, overlap_threshold=0.001,
787
+ unique_assignment=True, forced_assignment=False)
788
+
789
+ for cell, cell_span_nums in zip(cells, span_nums_by_cell):
790
+ cell_spans = [table_spans[num] for num in cell_span_nums]
791
+ # TODO: Refine how text is extracted; should be character-based, not span-based;
792
+ # but need to associate
793
+ # cell['cell_text'] = extract_text_from_spans(cell_spans, remove_integer_superscripts=False) # TODO
794
+ cell['spans'] = cell_spans
795
+
796
+ # Adjust the row, column, and cell bounding boxes to reflect the extracted text
797
+ num_rows = len(rows)
798
+ rows = sort_objects_top_to_bottom(rows)
799
+ num_columns = len(columns)
800
+ columns = sort_objects_left_to_right(columns)
801
+ min_y_values_by_row = defaultdict(list)
802
+ max_y_values_by_row = defaultdict(list)
803
+ min_x_values_by_column = defaultdict(list)
804
+ max_x_values_by_column = defaultdict(list)
805
+ for cell in cells:
806
+ min_row = min(cell["row_nums"])
807
+ max_row = max(cell["row_nums"])
808
+ min_column = min(cell["column_nums"])
809
+ max_column = max(cell["column_nums"])
810
+ for span in cell['spans']:
811
+ min_x_values_by_column[min_column].append(span['bbox'][0])
812
+ min_y_values_by_row[min_row].append(span['bbox'][1])
813
+ max_x_values_by_column[max_column].append(span['bbox'][2])
814
+ max_y_values_by_row[max_row].append(span['bbox'][3])
815
+ for row_num, row in enumerate(rows):
816
+ if len(min_x_values_by_column[0]) > 0:
817
+ row['bbox'][0] = min(min_x_values_by_column[0])
818
+ if len(min_y_values_by_row[row_num]) > 0:
819
+ row['bbox'][1] = min(min_y_values_by_row[row_num])
820
+ if len(max_x_values_by_column[num_columns-1]) > 0:
821
+ row['bbox'][2] = max(max_x_values_by_column[num_columns-1])
822
+ if len(max_y_values_by_row[row_num]) > 0:
823
+ row['bbox'][3] = max(max_y_values_by_row[row_num])
824
+ for column_num, column in enumerate(columns):
825
+ if len(min_x_values_by_column[column_num]) > 0:
826
+ column['bbox'][0] = min(min_x_values_by_column[column_num])
827
+ if len(min_y_values_by_row[0]) > 0:
828
+ column['bbox'][1] = min(min_y_values_by_row[0])
829
+ if len(max_x_values_by_column[column_num]) > 0:
830
+ column['bbox'][2] = max(max_x_values_by_column[column_num])
831
+ if len(max_y_values_by_row[num_rows-1]) > 0:
832
+ column['bbox'][3] = max(max_y_values_by_row[num_rows-1])
833
+ for cell in cells:
834
+ row_rect = Rect()
835
+ column_rect = Rect()
836
+ for row_num in cell['row_nums']:
837
+ row_rect.include_rect(list(rows[row_num]['bbox']))
838
+ for column_num in cell['column_nums']:
839
+ column_rect.include_rect(list(columns[column_num]['bbox']))
840
+ cell_rect = row_rect.intersect(column_rect)
841
+ if cell_rect.get_area() > 0: # getArea()
842
+ cell['bbox'] = list(cell_rect)
843
+ pass
844
+
845
+ return cells, confidence_score
846
+
847
+
848
+ def remove_supercell_overlap(supercell1, supercell2):
849
+ """
850
+ This function resolves overlap between supercells (supercells must be
851
+ disjoint) by iteratively shrinking supercells by the fewest grid cells
852
+ necessary to resolve the overlap.
853
+ Example:
854
+ If two supercells overlap at grid cell (R, C), and supercell #1 is less
855
+ confident than supercell #2, we eliminate either row R from supercell #1
856
+ or column C from supercell #1 by comparing the number of columns in row R
857
+ versus the number of rows in column C. If the number of columns in row R
858
+ is less than the number of rows in column C, we eliminate row R from
859
+ supercell #1. This resolves the overlap by removing fewer grid cells from
860
+ supercell #1 than if we eliminated column C from it.
861
+ """
862
+ common_rows = set(supercell1['row_numbers']).intersection(set(supercell2['row_numbers']))
863
+ common_columns = set(supercell1['column_numbers']).intersection(set(supercell2['column_numbers']))
864
+
865
+ # While the supercells have overlapping grid cells, continue shrinking the less-confident
866
+ # supercell one row or one column at a time
867
+ while len(common_rows) > 0 and len(common_columns) > 0:
868
+ # Try to shrink the supercell as little as possible to remove the overlap;
869
+ # if the supercell has fewer rows than columns, remove an overlapping column,
870
+ # because this removes fewer grid cells from the supercell;
871
+ # otherwise remove an overlapping row
872
+ if len(supercell2['row_numbers']) < len(supercell2['column_numbers']):
873
+ min_column = min(supercell2['column_numbers'])
874
+ max_column = max(supercell2['column_numbers'])
875
+ if max_column in common_columns:
876
+ common_columns.remove(max_column)
877
+ supercell2['column_numbers'].remove(max_column)
878
+ elif min_column in common_columns:
879
+ common_columns.remove(min_column)
880
+ supercell2['column_numbers'].remove(min_column)
881
+ else:
882
+ supercell2['column_numbers'] = []
883
+ common_columns = set()
884
+ else:
885
+ min_row = min(supercell2['row_numbers'])
886
+ max_row = max(supercell2['row_numbers'])
887
+ if max_row in common_rows:
888
+ common_rows.remove(max_row)
889
+ supercell2['row_numbers'].remove(max_row)
890
+ elif min_row in common_rows:
891
+ common_rows.remove(min_row)
892
+ supercell2['row_numbers'].remove(min_row)
893
+ else:
894
+ supercell2['row_numbers'] = []
895
+ common_rows = set()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ -e git+https://github.com/mindee/doctr.git#egg=python-doctr[tf]
2
+ streamlit>=0.65.0
3
+ PyMuPDF>=1.16.0,!=1.18.11,!=1.18.12,!=1.19.5
4
+ tf2onnx==1.13.0
5
+ Pillow==9.0.1
6
+ pytesseract==0.3.10
7
+ torch==1.12.0
8
+ torchvision==0.13.0
9
+ numpy==1.21.6
tessdata/eng.traineddata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8280aed0782fe27257a68ea10fe7ef324ca0f8d85bd2fd145d1c2b560bcb66ba
3
+ size 15400601
weights/structure_wts.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46121ab2f4aba48a7d38624c861658ffeaacd0f305e95efcf66cb017e588b700
3
+ size 14371957