rizgiak commited on
Commit
4df4988
β€’
1 Parent(s): 16b0d52

initial commit

Browse files
AUTHORS.rst ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ =======
2
+ Credits
3
+ =======
4
+
5
+ Development Lead
6
+ ----------------
7
+
8
+ * Full name of the author <Email of the author>
9
+
10
+ Contributors
11
+ ------------
12
+
13
+ None yet. Why not be the first?
README.md CHANGED
@@ -1,12 +1,22 @@
1
  ---
2
- title: Table To Csv Pipeline
3
- emoji: πŸƒ
4
- colorFrom: purple
5
- colorTo: blue
6
  sdk: streamlit
7
- sdk_version: 1.28.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Table Extraction
3
+ emoji: πŸš€
4
+ colorFrom: indigo
5
+ colorTo: purple
6
  sdk: streamlit
7
+ sdk_version: 1.21.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
+ # huggingface-space
12
 
13
+ Imported from https://huggingface.co/spaces/jurgendn/table-extraction with some adjustment.
14
+
15
+
16
+ Current pipeline:
17
+
18
+ Table detection: https://huggingface.co/microsoft/table-transformer-detection
19
+
20
+ Table recognition: https://huggingface.co/microsoft/table-transformer-structure-recognition
21
+
22
+ OCR: https://github.com/pbcquoc/vietocr
app.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import string
3
+ from collections import Counter
4
+ from itertools import count, tee
5
+
6
+ import cv2
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import pandas as pd
10
+ import streamlit as st
11
+ import torch
12
+ from PIL import Image
13
+ from transformers import DetrImageProcessor, TableTransformerForObjectDetection
14
+ from vietocr.tool.config import Cfg
15
+ from vietocr.tool.predictor import Predictor
16
+
17
+ st.set_option('deprecation.showPyplotGlobalUse', False)
18
+ st.set_page_config(layout='wide')
19
+ st.title("Table Detection and Table Structure Recognition")
20
+ st.write(
21
+ "Implemented by MSFT team: https://github.com/microsoft/table-transformer")
22
+
23
+ # config = Cfg.load_config_from_name('vgg_transformer')
24
+ config = Cfg.load_config_from_name('vgg_seq2seq')
25
+ config['cnn']['pretrained'] = False
26
+ config['device'] = 'cpu'
27
+ config['predictor']['beamsearch'] = False
28
+ detector = Predictor(config)
29
+
30
+ table_detection_model = TableTransformerForObjectDetection.from_pretrained(
31
+ "microsoft/table-transformer-detection")
32
+
33
+ table_recognition_model = TableTransformerForObjectDetection.from_pretrained(
34
+ "microsoft/table-transformer-structure-recognition")
35
+
36
+
37
+ def PIL_to_cv(pil_img):
38
+ return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
39
+
40
+
41
+ def cv_to_PIL(cv_img):
42
+ return Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
43
+
44
+
45
+ async def pytess(cell_pil_img, threshold: float = 0.5):
46
+ text, prob = detector.predict(cell_pil_img, return_prob=True)
47
+ if prob < threshold:
48
+ return ""
49
+ return text.strip()
50
+
51
+
52
+ def sharpen_image(pil_img):
53
+
54
+ img = PIL_to_cv(pil_img)
55
+ sharpen_kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]])
56
+
57
+ sharpen = cv2.filter2D(img, -1, sharpen_kernel)
58
+ pil_img = cv_to_PIL(sharpen)
59
+ return pil_img
60
+
61
+
62
+ def uniquify(seq, suffs=count(1)):
63
+ """Make all the items unique by adding a suffix (1, 2, etc).
64
+ Credit: https://stackoverflow.com/questions/30650474/python-rename-duplicates-in-list-with-progressive-numbers-without-sorting-list
65
+ `seq` is mutable sequence of strings.
66
+ `suffs` is an optional alternative suffix iterable.
67
+ """
68
+ not_unique = [k for k, v in Counter(seq).items() if v > 1]
69
+
70
+ suff_gens = dict(zip(not_unique, tee(suffs, len(not_unique))))
71
+ for idx, s in enumerate(seq):
72
+ try:
73
+ suffix = str(next(suff_gens[s]))
74
+ except KeyError:
75
+ continue
76
+ else:
77
+ seq[idx] += suffix
78
+
79
+ return seq
80
+
81
+
82
+ def binarizeBlur_image(pil_img):
83
+ image = PIL_to_cv(pil_img)
84
+ thresh = cv2.threshold(image, 150, 255, cv2.THRESH_BINARY_INV)[1]
85
+
86
+ result = cv2.GaussianBlur(thresh, (5, 5), 0)
87
+ result = 255 - result
88
+ return cv_to_PIL(result)
89
+
90
+
91
+ def td_postprocess(pil_img):
92
+ '''
93
+ Removes gray background from tables
94
+ '''
95
+ img = PIL_to_cv(pil_img)
96
+
97
+ hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
98
+ mask = cv2.inRange(hsv, (0, 0, 100),
99
+ (255, 5, 255)) # (0, 0, 100), (255, 5, 255)
100
+ nzmask = cv2.inRange(hsv, (0, 0, 5),
101
+ (255, 255, 255)) # (0, 0, 5), (255, 255, 255))
102
+ nzmask = cv2.erode(nzmask, np.ones((3, 3))) # (3,3)
103
+ mask = mask & nzmask
104
+
105
+ new_img = img.copy()
106
+ new_img[np.where(mask)] = 255
107
+
108
+ return cv_to_PIL(new_img)
109
+
110
+
111
+ # def super_res(pil_img):
112
+ # # requires opencv-contrib-python installed without the opencv-python
113
+ # sr = dnn_superres.DnnSuperResImpl_create()
114
+ # image = PIL_to_cv(pil_img)
115
+ # model_path = "./LapSRN_x8.pb"
116
+ # model_name = model_path.split('/')[1].split('_')[0].lower()
117
+ # model_scale = int(model_path.split('/')[1].split('_')[1].split('.')[0][1])
118
+
119
+ # sr.readModel(model_path)
120
+ # sr.setModel(model_name, model_scale)
121
+ # final_img = sr.upsample(image)
122
+ # final_img = cv_to_PIL(final_img)
123
+
124
+ # return final_img
125
+
126
+
127
+ def table_detector(image, THRESHOLD_PROBA):
128
+ '''
129
+ Table detection using DEtect-object TRansformer pre-trained on 1 million tables
130
+
131
+ '''
132
+
133
+ feature_extractor = DetrImageProcessor(do_resize=True,
134
+ size=800,
135
+ max_size=800)
136
+ encoding = feature_extractor(image, return_tensors="pt")
137
+
138
+ with torch.no_grad():
139
+ outputs = table_detection_model(**encoding)
140
+
141
+ probas = outputs.logits.softmax(-1)[0, :, :-1]
142
+ keep = probas.max(-1).values > THRESHOLD_PROBA
143
+
144
+ target_sizes = torch.tensor(image.size[::-1]).unsqueeze(0)
145
+ postprocessed_outputs = feature_extractor.post_process(
146
+ outputs, target_sizes)
147
+ bboxes_scaled = postprocessed_outputs[0]['boxes'][keep]
148
+
149
+ return (probas[keep], bboxes_scaled)
150
+
151
+
152
+ def table_struct_recog(image, THRESHOLD_PROBA):
153
+ '''
154
+ Table structure recognition using DEtect-object TRansformer pre-trained on 1 million tables
155
+ '''
156
+
157
+ feature_extractor = DetrImageProcessor(do_resize=True,
158
+ size=1000,
159
+ max_size=1000)
160
+ encoding = feature_extractor(image, return_tensors="pt")
161
+
162
+ with torch.no_grad():
163
+ outputs = table_recognition_model(**encoding)
164
+
165
+ probas = outputs.logits.softmax(-1)[0, :, :-1]
166
+ keep = probas.max(-1).values > THRESHOLD_PROBA
167
+
168
+ target_sizes = torch.tensor(image.size[::-1]).unsqueeze(0)
169
+ postprocessed_outputs = feature_extractor.post_process(
170
+ outputs, target_sizes)
171
+ bboxes_scaled = postprocessed_outputs[0]['boxes'][keep]
172
+
173
+ return (probas[keep], bboxes_scaled)
174
+
175
+
176
+ class TableExtractionPipeline():
177
+
178
+ colors = ["red", "blue", "green", "yellow", "orange", "violet"]
179
+
180
+ # colors = ["red", "blue", "green", "red", "red", "red"]
181
+
182
+ def add_padding(self,
183
+ pil_img,
184
+ top,
185
+ right,
186
+ bottom,
187
+ left,
188
+ color=(255, 255, 255)):
189
+ '''
190
+ Image padding as part of TSR pre-processing to prevent missing table edges
191
+ '''
192
+ width, height = pil_img.size
193
+ new_width = width + right + left
194
+ new_height = height + top + bottom
195
+ result = Image.new(pil_img.mode, (new_width, new_height), color)
196
+ result.paste(pil_img, (left, top))
197
+ return result
198
+
199
+ def plot_results_detection(self, c1, model, pil_img, prob, boxes,
200
+ delta_xmin, delta_ymin, delta_xmax, delta_ymax):
201
+ '''
202
+ crop_tables and plot_results_detection must have same co-ord shifts because 1 only plots the other one updates co-ordinates
203
+ '''
204
+ # st.write('img_obj')
205
+ # st.write(pil_img)
206
+ plt.imshow(pil_img)
207
+ ax = plt.gca()
208
+
209
+ for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
210
+ cl = p.argmax()
211
+ xmin, ymin, xmax, ymax = xmin - delta_xmin, ymin - delta_ymin, xmax + delta_xmax, ymax + delta_ymax
212
+ ax.add_patch(
213
+ plt.Rectangle((xmin, ymin),
214
+ xmax - xmin,
215
+ ymax - ymin,
216
+ fill=False,
217
+ color='red',
218
+ linewidth=3))
219
+ text = f'{model.config.id2label[cl.item()]}: {p[cl]:0.2f}'
220
+ ax.text(xmin - 20,
221
+ ymin - 50,
222
+ text,
223
+ fontsize=10,
224
+ bbox=dict(facecolor='yellow', alpha=0.5))
225
+ plt.axis('off')
226
+ c1.pyplot()
227
+
228
+ def crop_tables(self, pil_img, prob, boxes, delta_xmin, delta_ymin,
229
+ delta_xmax, delta_ymax):
230
+ '''
231
+ crop_tables and plot_results_detection must have same co-ord shifts because 1 only plots the other one updates co-ordinates
232
+ '''
233
+ cropped_img_list = []
234
+
235
+ for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
236
+
237
+ xmin, ymin, xmax, ymax = xmin - delta_xmin, ymin - delta_ymin, xmax + delta_xmax, ymax + delta_ymax
238
+ cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
239
+ cropped_img_list.append(cropped_img)
240
+
241
+ return cropped_img_list
242
+
243
+ def generate_structure(self, c2, model, pil_img, prob, boxes,
244
+ expand_rowcol_bbox_top, expand_rowcol_bbox_bottom):
245
+ '''
246
+ Co-ordinates are adjusted here by 3 'pixels'
247
+ To plot table pillow image and the TSR bounding boxes on the table
248
+ '''
249
+ # st.write('img_obj')
250
+ # st.write(pil_img)
251
+ plt.figure(figsize=(32, 20))
252
+ plt.imshow(pil_img)
253
+ ax = plt.gca()
254
+ rows = {}
255
+ cols = {}
256
+ idx = 0
257
+
258
+ for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
259
+
260
+ xmin, ymin, xmax, ymax = xmin, ymin, xmax, ymax
261
+ cl = p.argmax()
262
+ class_text = model.config.id2label[cl.item()]
263
+ text = f'{class_text}: {p[cl]:0.2f}'
264
+ # or (class_text == 'table column')
265
+ if (class_text
266
+ == 'table row') or (class_text
267
+ == 'table projected row header') or (
268
+ class_text == 'table column'):
269
+ ax.add_patch(
270
+ plt.Rectangle((xmin, ymin),
271
+ xmax - xmin,
272
+ ymax - ymin,
273
+ fill=False,
274
+ color=self.colors[cl.item()],
275
+ linewidth=2))
276
+ ax.text(xmin - 10,
277
+ ymin - 10,
278
+ text,
279
+ fontsize=5,
280
+ bbox=dict(facecolor='yellow', alpha=0.5))
281
+
282
+ if class_text == 'table row':
283
+ rows['table row.' +
284
+ str(idx)] = (xmin, ymin - expand_rowcol_bbox_top, xmax,
285
+ ymax + expand_rowcol_bbox_bottom)
286
+ if class_text == 'table column':
287
+ cols['table column.' +
288
+ str(idx)] = (xmin, ymin - expand_rowcol_bbox_top, xmax,
289
+ ymax + expand_rowcol_bbox_bottom)
290
+
291
+ idx += 1
292
+
293
+ plt.axis('on')
294
+ c2.pyplot()
295
+ return rows, cols
296
+
297
+ def sort_table_featuresv2(self, rows: dict, cols: dict):
298
+ # Sometimes the header and first row overlap, and we need the header bbox not to have first row's bbox inside the headers bbox
299
+ rows_ = {
300
+ table_feature: (xmin, ymin, xmax, ymax)
301
+ for table_feature, (
302
+ xmin, ymin, xmax,
303
+ ymax) in sorted(rows.items(), key=lambda tup: tup[1][1])
304
+ }
305
+ cols_ = {
306
+ table_feature: (xmin, ymin, xmax, ymax)
307
+ for table_feature, (
308
+ xmin, ymin, xmax,
309
+ ymax) in sorted(cols.items(), key=lambda tup: tup[1][0])
310
+ }
311
+
312
+ return rows_, cols_
313
+
314
+ def individual_table_featuresv2(self, pil_img, rows: dict, cols: dict):
315
+
316
+ for k, v in rows.items():
317
+ xmin, ymin, xmax, ymax = v
318
+ cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
319
+ rows[k] = xmin, ymin, xmax, ymax, cropped_img
320
+
321
+ for k, v in cols.items():
322
+ xmin, ymin, xmax, ymax = v
323
+ cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
324
+ cols[k] = xmin, ymin, xmax, ymax, cropped_img
325
+
326
+ return rows, cols
327
+
328
+ def object_to_cellsv2(self, master_row: dict, cols: dict,
329
+ expand_rowcol_bbox_top, expand_rowcol_bbox_bottom,
330
+ padd_left):
331
+ '''Removes redundant bbox for rows&columns and divides each row into cells from columns
332
+ Args:
333
+
334
+ Returns:
335
+
336
+
337
+ '''
338
+ cells_img = {}
339
+ header_idx = 0
340
+ row_idx = 0
341
+ previous_xmax_col = 0
342
+ new_cols = {}
343
+ new_master_row = {}
344
+ previous_ymin_row = 0
345
+ new_cols = cols
346
+ new_master_row = master_row
347
+ ## Below 2 for loops remove redundant bounding boxes ###
348
+ # for k_col, v_col in cols.items():
349
+ # xmin_col, _, xmax_col, _, col_img = v_col
350
+ # if (np.isclose(previous_xmax_col, xmax_col, atol=5)) or (xmin_col >= xmax_col):
351
+ # print('Found a column with double bbox')
352
+ # continue
353
+ # previous_xmax_col = xmax_col
354
+ # new_cols[k_col] = v_col
355
+
356
+ # for k_row, v_row in master_row.items():
357
+ # _, ymin_row, _, ymax_row, row_img = v_row
358
+ # if (np.isclose(previous_ymin_row, ymin_row, atol=5)) or (ymin_row >= ymax_row):
359
+ # print('Found a row with double bbox')
360
+ # continue
361
+ # previous_ymin_row = ymin_row
362
+ # new_master_row[k_row] = v_row
363
+ ######################################################
364
+ for k_row, v_row in new_master_row.items():
365
+
366
+ _, _, _, _, row_img = v_row
367
+ xmax, ymax = row_img.size
368
+ xa, ya, xb, yb = 0, 0, 0, ymax
369
+ row_img_list = []
370
+ # plt.imshow(row_img)
371
+ # st.pyplot()
372
+ for idx, kv in enumerate(new_cols.items()):
373
+ k_col, v_col = kv
374
+ xmin_col, _, xmax_col, _, col_img = v_col
375
+ xmin_col, xmax_col = xmin_col - padd_left - 10, xmax_col - padd_left
376
+ xa = xmin_col
377
+ xb = xmax_col
378
+ if idx == 0:
379
+ xa = 0
380
+ if idx == len(new_cols) - 1:
381
+ xb = xmax
382
+ xa, ya, xb, yb = xa, ya, xb, yb
383
+
384
+ row_img_cropped = row_img.crop((xa, ya, xb, yb))
385
+ row_img_list.append(row_img_cropped)
386
+
387
+ cells_img[k_row + '.' + str(row_idx)] = row_img_list
388
+ row_idx += 1
389
+
390
+ return cells_img, len(new_cols), len(new_master_row) - 1
391
+
392
+ def clean_dataframe(self, df):
393
+ '''
394
+ Remove irrelevant symbols that appear with tesseractOCR
395
+ '''
396
+ # df.columns = [col.replace('|', '') for col in df.columns]
397
+
398
+ for col in df.columns:
399
+
400
+ df[col] = df[col].str.replace("'", '', regex=True)
401
+ df[col] = df[col].str.replace('"', '', regex=True)
402
+ df[col] = df[col].str.replace(']', '', regex=True)
403
+ df[col] = df[col].str.replace('[', '', regex=True)
404
+ df[col] = df[col].str.replace('{', '', regex=True)
405
+ df[col] = df[col].str.replace('}', '', regex=True)
406
+ return df
407
+
408
+ @st.cache
409
+ def convert_df(self, df):
410
+ return df.to_csv().encode('utf-8')
411
+
412
+ def create_dataframe(self, c3, cell_ocr_res: list, max_cols: int,
413
+ max_rows: int):
414
+ '''Create dataframe using list of cell values of the table, also checks for valid header of dataframe
415
+ Args:
416
+ cell_ocr_res: list of strings, each element representing a cell in a table
417
+ max_cols, max_rows: number of columns and rows
418
+ Returns:
419
+ dataframe : final dataframe after all pre-processing
420
+ '''
421
+
422
+ headers = cell_ocr_res[:max_cols]
423
+ new_headers = uniquify(headers,
424
+ (f' {x!s}' for x in string.ascii_lowercase))
425
+ counter = 0
426
+
427
+ cells_list = cell_ocr_res[max_cols:]
428
+ df = pd.DataFrame("", index=range(0, max_rows), columns=new_headers)
429
+
430
+ cell_idx = 0
431
+ for nrows in range(max_rows):
432
+ for ncols in range(max_cols):
433
+ df.iat[nrows, ncols] = str(cells_list[cell_idx])
434
+ cell_idx += 1
435
+
436
+ ## To check if there are duplicate headers if result of uniquify+col == col
437
+ ## This check removes headers when all headers are empty or if median of header word count is less than 6
438
+ for x, col in zip(string.ascii_lowercase, new_headers):
439
+ if f' {x!s}' == col:
440
+ counter += 1
441
+ header_char_count = [len(col) for col in new_headers]
442
+
443
+ # if (counter == len(new_headers)) or (statistics.median(header_char_count) < 6):
444
+ # st.write('woooot')
445
+ # df.columns = uniquify(df.iloc[0], (f' {x!s}' for x in string.ascii_lowercase))
446
+ # df = df.iloc[1:,:]
447
+
448
+ df = self.clean_dataframe(df)
449
+
450
+ c3.dataframe(df)
451
+ csv = self.convert_df(df)
452
+ c3.download_button("Download table",
453
+ csv,
454
+ "file.csv",
455
+ "text/csv",
456
+ key='download-csv-' + df.iloc[0, 0])
457
+
458
+ return df
459
+
460
+ async def start_process(self, image_path: str, TD_THRESHOLD, TSR_THRESHOLD,
461
+ OCR_THRESHOLD, padd_top, padd_left, padd_bottom,
462
+ padd_right, delta_xmin, delta_ymin, delta_xmax,
463
+ delta_ymax, expand_rowcol_bbox_top,
464
+ expand_rowcol_bbox_bottom):
465
+ '''
466
+ Initiates process of generating pandas dataframes from raw pdf-page images
467
+
468
+ '''
469
+ image = Image.open(image_path).convert("RGB")
470
+ probas, bboxes_scaled = table_detector(image,
471
+ THRESHOLD_PROBA=TD_THRESHOLD)
472
+
473
+ if bboxes_scaled.nelement() == 0:
474
+ st.write('No table found in the pdf-page image')
475
+ return ''
476
+
477
+ # try:
478
+ # st.write('Document: '+image_path.split('/')[-1])
479
+ c1, c2, c3 = st.columns((1, 1, 1))
480
+
481
+ self.plot_results_detection(c1, table_detection_model, image, probas,
482
+ bboxes_scaled, delta_xmin, delta_ymin,
483
+ delta_xmax, delta_ymax)
484
+ cropped_img_list = self.crop_tables(image, probas, bboxes_scaled,
485
+ delta_xmin, delta_ymin, delta_xmax,
486
+ delta_ymax)
487
+
488
+ for idx, unpadded_table in enumerate(cropped_img_list):
489
+
490
+ table = self.add_padding(unpadded_table, padd_top, padd_right,
491
+ padd_bottom, padd_left)
492
+ # table = super_res(table)
493
+ # table = binarizeBlur_image(table)
494
+ # table = sharpen_image(table) # Test sharpen image next
495
+ # table = td_postprocess(table)
496
+
497
+ # table.save("result"+str(idx)+".png")
498
+
499
+ probas, bboxes_scaled = table_struct_recog(
500
+ table, THRESHOLD_PROBA=TSR_THRESHOLD)
501
+ rows, cols = self.generate_structure(c2, table_recognition_model,
502
+ table, probas, bboxes_scaled,
503
+ expand_rowcol_bbox_top,
504
+ expand_rowcol_bbox_bottom)
505
+ # st.write(len(rows), len(cols))
506
+ rows, cols = self.sort_table_featuresv2(rows, cols)
507
+ master_row, cols = self.individual_table_featuresv2(
508
+ table, rows, cols)
509
+
510
+ cells_img, max_cols, max_rows = self.object_to_cellsv2(
511
+ master_row, cols, expand_rowcol_bbox_top,
512
+ expand_rowcol_bbox_bottom, padd_left)
513
+
514
+ sequential_cell_img_list = []
515
+ for k, img_list in cells_img.items():
516
+ for img in img_list:
517
+ # img = super_res(img)
518
+ # img = sharpen_image(img) # Test sharpen image next
519
+ # img = binarizeBlur_image(img)
520
+ # img = self.add_padding(img, 10,10,10,10)
521
+ # plt.imshow(img)
522
+ # c3.pyplot()
523
+ sequential_cell_img_list.append(
524
+ pytess(cell_pil_img=img, threshold=OCR_THRESHOLD))
525
+
526
+ cell_ocr_res = await asyncio.gather(*sequential_cell_img_list)
527
+
528
+ self.create_dataframe(c3, cell_ocr_res, max_cols, max_rows)
529
+ st.write(
530
+ 'Errors in OCR is due to either quality of the image or performance of the OCR'
531
+ )
532
+ # except:
533
+ # st.write('Either incorrectly identified table or no table, to debug remove try/except')
534
+ # break
535
+ # break
536
+
537
+
538
+ if __name__ == "__main__":
539
+
540
+ img_name = st.file_uploader("Upload an image with table(s)")
541
+ st1, st2, st3 = st.columns((1, 1, 1))
542
+ TD_th = st1.slider('Table detection threshold', 0.0, 1.0, 0.8)
543
+ TSR_th = st2.slider('Table structure recognition threshold', 0.0, 1.0, 0.7)
544
+ OCR_th = st3.slider("Text Probs Threshold", 0.0, 1.0, 0.5)
545
+
546
+ st1, st2, st3, st4 = st.columns((1, 1, 1, 1))
547
+
548
+ padd_top = st1.slider('Padding top', 0, 200, 90)
549
+ padd_left = st2.slider('Padding left', 0, 200, 40)
550
+ padd_right = st3.slider('Padding right', 0, 200, 40)
551
+ padd_bottom = st4.slider('Padding bottom', 0, 200, 90)
552
+
553
+ te = TableExtractionPipeline()
554
+ # for img in image_list:
555
+ if img_name is not None:
556
+ asyncio.run(
557
+ te.start_process(img_name,
558
+ TD_THRESHOLD=TD_th,
559
+ TSR_THRESHOLD=TSR_th,
560
+ OCR_THRESHOLD=OCR_th,
561
+ padd_top=padd_top,
562
+ padd_left=padd_left,
563
+ padd_bottom=padd_bottom,
564
+ padd_right=padd_right,
565
+ delta_xmin=10, # add offset to the left of the table
566
+ delta_ymin=3, # add offset to the bottom of the table
567
+ delta_xmax=10, # add offset to the right of the table
568
+ delta_ymax=3, # add offset to the top of the table
569
+ expand_rowcol_bbox_top=0,
570
+ expand_rowcol_bbox_bottom=0))
components/callbacks.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Define callbacks here
2
+ from pytorch_lightning.callbacks import EarlyStopping
3
+
4
+ early_stopping = EarlyStopping(monitor="loss", min_delta=0, patience=3)
components/data_module.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, List, Optional, Union
2
+
3
+ import torch
4
+ from pytorch_lightning import LightningDataModule
5
+ from torch.utils.data import DataLoader, Dataset
6
+
7
+
8
+ class SampleDataset(Dataset):
9
+
10
+ def __init__(self,
11
+ x: Union[List, torch.Tensor],
12
+ y: Union[List, torch.Tensor],
13
+ transforms: Optional[Callable] = None) -> None:
14
+ super(SampleDataset, self).__init__()
15
+ self.x = x
16
+ self.y = y
17
+
18
+ if transforms is None:
19
+ # Replace None with some default transforms
20
+ # If image, could be an Resize and ToTensor
21
+ self.transforms = lambda x: x
22
+ else:
23
+ self.transforms = transforms
24
+
25
+ def __len__(self):
26
+ return len(self.x)
27
+
28
+ def __getitem__(self, index: int):
29
+ x = self.x[index]
30
+ y = self.y[index]
31
+
32
+ x = self.transforms(x)
33
+ return x, y
34
+
35
+
36
+ class SampleDataModule(LightningDataModule):
37
+
38
+ def __init__(self,
39
+ x: Union[List, torch.Tensor],
40
+ y: Union[List, torch.Tensor],
41
+ transforms: Optional[Callable] = None,
42
+ val_ratio: float = 0,
43
+ batch_size: int = 32) -> None:
44
+ super(SampleDataModule, self).__init__()
45
+ assert 0 <= val_ratio < 1
46
+ assert isinstance(batch_size, int)
47
+ self.x = x
48
+ self.y = y
49
+
50
+ self.transforms = transforms
51
+ self.val_ratio = val_ratio
52
+ self.batch_size = batch_size
53
+
54
+ self.setup()
55
+ self.prepare_data()
56
+
57
+ def setup(self, stage: Optional[str] = None) -> None:
58
+ pass
59
+
60
+ def prepare_data(self) -> None:
61
+ n_samples: int = len(self.x)
62
+ train_size: int = n_samples - int(n_samples * self.val_ratio)
63
+
64
+ self.train_dataset = SampleDataset(x=self.x[:train_size],
65
+ y=self.y[:train_size],
66
+ transforms=self.transforms)
67
+ if train_size < n_samples:
68
+ self.val_dataset = SampleDataset(x=self.x[train_size:],
69
+ y=self.y[train_size:],
70
+ transforms=self.transforms)
71
+ else:
72
+ self.val_dataset = SampleDataset(x=self.x[-self.batch_size:],
73
+ y=self.y[-self.batch_size:],
74
+ transforms=self.transforms)
75
+
76
+ def train_dataloader(self) -> DataLoader:
77
+ return DataLoader(dataset=self.train_dataset,
78
+ batch_size=self.batch_size)
79
+
80
+ def val_dataloader(self) -> DataLoader:
81
+ return DataLoader(dataset=self.val_dataset, batch_size=self.batch_size)
config.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from dynaconf import Dynaconf
2
+
3
+ CFG = Dynaconf(envvar_prefix="DYNACONF", settings_files=["config/config.yaml"])
config/config.yaml ADDED
File without changes
data/.gitkeep ADDED
File without changes
docker-compose.yml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "3.7"
2
+
3
+ services:
4
+ model_name:
5
+ build:
6
+ context: .
7
+ dockerfile: .docker/Dockerfile
8
+ container_name: model_name
9
+ ports:
10
+ - "8996:8996"
11
+ env_file:
12
+ - ./.env
13
+ volumes:
14
+ - ./data:/home/working/data:ro
15
+
16
+ # This part is used to enable GPU support
17
+ deploy:
18
+ resources:
19
+ reservations:
20
+ devices:
21
+ - driver: nvidia
22
+ count: 1
23
+ capabilities: [ gpu ]
models/__init__.py ADDED
File without changes
models/base_model/classification.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from typing import Any, Dict, List
3
+
4
+ import torch
5
+ from pytorch_lightning import LightningModule
6
+ from torch import Tensor
7
+
8
+
9
+ class LightningClassification(LightningModule):
10
+
11
+ @abstractmethod
12
+ def __init__(self, *args, **kwargs) -> None:
13
+ super(LightningClassification, self).__init__(*args, **kwargs)
14
+ self.train_batch_output: List[Dict] = []
15
+ self.validation_batch_output: List[Dict] = []
16
+ self.log_value_list: List[str] = ['loss', 'f1', 'precision', 'recall']
17
+
18
+ @abstractmethod
19
+ def forward(self, *args, **kwargs) -> Any:
20
+ pass
21
+
22
+ @abstractmethod
23
+ def configure_optimizers(self):
24
+ pass
25
+
26
+ @abstractmethod
27
+ def loss(self, input: Tensor, target: Tensor, **kwargs) -> Tensor:
28
+ pass
29
+
30
+ @abstractmethod
31
+ def training_step(self, batch, batch_idx):
32
+ pass
33
+
34
+ def __average(self, key: str, outputs: List[Dict]) -> Tensor:
35
+ target_arr = torch.Tensor([val[key] for val in outputs]).float()
36
+ return target_arr.mean()
37
+
38
+ @torch.no_grad()
39
+ def on_train_epoch_start(self) -> None:
40
+ self.train_batch_output = []
41
+
42
+ @torch.no_grad()
43
+ def on_train_epoch_end(self) -> None:
44
+ for key in self.log_value_list:
45
+ val = self.__average(key=key, outputs=self.train_batch_output)
46
+ log_name = f"training/{key}"
47
+ self.log(name=log_name, value=val)
48
+
49
+ @abstractmethod
50
+ @torch.no_grad()
51
+ def validation_step(self, batch, batch_idx):
52
+ pass
53
+
54
+ @torch.no_grad()
55
+ def on_validation_epoch_start(self) -> None:
56
+ self.validation_batch_output = []
57
+
58
+ @torch.no_grad()
59
+ def on_validation_epoch_end(self) -> None:
60
+ for key in self.log_value_list:
61
+ val = self.__average(key=key, outputs=self.validation_batch_output)
62
+ log_name = f"val/{key}"
63
+ self.log(name=log_name, value=val)
models/base_model/gan.py ADDED
File without changes
models/base_model/regression.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from typing import Any, Dict, List
3
+
4
+ import torch
5
+ from pytorch_lightning import LightningModule
6
+ from torch import Tensor
7
+
8
+
9
+ class LightningRegression(LightningModule):
10
+
11
+ @abstractmethod
12
+ def __init__(self, *args, **kwargs) -> None:
13
+ super(LightningRegression, self).__init__(*args, **kwargs)
14
+ self.train_step_output: List[Dict] = []
15
+ self.validation_step_output: List[Dict] = []
16
+ self.log_value_list: List[str] = ['loss', 'mse', 'mape']
17
+
18
+ @abstractmethod
19
+ def forward(self, *args, **kwargs) -> Any:
20
+ pass
21
+
22
+ @abstractmethod
23
+ def configure_optimizers(self):
24
+ pass
25
+
26
+ @abstractmethod
27
+ def loss(self, input: Tensor, output: Tensor, **kwargs):
28
+ return 0
29
+
30
+ @abstractmethod
31
+ def training_step(self, batch, batch_idx):
32
+ pass
33
+
34
+ def __average(self, key: str, outputs: List[Dict]) -> Tensor:
35
+ target_arr = torch.Tensor([val[key] for val in outputs]).float()
36
+ return target_arr.mean()
37
+
38
+ @torch.no_grad()
39
+ def on_train_epoch_end(self) -> None:
40
+ for key in self.log_value_list:
41
+ val = self.__average(key=key, outputs=self.train_step_output)
42
+ log_name = f"training/{key}"
43
+ self.log(name=log_name, value=val)
44
+
45
+ @torch.no_grad()
46
+ @abstractmethod
47
+ def validation_step(self, batch, batch_idx):
48
+ pass
49
+
50
+ @torch.no_grad()
51
+ def validation_epoch_end(self, outputs):
52
+ for key in self.log_value_list:
53
+ val = self.__average(key=key, outputs=self.validation_step_output)
54
+ log_name = f"training/{key}"
55
+ self.log(name=log_name, value=val)
models/metrics/classification.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import torch
4
+ from torchmetrics import functional as FM
5
+
6
+
7
+ def classification_metrics(
8
+ preds: torch.Tensor,
9
+ target: torch.Tensor,
10
+ num_classes: int,
11
+ average: str = 'macro',
12
+ task: str = 'multiclass') -> Dict[str, torch.Tensor]:
13
+ """
14
+ get_classification_metrics
15
+ Return some metrics evaluation the classification task
16
+
17
+ Parameters
18
+ ----------
19
+ preds : torch.Tensor
20
+ logits, probs
21
+ target : torch.Tensor
22
+ targets label
23
+
24
+ Returns
25
+ -------
26
+ Dict[str, torch.Tensor]
27
+ _description_
28
+ """
29
+ f1 = FM.f1_score(preds=preds,
30
+ target=target,
31
+ num_classes=num_classes,
32
+ task=task,
33
+ average=average)
34
+ recall = FM.recall(preds=preds,
35
+ target=target,
36
+ num_classes=num_classes,
37
+ task=task,
38
+ average=average)
39
+ precision = FM.precision(preds=preds,
40
+ target=target,
41
+ num_classes=num_classes,
42
+ task=task,
43
+ average=average)
44
+ return dict(f1=f1, precision=precision, recall=recall)
models/metrics/regression.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import torch
4
+ from torchmetrics import functional as FM
5
+
6
+
7
+ def regression_metrics(preds: torch.Tensor,
8
+ target: torch.Tensor) -> Dict[str, torch.Tensor]:
9
+ """
10
+ get_classification_metrics
11
+ Return some metrics evaluation the classification task
12
+
13
+ Parameters
14
+ ----------
15
+ preds : torch.Tensor
16
+ logits, probs
17
+ target : torch.Tensor
18
+ targets label
19
+
20
+ Returns
21
+ -------
22
+ Dict[str, torch.Tensor]
23
+ _description_
24
+ """
25
+ mse: torch.Tensor = FM.mean_squared_error(preds=preds, target=target)
26
+ mape: torch.Tensor = FM.mean_absolute_percentage_error(preds=preds,
27
+ target=target)
28
+ return dict(mse=mse, mape=mape)
models/model_lit.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor, nn, optim
2
+ from torch.nn import functional as F
3
+
4
+ from .base_model.classification import LightningClassification
5
+ from .metrics.classification import classification_metrics
6
+ from .modules.sample_torch_module import UselessLayer
7
+
8
+
9
+ class UselessClassification(LightningClassification):
10
+
11
+ def __init__(self, n_classes: int, lr: float, **kwargs) -> None:
12
+ super(UselessClassification).__init__()
13
+ self.save_hyperparameters()
14
+ self.n_classes = n_classes
15
+ self.lr = lr
16
+ self.main = nn.Sequential(UselessLayer(), nn.GELU())
17
+
18
+ def forward(self, x: Tensor) -> Tensor:
19
+ return self.main(x)
20
+
21
+ def loss(self, input: Tensor, target: Tensor) -> Tensor:
22
+ return F.mse_loss(input=input, target=target)
23
+
24
+ def configure_optimizers(self):
25
+ optimizer = optim.Adam(params=self.parameters(), lr=self.lr)
26
+ return optimizer
27
+
28
+ def training_step(self, batch, batch_idx):
29
+ x, y = batch
30
+
31
+ logits = self.forward(x)
32
+ loss = self.loss(input=x, target=y)
33
+ metrics = classification_metrics(preds=logits,
34
+ target=y,
35
+ num_classes=self.n_classes)
36
+
37
+ self.train_batch_output.append({'loss': loss, **metrics})
38
+ return loss
39
+
40
+ def validation_step(self, batch, batch_idx):
41
+ x, y = batch
42
+
43
+ logits = self.forward(x)
44
+ loss = self.loss(input=x, target=y)
45
+ metrics = classification_metrics(preds=logits,
46
+ target=y,
47
+ num_classes=self.n_classes)
48
+
49
+ self.validation_batch_output.append({'loss': loss, **metrics})
50
+ return loss
models/modules/sample_torch_module.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor, nn
2
+
3
+
4
+ class UselessLayer(nn.Module):
5
+
6
+ def __init__(self) -> None:
7
+ super(UselessLayer, self).__init__()
8
+ self.seq = nn.Identity()
9
+
10
+ def forward(self, x: Tensor) -> Tensor:
11
+ x = self.seq(x)
12
+ return x
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ timm==0.9.2
2
+ torch --index-url https://download.pytorch.org/whl/cpu
3
+ torchvision --index-url https://download.pytorch.org/whl/cpu
4
+ torchaudio --index-url https://download.pytorch.org/whl/cpu
5
+ vietocr==0.3.11
6
+ streamlit==1.21.0
7
+ pandas
8
+ transformers==4.29.1
9
+ Pillow==9.5.0
tests/test_resource.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ def test_cuda():
2
+ from torch.cuda import is_available
3
+ assert is_available()
4
+
utils/.gitkeep ADDED
File without changes