|
from ultralytics import RTDETR |
|
import gradio as gr |
|
from huggingface_hub import snapshot_download |
|
from PIL import Image, ImageDraw, ImageFont |
|
import numpy as np |
|
import random |
|
from collections import defaultdict |
|
from typing import List, Dict |
|
import torch |
|
from transformers import LayoutLMv3ForTokenClassification |
|
from transformers import AutoProcessor |
|
from transformers import AutoModelForTokenClassification |
|
quant=False |
|
if quant: |
|
finetuned_fully = LayoutLMv3ForTokenClassification.from_pretrained("omarelsayeed/ArabicLayoutReader_4bit" , load_in_4bit=True) |
|
else: |
|
finetuned_fully = LayoutLMv3ForTokenClassification.from_pretrained("omarelsayeed/YARAB_FOK_ELDE2A") |
|
|
|
|
|
processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", |
|
apply_ocr=False) |
|
|
|
MAX_LEN = 70 |
|
CLS_TOKEN_ID = 0 |
|
UNK_TOKEN_ID = 3 |
|
EOS_TOKEN_ID = 2 |
|
import torch |
|
def boxes2inputs(boxes): |
|
bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]] |
|
input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID] |
|
attention_mask = [1] + [1] * len(boxes) + [1] |
|
return { |
|
"bbox": torch.tensor([bbox]), |
|
"attention_mask": torch.tensor([attention_mask]), |
|
"input_ids": torch.tensor([input_ids]), |
|
} |
|
def parse_logits(logits: torch.Tensor, length): |
|
""" |
|
parse logits to orders |
|
|
|
:param logits: logits from model |
|
:param length: input length |
|
:return: orders |
|
""" |
|
logits = logits[1 : length + 1, :length] |
|
orders = logits.argsort(descending=False).tolist() |
|
ret = [o.pop() for o in orders] |
|
while True: |
|
order_to_idxes = defaultdict(list) |
|
for idx, order in enumerate(ret): |
|
order_to_idxes[order].append(idx) |
|
|
|
order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1} |
|
if not order_to_idxes: |
|
break |
|
|
|
for order, idxes in order_to_idxes.items(): |
|
|
|
idxes_to_logit = {} |
|
for idx in idxes: |
|
idxes_to_logit[idx] = logits[idx, order] |
|
idxes_to_logit = sorted( |
|
idxes_to_logit.items(), key=lambda x: x[1], reverse=True |
|
) |
|
|
|
for idx, _ in idxes_to_logit[1:]: |
|
ret[idx] = orders[idx].pop() |
|
return ret |
|
|
|
|
|
def prepare_inputs( |
|
inputs, model |
|
): |
|
ret = {} |
|
for k, v in inputs.items(): |
|
v = v.to(model.device) |
|
if torch.is_floating_point(v): |
|
v = v.to(model.dtype) |
|
ret[k] = v |
|
return ret |
|
|
|
|
|
def get_orders(image_path , boxes): |
|
inputs = boxes2inputs(boxes) |
|
inputs = prepare_inputs(inputs, finetuned_fully) |
|
logits = finetuned_fully(**inputs).logits.cpu().squeeze(0) |
|
predictions = parse_logits(logits, len(boxes)) |
|
return predictions |
|
|
|
model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt" |
|
model = RTDETR(model_dir) |
|
|
|
|
|
def detect_layout(img, conf_threshold, iou_threshold): |
|
"""Predicts objects in an image using a YOLO11 model with adjustable confidence and IOU thresholds.""" |
|
results = model.predict( |
|
source=img, |
|
conf=conf_threshold, |
|
iou=iou_threshold, |
|
show_labels=True, |
|
show_conf=True, |
|
imgsz=1024, |
|
agnostic_nms= True, |
|
max_det=34, |
|
nms=True |
|
)[0] |
|
bboxes = results.boxes.xyxy.cpu().tolist() |
|
classes = results.boxes.cls.cpu().tolist() |
|
mapping = {0: 'CheckBox', |
|
1: 'List', |
|
2: 'P', |
|
3: 'abandon', |
|
4: 'figure', |
|
5: 'gridless_table', |
|
6: 'handwritten_signature', |
|
7: 'qr_code', |
|
8: 'table', |
|
9: 'title'} |
|
classes = [mapping[i] for i in classes] |
|
return bboxes , classes |
|
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
|
def draw_bboxes_on_image(image_path, bboxes, classes, reading_order): |
|
|
|
class_colors = { |
|
'CheckBox': 'orange', |
|
'List': 'blue', |
|
'P': 'green', |
|
'abandon': 'purple', |
|
'figure': 'cyan', |
|
'gridless_table': 'yellow', |
|
'handwritten_signature': 'magenta', |
|
'qr_code': 'red', |
|
'table': 'brown', |
|
'title': 'pink' |
|
} |
|
|
|
|
|
image = image_path |
|
|
|
|
|
draw = ImageDraw.Draw(image) |
|
|
|
|
|
try: |
|
font = ImageFont.truetype("arial.ttf", 20) |
|
title_font = ImageFont.truetype("arial.ttf", 30) |
|
except IOError: |
|
font = ImageFont.load_default(size = 30) |
|
title_font = font |
|
|
|
|
|
for i in range(len(bboxes)): |
|
x1, y1, x2, y2 = bboxes[i] |
|
class_name = classes[i] |
|
order = reading_order[i] |
|
|
|
|
|
color = class_colors[class_name] |
|
|
|
|
|
if class_name == 'title': |
|
box_thickness = 4 |
|
label_font = title_font |
|
else: |
|
box_thickness = 2 |
|
label_font = font |
|
|
|
|
|
draw.rectangle([x1, y1, x2, y2], outline=color, width=box_thickness) |
|
|
|
|
|
label = f"{class_name}-{order}" |
|
|
|
|
|
bbox = draw.textbbox((x1, y1 - 20), label, font=label_font) |
|
label_width = bbox[2] - bbox[0] |
|
label_height = bbox[3] - bbox[1] |
|
|
|
|
|
draw.text((x1, y1 - label_height), label, fill="black", font=label_font) |
|
|
|
|
|
return image |
|
|
|
|
|
|
|
def scale_and_normalize_boxes(bboxes, old_width = 1024, old_height= 1024, new_width=640, new_height=640, normalize_width=1000, normalize_height=1000): |
|
""" |
|
Scales and normalizes bounding boxes from original dimensions to new dimensions. |
|
|
|
Args: |
|
bboxes (list of lists): List of bounding boxes in [x_min, y_min, x_max, y_max] format. |
|
old_width (int or float): Width of the original image. |
|
old_height (int or float): Height of the original image. |
|
new_width (int or float): Width of the scaled image. |
|
new_height (int or float): Height of the scaled image. |
|
normalize_width (int or float): Width of the normalization range (e.g., target resolution width). |
|
normalize_height (int or float): Height of the normalization range (e.g., target resolution height). |
|
|
|
Returns: |
|
list of lists: Scaled and normalized bounding boxes in [x_min, y_min, x_max, y_max] format. |
|
""" |
|
scale_x = new_width / old_width |
|
scale_y = new_height / old_height |
|
|
|
def scale_and_normalize_single(bbox): |
|
|
|
x_min, y_min, x_max, y_max = bbox |
|
|
|
|
|
x_min *= scale_x |
|
y_min *= scale_y |
|
x_max *= scale_x |
|
y_max *= scale_y |
|
|
|
|
|
x_min = int(normalize_width * (x_min / new_width)) |
|
y_min = int(normalize_height * (y_min / new_height)) |
|
x_max = int(normalize_width * (x_max / new_width)) |
|
y_max = int(normalize_height * (y_max / new_height)) |
|
|
|
return [x_min, y_min, x_max, y_max] |
|
|
|
|
|
return [scale_and_normalize_single(bbox) for bbox in bboxes] |
|
|
|
|
|
|
|
from PIL import Image, ImageDraw |
|
|
|
def is_inside(box1, box2): |
|
|
|
return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3] |
|
|
|
def is_overlap(box1, box2): |
|
|
|
x1, y1, x2, y2 = box1 |
|
x3, y3, x4, y4 = box2 |
|
|
|
|
|
return not (x2 <= x3 or x4 <= x1 or y2 <= y3 or y4 <= y1) |
|
|
|
def remove_overlapping_and_inside_boxes(boxes, classes): |
|
to_remove = [] |
|
|
|
for i, box1 in enumerate(boxes): |
|
for j, box2 in enumerate(boxes): |
|
if i != j: |
|
if is_inside(box1, box2): |
|
|
|
to_remove.append(i) |
|
elif is_inside(box2, box1): |
|
|
|
to_remove.append(j) |
|
elif is_overlap(box1, box2): |
|
|
|
if (box2[2] - box2[0]) * (box2[3] - box2[1]) < (box1[2] - box1[0]) * (box1[3] - box1[1]): |
|
to_remove.append(j) |
|
else: |
|
to_remove.append(i) |
|
|
|
|
|
to_remove = sorted(set(to_remove), reverse=True) |
|
|
|
|
|
for idx in to_remove: |
|
del boxes[idx] |
|
del classes[idx] |
|
|
|
return boxes, classes |
|
|
|
def process_r1(r1): |
|
|
|
if len(r1) == 2: |
|
return r1 |
|
max_index = r1.index(max(r1)) |
|
one_index = r1.index(1) |
|
|
|
r1[max_index] = 1 |
|
|
|
|
|
r1 = [x + 1 if x not in (0, 1) else x for x in r1] |
|
r1[one_index] +=1 |
|
return r1 |
|
|
|
|
|
def full_predictions(IMAGE_PATH, conf_threshold, iou_threshold): |
|
IMAGE_PATH = IMAGE_PATH.resize((1024,1024)) |
|
bboxes, classes = detect_layout(IMAGE_PATH, conf_threshold, iou_threshold) |
|
bboxes, classes = remove_overlapping_and_inside_boxes(bboxes, classes) |
|
orders = get_orders(IMAGE_PATH, scale_and_normalize_boxes(bboxes)) |
|
orders = process_r1(orders) |
|
final_image = draw_bboxes_on_image(IMAGE_PATH, bboxes, classes, orders) |
|
return final_image |
|
|
|
|
|
|
|
iface = gr.Interface( |
|
fn=full_predictions, |
|
inputs=[ |
|
gr.Image(type="pil", label="Upload Image"), |
|
gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence threshold"), |
|
gr.Slider(minimum=0, maximum=1, value=0.45, label="IoU threshold"), |
|
], |
|
outputs=gr.Image(type="pil", label="Result"), |
|
title="Ultralytics Gradio", |
|
description="Upload images for inference. The Ultralytics YOLO11n model is used by default.", |
|
examples=[ |
|
["kashida.png", 0.2, 0.45], |
|
["image.jpg", 0.2, 0.45], |
|
["Screenshot 2024-11-06 130230.png" , 0.25 , 0.45] |
|
], |
|
theme=gr.themes.Default() |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |