# Copyright (c) 2022, Lawrence Livermore National Security, LLC. # All rights reserved. # See the top-level LICENSE and NOTICE files for details. # LLNL-CODE-838964 # SPDX-License-Identifier: Apache-2.0-with-LLVM-exception import cv2 from pathlib import Path import torch import json from detectron2.config import CfgNode as CN from detectron2.config import get_cfg from detectron2.utils.visualizer import ColorMode, Visualizer from detectron2.data import MetadataCatalog from detectron2.engine import DefaultPredictor from pdf2image import convert_from_path from PIL import Image import numpy as np from dit_object_detection.ditod import add_vit_config import base_utils from pdfminer.layout import LTTextLineHorizontal, LTTextBoxHorizontal, LTAnno, LTChar from tokenizers.pre_tokenizers import Whitespace import warnings warnings.filterwarnings("ignore") dit_path = Path('DiT_Extractor/dit_object_detection') cfg = get_cfg() add_vit_config(cfg) cfg.merge_from_file(dit_path / "publaynet_configs/cascade/cascade_dit_base.yaml") cfg.MODEL.WEIGHTS = "https://layoutlm.blob.core.windows.net/dit/dit-fts/publaynet_dit-b_cascade.pth" cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu" predictor = DefaultPredictor(cfg) thing_classes = ["text","title","list","table","figure"] thing_map = dict(map(reversed, enumerate(thing_classes))) md = MetadataCatalog.get(cfg.DATASETS.TEST[0]) md.set(thing_classes=thing_classes) def get_pdf_image(pdf_file, page): image = convert_from_path(pdf_file, dpi=200, first_page=page, last_page=page) return image def get_characters(subelement): all_chars = [] if isinstance(subelement, LTTextLineHorizontal): for char in subelement: if isinstance(char, LTChar): all_chars.append((char.bbox, char.get_text())) if isinstance(char, LTAnno): # No bbox, just a space, so make a thin slice after previous text bbox = all_chars[-1][0] bbox = (bbox[2],bbox[1],bbox[2],bbox[3]) all_chars.append((bbox, char.get_text())) return all_chars def get_dit_preds(pdf, score_threshold=0.5): page_count = base_utils.get_pdf_page_count(pdf) # Input is numpy array of PIL image page_sizes = base_utils.get_page_sizes(pdf) sections = {} viz_images = [] page_words = base_utils.get_pdf_words(pdf) for page in range(1, page_count+1): #range(2, page_count + 1): image = get_pdf_image(pdf, page) image = np.array(image[0]) # Get prediction output = predictor(image)["instances"] output = output.to('cpu') # Visualize predictions v = Visualizer(image[:, :, ::-1], md, scale=1.0, instance_mode=ColorMode.SEGMENTATION) result = v.draw_instance_predictions(output) result_image = result.get_image()[:, :, ::-1] viz_img = Image.fromarray(result_image) viz_images.append(viz_img) words = page_words[page-1] # Convert from image_size to page size pdf_dimensions = page_sizes[page-1][2:] # Swap height/width pdf_image_size = (output.image_size[1], output.image_size[0]) scale = np.array(pdf_dimensions) / np.array(pdf_image_size) scale_box = np.hstack((scale,scale)) # Words are in page coordinates id = 0 sections[page-1] = [] draw = image.copy() for box_t, clazz, score in zip(output.get('pred_boxes'), output.get('pred_classes'), output.get('scores')): if score < score_threshold: continue box = box_t.numpy() # Flip along Y axis box[1] = pdf_image_size[1] - box[1] box[3] = pdf_image_size[1] - box[3] # Scale scaled = box * scale_box # This is the correct order scaled = [scaled[0], scaled[3], scaled[2], scaled[1]] if clazz != thing_map['text']: continue start = box[0:2].tolist() end = box[2:4].tolist() start = [int(x) for x in start] end = [int(x) for x in end] out = {} for word in words.copy(): if base_utils.partial_overlaps(word[0:4], scaled): if out == {}: id += 1 out['coord'] = word[0:4] out['subelements'] = [] out['type'] = 'content_block' out['id']= id out['text'] = '' out['coord'] = base_utils.union(out['coord'], word[0:4]) out['text'] = out['text'] + word[4].get_text() characters = get_characters(word[4]) out['subelements'].append(characters) words.remove(word) if len(out) != 0: sections[page-1].append(out) # Write final annotation out_name = Path(pdf).name[:-4] + ".json" with open(out_name, 'w', encoding='utf8') as json_out: json.dump(sections, json_out, ensure_ascii=False, indent=4) return viz_images