|
from ultralytics import RTDETR |
|
import gradio as gr |
|
from huggingface_hub import snapshot_download |
|
from PIL import Image |
|
from PIL import Image, ImageDraw, ImageFont |
|
|
|
|
|
from collections import defaultdict |
|
from typing import List, Dict |
|
import torch |
|
from transformers import LayoutLMv3ForTokenClassification |
|
|
|
|
|
layout_model = LayoutLMv3ForTokenClassification.from_pretrained("omarelsayeed/LayoutReader80Small") |
|
|
|
MAX_LEN = 100 |
|
CLS_TOKEN_ID = 0 |
|
UNK_TOKEN_ID = 3 |
|
EOS_TOKEN_ID = 2 |
|
|
|
|
|
def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]: |
|
bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]] |
|
input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID] |
|
attention_mask = [1] + [1] * len(boxes) + [1] |
|
return { |
|
"bbox": torch.tensor([bbox]), |
|
"attention_mask": torch.tensor([attention_mask]), |
|
"input_ids": torch.tensor([input_ids]), |
|
} |
|
|
|
def parse_logits(logits: torch.Tensor, length: int) -> List[int]: |
|
""" |
|
Parse logits to determine the reading order. |
|
""" |
|
logits = logits[1: length + 1, :length] |
|
orders = logits.argsort(descending=False).tolist() |
|
ret = [o.pop() for o in orders] |
|
while True: |
|
order_to_idxes = defaultdict(list) |
|
for idx, order in enumerate(ret): |
|
order_to_idxes[order].append(idx) |
|
|
|
order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1} |
|
if not order_to_idxes: |
|
break |
|
|
|
for order, idxes in order_to_idxes.items(): |
|
idxes_to_logit = {idx: logits[idx, order] for idx in idxes} |
|
idxes_to_logit = sorted(idxes_to_logit.items(), key=lambda x: x[1], reverse=True) |
|
for idx, _ in idxes_to_logit[1:]: |
|
ret[idx] = orders[idx].pop() |
|
|
|
return ret |
|
|
|
def get_orders(image_path, boxes): |
|
print(boxes) |
|
boxes = scale_and_normalize_boxes(boxes) |
|
print(boxes) |
|
inputs = boxes2inputs(boxes) |
|
inputs = {k: v.to(layout_model.device) for k, v in inputs.items()} |
|
logits = layout_model(**inputs).logits.cpu().squeeze(0) |
|
orders = parse_logits(logits, len(boxes)) |
|
return orders |
|
|
|
|
|
model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt" |
|
model = RTDETR(model_dir) |
|
|
|
|
|
def detect_layout(img, conf_threshold, iou_threshold): |
|
"""Predicts objects in an image using a YOLO11 model with adjustable confidence and IOU thresholds.""" |
|
results = model.predict( |
|
source=img, |
|
conf=conf_threshold, |
|
iou=iou_threshold, |
|
show_labels=True, |
|
show_conf=True, |
|
imgsz=1024, |
|
agnostic_nms= True, |
|
max_det=34, |
|
nms=True |
|
)[0] |
|
bboxes = results.boxes.xyxy.cpu().tolist() |
|
classes = results.boxes.cls.cpu().tolist() |
|
mapping = {0: 'CheckBox', |
|
1: 'List', |
|
2: 'P', |
|
3: 'abandon', |
|
4: 'figure', |
|
5: 'gridless_table', |
|
6: 'handwritten_signature', |
|
7: 'qr_code', |
|
8: 'table', |
|
9: 'title'} |
|
classes = [mapping[i] for i in classes] |
|
return bboxes , classes |
|
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
|
def draw_bboxes_on_image(image_path, bboxes, classes, reading_order): |
|
|
|
class_colors = { |
|
'CheckBox': '#FFA500', |
|
'List': '#4682B4', |
|
'P': '#32CD32', |
|
'abandon': '#8A2BE2', |
|
'figure': '#00CED1', |
|
'gridless_table': '#FFD700', |
|
'handwritten_signature': '#FF69B4', |
|
'qr_code': '#FF4500', |
|
'table': '#8B4513', |
|
'title': '#FF1493' |
|
} |
|
|
|
|
|
image = Image.open(image_path).convert("RGBA") |
|
overlay = Image.new("RGBA", image.size, (255, 255, 255, 0)) |
|
draw = ImageDraw.Draw(overlay) |
|
|
|
|
|
try: |
|
font = ImageFont.truetype("arial.ttf", 20) |
|
title_font = ImageFont.truetype("arial.ttf", 30) |
|
except IOError: |
|
font = ImageFont.load_default() |
|
title_font = font |
|
|
|
|
|
for i in range(len(bboxes)): |
|
x1, y1, x2, y2 = bboxes[i] |
|
class_name = classes[i] |
|
order = reading_order[i] |
|
color = class_colors.get(class_name, "#FFFFFF") |
|
|
|
|
|
if class_name == 'title': |
|
box_thickness = 6 |
|
label_font = title_font |
|
else: |
|
box_thickness = 3 |
|
label_font = font |
|
|
|
|
|
draw.rounded_rectangle( |
|
[(x1, y1), (x2, y2)], |
|
radius=10, |
|
outline=color, |
|
width=box_thickness |
|
) |
|
|
|
|
|
label = f"{class_name}-{order}" |
|
text_width, text_height = draw.textsize(label, font=label_font) |
|
|
|
|
|
padding = 5 |
|
label_bg_coords = [ |
|
x1, y1 - text_height - 2 * padding, |
|
x1 + text_width + 2 * padding, y1 |
|
] |
|
|
|
|
|
draw.rectangle(label_bg_coords, fill=(0, 0, 0, 128)) |
|
|
|
|
|
draw.text( |
|
(x1 + padding, y1 - text_height - padding), |
|
label, |
|
fill="white", |
|
font=label_font |
|
) |
|
|
|
|
|
image_with_overlay = Image.alpha_composite(image, overlay) |
|
|
|
|
|
return image_with_overlay.convert("RGB") |
|
|
|
|
|
|
|
def scale_and_normalize_boxes(bboxes, old_width = 1024, old_height= 1024, new_width=640, new_height=640, normalize_width=1000, normalize_height=1000): |
|
""" |
|
Scales and normalizes bounding boxes from original dimensions to new dimensions. |
|
|
|
Args: |
|
bboxes (list of lists): List of bounding boxes in [x_min, y_min, x_max, y_max] format. |
|
old_width (int or float): Width of the original image. |
|
old_height (int or float): Height of the original image. |
|
new_width (int or float): Width of the scaled image. |
|
new_height (int or float): Height of the scaled image. |
|
normalize_width (int or float): Width of the normalization range (e.g., target resolution width). |
|
normalize_height (int or float): Height of the normalization range (e.g., target resolution height). |
|
|
|
Returns: |
|
list of lists: Scaled and normalized bounding boxes in [x_min, y_min, x_max, y_max] format. |
|
""" |
|
scale_x = new_width / old_width |
|
scale_y = new_height / old_height |
|
|
|
def scale_and_normalize_single(bbox): |
|
|
|
x_min, y_min, x_max, y_max = bbox |
|
|
|
|
|
x_min *= scale_x |
|
y_min *= scale_y |
|
x_max *= scale_x |
|
y_max *= scale_y |
|
|
|
|
|
x_min = int(normalize_width * (x_min / new_width)) |
|
y_min = int(normalize_height * (y_min / new_height)) |
|
x_max = int(normalize_width * (x_max / new_width)) |
|
y_max = int(normalize_height * (y_max / new_height)) |
|
|
|
return [x_min, y_min, x_max, y_max] |
|
|
|
|
|
return [scale_and_normalize_single(bbox) for bbox in bboxes] |
|
|
|
|
|
|
|
from PIL import Image, ImageDraw |
|
|
|
def is_inside(box1, box2): |
|
|
|
return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3] |
|
|
|
def is_overlap(box1, box2): |
|
|
|
x1, y1, x2, y2 = box1 |
|
x3, y3, x4, y4 = box2 |
|
|
|
|
|
return not (x2 <= x3 or x4 <= x1 or y2 <= y3 or y4 <= y1) |
|
|
|
def remove_overlapping_and_inside_boxes(boxes, classes): |
|
to_remove = [] |
|
|
|
for i, box1 in enumerate(boxes): |
|
for j, box2 in enumerate(boxes): |
|
if i != j: |
|
if is_inside(box1, box2): |
|
|
|
to_remove.append(i) |
|
elif is_inside(box2, box1): |
|
|
|
to_remove.append(j) |
|
elif is_overlap(box1, box2): |
|
|
|
if (box2[2] - box2[0]) * (box2[3] - box2[1]) < (box1[2] - box1[0]) * (box1[3] - box1[1]): |
|
to_remove.append(j) |
|
else: |
|
to_remove.append(i) |
|
|
|
|
|
to_remove = sorted(set(to_remove), reverse=True) |
|
|
|
|
|
for idx in to_remove: |
|
del boxes[idx] |
|
del classes[idx] |
|
|
|
return boxes, classes |
|
|
|
|
|
def full_predictions(IMAGE_PATH, conf_threshold, iou_threshold): |
|
bboxes, classes = detect_layout(IMAGE_PATH, conf_threshold, iou_threshold) |
|
bboxes, classes = remove_overlapping_and_inside_boxes(bboxes, classes) |
|
orders = get_orders(IMAGE_PATH, bboxes) |
|
final_image = draw_bboxes_on_image(IMAGE_PATH, bboxes, classes, orders) |
|
return final_image |
|
|
|
|
|
|
|
iface = gr.Interface( |
|
fn=full_predictions, |
|
inputs=[ |
|
gr.Image(type="pil", label="Upload Image"), |
|
gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence threshold"), |
|
gr.Slider(minimum=0, maximum=1, value=0.45, label="IoU threshold"), |
|
], |
|
outputs=gr.Image(type="pil", label="Result"), |
|
title="Ultralytics Gradio", |
|
description="Upload images for inference. The Ultralytics YOLO11n model is used by default.", |
|
examples=[ |
|
["kashida.png", 0.2, 0.45], |
|
["image.jpg", 0.2, 0.45], |
|
["Screenshot 2024-11-06 130230.png" , 0.25 , 0.45] |
|
], |
|
theme=gr.themes.Default() |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |