DocumentQA / DiT_Extractor /dit_runner.py
Epoching's picture
init
c14d9ad
raw
history blame
5.56 kB
# Copyright (c) 2022, Lawrence Livermore National Security, LLC.
# All rights reserved.
# See the top-level LICENSE and NOTICE files for details.
# LLNL-CODE-838964
# SPDX-License-Identifier: Apache-2.0-with-LLVM-exception
import cv2
from pathlib import Path
import torch
import json
from detectron2.config import CfgNode as CN
from detectron2.config import get_cfg
from detectron2.utils.visualizer import ColorMode, Visualizer
from detectron2.data import MetadataCatalog
from detectron2.engine import DefaultPredictor
from pdf2image import convert_from_path
from PIL import Image
import numpy as np
from dit_object_detection.ditod import add_vit_config
import base_utils
from pdfminer.layout import LTTextLineHorizontal, LTTextBoxHorizontal, LTAnno, LTChar
from tokenizers.pre_tokenizers import Whitespace
import warnings
warnings.filterwarnings("ignore")
dit_path = Path('DiT_Extractor/dit_object_detection')
cfg = get_cfg()
add_vit_config(cfg)
cfg.merge_from_file(dit_path / "publaynet_configs/cascade/cascade_dit_base.yaml")
cfg.MODEL.WEIGHTS = "https://layoutlm.blob.core.windows.net/dit/dit-fts/publaynet_dit-b_cascade.pth"
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
predictor = DefaultPredictor(cfg)
thing_classes = ["text","title","list","table","figure"]
thing_map = dict(map(reversed, enumerate(thing_classes)))
md = MetadataCatalog.get(cfg.DATASETS.TEST[0])
md.set(thing_classes=thing_classes)
def get_pdf_image(pdf_file, page):
image = convert_from_path(pdf_file, dpi=200, first_page=page, last_page=page)
return image
def get_characters(subelement):
all_chars = []
if isinstance(subelement, LTTextLineHorizontal):
for char in subelement:
if isinstance(char, LTChar):
all_chars.append((char.bbox, char.get_text()))
if isinstance(char, LTAnno):
# No bbox, just a space, so make a thin slice after previous text
bbox = all_chars[-1][0]
bbox = (bbox[2],bbox[1],bbox[2],bbox[3])
all_chars.append((bbox, char.get_text()))
return all_chars
def get_dit_preds(pdf, score_threshold=0.5):
page_count = base_utils.get_pdf_page_count(pdf)
# Input is numpy array of PIL image
page_sizes = base_utils.get_page_sizes(pdf)
sections = {}
viz_images = []
page_words = base_utils.get_pdf_words(pdf)
for page in range(1, page_count+1): #range(2, page_count + 1):
image = get_pdf_image(pdf, page)
image = np.array(image[0])
# Get prediction
output = predictor(image)["instances"]
output = output.to('cpu')
# Visualize predictions
v = Visualizer(image[:, :, ::-1],
md,
scale=1.0,
instance_mode=ColorMode.SEGMENTATION)
result = v.draw_instance_predictions(output)
result_image = result.get_image()[:, :, ::-1]
viz_img = Image.fromarray(result_image)
viz_images.append(viz_img)
words = page_words[page-1]
# Convert from image_size to page size
pdf_dimensions = page_sizes[page-1][2:]
# Swap height/width
pdf_image_size = (output.image_size[1], output.image_size[0])
scale = np.array(pdf_dimensions) / np.array(pdf_image_size)
scale_box = np.hstack((scale,scale))
# Words are in page coordinates
id = 0
sections[page-1] = []
draw = image.copy()
for box_t, clazz, score in zip(output.get('pred_boxes'), output.get('pred_classes'), output.get('scores')):
if score < score_threshold:
continue
box = box_t.numpy()
# Flip along Y axis
box[1] = pdf_image_size[1] - box[1]
box[3] = pdf_image_size[1] - box[3]
# Scale
scaled = box * scale_box
# This is the correct order
scaled = [scaled[0], scaled[3], scaled[2], scaled[1]]
if clazz != thing_map['text']:
continue
start = box[0:2].tolist()
end = box[2:4].tolist()
start = [int(x) for x in start]
end = [int(x) for x in end]
out = {}
for word in words.copy():
if base_utils.partial_overlaps(word[0:4], scaled):
if out == {}:
id += 1
out['coord'] = word[0:4]
out['subelements'] = []
out['type'] = 'content_block'
out['id']= id
out['text'] = ''
out['coord'] = base_utils.union(out['coord'], word[0:4])
out['text'] = out['text'] + word[4].get_text()
characters = get_characters(word[4])
out['subelements'].append(characters)
words.remove(word)
if len(out) != 0:
sections[page-1].append(out)
# Write final annotation
out_name = Path(pdf).name[:-4] + ".json"
with open(out_name, 'w', encoding='utf8') as json_out:
json.dump(sections, json_out, ensure_ascii=False, indent=4)
return viz_images