File size: 5,555 Bytes
c14d9ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Copyright (c) 2022, Lawrence Livermore National Security, LLC. 
# All rights reserved.
# See the top-level LICENSE and NOTICE files for details.
# LLNL-CODE-838964

# SPDX-License-Identifier: Apache-2.0-with-LLVM-exception

import cv2
from pathlib import Path
import torch
import json

from detectron2.config import CfgNode as CN
from detectron2.config import get_cfg
from detectron2.utils.visualizer import ColorMode, Visualizer
from detectron2.data import MetadataCatalog
from detectron2.engine import DefaultPredictor

from pdf2image import convert_from_path

from PIL import Image
import numpy as np

from dit_object_detection.ditod import add_vit_config
import base_utils
from pdfminer.layout import LTTextLineHorizontal, LTTextBoxHorizontal, LTAnno, LTChar

from tokenizers.pre_tokenizers import Whitespace

import warnings
warnings.filterwarnings("ignore")

dit_path = Path('DiT_Extractor/dit_object_detection')

cfg = get_cfg()
add_vit_config(cfg)
cfg.merge_from_file(dit_path / "publaynet_configs/cascade/cascade_dit_base.yaml")

cfg.MODEL.WEIGHTS = "https://layoutlm.blob.core.windows.net/dit/dit-fts/publaynet_dit-b_cascade.pth"
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

predictor = DefaultPredictor(cfg)

thing_classes = ["text","title","list","table","figure"]
thing_map = dict(map(reversed, enumerate(thing_classes)))
md = MetadataCatalog.get(cfg.DATASETS.TEST[0])
md.set(thing_classes=thing_classes)


def get_pdf_image(pdf_file, page):
    image = convert_from_path(pdf_file, dpi=200, first_page=page, last_page=page)
    return image

def get_characters(subelement):
    all_chars = []
    if isinstance(subelement, LTTextLineHorizontal):
        for char in subelement:
            if isinstance(char, LTChar):
                all_chars.append((char.bbox, char.get_text()))
            if isinstance(char, LTAnno):
                # No bbox, just a space, so make a thin slice after previous text
                bbox = all_chars[-1][0]
                bbox = (bbox[2],bbox[1],bbox[2],bbox[3])
                all_chars.append((bbox, char.get_text()))
    return all_chars


def get_dit_preds(pdf, score_threshold=0.5):
    
    page_count = base_utils.get_pdf_page_count(pdf)
          
    # Input is numpy array of PIL image
    page_sizes = base_utils.get_page_sizes(pdf)
    
    sections = {}
    viz_images = []
    page_words = base_utils.get_pdf_words(pdf)
    for page in range(1, page_count+1): #range(2, page_count + 1):
        image = get_pdf_image(pdf, page)
        image = np.array(image[0])
        # Get prediction
        output = predictor(image)["instances"]
        output = output.to('cpu')
        
        # Visualize predictions
        v = Visualizer(image[:, :, ::-1],
                       md,
                       scale=1.0,
                       instance_mode=ColorMode.SEGMENTATION)
        result = v.draw_instance_predictions(output)
        result_image = result.get_image()[:, :, ::-1]
        viz_img = Image.fromarray(result_image)
        viz_images.append(viz_img)
        
        words = page_words[page-1]
        
        # Convert from image_size to page size
        pdf_dimensions = page_sizes[page-1][2:]
        # Swap height/width
        pdf_image_size = (output.image_size[1], output.image_size[0]) 
        
        scale = np.array(pdf_dimensions) / np.array(pdf_image_size)
        scale_box = np.hstack((scale,scale))
        # Words are in page coordinates        
        
        id = 0
        sections[page-1] = []
        draw = image.copy()
        for box_t, clazz, score in zip(output.get('pred_boxes'), output.get('pred_classes'), output.get('scores')):
            
            if score < score_threshold:
                continue
            
            box = box_t.numpy()
            # Flip along Y axis
            box[1] = pdf_image_size[1] - box[1]
            box[3] = pdf_image_size[1] - box[3]
            # Scale
            scaled = box * scale_box
            # This is the correct order
            scaled = [scaled[0], scaled[3], scaled[2], scaled[1]]
            if clazz != thing_map['text']:
                continue
                        
            start = box[0:2].tolist()
            end = box[2:4].tolist()
            start = [int(x) for x in start]
            end = [int(x) for x in end]
            
            out = {}
            
            for word in words.copy():
                if base_utils.partial_overlaps(word[0:4], scaled):                    
                    if out == {}:
                        id += 1
                        out['coord'] = word[0:4]
                        out['subelements'] = []                        
                        out['type'] = 'content_block'                      
                        out['id']= id
                        out['text'] = ''
                    
                    out['coord'] = base_utils.union(out['coord'], word[0:4])                        
                    out['text'] = out['text'] + word[4].get_text()
                    
                    characters = get_characters(word[4])
                    out['subelements'].append(characters)
                    words.remove(word)
            
            if len(out) != 0:
                sections[page-1].append(out)
                
    # Write final annotation
    
    out_name = Path(pdf).name[:-4] + ".json"
    with open(out_name, 'w', encoding='utf8') as json_out:
        json.dump(sections, json_out, ensure_ascii=False, indent=4)
        
    return viz_images